Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions pytensor/scan/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,46 @@ def foldr(fn, sequences, outputs_info, non_sequences=None, mode=None, name=None)
mode=mode,
name=name,
)


def filter(
fn,
sequences,
non_sequences=None,
go_backwards=False,
mode=None,
name=None,
):
"""Construct a `Scan` `Op` that functions like `filter`.

Parameters
----------
fn : callable
Predicate function returning a boolean tensor.
sequences : list
Sequences to filter.
non_sequences : list
Non-iterated arguments passed to `fn`.
go_backwards : bool
Whether to iterate in reverse.
mode : str or None
See ``scan``.
name : str or None
See ``scan``.
"""
mask, _ = scan(
fn=fn,
sequences=sequences,
outputs_info=None,
non_sequences=non_sequences,
go_backwards=go_backwards,
mode=mode,
name=f"{name or ''}_mask",
)

if isinstance(sequences, (list, tuple)):
filtered_sequences = [seq[mask] for seq in sequences]
else:
filtered_sequences = sequences[mask]

return filtered_sequences
20 changes: 20 additions & 0 deletions tests/scan/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,23 @@ def test_foldr_memory_consumption():
gx = grad(o, x)
f2 = function([], gx)
utt.assert_allclose(f2(), np.ones((10,)))


def test_filter():
import pytensor.tensor as pt

v = pt.vector("v")

def fn(x):
return pt.eq(x % 2, 0)

from pytensor.scan.views import filter as pt_filter

filtered = pt_filter(fn, v)
f = function([v], filtered, allow_input_downcast=True)

rng = np.random.default_rng(utt.fetch_seed())
vals = rng.integers(0, 10, size=(10,))
expected = vals[vals % 2 == 0]
result = f(vals)
utt.assert_allclose(expected, result)