@@ -273,52 +273,68 @@ def __post_init__(self): # noqa
273273 ((1 - self .weight_mt ) * weight_free , self .weight_mt ), axis = - 1
274274 )
275275
276+ # build exchange matrix
277+ # if self.model == "bm":
278+ # # build exchange matrix rows
279+ # k0 = 0 * self.kbm
280+ # kiew = torch.cat(
281+ # (k0, self.kbm * self.weight[..., [0]]), axis=-1
282+ # ) # intra-/extra-cellular water
283+ # kmw = torch.cat(
284+ # (self.kbm * self.weight[..., [1]], k0), axis=-1
285+ # ) # myelin water
286+ # self.k = torch.stack((kiew, kmw), axis=-2)
287+ # elif self.model == "mt":
288+ # k0 = 0 * self.kmt
289+ # # build exchange matrix rows
290+ # kfree = torch.cat(
291+ # (k0, self.kmt * self.weight[..., [0]]), axis=-1
292+ # ) # myelin water (exchange both with intra-/extra-cellular water and semisolid)
293+ # kbound = torch.cat(
294+ # (self.kmt * self.weight[..., [1]], k0), axis=-1
295+ # ) # semisolid (exchange with myelin water only)
296+ # self.k = torch.stack((kfree, kbound), axis=-2)
297+ # elif self.model == "bm-mt":
298+ # k0 = 0 * self.kbm
299+ # # build exchange matrix rows
300+ # kiew = torch.cat(
301+ # (k0, self.kbm * self.weight[..., [1]], k0), axis=-1
302+ # ) # intra-/extra-cellular water (exchange with myelin water only)
303+ # kmw = torch.cat(
304+ # (self.kbm * self.weight[..., [0]], k0, self.kmt * self.weight[..., 2]),
305+ # axis=-1,
306+ # ) # myelin water (exchange both with intra-/extra-cellular water and semisolid)
307+ # kbound = torch.cat(
308+ # (k0, self.kmt * self.weight[..., [1]], k0), axis=-1
309+ # ) # semisolid (exchange with myelin water only)
310+ # self.k = np.stack((kiew, kmw, kbound), axis=-2)
311+ # else:
312+ # self.k = None
313+
314+ # # finalize exchange
315+ # if self.k is not None:
316+ # self.k = _particle_conservation(self.k)
317+
318+ # # single pool voxels do not exchange
319+ # idx = (self.weight == 1).sum(axis=-1) == 1
320+ # self.k[idx, :, :] = 0.0
321+
276322 # build exchange matrix
277323 if self .model == "bm" :
278- # build exchange matrix rows
279- k0 = 0 * self .kbm
280- kiew = torch .cat (
281- (k0 , self .kbm * self .weight [..., [0 ]]), axis = - 1
282- ) # intra-/extra-cellular water
283- kmw = torch .cat (
284- (self .kbm * self .weight [..., [1 ]], k0 ), axis = - 1
285- ) # myelin water
286- self .k = torch .stack ((kiew , kmw ), axis = - 2 )
324+ self .k = self .kbm
287325 elif self .model == "mt" :
288- k0 = 0 * self .kmt
289- # build exchange matrix rows
290- kfree = torch .cat (
291- (k0 , self .kmt * self .weight [..., [0 ]]), axis = - 1
292- ) # myelin water (exchange both with intra-/extra-cellular water and semisolid)
293- kbound = torch .cat (
294- (self .kmt * self .weight [..., [1 ]], k0 ), axis = - 1
295- ) # semisolid (exchange with myelin water only)
296- self .k = torch .stack ((kfree , kbound ), axis = - 2 )
326+ self .k = self .kmt
297327 elif self .model == "bm-mt" :
298- k0 = 0 * self .kbm
299- # build exchange matrix rows
300- kiew = torch .cat (
301- (k0 , self .kbm * self .weight [..., [1 ]], k0 ), axis = - 1
302- ) # intra-/extra-cellular water (exchange with myelin water only)
303- kmw = torch .cat (
304- (self .kbm * self .weight [..., [0 ]], k0 , self .kmt * self .weight [..., 2 ]),
305- axis = - 1 ,
306- ) # myelin water (exchange both with intra-/extra-cellular water and semisolid)
307- kbound = torch .cat (
308- (k0 , self .kmt * self .weight [..., [1 ]], k0 ), axis = - 1
309- ) # semisolid (exchange with myelin water only)
310- self .k = np .stack ((kiew , kmw , kbound ), axis = - 2 )
328+ self .k = torch .cat ((self .kbm , self .kmt ), axis = - 1 )
311329 else :
312330 self .k = None
313331
314332 # finalize exchange
315333 if self .k is not None :
316- self .k = _particle_conservation (self .k )
317-
318334 # single pool voxels do not exchange
319- idx = (self .weight == 1 ).sum (axis = - 1 ) == 1
320- self .k [idx , :, : ] = 0.0
321-
335+ idx = torch . isclose (self .weight , torch . tensor ( 1.0 ) ).sum (axis = - 1 ) == 1
336+ self .k [idx , :] = 0.0
337+
322338 # chemical shift
323339 if self .model is not None and "bm" in self .model :
324340 if self .chemshift is not None and self .chemshift_bm is None :
@@ -422,13 +438,16 @@ def get_sim_inputs(self, modelsig): # noqa
422438 def reformat (self , input ): # noqa
423439 # handle tuples
424440 if isinstance (input , (list , tuple )):
425- output = [item [..., 0 , 0 ] + 1j * item [..., - 1 , 0 ] for item in input ]
426-
441+ output = [item [..., 0 , :] + 1j * item [..., - 1 , :] for item in input ]
442+ # output = [torch.diagonal(item, dim1=-2, dim2=-1) if len(item.shape) == 4 else item for item in output]
443+ output = [item .reshape (* item .shape [:2 ], - 1 ) if len (item .shape ) == 4 else item for item in output ]
444+
427445 # stack
428446 if len (output ) == 1 :
429447 output = output [0 ]
430448 else :
431- output = torch .stack (output , dim = 1 )
449+ output = torch .concatenate (output , dim = - 1 )
450+ output = output .permute (2 , 0 , 1 )
432451 else :
433452 output = input [..., 0 ] + 1j * input [..., - 1 ]
434453
@@ -543,16 +562,16 @@ def __call__(self, **seq_kwargs): # noqa
543562
544563
545564# %% local utils
546- def _particle_conservation (k ):
547- """Adjust diagonal of exchange matrix by imposing particle conservation."""
548- # get shape
549- npools = k .shape [- 1 ]
565+ # def _particle_conservation(k):
566+ # """Adjust diagonal of exchange matrix by imposing particle conservation."""
567+ # # get shape
568+ # npools = k.shape[-1]
550569
551- for n in range (npools ):
552- k [..., n , n ] = 0.0 # ignore existing diagonal
553- k [..., n , n ] = - k [..., n ].sum (dim = - 1 )
570+ # for n in range(npools):
571+ # k[..., n, n] = 0.0 # ignore existing diagonal
572+ # k[..., n, n] = -k[..., n].sum(dim=-1)
554573
555- return k
574+ # return k
556575
557576def inspect_signature (input ):
558577 return list (inspect .signature (input ).parameters )
0 commit comments