@@ -289,15 +289,10 @@ def emd(
289289 ot.optim.cg : General regularized OT
290290 """
291291
292- edge_sources = None
293- edge_targets = None
294- edge_costs = None
295292 n1 , n2 = None , None
296293
297- # Get backend from M first, then use it for list_to_array
298- # This ensures empty lists [] are converted to arrays in the correct backend
299- nx_M = get_backend (M )
300- a , b = list_to_array (a , b , nx = nx_M )
294+ # Convert lists to arrays, using M to detect backend when a,b are empty
295+ a , b , M = list_to_array (a , b , M )
301296 nx = get_backend (a , b , M )
302297
303298 # Check if M is sparse using backend's issparse method
@@ -325,15 +320,6 @@ def emd(
325320 if edge_costs .dtype != np .float64 :
326321 edge_costs = edge_costs .astype (np .float64 )
327322
328- elif isinstance (M , tuple ):
329- raise ValueError (
330- "Tuple format for sparse cost matrix is not supported. "
331- "Please use backend-appropriate sparse COO format (e.g., scipy.sparse.coo_matrix, torch.sparse_coo_tensor, etc.)."
332- )
333- else :
334- is_sparse = False
335- a , b , M = list_to_array (a , b , M )
336-
337323 if len (a ) != 0 :
338324 type_as = a
339325 elif len (b ) != 0 :
@@ -458,10 +444,10 @@ def emd2(
458444 processes = 1 ,
459445 numItermax = 100000 ,
460446 log = False ,
447+ return_matrix = False ,
461448 center_dual = True ,
462449 numThreads = 1 ,
463450 check_marginals = True ,
464- return_matrix = False ,
465451):
466452 r"""Solves the Earth Movers distance problem and returns the loss
467453
@@ -514,7 +500,7 @@ def emd2(
514500 The maximum number of iterations before stopping the optimization
515501 algorithm if it has not converged.
516502 log: boolean, optional (default=False)
517- If True, returns a dictionary containing dual
503+ If True, returns a dictionary containing the cost and dual
518504 variables. Otherwise returns only the optimal transportation cost.
519505 return_matrix: boolean, optional (default=False)
520506 If True, returns the optimal transportation matrix in the log.
@@ -542,8 +528,9 @@ def emd2(
542528 W: float, array-like
543529 Optimal transportation loss for the given parameters
544530 log: dict
545- If input log is true, a dictionary containing dual
546- variables and exit status
531+ If input log is true, a dictionary containing the cost, dual
532+ variables (u, v), exit status, and optionally the optimal
533+ transportation matrix (G) if return_matrix is True
547534
548535
549536 Examples
@@ -575,15 +562,9 @@ def emd2(
575562 ot.optim.cg : General regularized OT
576563 """
577564
578- edge_sources = None
579- edge_targets = None
580- edge_costs = None
581565 n1 , n2 = None , None
582566
583- # Get backend from M first, then use it for list_to_array
584- # This ensures empty lists [] are converted to arrays in the correct backend
585- nx_M = get_backend (M )
586- a , b = list_to_array (a , b , nx = nx_M )
567+ a , b , M = list_to_array (a , b , M )
587568 nx = get_backend (a , b , M )
588569
589570 # Check if M is sparse using backend's issparse method
@@ -596,43 +577,26 @@ def emd2(
596577 # Check if backend supports sparse matrices
597578 backend_name = nx .__class__ .__name__
598579 if backend_name in ["JaxBackend" , "TensorflowBackend" ]:
599- raise NotImplementedError (
600- f"Sparse optimal transport is not supported for { backend_name } . "
601- "JAX does not have native sparse matrix support, and TensorFlow's "
602- "sparse implementation is incomplete. Please convert your sparse "
603- "matrix to dense format using M.toarray() or equivalent before calling emd2()."
604- )
580+ raise NotImplementedError ()
605581
606582 # Save original M for gradient tracking (before numpy conversion)
607583 M_original_sparse = M
608584
609- # Extract COO data using backend method - returns numpy arrays
610585 edge_sources , edge_targets , edge_costs , (n1 , n2 ) = nx .sparse_coo_data (M )
611586
612- # Ensure correct dtypes for C++ solver
613587 if edge_sources .dtype != np .uint64 :
614588 edge_sources = edge_sources .astype (np .uint64 )
615589 if edge_targets .dtype != np .uint64 :
616590 edge_targets = edge_targets .astype (np .uint64 )
617591 if edge_costs .dtype != np .float64 :
618592 edge_costs = edge_costs .astype (np .float64 )
619593
620- elif isinstance (M , tuple ):
621- raise ValueError (
622- "Tuple format for sparse cost matrix is not supported. "
623- "Please use backend-appropriate sparse COO format (e.g., scipy.sparse.coo_matrix, torch.sparse_coo_tensor, etc.)."
624- )
625- else :
626- # Dense matrix
627- is_sparse = False
628- a , b , M = list_to_array (a , b , M )
629-
630594 if len (a ) != 0 :
631595 type_as = a
632596 elif len (b ) != 0 :
633597 type_as = b
634598 else :
635- type_as = a # Can't use M for sparse case
599+ type_as = a
636600
637601 # Set n1, n2 if not already set (dense case)
638602 if n1 is None :
@@ -649,7 +613,6 @@ def emd2(
649613
650614 if is_sparse :
651615 # Use the original sparse tensor (preserves gradients for PyTorch)
652- # instead of converting from numpy
653616 edge_costs_original = M_original_sparse
654617 else :
655618 edge_costs_original = None
@@ -682,12 +645,11 @@ def emd2(
682645 numThreads = check_number_threads (numThreads )
683646
684647 # ============================================================================
685- # DEFINE SOLVER FUNCTION (works for both sparse and dense)
648+ # DEFINE SOLVER FUNCTION
686649 # ============================================================================
687650 def f (b ):
688651 bsel = b != 0
689652
690- # Call appropriate solver
691653 if is_sparse :
692654 # Solve sparse EMD
693655 flow_sources , flow_targets , flow_values , cost , u , v , result_code = (
@@ -745,6 +707,23 @@ def f(b):
745707 grad_M_sparse ,
746708 ),
747709 )
710+
711+ # Build transport plan in backend sparse format
712+ flow_values_backend = nx .from_numpy (flow_values , type_as = type_as )
713+ flow_sources_backend = nx .from_numpy (
714+ flow_sources .astype (np .int64 ), type_as = type_as
715+ )
716+ flow_targets_backend = nx .from_numpy (
717+ flow_targets .astype (np .int64 ), type_as = type_as
718+ )
719+
720+ G_backend = nx .coo_matrix (
721+ flow_values_backend ,
722+ flow_sources_backend ,
723+ flow_targets_backend ,
724+ shape = (n1 , n2 ),
725+ type_as = type_as ,
726+ )
748727 else :
749728 # Dense case: warn about integer casting
750729 if not nx .is_floating_point (type_as ):
@@ -772,20 +751,14 @@ def f(b):
772751 # Return results
773752 if log or return_matrix :
774753 log_dict = {
754+ "cost" : cost ,
775755 "u" : nx .from_numpy (u , type_as = type_as ),
776756 "v" : nx .from_numpy (v , type_as = type_as ),
777757 "warning" : check_result (result_code ),
778758 "result_code" : result_code ,
779759 }
780-
781760 if return_matrix :
782- if is_sparse :
783- G = np .zeros ((len (a ), len (b )), dtype = np .float64 )
784- G [flow_sources , flow_targets ] = flow_values
785- log_dict ["G" ] = nx .from_numpy (G , type_as = type_as )
786- else :
787- log_dict ["G" ] = G_backend
788-
761+ log_dict ["G" ] = G_backend
789762 return [cost , log_dict ]
790763 else :
791764 return cost
0 commit comments