Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions pytensor/scan/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,63 @@ def foldr(fn, sequences, outputs_info, non_sequences=None, mode=None, name=None)
mode=mode,
name=name,
)


def filter(
fn,
sequences,
non_sequences=None,
go_backwards=False,
mode=None,
name=None,
):
"""Construct a `Scan` `Op` that functions like `filter`.

Parameters
----------
fn : callable
Predicate function returning a boolean tensor.
sequences : list
Sequences to filter.
non_sequences : list
Non-iterated arguments passed to `fn`.
go_backwards : bool
Whether to iterate in reverse.
mode : str or None
See ``scan``.
name : str or None
See ``scan``.

Notes
-----
If the predicate function `fn` returns multiple boolean masks (one per sequence),
each mask will be applied to its corresponding sequence. If it returns a single mask,
that mask will be broadcast to all sequences.
"""
mask, _ = scan(
fn=fn,
sequences=sequences,
outputs_info=None,
non_sequences=non_sequences,
go_backwards=go_backwards,
mode=mode,
name=name,
)

if isinstance(mask, (list, tuple)):
# One mask per sequence
if not isinstance(sequences, (list, tuple)):
raise TypeError(
"If multiple masks are returned, sequences must be a list or tuple."
)
if len(mask) != len(sequences):
raise ValueError("Number of masks must match number of sequences.")
filtered_sequences = [seq[m] for seq, m in zip(sequences, mask)]
else:
# Single mask applied to all sequences
if isinstance(sequences, (list, tuple)):
filtered_sequences = [seq[mask] for seq in sequences]
else:
filtered_sequences = sequences[mask]

return filtered_sequences
40 changes: 40 additions & 0 deletions tests/scan/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytensor.tensor as pt
from pytensor import config, function, grad, shared
from pytensor.compile.mode import FAST_RUN
from pytensor.scan.views import filter as pt_filter
from pytensor.scan.views import foldl, foldr
from pytensor.scan.views import map as pt_map
from pytensor.scan.views import reduce as pt_reduce
Expand Down Expand Up @@ -133,3 +134,42 @@ def test_foldr_memory_consumption():
gx = grad(o, x)
f2 = function([], gx)
utt.assert_allclose(f2(), np.ones((10,)))


def test_filter():
v = pt.vector("v")

def fn(x):
return pt.eq(x % 2, 0)

filtered = pt_filter(fn, v)
f = function([v], filtered, allow_input_downcast=True)

rng = np.random.default_rng(utt.fetch_seed())
vals = rng.integers(0, 10, size=(10,))
expected = vals[vals % 2 == 0]
result = f(vals)
utt.assert_allclose(expected, result)


def test_filter_multiple_masks():
v1 = pt.vector("v1")
v2 = pt.vector("v2")

def fn(x1, x2):
# Mask v1 for even numbers, mask v2 for numbers > 5
return pt.eq(x1 % 2, 0), pt.gt(x2, 5)

filtered_v1, filtered_v2 = pt_filter(fn, [v1, v2])
f = function([v1, v2], [filtered_v1, filtered_v2], allow_input_downcast=True)

vals1 = np.arange(10)
vals2 = np.arange(10)

expected_v1 = vals1[vals1 % 2 == 0]
expected_v2 = vals2[vals2 > 5]

result_v1, result_v2 = f(vals1, vals2)

utt.assert_allclose(expected_v1, result_v1)
utt.assert_allclose(expected_v2, result_v2)
Loading