@@ -35,13 +35,12 @@ def bsp_solve(np.ndarray[double, ndim=2, mode="c"] X, np.ndarray[double, ndim=2,
3535 """
3636 cdef int n = X.shape[0 ]
3737 cdef int d = X.shape[1 ]
38- cdef uint64_t nb_plans = plans.shape[0 ]
39- cdef np.ndarray[np.int64, ndim= 2 , mode= " c" ] plans = np.zeros((n, n_plan), dtype = np.int64)
40- cdef np.ndarray[np.int64, ndim= 2 , mode= " c" ] plann = np.zeros((n,), dtype = np.int64)
38+ cdef np.ndarray[int , ndim= 2 , mode= " c" ] plans = np.zeros((n, n_plans), dtype = np.int64)
39+ cdef np.ndarray[int , ndim= 2 , mode= " c" ] plan = np.zeros((n,), dtype = np.int64)
4140
4241 cdef double cost
4342
44- cost = BSPOT_wrap(n, n, d, < double * > X.data, < double * > Y.data, nb_plans , < int * > c_plans .data, < int * > c_plan .data)
43+ cost = BSPOT_wrap(n, n, d, < double * > X.data, < double * > Y.data, n_plans , < int * > plans .data, < int * > plan .data)
4544
4645 # add
4746
@@ -50,7 +49,7 @@ def bsp_solve(np.ndarray[double, ndim=2, mode="c"] X, np.ndarray[double, ndim=2,
5049
5150@ cython.boundscheck (False )
5251@ cython.wraparound (False )
53- def merge_plans (np.ndarray[np. int64 , ndim = 2 , mode = " c" ] plans):
52+ def merge_plans (np.ndarray[int , ndim = 2 , mode = " c" ] plans):
5453 """
5554 Merges OT plans
5655
0 commit comments