Skip to content

Commit 38bef3d

Browse files
committed
Merge branch 'dev' of https://github.com/stefanradev93/BayesFlow into dev
2 parents 8c9e6b7 + 9853265 commit 38bef3d

File tree

3 files changed

+36
-4
lines changed

3 files changed

+36
-4
lines changed

bayesflow/adapters/adapter.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -234,11 +234,11 @@ def __len__(self):
234234

235235
def apply(
236236
self,
237+
include: str | Sequence[str] = None,
237238
*,
238239
forward: np.ufunc | str,
239240
inverse: np.ufunc | str = None,
240241
predicate: Predicate = None,
241-
include: str | Sequence[str] = None,
242242
exclude: str | Sequence[str] = None,
243243
**kwargs,
244244
):
@@ -389,6 +389,7 @@ def convert_dtype(
389389
exclude: str | Sequence[str] = None,
390390
):
391391
"""Append a :py:class:`~transforms.ConvertDType` transform to the adapter.
392+
See also :py:meth:`~bayesflow.adapters.Adapter.map_dtype`.
392393
393394
Parameters
394395
----------
@@ -526,6 +527,24 @@ def log(self, keys: str | Sequence[str], *, p1: bool = False):
526527
self.transforms.append(transform)
527528
return self
528529

530+
def map_dtype(self, keys: str | Sequence[str], to_dtype: str):
531+
"""Append a :py:class:`~transforms.ConvertDType` transform to the adapter.
532+
See also :py:meth:`~bayesflow.adapters.Adapter.convert_dtype`.
533+
534+
Parameters
535+
----------
536+
keys : str or Sequence of str
537+
The names of the variables to transform.
538+
to_dtype : str
539+
Target dtype
540+
"""
541+
if isinstance(keys, str):
542+
keys = [keys]
543+
544+
transform = MapTransform({key: ConvertDType(to_dtype) for key in keys})
545+
self.transforms.append(transform)
546+
return self
547+
529548
def one_hot(self, keys: str | Sequence[str], num_classes: int):
530549
"""Append a :py:class:`~transforms.OneHot` transform to the adapter.
531550
@@ -591,9 +610,9 @@ def sqrt(self, keys: str | Sequence[str]):
591610

592611
def standardize(
593612
self,
613+
include: str | Sequence[str] = None,
594614
*,
595615
predicate: Predicate = None,
596-
include: str | Sequence[str] = None,
597616
exclude: str | Sequence[str] = None,
598617
**kwargs,
599618
):
@@ -622,9 +641,9 @@ def standardize(
622641

623642
def to_array(
624643
self,
644+
include: str | Sequence[str] = None,
625645
*,
626646
predicate: Predicate = None,
627-
include: str | Sequence[str] = None,
628647
exclude: str | Sequence[str] = None,
629648
**kwargs,
630649
):

bayesflow/adapters/transforms/filter_transform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@ class FilterTransform(Transform):
2929

3030
def __init__(
3131
self,
32+
include: str | Sequence[str] = None,
3233
*,
3334
transform_constructor: Callable[..., ElementwiseTransform],
3435
predicate: Predicate = None,
35-
include: str | Sequence[str] = None,
3636
exclude: str | Sequence[str] = None,
3737
**kwargs,
3838
):

bayesflow/utils/empty.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
class empty:
2+
"""
3+
Placeholder value for arguments left empty
4+
5+
Usage:
6+
7+
def f(x=empty):
8+
if x is empty:
9+
# we know the user did not pass x
10+
if x is None:
11+
# the user could have passed None explicitly
12+
13+
"""

0 commit comments

Comments
 (0)