Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import numpy as np
import torch
import torch.compiler

from pytensor import In
from pytensor.compile import PYTORCH
Expand Down
39 changes: 24 additions & 15 deletions pytensor/tensor/blas_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,20 +417,6 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, must_initialize_y=False, params=N
}
}

if (%(must_initialize_y)d && dbeta == 0)
{
// Most BLAS implementations of GEMV ignore y=nan when beta=0
// PyTensor considers that the correct behavior,
// and even exploits it to avoid copying or initializing outputs.
// By deciding to exploit this, however, it becomes our responsibility
// to ensure the behavior even in the rare cases BLAS deviates,
// or users will get errors, even for graphs that had no nan to begin with.
PyObject *zero = PyFloat_FromDouble(0.);
if (zero == NULL) %(fail)s;
if (PyArray_FillWithScalar(%(z)s, zero) != 0) %(fail)s;
Py_DECREF(zero);
}

{
int NA0 = PyArray_DIMS(%(A)s)[0];
int NA1 = PyArray_DIMS(%(A)s)[1];
Expand All @@ -439,6 +425,17 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, must_initialize_y=False, params=N
{
// Non-empty A matrix

if (%(must_initialize_y)d && dbeta == 0)
{
// Most BLAS implementations of GEMV ignore y=nan when beta=0
// PyTensor considers that the correct behavior,
// and even exploits it to avoid copying or initializing outputs.
// By deciding to exploit this, however, it becomes our responsibility
// to ensure the behavior even in the rare cases BLAS deviates,
// or users will get errors, even for graphs that had no nan to begin with.
PyArray_FILLWBYTE(%(z)s, 0);
}

/* In the case where A is actually a row or column matrix,
* the strides corresponding to the dummy dimension don't matter,
* but BLAS requires these to be no smaller than the number of elements in the array.
Expand Down Expand Up @@ -567,6 +564,18 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, must_initialize_y=False, params=N
"A is neither C nor F-contiguous, it should have been copied into a memory-contiguous array;");
%(fail)s
}
} else
{
// Empty A matrix, just scale y by beta
if (dbeta != 1.0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not if dbeta == 1?

Copy link
Member Author

@ricardoV94 ricardoV94 Nov 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then you get the initial y. We don't create gemv with uninitialized y when dbeta!=0

That would show up in y = y * beta + A @ x * alpha. You get whatever the user y was. Multiplying by 1 wouldn't change it

Copy link
Member Author

@ricardoV94 ricardoV94 Nov 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or I may have missed the motivation for your question. This is an eager optimization the theano guys had, do you think it's silly?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah of course. 🤦
Looks good to me, except for the test failure then.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just the unrelated torch module or did miss another?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's just that

{
npy_intp Sz = PyArray_STRIDES(%(z)s)[0] / elemsize;
dtype_%(z)s* z_data = (dtype_%(z)s*) PyArray_DATA(%(z)s);
for (npy_intp i = 0; i < NA0; ++i)
{
z_data[i * Sz] = (dbeta == 0.0) ? 0 : z_data[i * Sz] * dbeta;
}
}
}
}
"""
Expand Down Expand Up @@ -598,7 +607,7 @@ def c_code(self, node, name, inp, out, sub):
return code

def c_code_cache_version(self):
return (17, blas_header_version(), must_initialize_y_gemv())
return (18, blas_header_version(), must_initialize_y_gemv())


cgemv_inplace = CGemv(inplace=True)
Expand Down
48 changes: 39 additions & 9 deletions tests/tensor/test_blas_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,15 @@
from pytensor.tensor.basic import AllocEmpty
from pytensor.tensor.blas import Ger
from pytensor.tensor.blas_c import CGemv, CGer, must_initialize_y_gemv
from pytensor.tensor.type import dmatrix, dvector, matrix, scalar, tensor, vector
from pytensor.tensor.type import (
dmatrix,
dscalar,
dvector,
matrix,
scalar,
tensor,
vector,
)
from tests import unittest_tools
from tests.tensor.test_blas import BaseGemv, TestBlasStrides
from tests.unittest_tools import OptimizationTestMixin
Expand Down Expand Up @@ -143,19 +151,21 @@ def setup_method(self):
def test_nan_beta_0(self, inplace):
mode = self.mode.including()
mode.check_isfinite = False
beta = self.a.type("beta")
f = pytensor.function(
[self.A, self.x, pytensor.In(self.y, mutable=inplace), self.a],
self.a * self.y + pt.dot(self.A, self.x),
[self.A, self.x, pytensor.In(self.y, mutable=inplace), beta],
beta * self.y + pt.dot(self.A, self.x),
mode=mode,
)
[node] = f.maker.fgraph.apply_nodes
assert isinstance(node.op, CGemv) and node.op.inplace == inplace
for rows in (3, 1):
Aval = np.ones((rows, 1), dtype=self.dtype)
xval = np.ones((1,), dtype=self.dtype)
yval = np.full((rows,), np.nan, dtype=self.dtype)
zval = f(Aval, xval, yval, 0)
assert not np.isnan(zval).any()
for rows in (3, 1, 0):
for cols in (1, 0):
Aval = np.ones((rows, cols), dtype=self.dtype)
xval = np.ones((cols,), dtype=self.dtype)
yval = np.full((rows,), np.nan, dtype=self.dtype)
zval = f(Aval, xval, yval, beta=0)
assert not np.isnan(zval).any(), f"{rows=}, {cols=}"

def test_optimizations_vm(self):
skip_if_blas_ldflags_empty()
Expand Down Expand Up @@ -294,6 +304,26 @@ def test_multiple_inplace(self):
== 2
)

def test_empty_A(self):
A = dmatrix("A")
x = dvector("x")
y = dvector("y")
alpha = 1.0
beta = dscalar("beta")
gemv = CGemv(inplace=True)(y, alpha, A, x, beta)
fn = pytensor.function(
[A, x, y, beta],
gemv,
accept_inplace=True,
)
test_A = np.empty((10, 0))
test_x = np.empty((0,))
test_y = np.random.random((10,))
for test_beta in [0.0, 1.0, 2.0]:
out = fn(test_A, test_x, test_y.copy(), test_beta)
expected = test_beta * test_y
np.testing.assert_allclose(out, expected)


class TestCGemvFloat32(BaseGemv, OptimizationTestMixin):
mode = mode_blas_opt
Expand Down
Loading