Skip to content

Commit 5cdef87

Browse files
committed
resolve comments
1 parent 044023e commit 5cdef87

File tree

3 files changed

+23
-18
lines changed

3 files changed

+23
-18
lines changed

deepmd/dpmodel/fitting/general_fitting.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -155,15 +155,17 @@ def __init__(
155155
self.fparam_inv_std = np.ones(self.numb_fparam) # pylint: disable=no-explicit-dtype
156156
else:
157157
self.fparam_avg, self.fparam_inv_std = None, None
158-
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
158+
if self.numb_aparam > 0:
159159
self.aparam_avg = np.zeros(self.numb_aparam) # pylint: disable=no-explicit-dtype
160160
self.aparam_inv_std = np.ones(self.numb_aparam) # pylint: disable=no-explicit-dtype
161161
else:
162162
self.aparam_avg, self.aparam_inv_std = None, None
163163
# init networks
164-
in_dim = self.dim_descrpt + self.numb_fparam
165-
if not self.use_aparam_as_mask:
166-
in_dim += self.numb_aparam
164+
in_dim = (
165+
self.dim_descrpt
166+
+ self.numb_fparam
167+
+ (0 if self.use_aparam_as_mask else self.numb_aparam)
168+
)
167169
self.nets = NetworkCollection(
168170
1 if not self.mixed_types else 0,
169171
self.ntypes,
@@ -391,7 +393,7 @@ def _call_common(
391393
axis=-1,
392394
)
393395
# check aparam dim, concate to input descriptor
394-
if not self.use_aparam_as_mask and self.numb_aparam > 0:
396+
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
395397
assert aparam is not None, "aparam should not be None"
396398
if aparam.shape[-1] != self.numb_aparam:
397399
raise ValueError(

deepmd/pt/model/task/fitting.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def __init__(
198198
)
199199
else:
200200
self.fparam_avg, self.fparam_inv_std = None, None
201-
if not self.use_aparam_as_mask and self.numb_aparam > 0:
201+
if self.numb_aparam > 0:
202202
self.register_buffer(
203203
"aparam_avg",
204204
torch.zeros(self.numb_aparam, dtype=self.prec, device=device),
@@ -210,9 +210,11 @@ def __init__(
210210
else:
211211
self.aparam_avg, self.aparam_inv_std = None, None
212212

213-
in_dim = self.dim_descrpt + self.numb_fparam
214-
if not self.use_aparam_as_mask:
215-
in_dim += self.numb_aparam
213+
in_dim = (
214+
self.dim_descrpt
215+
+ self.numb_fparam
216+
+ (0 if self.use_aparam_as_mask else self.numb_aparam)
217+
)
216218

217219
self.filter_layers = NetworkCollection(
218220
1 if not self.mixed_types else 0,
@@ -444,7 +446,7 @@ def _forward_common(
444446
dim=-1,
445447
)
446448
# check aparam dim, concate to input descriptor
447-
if not self.use_aparam_as_mask and self.numb_aparam > 0:
449+
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
448450
assert aparam is not None, "aparam should not be None"
449451
assert self.aparam_avg is not None
450452
assert self.aparam_inv_std is not None

deepmd/tf/fit/ener.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def compute_input_stats(self, all_stat: dict, protection: float = 1e-2) -> None:
340340
self.fparam_std[ii] = protection
341341
self.fparam_inv_std = 1.0 / self.fparam_std
342342
# stat aparam
343-
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
343+
if self.numb_aparam > 0:
344344
sys_sumv = []
345345
sys_sumv2 = []
346346
sys_sumn = []
@@ -505,7 +505,7 @@ def build(
505505
self.fparam_avg = 0.0
506506
if self.fparam_inv_std is None:
507507
self.fparam_inv_std = 1.0
508-
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
508+
if self.numb_aparam > 0:
509509
if self.aparam_avg is None:
510510
self.aparam_avg = 0.0
511511
if self.aparam_inv_std is None:
@@ -561,7 +561,7 @@ def build(
561561
trainable=False,
562562
initializer=tf.constant_initializer(self.fparam_inv_std),
563563
)
564-
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
564+
if self.numb_aparam > 0:
565565
t_aparam_avg = tf.get_variable(
566566
"t_aparam_avg",
567567
self.numb_aparam,
@@ -602,7 +602,7 @@ def build(
602602
fparam = (fparam - t_fparam_avg) * t_fparam_istd
603603

604604
aparam = None
605-
if not self.use_aparam_as_mask and self.numb_aparam > 0:
605+
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
606606
aparam = input_dict["aparam"]
607607
aparam = tf.reshape(aparam, [-1, self.numb_aparam])
608608
aparam = (aparam - t_aparam_avg) * t_aparam_istd
@@ -895,9 +895,6 @@ def serialize(self, suffix: str = "") -> dict:
895895
dict
896896
The serialized data
897897
"""
898-
in_dim = self.dim_descrpt + self.numb_fparam
899-
if not self.use_aparam_as_mask:
900-
in_dim += self.numb_aparam
901898
data = {
902899
"@class": "Fitting",
903900
"type": "ener",
@@ -924,7 +921,11 @@ def serialize(self, suffix: str = "") -> dict:
924921
"nets": self.serialize_network(
925922
ntypes=self.ntypes,
926923
ndim=0 if self.mixed_types else 1,
927-
in_dim=in_dim,
924+
in_dim=(
925+
self.dim_descrpt
926+
+ self.numb_fparam
927+
+ (0 if self.use_aparam_as_mask else self.numb_aparam)
928+
),
928929
neuron=self.n_neuron,
929930
activation_function=self.activation_function_name,
930931
resnet_dt=self.resnet_dt,

0 commit comments

Comments
 (0)