Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion pytensor/link/mlx/dispatch/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,25 @@ def mlx_fn(x, indices, y):
return x

def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list):
indices = indices_from_subtensor(ilist, idx_list)
def get_slice_int(element):
if element is None:
return None
try:
return int(element)
except Exception:
Copy link

Copilot AI Oct 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a bare except Exception is too broad. This should catch specific exceptions like TypeError or ValueError that would occur when trying to convert a non-integer value. The current implementation could mask unexpected errors.

Suggested change
except Exception:
except (TypeError, ValueError):

Copilot uses AI. Check for mistakes.
return element

indices = tuple(
[
slice(
get_slice_int(s.start), get_slice_int(s.stop), get_slice_int(s.step)
)
if isinstance(s, slice)
else s
for s in indices_from_subtensor(ilist, idx_list)
]
)

if len(indices) == 1:
indices = indices[0]

Expand Down
13 changes: 13 additions & 0 deletions tests/link/mlx/test_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,19 @@ def test_mlx_IncSubtensor_increment():
assert not out_pt.owner.op.set_instead_of_inc
compare_mlx_and_py([], [out_pt], [])

# Increment slice
out_pt = pt_subtensor.inc_subtensor(x_pt[:, :, 2:], st_pt)
compare_mlx_and_py([], [out_pt], [])

out_pt = pt_subtensor.inc_subtensor(x_pt[:, :, -3:], st_pt)
compare_mlx_and_py([], [out_pt], [])

out_pt = pt_subtensor.inc_subtensor(x_pt[::2, ::2, ::2], st_pt)
compare_mlx_and_py([], [out_pt], [])

out_pt = pt_subtensor.inc_subtensor(x_pt[:, :, :], st_pt)
compare_mlx_and_py([], [out_pt], [])


def test_mlx_AdvancedIncSubtensor_set():
"""Test advanced set operations using AdvancedIncSubtensor."""
Expand Down
Loading