Skip to content

Conversation

@ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Nov 27, 2025

This is a spinoff of #811

Background

Some inplace tests started failing when we switched to numba backend by default. These tests were defining Elemwise/ ScalarOp variants that didn't make much sense.

The decorator @scalar_elemwise when called on a function with inplace suffix would create a inplace version of Elemwise that would write the first output on the first input. This isn't always safe as we can only do that when the dtypes match. To work around this the ScalarOps were being re-created with transfer_type which would override regular make_node logic to force the output type to match one of the first input type, even if this doesn't make sense like add(int8, int64)->(int8).

These are the sort of cases that failed in numba backend, in part because of missing #1747, in part because it would require a lot of logic on custom ScalarOp implementation (those that don't fallback to either cython or python obj mode), because we would always have to go back and check if the output dtype matched the promised one from the ScalarOp, even when it makes no sense.

Changes

  • This PR removes the special inplace behavior of @scalar_elemwise and default inplace Elemwise Ops
  • Get rid of the so many redundant inplace tests.
    • Again the logic of inplacing should be mostly detached from the ScalarOp in question. It doesn't make sense to test a gazillion scalar Ops unless there was a good reason for it.
  • It then removes transfer_type which was mostly used for those dummy Op/tests and
    • SecondOp (replaced by a custom method)
    • InplaceElemwise rewriter. This must have been a zealous safety thing? We only ever try to inplace on inputs that match the output. Again the ScalarOp shouldn't have to be modified because the Elemwise is working inplace. With this change we can also reuse the original ScalarOps, which saves quite some time when inplacing Composite, which would otherwise involve graph cloning
    • Some other inplace tests, where it wasn't really needed

@ricardoV94 ricardoV94 changed the title Remove transfer_type and stop defining inplace Elemwise Ops ahead of time Remove transfer_type and stop creating inplace default Elemwise Ops Nov 27, 2025
@ricardoV94 ricardoV94 changed the title Remove transfer_type and stop creating inplace default Elemwise Ops Remove transfer_type and default inplace Elemwise Ops Nov 27, 2025
install-xarray: [0]
part:
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/xtensor"
- "tests --ignore=tests/scan --ignore=tests/tensor --ignore=tests/xtensor"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This removes the tensor/test_inplace.py and reorders the ignore / parts to be more readible. They ignore show up in the same order where each part is that reintroduced.

May still change after the PR runs to rebalance the workload so they take more or less the same time

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Jobs look well balanced

complex2=(random_complex(2, 3, rng=rng), random(2, 3, rng=rng)),
empty=(np.asarray([], dtype=config.floatX), np.asarray([1], dtype=config.floatX)),
)
TestAddInplaceBroadcast = makeBroadcastTester(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kept one just for good measure

Comment on lines +261 to +264
Zt = Z.transpose()
assert isinstance(Zt.owner.op, DimShuffle) and Zt.owner.op.view_map == {0: [0]}
with pytest.raises(InconsistencyError, match=Gemm.E_z_uniq):
gemm_inplace(Z, 1.0, A, inplace.transpose_inplace(Z), 1.0)
gemm_inplace(Z, 1.0, A, Zt, 1.0)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the very old times DimShuffle wasn't a view by default, only after compilation. This transpose_inplace would create one that was a view by default. Now that it is always, it wasn't needed. I added an explicit assert

@ricardoV94
Copy link
Member Author

Just some minor failure, ready for review

@ricardoV94 ricardoV94 force-pushed the transfer_type branch 2 times, most recently from 63f2dab to 7fd9ea4 Compare November 27, 2025 15:40
try:
return type(op)(scalar_op, inplace_pattern).make_node(*node.inputs)
except TypeError:
# Elemwise raises TypeError if we try to inplace an output on an input of a different dtype
Copy link
Member Author

@ricardoV94 ricardoV94 Nov 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No prior test caused this to happen, and I suspect it would be rare, but added an explicit handling and test.

In general it would be nice that node.op.make_node(*node.inputs) would always return the same type, but global config variables can change the behavior of Scalar.make_node. The transfer_type was working as a kind of freeze_output_type here I guess.

We can revisit this approach later, and perhaps put the override behavior at the level of Elemwise, not ScalarOp. But for now I'm inclined to go with simplicity.

Note that with the default config.cast_policy="custom" or optional config.cast_policy="numpy" floatX plays no role in Scalar.make_node, it only affects the initial creation of tensor/constants. Only config.cast_policy="numpy+floatX" could lead to this edge case.

Copy link
Member Author

@ricardoV94 ricardoV94 Nov 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PS: This sort of issue could show up also during rewrites, since it is invalid to return a replacement of a different dtype. So even something like log(add(1, x)) -> log1p(x) could be problematic, and always requires an explicit cast for safety. This just shows how the behavior puts complexity everywhere, and this was but one instance of that.

…acing

This helper could arbitrarily override the default output_type from `ScalarOp.make_node` so that the output type matched one of the input types.
This can be used to create artificial Op signatures that don't make sense or can't be cleanly implemented in other backends. For instance an Add with signature (int8,int64)->int8.

This helper was historically used in:
 1. Elemwise inplace rewrite. I assume as a preventive measure. However, regular use should never require changing the ScalarOp signature, as we only try to inplace on inputs that match the output dtype and recreating the same Op with the same input types should always return the same output type. ScalarOp don't have a concept of inplace, only the Elemwise wrapper does, and it shouldn't require recreating/mutating the original Op.
 2. SecondOp. Here it makes sense, but a custom static_method works just as well
 3. Inplace tests with the inplace variants of `@scalar_elemwise` decorator. These test Classes were removed. It still didn't make sense to test/force Ops to have an artifical signature for the sake of tests. They were removed anyway
@ricardoV94 ricardoV94 merged commit ae499a4 into pymc-devs:main Nov 28, 2025
56 checks passed
@ricardoV94 ricardoV94 deleted the transfer_type branch November 28, 2025 14:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Get rid of pre-defined inplace Elemwise Operators?

2 participants