File tree Expand file tree Collapse file tree 1 file changed +5
-2
lines changed
Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Original file line number Diff line number Diff line change @@ -873,21 +873,24 @@ def sliced_plans(
873873 )
874874 ** (1 / p )
875875 )
876- * plan [k ]. data
876+ * plan [k ]
877877 )
878878 for k in range (n_proj )
879879 ]
880880 else : # metric == "sqeuclidean"
881881 costs = [
882882 nx .sum (
883883 (nx .sum ((X [plan [k ].row ] - Y [plan [k ].col ]) ** 2 , axis = 1 ))
884- * plan [k ]. data
884+ * plan [k ]
885885 )
886886 for k in range (n_proj )
887887 ]
888888
889889 if dense and not str (nx ) == "jax" :
890890 plan = [nx .todense (plan [k ]) for k in range (n_proj )]
891+ elif str (nx ) == "jax" :
892+ warnings .warn ("JAX does not support sparse matrices, converting to dense" )
893+ plan = [nx .todense (plan [k ]) for k in range (n_proj )]
891894
892895 if log :
893896 log_dict = {"X_theta" : X_theta , "Y_theta" : Y_theta , "thetas" : thetas }
You can’t perform that action at this time.
0 commit comments