Skip to content

Commit 70e8e3d

Browse files
committed
Bug fix in DE ssfp/bssfp + more consistent derivative wrt exchange
1 parent 5de4bf3 commit 70e8e3d

File tree

5 files changed

+106
-54
lines changed

5 files changed

+106
-54
lines changed

examples/scripts/02-derivatives.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def crlb_finitediff_cost(flip, ESP, T1, T2):
105105

106106
# plot derivative
107107
#%%plots
108-
fsz = 20
108+
fsz = 14
109109
plt.figure()
110110
plt.subplot(2,2,1)
111111
plt.rcParams.update({'font.size': 0.5 * fsz})
@@ -120,7 +120,7 @@ def crlb_finitediff_cost(flip, ESP, T1, T2):
120120
plt.xlabel("Echo #", fontsize=fsz)
121121
plt.xlim([-1, 97])
122122
plt.ylabel(r"$\frac{\partial signal}{\partial T2}$ [a.u.]", fontsize=fsz)
123-
plt.legend(["Finite Diff", "Auto Diff"])
123+
plt.legend(["Finite Diff", "Auto Diff"], prop={"size": 14})
124124

125125

126126
plt.subplot(2,2,3)
@@ -129,7 +129,7 @@ def crlb_finitediff_cost(flip, ESP, T1, T2):
129129
plt.xlabel("Echo #", fontsize=fsz)
130130
plt.xlim([-1, 97])
131131
plt.ylabel(r"$\frac{\partial CRLB}{\partial FA}$ [a.u.]", fontsize=fsz)
132-
plt.legend(["Finite Diff", "Auto Diff"])
132+
plt.legend(["Finite Diff", "Auto Diff"], prop={"size": 14})
133133

134134
plt.subplot(2,2,4)
135135

@@ -148,8 +148,8 @@ def crlb_finitediff_cost(flip, ESP, T1, T2):
148148
# Add some text for labels, title and custom x-axis tick labels, etc.
149149
plt.ylabel('Execution Time [s]', fontsize=fsz)
150150
plt.xticks(x, labels, fontsize=fsz)
151-
# plt.ylim([0, 25])
152-
plt.legend()
151+
plt.ylim([0, 50])
152+
plt.legend(loc='upper left', prop={"size": 14})
153153

154154
plt.bar_label(rects1, padding=3, fontsize=fsz)
155155
plt.bar_label(rects2, padding=3, fontsize=fsz)

src/epgtorchx/bloch/model/base.py

Lines changed: 66 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -273,52 +273,68 @@ def __post_init__(self): # noqa
273273
((1 - self.weight_mt) * weight_free, self.weight_mt), axis=-1
274274
)
275275

