119119 import jax
120120 import jax .numpy as jnp
121121 import jax .scipy .special as jspecial
122- from jax .lib import xla_bridge
122+ from jax .extend . backend import get_backend as _jax_get_backend
123123
124124 jax_type = jax .numpy .ndarray
125125 jax_new_version = float ("." .join (jax .__version__ .split ("." )[1 :])) > 4.24
@@ -178,7 +178,16 @@ def _get_backend_instance(backend_impl):
178178
179179
180180def _check_args_backend (backend_impl , args ):
181- is_instance = set (isinstance (arg , backend_impl .__type__ ) for arg in args )
181+ # Get backend instance to use issparse method
182+ backend = _get_backend_instance (backend_impl )
183+
184+ # Check if each arg is either:
185+ # 1. An instance of backend.__type__ (e.g., np.ndarray for NumPy)
186+ # 2. A sparse matrix recognized by backend.issparse() (e.g., scipy.sparse for NumPy)
187+ is_instance = set (
188+ isinstance (arg , backend_impl .__type__ ) or backend .issparse (arg ) for arg in args
189+ )
190+
182191 # check that all arguments matched or not the type
183192 if len (is_instance ) == 1 :
184193 return is_instance .pop ()
@@ -839,6 +848,31 @@ def todense(self, a):
839848 """
840849 raise NotImplementedError ()
841850
851+ def sparse_coo_data (self , a ):
852+ r"""
853+ Extracts COO format data (row, col, data, shape) from a sparse matrix.
854+
855+ Returns row indices, column indices, data values, and shape as numpy arrays/tuple.
856+ This is used to interface with C++ solvers that require explicit edge lists.
857+
858+ Parameters
859+ ----------
860+ a : sparse matrix
861+ Sparse matrix in backend's COO format
862+
863+ Returns
864+ -------
865+ row : numpy.ndarray
866+ Row indices (1D array)
867+ col : numpy.ndarray
868+ Column indices (1D array)
869+ data : numpy.ndarray
870+ Data values (1D array)
871+ shape : tuple
872+ Shape of the matrix (n_rows, n_cols)
873+ """
874+ raise NotImplementedError ()
875+
842876 def where (self , condition , x , y ):
843877 r"""
844878 Returns elements chosen from x or y depending on condition.
@@ -1349,6 +1383,15 @@ def todense(self, a):
13491383 else :
13501384 return a
13511385
1386+ def sparse_coo_data (self , a ):
1387+ # Convert to COO format if needed
1388+ if not isinstance (a , coo_matrix ):
1389+ a_coo = coo_matrix (a )
1390+ else :
1391+ a_coo = a
1392+
1393+ return a_coo .row , a_coo .col , a_coo .data , a_coo .shape
1394+
13521395 def where (self , condition , x = None , y = None ):
13531396 if x is None and y is None :
13541397 return np .where (condition )
@@ -1509,7 +1552,7 @@ def __init__(self):
15091552 self .__type_list__ = []
15101553 # available_devices = jax.devices("cpu")
15111554 available_devices = []
1512- if xla_bridge . get_backend ().platform == "gpu" :
1555+ if _jax_get_backend ().platform == "gpu" :
15131556 available_devices += jax .devices ("gpu" )
15141557 for d in available_devices :
15151558 self .__type_list__ += [
@@ -1768,6 +1811,15 @@ def todense(self, a):
17681811 # Currently, JAX does not support sparse matrices
17691812 return a
17701813
1814+ def sparse_coo_data (self , a ):
1815+ # JAX doesn't support sparse matrices, so this shouldn't be called
1816+ # But if it is, convert the dense array to sparse using scipy
1817+ a_np = self .to_numpy (a )
1818+ from scipy .sparse import coo_matrix
1819+
1820+ a_coo = coo_matrix (a_np )
1821+ return a_coo .row , a_coo .col , a_coo .data , a_coo .shape
1822+
17711823 def where (self , condition , x = None , y = None ):
17721824 if x is None and y is None :
17731825 return jnp .where (condition )
@@ -1938,6 +1990,7 @@ def __init__(self):
19381990 self .rng_cuda_ = torch .Generator ("cpu" )
19391991
19401992 from torch .autograd import Function
1993+ from torch .autograd .function import once_differentiable
19411994
19421995 # define a function that takes inputs val and grads
19431996 # ad returns a val tensor with proper gradients
@@ -1952,7 +2005,31 @@ def backward(ctx, grad_output):
19522005 # the gradients are grad
19532006 return (None , None ) + tuple (g * grad_output for g in ctx .grads )
19542007
2008+ # define a differentiable SPD matrix sqrt
2009+ # with closed-form VJP
2010+ class MatrixSqrtFunction (Function ):
2011+ @staticmethod
2012+ def forward (ctx , a ):
2013+ a_sym = 0.5 * (a + a .transpose (- 2 , - 1 ))
2014+ L , V = torch .linalg .eigh (a_sym )
2015+ s = L .clamp_min (0 ).sqrt ()
2016+ y = (V * s .unsqueeze (- 2 )) @ V .transpose (- 2 , - 1 )
2017+ ctx .save_for_backward (s , V )
2018+ return y
2019+
2020+ @staticmethod
2021+ @once_differentiable
2022+ def backward (ctx , g ):
2023+ s , V = ctx .saved_tensors
2024+ g_sym = 0.5 * (g + g .transpose (- 2 , - 1 ))
2025+ ghat = V .transpose (- 2 , - 1 ) @ g_sym @ V
2026+ d = s .unsqueeze (- 1 ) + s .unsqueeze (- 2 )
2027+ xhat = ghat / d
2028+ xhat = xhat .masked_fill (d == 0 , 0 )
2029+ return V @ xhat @ V .transpose (- 2 , - 1 )
2030+
19552031 self .ValFunction = ValFunction
2032+ self .MatrixSqrtFunction = MatrixSqrtFunction
19562033
19572034 def _to_numpy (self , a ):
19582035 if isinstance (a , float ) or isinstance (a , int ) or isinstance (a , np .ndarray ):
@@ -2315,6 +2392,20 @@ def todense(self, a):
23152392 else :
23162393 return a
23172394
2395+ def sparse_coo_data (self , a ):
2396+ # For torch sparse tensors, coalesce first to ensure unique indices
2397+ a_coalesced = a .coalesce ()
2398+ indices = a_coalesced ._indices ()
2399+ values = a_coalesced ._values ()
2400+
2401+ # Convert to numpy
2402+ row = self .to_numpy (indices [0 ])
2403+ col = self .to_numpy (indices [1 ])
2404+ data = self .to_numpy (values )
2405+ shape = tuple (a_coalesced .shape )
2406+
2407+ return row , col , data , shape
2408+
23182409 def where (self , condition , x = None , y = None ):
23192410 if x is None and y is None :
23202411 return torch .where (condition )
@@ -2395,12 +2486,7 @@ def pinv(self, a, hermitian=False):
23952486 return torch .linalg .pinv (a , hermitian = hermitian )
23962487
23972488 def sqrtm (self , a ):
2398- L , V = torch .linalg .eigh (a )
2399- L = torch .sqrt (L )
2400- # Q[...] = V[...] @ diag(L[...])
2401- Q = torch .einsum ("...jk,...k->...jk" , V , L )
2402- # R[...] = Q[...] @ V[...].T
2403- return torch .einsum ("...jk,...kl->...jl" , Q , torch .transpose (V , - 1 , - 2 ))
2489+ return self .MatrixSqrtFunction .apply (a )
24042490
24052491 def eigh (self , a ):
24062492 return torch .linalg .eigh (a )
0 commit comments