Skip to content

Commit e991d04

Browse files
authored
Merge branch 'master' into master
2 parents 1ae8472 + d3867c6 commit e991d04

File tree

6 files changed

+66
-10
lines changed

6 files changed

+66
-10
lines changed

.github/workflows/build_tests.yml

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ jobs:
2727

2828

2929
- name: Checking Out Repository
30-
uses: actions/checkout@v2
30+
uses: actions/checkout@v4
3131
# Install Python & Packages
3232
- uses: actions/setup-python@v4
3333
with:
@@ -39,6 +39,20 @@ jobs:
3939
pre-commit install --install-hooks
4040
pre-commit run --all-files
4141
42+
build_from_source:
43+
runs-on: ubuntu-latest
44+
steps:
45+
- uses: actions/checkout@v4
46+
- name: Set up Python
47+
uses: actions/setup-python@v5
48+
with:
49+
python-version: "3.12"
50+
- name: Build from source
51+
run: |
52+
python -m pip install --upgrade pip setuptools wheel
53+
python -m pip install cython numpy
54+
python setup.py sdist bdist_wheel
55+
pip install dist/*.tar.gz
4256
4357
linux:
4458

RELEASES.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
# Releases
22

3+
## 0.9.7.dev0
4+
5+
#### Closed issues
6+
- Fix deprecated JAX function in `ot.backend.JaxBackend` (PR #771, Issue #770)
7+
- Add test for build from source (PR #772, Issue #764)
8+
- Stable `ot.TorchBackend.sqrtm` around repeated eigvals (PR #774, Issue #773)
9+
310
## 0.9.6.post1
411

512
*September 2025*

ot/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@
7777
# utils functions
7878
from .utils import dist, unif, tic, toc, toq
7979

80-
__version__ = "0.9.6.post1"
80+
__version__ = "0.9.7.dev0"
8181

8282
__all__ = [
8383
"emd",

ot/backend.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@
119119
import jax
120120
import jax.numpy as jnp
121121
import jax.scipy.special as jspecial
122-
from jax.lib import xla_bridge
122+
from jax.extend.backend import get_backend as _jax_get_backend
123123

124124
jax_type = jax.numpy.ndarray
125125
jax_new_version = float(".".join(jax.__version__.split(".")[1:])) > 4.24
@@ -1509,7 +1509,7 @@ def __init__(self):
15091509
self.__type_list__ = []
15101510
# available_devices = jax.devices("cpu")
15111511
available_devices = []
1512-
if xla_bridge.get_backend().platform == "gpu":
1512+
if _jax_get_backend().platform == "gpu":
15131513
available_devices += jax.devices("gpu")
15141514
for d in available_devices:
15151515
self.__type_list__ += [
@@ -1938,6 +1938,7 @@ def __init__(self):
19381938
self.rng_cuda_ = torch.Generator("cpu")
19391939

19401940
from torch.autograd import Function
1941+
from torch.autograd.function import once_differentiable
19411942

19421943
# define a function that takes inputs val and grads
19431944
# ad returns a val tensor with proper gradients
@@ -1952,7 +1953,31 @@ def backward(ctx, grad_output):
19521953
# the gradients are grad
19531954
return (None, None) + tuple(g * grad_output for g in ctx.grads)
19541955

1956+
# define a differentiable SPD matrix sqrt
1957+
# with closed-form VJP
1958+
class MatrixSqrtFunction(Function):
1959+
@staticmethod
1960+
def forward(ctx, a):
1961+
a_sym = 0.5 * (a + a.transpose(-2, -1))
1962+
L, V = torch.linalg.eigh(a_sym)
1963+
s = L.clamp_min(0).sqrt()
1964+
y = (V * s.unsqueeze(-2)) @ V.transpose(-2, -1)
1965+
ctx.save_for_backward(s, V)
1966+
return y
1967+
1968+
@staticmethod
1969+
@once_differentiable
1970+
def backward(ctx, g):
1971+
s, V = ctx.saved_tensors
1972+
g_sym = 0.5 * (g + g.transpose(-2, -1))
1973+
ghat = V.transpose(-2, -1) @ g_sym @ V
1974+
d = s.unsqueeze(-1) + s.unsqueeze(-2)
1975+
xhat = ghat / d
1976+
xhat = xhat.masked_fill(d == 0, 0)
1977+
return V @ xhat @ V.transpose(-2, -1)
1978+
19551979
self.ValFunction = ValFunction
1980+
self.MatrixSqrtFunction = MatrixSqrtFunction
19561981

19571982
def _to_numpy(self, a):
19581983
if isinstance(a, float) or isinstance(a, int) or isinstance(a, np.ndarray):
@@ -2395,12 +2420,7 @@ def pinv(self, a, hermitian=False):
23952420
return torch.linalg.pinv(a, hermitian=hermitian)
23962421

23972422
def sqrtm(self, a):
2398-
L, V = torch.linalg.eigh(a)
2399-
L = torch.sqrt(L)
2400-
# Q[...] = V[...] @ diag(L[...])
2401-
Q = torch.einsum("...jk,...k->...jk", V, L)
2402-
# R[...] = Q[...] @ V[...].T
2403-
return torch.einsum("...jk,...kl->...jl", Q, torch.transpose(V, -1, -2))
2423+
return self.MatrixSqrtFunction.apply(a)
24042424

24052425
def eigh(self, a):
24062426
return torch.linalg.eigh(a)

ot/stochastic.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ def averaged_sgd_entropic_transport(
178178
179179
Parameters
180180
----------
181+
a : ndarray, shape (ns,)
182+
source measure
181183
b : ndarray, shape (nt,)
182184
target measure
183185
M : ndarray, shape (ns, nt)

test/test_backend.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,19 @@ def fun(a, b, d):
822822
assert nx.allclose(dl_db, b)
823823

824824

825+
def test_sqrtm_backward_torch():
826+
if not torch:
827+
pytest.skip("Torch not available")
828+
nx = ot.backend.TorchBackend()
829+
torch.manual_seed(42)
830+
d = 5
831+
A = torch.randn(d, d, dtype=torch.float64, device="cpu")
832+
A = A @ A.T
833+
A.requires_grad_(True)
834+
func = lambda x: nx.sqrtm(x).sum()
835+
assert torch.autograd.gradcheck(func, (A,), atol=1e-4, rtol=1e-4)
836+
837+
825838
def test_get_backend_none():
826839
a, b = np.zeros((2, 3)), None
827840
nx = get_backend(a, b)

0 commit comments

Comments
 (0)