276+
# build exchange matrix
277+
# if self.model == "bm":
278+
# # build exchange matrix rows
279+
# k0 = 0 * self.kbm
280+
# kiew = torch.cat(
281+
# (k0, self.kbm * self.weight[..., [0]]), axis=-1
282+
# ) # intra-/extra-cellular water
283+
# kmw = torch.cat(
284+
# (self.kbm * self.weight[..., [1]], k0), axis=-1
285+
# ) # myelin water
286+
# self.k = torch.stack((kiew, kmw), axis=-2)
287+
# elif self.model == "mt":
288+
# k0 = 0 * self.kmt
289+
# # build exchange matrix rows
290+
# kfree = torch.cat(
291+
# (k0, self.kmt * self.weight[..., [0]]), axis=-1
292+
# ) # myelin water (exchange both with intra-/extra-cellular water and semisolid)
293+
# kbound = torch.cat(
294+
# (self.kmt * self.weight[..., [1]], k0), axis=-1
295+
# ) # semisolid (exchange with myelin water only)
296+
# self.k = torch.stack((kfree, kbound), axis=-2)
297+
# elif self.model == "bm-mt":
298+
# k0 = 0 * self.kbm
299+
# # build exchange matrix rows
300+
# kiew = torch.cat(
301+
# (k0, self.kbm * self.weight[..., [1]], k0), axis=-1
302+
# ) # intra-/extra-cellular water (exchange with myelin water only)
303+
# kmw = torch.cat(
304+
# (self.kbm * self.weight[..., [0]], k0, self.kmt * self.weight[..., 2]),
305+
# axis=-1,
306+
# ) # myelin water (exchange both with intra-/extra-cellular water and semisolid)
307+
# kbound = torch.cat(
308+
# (k0, self.kmt * self.weight[..., [1]], k0), axis=-1
309+
# ) # semisolid (exchange with myelin water only)
310+
# self.k = np.stack((kiew, kmw, kbound), axis=-2)
311+
# else:
312+
# self.k = None
313+
314+
# # finalize exchange
315+
# if self.k is not None:
316+
# self.k = _particle_conservation(self.k)
317+
318+
# # single pool voxels do not exchange
319+
# idx = (self.weight == 1).sum(axis=-1) == 1
320+
# self.k[idx, :, :] = 0.0
321+
276322
# build exchange matrix
277323
if self.model == "bm":
278-
# build exchange matrix rows
279-
k0 = 0 * self.kbm
280-
kiew = torch.cat(
281-
(k0, self.kbm * self.weight[..., [0]]), axis=-1
282-
) # intra-/extra-cellular water
283-
kmw = torch.cat(
284-
(self.kbm * self.weight[..., [1]], k0), axis=-1
285-
) # myelin water
286-
self.k = torch.stack((kiew, kmw), axis=-2)
324+
self.k = self.kbm
287325
elif self.model == "mt":
288-
k0 = 0 * self.kmt
289-
# build exchange matrix rows
290-
kfree = torch.cat(
291-
(k0, self.kmt * self.weight[..., [0]]), axis=-1
292-
) # myelin water (exchange both with intra-/extra-cellular water and semisolid)
293-
kbound = torch.cat(
294-
(self.kmt * self.weight[..., [1]], k0), axis=-1
295-
) # semisolid (exchange with myelin water only)
296-
self.k = torch.stack((kfree, kbound), axis=-2)
326+
self.k = self.kmt
297327
elif self.model == "bm-mt":
298-
k0 = 0 * self.kbm
299-
# build exchange matrix rows
300-
kiew = torch.cat(
301-
(k0, self.kbm * self.weight[..., [1]], k0), axis=-1
302-
) # intra-/extra-cellular water (exchange with myelin water only)
303-
kmw = torch.cat(
304-
(self.kbm * self.weight[..., [0]], k0, self.kmt * self.weight[..., 2]),
305-
axis=-1,
306-
) # myelin water (exchange both with intra-/extra-cellular water and semisolid)
307-
kbound = torch.cat(
308-
(k0, self.kmt * self.weight[..., [1]], k0), axis=-1
309-
) # semisolid (exchange with myelin water only)
310-
self.k = np.stack((kiew, kmw, kbound), axis=-2)
328+
self.k = torch.cat((self.kbm, self.kmt), axis=-1)
311329
else:
312330
self.k = None
313331

314332
# finalize exchange
315333
if self.k is not None:
316-
self.k = _particle_conservation(self.k)
317-
318334
# single pool voxels do not exchange
319-
idx = (self.weight == 1).sum(axis=-1) == 1
320-
self.k[idx, :, :] = 0.0
321-
335+
idx = torch.isclose(self.weight, torch.tensor(1.0)).sum(axis=-1) == 1
336+
self.k[idx, :] = 0.0
337+
322338
# chemical shift
323339
if self.model is not None and "bm" in self.model:
324340
if self.chemshift is not None and self.chemshift_bm is None:
@@ -422,13 +438,16 @@ def get_sim_inputs(self, modelsig): # noqa
422438
def reformat(self, input): # noqa
423439
# handle tuples
424440
if isinstance(input, (list, tuple)):
425-
output = [item[..., 0, 0] + 1j * item[..., -1, 0] for item in input]
426-
441+
output = [item[..., 0, :] + 1j * item[..., -1, :] for item in input]
442+
# output = [torch.diagonal(item, dim1=-2, dim2=-1) if len(item.shape) == 4 else item for item in output]
443+
output = [item.reshape(*item.shape[:2], -1) if len(item.shape) == 4 else item for item in output]
444+
427445
# stack
428446
if len(output) == 1:
429447
output = output[0]
430448
else:
431-
output = torch.stack(output, dim=1)
449+
output = torch.concatenate(output, dim=-1)
450+
output = output.permute(2, 0, 1)
432451
else:
433452
output = input[..., 0] + 1j * input[..., -1]
434453

