Skip to content

Commit 674ebad

Browse files
authored
fix(jax): workaround for "xxTracer is not a valid JAX type" (#4776)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Bug Fixes** - Improved array handling in various components to ensure correct data is passed during calculations, enhancing reliability and consistency in model operations. - **Style** - Minor formatting adjustments for improved code readability, with no impact on end-user functionality. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
1 parent 51f61cd commit 674ebad

File tree

11 files changed

+51
-16
lines changed

11 files changed

+51
-16
lines changed

deepmd/dpmodel/atomic_model/pairtab_atomic_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def _pair_tabulated_inter(
293293

294294
uu -= idx
295295
table_coef = self._extract_spline_coefficient(
296-
i_type, j_type, idx, self.tab_data, nspline
296+
i_type, j_type, idx, self.tab_data[...], nspline
297297
)
298298
table_coef = xp.reshape(table_coef, (nframes, nloc, nnei, 4))
299299
ener = self._calculate_ener(table_coef, uu)

deepmd/dpmodel/descriptor/dpa1.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -951,7 +951,11 @@ def call(
951951
xp = array_api_compat.array_namespace(nlist, coord_ext, atype_ext)
952952
# nf x nloc x nnei x 4
953953
dmatrix, diff, sw = self.env_mat.call(
954-
coord_ext, atype_ext, nlist, self.mean, self.stddev
954+
coord_ext,
955+
atype_ext,
956+
nlist,
957+
self.mean[...],
958+
self.stddev[...],
955959
)
956960
nf, nloc, nnei, _ = dmatrix.shape
957961
atype = atype_ext[:, :nloc]

deepmd/dpmodel/descriptor/repflows.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,11 @@ def call(
472472
nlist = xp.where(exclude_mask, nlist, xp.full_like(nlist, -1))
473473
# nb x nloc x nnei x 4, nb x nloc x nnei x 3, nb x nloc x nnei x 1
474474
dmatrix, diff, sw = self.env_mat_edge.call(
475-
coord_ext, atype_ext, nlist, self.mean, self.stddev
475+
coord_ext,
476+
atype_ext,
477+
nlist,
478+
self.mean[...],
479+
self.stddev[...],
476480
)
477481
# nb x nloc x nnei
478482
nlist_mask = nlist != -1

deepmd/dpmodel/descriptor/repformers.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,11 @@ def call(
441441
nlist = xp.where(exclude_mask, nlist, xp.full_like(nlist, -1))
442442
# nf x nloc x nnei x 4
443443
dmatrix, diff, sw = self.env_mat.call(
444-
coord_ext, atype_ext, nlist, self.mean, self.stddev
444+
coord_ext,
445+
atype_ext,
446+
nlist,
447+
self.mean[...],
448+
self.stddev[...],
445449
)
446450
nf, nloc, nnei, _ = dmatrix.shape
447451
# nf x nloc x nnei

deepmd/dpmodel/descriptor/se_e2_a.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,11 @@ def call(
591591
input_dtype = coord_ext.dtype
592592
# nf x nloc x nnei x 4
593593
rr, diff, ww = self.env_mat.call(
594-
coord_ext, atype_ext, nlist, self.davg, self.dstd
594+
coord_ext,
595+
atype_ext,
596+
nlist,
597+
self.davg[...],
598+
self.dstd[...],
595599
)
596600
nf, nloc, nnei, _ = rr.shape
597601
sec = self.sel_cumsum

deepmd/dpmodel/descriptor/se_r.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,12 @@ def call(
373373
del mapping
374374
# nf x nloc x nnei x 1
375375
rr, diff, ww = self.env_mat.call(
376-
coord_ext, atype_ext, nlist, self.davg, self.dstd, True
376+
coord_ext,
377+
atype_ext,
378+
nlist,
379+
self.davg[...],
380+
self.dstd[...],
381+
True,
377382
)
378383
nf, nloc, nnei, _ = rr.shape
379384
sec = self.sel_cumsum

deepmd/dpmodel/descriptor/se_t.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,11 @@ def call(
349349
xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist)
350350
# nf x nloc x nnei x 4
351351
rr, diff, ww = self.env_mat.call(
352-
coord_ext, atype_ext, nlist, self.davg, self.dstd
352+
coord_ext,
353+
atype_ext,
354+
nlist,
355+
self.davg[...],
356+
self.dstd[...],
353357
)
354358
nf, nloc, nnei, _ = rr.shape
355359
sec = self.sel_cumsum

deepmd/dpmodel/descriptor/se_t_tebd.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -733,7 +733,11 @@ def call(
733733
xp = array_api_compat.array_namespace(nlist, coord_ext, atype_ext)
734734
# nf x nloc x nnei x 4
735735
dmatrix, diff, sw = self.env_mat.call(
736-
coord_ext, atype_ext, nlist, self.mean, self.stddev
736+
coord_ext,
737+
atype_ext,
738+
nlist,
739+
self.mean[...],
740+
self.stddev[...],
737741
)
738742
nf, nloc, nnei, _ = dmatrix.shape
739743
exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext)

deepmd/dpmodel/fitting/general_fitting.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def _call_common(
410410
f"get an input fparam of dim {fparam.shape[-1]}, "
411411
f"which is not consistent with {self.numb_fparam}."
412412
)
413-
fparam = (fparam - self.fparam_avg) * self.fparam_inv_std
413+
fparam = (fparam - self.fparam_avg[...]) * self.fparam_inv_std[...]
414414
fparam = xp.tile(
415415
xp.reshape(fparam, [nf, 1, self.numb_fparam]), (1, nloc, 1)
416416
)
@@ -432,7 +432,7 @@ def _call_common(
432432
f"which is not consistent with {self.numb_aparam}."
433433
)
434434
aparam = xp.reshape(aparam, [nf, nloc, self.numb_aparam])
435-
aparam = (aparam - self.aparam_avg) * self.aparam_inv_std
435+
aparam = (aparam - self.aparam_avg[...]) * self.aparam_inv_std[...]
436436
xx = xp.concat(
437437
[xx, aparam],
438438
axis=-1,
@@ -445,7 +445,9 @@ def _call_common(
445445

446446
if self.dim_case_embd > 0:
447447
assert self.case_embd is not None
448-
case_embd = xp.tile(xp.reshape(self.case_embd, [1, 1, -1]), [nf, nloc, 1])
448+
case_embd = xp.tile(
449+
xp.reshape(self.case_embd[...], [1, 1, -1]), [nf, nloc, 1]
450+
)
449451
xx = xp.concat(
450452
[xx, case_embd],
451453
axis=-1,
@@ -482,7 +484,9 @@ def _call_common(
482484
outs -= self.nets[()](xx_zeros)
483485
outs += xp.reshape(
484486
xp.take(
485-
xp.astype(self.bias_atom_e, outs.dtype), xp.reshape(atype, [-1]), axis=0
487+
xp.astype(self.bias_atom_e[...], outs.dtype),
488+
xp.reshape(atype, [-1]),
489+
axis=0,
486490
),
487491
[nf, nloc, net_dim_out],
488492
)

deepmd/dpmodel/utils/exclude_mask.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ def build_type_exclude_mask(
5353
xp = array_api_compat.array_namespace(atype)
5454
nf, natom = atype.shape
5555
return xp.reshape(
56-
xp.take(self.type_mask, xp.reshape(atype, [-1]), axis=0), (nf, natom)
56+
xp.take(self.type_mask[...], xp.reshape(atype, [-1]), axis=0),
57+
(nf, natom),
5758
)
5859

5960

@@ -131,7 +132,8 @@ def build_type_exclude_mask(
131132
# nf x (nloc x nnei)
132133
type_ij = xp.reshape(type_ij, (nf, nloc * nnei))
133134
mask = xp.reshape(
134-
xp.take(self.type_mask, xp.reshape(type_ij, (-1,))), (nf, nloc, nnei)
135+
xp.take(self.type_mask[...], xp.reshape(type_ij, (-1,))),
136+
(nf, nloc, nnei),
135137
)
136138
return mask
137139

0 commit comments

Comments
 (0)