66# Nicolas Courty <ncourty@irisa.fr>
77# Rémi Flamary <remi.flamary@polytechnique.edu>
88# Eloi Tanguy <eloi.tanguy@math.cnrs.fr>
9+ # Laetitia Chapel <laetitia.chapel@irisa.fr>
910#
1011# License: MIT License
1112
@@ -820,21 +821,23 @@ def sliced_plans(
820821 for k in range (n_proj )
821822 ]
822823
824+ if not dense and str (nx ) == "jax" :
825+ warnings .warn ("JAX does not support sparse matrices, converting to dense" )
826+ plan = [nx .todense (plan [k ]) for k in range (n_proj )]
827+
823828 else : # we compute plans
824829 _ , plan = wasserstein_1d (
825830 X_theta , Y_theta , a , b , p , require_sort = True , return_plan = True
826831 )
827832
828- if str (nx ) == "tensorflow" : # tf does not support duplicate entries
829- plan = [plan [k ].tocsr ().tocoo () for k in range (n_proj )]
830-
831- if str (nx ) == "jax" :
832- plan = [nx .todense (plan [k ]) for k in range (n_proj )]
833+ if str (nx ) == "jax" : # dense computation for jax
833834 if not dense :
834835 warnings .warn (
835- "JAX does not support sparse matrices, converting" " to dense"
836+ "JAX does not support sparse matrices, converting to dense"
836837 )
837838
839+ plan = [nx .todense (plan [k ]) for k in range (n_proj )]
840+
838841 costs = [
839842 nx .sum (
840843 (
@@ -854,7 +857,11 @@ def sliced_plans(
854857 )
855858 for k in range (n_proj )
856859 ]
860+
857861 else :
862+ if str (nx ) == "tensorflow" : # tf does not support multiple indexing
863+ plan = [plan [k ].tocsr ().tocoo () for k in range (n_proj )]
864+
858865 if metric in ("minkowski" , "euclidean" , "cityblock" ):
859866 costs = [
860867 nx .sum (
@@ -870,7 +877,7 @@ def sliced_plans(
870877 )
871878 for k in range (n_proj )
872879 ]
873- else : # metric = "sqeuclidean"
880+ else : # metric == "sqeuclidean"
874881 costs = [
875882 nx .sum (
876883 (nx .sum ((X [plan [k ].row ] - Y [plan [k ].col ]) ** 2 , axis = 1 ))
@@ -879,7 +886,7 @@ def sliced_plans(
879886 for k in range (n_proj )
880887 ]
881888
882- if dense :
889+ if dense and not str ( nx ) == "jax" :
883890 plan = [nx .todense (plan [k ]) for k in range (n_proj )]
884891
885892 if log :
0 commit comments