@@ -543,16 +562,16 @@ def __call__(self, **seq_kwargs): # noqa
543562

544563

545564
# %% local utils
546-
def _particle_conservation(k):
547-
"""Adjust diagonal of exchange matrix by imposing particle conservation."""
548-
# get shape
549-
npools = k.shape[-1]
565+
# def _particle_conservation(k):
566+
# """Adjust diagonal of exchange matrix by imposing particle conservation."""
567+
# # get shape
568+
# npools = k.shape[-1]
550569

551-
for n in range(npools):
552-
k[..., n, n] = 0.0 # ignore existing diagonal
553-
k[..., n, n] = -k[..., n].sum(dim=-1)
570+
# for n in range(npools):
571+
# k[..., n, n] = 0.0 # ignore existing diagonal
572+
# k[..., n, n] = -k[..., n].sum(dim=-1)
554573

555-
return k
574+
# return k
556575

557576
def inspect_signature(input):
558577
return list(inspect.signature(input).parameters)

src/epgtorchx/bloch/model/bssfpmrf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def sequence(
234234
# prepare free precession period
235235
Xpre, Xpost = blocks.bSSFPFidStep(states, TE, TR, T1, T2, weight, k, chemshift)
236236

237-
for r in (nreps):
237+
for r in range(nreps):
238238
# magnetization prep
239239
states = Prep(states)
240240

src/epgtorchx/bloch/model/ssfpmrf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def sequence(
280280
states, TE, TR, T1, T2, weight, k, chemshift, D, v, grad_props
281281
)
282282

283-
for r in (nreps):
283+
for r in range(nreps):
284284
# magnetization prep
285285
states = Prep(states)
286286

src/epgtorchx/bloch/ops/_relaxation_op.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ def __init__(
4848
weight = torch.as_tensor(weight, dtype=torch.float32, device=device)
4949
if k is not None:
5050
k = torch.as_tensor(k, dtype=torch.float32, device=device)
51+
k = _prepare_exchange(weight, k)
52+
5153
if df is not None:
5254
df = torch.as_tensor(df, dtype=torch.float32, device=device)
5355
df = torch.atleast_1d(df)
@@ -102,6 +104,37 @@ def apply(self, states):
102104

103105

104106
# %% local utils
107+
def _prepare_exchange(weight, k):
108+
109+
# prepare
110+
if k.shape[-1] == 1: # BM or MT
111+
k0 = 0 * k
112+
k1 = torch.cat((k0, k * weight[..., [0]]), axis=-1)
113+
k2 = torch.cat((k * weight[..., [1]], k0), axis=-1)
114+
k = torch.stack((k1, k2), axis=-2)
115+
else: # BM-MT
116+
k0 = 0 * k[..., [0]]
117+
k1 = torch.cat((k0, k[...,[0]] * weight[..., [1]], k0), axis=-1)
118+
k2 = torch.cat((k[..., [0]] * weight[..., [0]], k0, k[..., [1]] * weight[..., [2]]), axis=-1)
119+
k3 = torch.cat((k0, k[...,[1]] * weight[..., [1]], k0), axis=-1)
120+
k = torch.stack((k1, k2, k3), axis=-2)
121+
122+
123+
# finalize exchange
124+
return _particle_conservation(k)
125+
126+
127+
def _particle_conservation(k):
128+
"""Adjust diagonal of exchange matrix by imposing particle conservation."""
129+
# get shape
130+
npools = k.shape[-1]
131+
132+
for n in range(npools):
133+
k[..., n, n] = 0.0 # ignore existing diagonal
134+
k[..., n, n] = -k[..., n].sum(dim=-1)
135+
136+
return k
137+
105138
def _transverse_relax_apply(states, E2):
106139
"""
107140
Apply transverse relaxation operator.

0 commit comments

Comments
 (0)