Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,15 @@ def __init__(
self.fparam_inv_std = np.ones(self.numb_fparam) # pylint: disable=no-explicit-dtype
else:
self.fparam_avg, self.fparam_inv_std = None, None
if self.numb_aparam > 0:
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
self.aparam_avg = np.zeros(self.numb_aparam) # pylint: disable=no-explicit-dtype
self.aparam_inv_std = np.ones(self.numb_aparam) # pylint: disable=no-explicit-dtype
else:
self.aparam_avg, self.aparam_inv_std = None, None
# init networks
in_dim = self.dim_descrpt + self.numb_fparam + self.numb_aparam
in_dim = self.dim_descrpt + self.numb_fparam
if not self.use_aparam_as_mask:
in_dim += self.numb_aparam
self.nets = NetworkCollection(
1 if not self.mixed_types else 0,
self.ntypes,
Expand Down Expand Up @@ -389,7 +391,7 @@ def _call_common(
axis=-1,
)
# check aparam dim, concate to input descriptor
if self.numb_aparam > 0:
if not self.use_aparam_as_mask and self.numb_aparam > 0:
assert aparam is not None, "aparam should not be None"
if aparam.shape[-1] != self.numb_aparam:
raise ValueError(
Expand Down
4 changes: 0 additions & 4 deletions deepmd/dpmodel/fitting/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,6 @@ def __init__(
raise NotImplementedError("tot_ener_zero is not implemented")
if spin is not None:
raise NotImplementedError("spin is not implemented")
if use_aparam_as_mask:
raise NotImplementedError("use_aparam_as_mask is not implemented")
if use_aparam_as_mask:
raise NotImplementedError("use_aparam_as_mask is not implemented")
if layer_name is not None:
raise NotImplementedError("layer_name is not implemented")

Expand Down
15 changes: 10 additions & 5 deletions deepmd/pt/model/task/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ class GeneralFitting(Fitting):
length as `ntypes` signaling if or not removing the vaccum contribution for the atom types in the list.
type_map: list[str], Optional
A list of strings. Give the name to each type of atoms.
use_aparam_as_mask: bool
If True, the aparam will not be used in fitting net for embedding.
"""

def __init__(
Expand All @@ -147,6 +149,7 @@ def __init__(
trainable: Union[bool, list[bool]] = True,
remove_vaccum_contribution: Optional[list[bool]] = None,
type_map: Optional[list[str]] = None,
use_aparam_as_mask: bool = False,
**kwargs,
):
super().__init__()
Expand All @@ -164,6 +167,7 @@ def __init__(
self.rcond = rcond
self.seed = seed
self.type_map = type_map
self.use_aparam_as_mask = use_aparam_as_mask
# order matters, should be place after the assignment of ntypes
self.reinit_exclude(exclude_types)
self.trainable = trainable
Expand Down Expand Up @@ -194,7 +198,7 @@ def __init__(
)
else:
self.fparam_avg, self.fparam_inv_std = None, None
if self.numb_aparam > 0:
if not self.use_aparam_as_mask and self.numb_aparam > 0:
self.register_buffer(
"aparam_avg",
torch.zeros(self.numb_aparam, dtype=self.prec, device=device),
Expand All @@ -206,7 +210,9 @@ def __init__(
else:
self.aparam_avg, self.aparam_inv_std = None, None

in_dim = self.dim_descrpt + self.numb_fparam + self.numb_aparam
in_dim = self.dim_descrpt + self.numb_fparam
if not self.use_aparam_as_mask:
in_dim += self.numb_aparam

self.filter_layers = NetworkCollection(
1 if not self.mixed_types else 0,
Expand Down Expand Up @@ -291,13 +297,12 @@ def serialize(self) -> dict:
# "trainable": self.trainable ,
# "atom_ener": self.atom_ener ,
# "layer_name": self.layer_name ,
# "use_aparam_as_mask": self.use_aparam_as_mask ,
# "spin": self.spin ,
## NOTICE: not supported by far
"tot_ener_zero": False,
"trainable": [self.trainable] * (len(self.neuron) + 1),
"layer_name": None,
"use_aparam_as_mask": False,
"use_aparam_as_mask": self.use_aparam_as_mask,
"spin": None,
}

Expand Down Expand Up @@ -439,7 +444,7 @@ def _forward_common(
dim=-1,
)
# check aparam dim, concate to input descriptor
if self.numb_aparam > 0:
if not self.use_aparam_as_mask and self.numb_aparam > 0:
assert aparam is not None, "aparam should not be None"
assert self.aparam_avg is not None
assert self.aparam_inv_std is not None
Expand Down
5 changes: 4 additions & 1 deletion deepmd/pt/model/task/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ class InvarFitting(GeneralFitting):
The `set_davg_zero` key in the descrptor should be set.
type_map: list[str], Optional
A list of strings. Give the name to each type of atoms.

use_aparam_as_mask: bool
If True, the aparam will not be used in fitting net for embedding.
"""

def __init__(
Expand All @@ -99,6 +100,7 @@ def __init__(
exclude_types: list[int] = [],
atom_ener: Optional[list[Optional[torch.Tensor]]] = None,
type_map: Optional[list[str]] = None,
use_aparam_as_mask: bool = False,
**kwargs,
):
self.dim_out = dim_out
Expand All @@ -122,6 +124,7 @@ def __init__(
if atom_ener is None or len([x for x in atom_ener if x is not None]) == 0
else [x is not None for x in atom_ener],
type_map=type_map,
use_aparam_as_mask=use_aparam_as_mask,
**kwargs,
)

Expand Down
28 changes: 15 additions & 13 deletions deepmd/tf/fit/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@
self.fparam_std[ii] = protection
self.fparam_inv_std = 1.0 / self.fparam_std
# stat aparam
if self.numb_aparam > 0:
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
sys_sumv = []
sys_sumv2 = []
sys_sumn = []
Expand Down Expand Up @@ -384,7 +384,7 @@
ext_fparam = tf.reshape(ext_fparam, [-1, self.numb_fparam])
ext_fparam = tf.cast(ext_fparam, self.fitting_precision)
layer = tf.concat([layer, ext_fparam], axis=1)
if aparam is not None:
if aparam is not None and not self.use_aparam_as_mask:
ext_aparam = tf.slice(
aparam,
[0, start_index * self.numb_aparam],
Expand Down Expand Up @@ -505,7 +505,7 @@
self.fparam_avg = 0.0
if self.fparam_inv_std is None:
self.fparam_inv_std = 1.0
if self.numb_aparam > 0:
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
if self.aparam_avg is None:
self.aparam_avg = 0.0
if self.aparam_inv_std is None:
Expand Down Expand Up @@ -561,7 +561,7 @@
trainable=False,
initializer=tf.constant_initializer(self.fparam_inv_std),
)
if self.numb_aparam > 0:
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
t_aparam_avg = tf.get_variable(
"t_aparam_avg",
self.numb_aparam,
Expand Down Expand Up @@ -602,12 +602,11 @@
fparam = (fparam - t_fparam_avg) * t_fparam_istd

aparam = None
if not self.use_aparam_as_mask:
if self.numb_aparam > 0:
aparam = input_dict["aparam"]
aparam = tf.reshape(aparam, [-1, self.numb_aparam])
aparam = (aparam - t_aparam_avg) * t_aparam_istd
aparam = tf.reshape(aparam, [-1, self.numb_aparam * natoms[0]])
if not self.use_aparam_as_mask and self.numb_aparam > 0:
aparam = input_dict["aparam"]
aparam = tf.reshape(aparam, [-1, self.numb_aparam])
aparam = (aparam - t_aparam_avg) * t_aparam_istd
aparam = tf.reshape(aparam, [-1, self.numb_aparam * natoms[0]])

atype_nall = tf.reshape(atype, [-1, natoms[1]])
self.atype_nloc = tf.slice(
Expand Down Expand Up @@ -783,7 +782,7 @@
self.fparam_inv_std = get_tensor_by_name_from_graph(
graph, f"fitting_attr{suffix}/t_fparam_istd"
)
if self.numb_aparam > 0:
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
self.aparam_avg = get_tensor_by_name_from_graph(
graph, f"fitting_attr{suffix}/t_aparam_avg"
)
Expand Down Expand Up @@ -883,7 +882,7 @@
if fitting.numb_fparam > 0:
fitting.fparam_avg = data["@variables"]["fparam_avg"]
fitting.fparam_inv_std = data["@variables"]["fparam_inv_std"]
if fitting.numb_aparam > 0:
if fitting.numb_aparam > 0 and not fitting.use_aparam_as_mask:
fitting.aparam_avg = data["@variables"]["aparam_avg"]
fitting.aparam_inv_std = data["@variables"]["aparam_inv_std"]
return fitting
Expand All @@ -896,6 +895,9 @@
dict
The serialized data
"""
in_dim = self.dim_descrpt + self.numb_fparam
if not self.use_aparam_as_mask:
in_dim += self.numb_aparam
data = {
"@class": "Fitting",
"type": "ener",
Expand All @@ -922,7 +924,7 @@
"nets": self.serialize_network(
ntypes=self.ntypes,
ndim=0 if self.mixed_types else 1,
in_dim=self.dim_descrpt + self.numb_fparam + self.numb_aparam,
in_dim=in_dim,
neuron=self.n_neuron,
activation_function=self.activation_function_name,
resnet_dt=self.resnet_dt,
Expand Down
8 changes: 7 additions & 1 deletion source/tests/consistent/fitting/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
class FittingTest:
"""Useful utilities for descriptor tests."""

def build_tf_fitting(self, obj, inputs, natoms, atype, fparam, suffix):
def build_tf_fitting(self, obj, inputs, natoms, atype, fparam, aparam, suffix):
t_inputs = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="i_inputs")
t_natoms = tf.placeholder(tf.int32, natoms.shape, name="i_natoms")
t_atype = tf.placeholder(tf.int32, [None], name="i_atype")
Expand All @@ -30,6 +30,12 @@ def build_tf_fitting(self, obj, inputs, natoms, atype, fparam, suffix):
)
extras["fparam"] = t_fparam
feed_dict[t_fparam] = fparam
if aparam is not None:
t_aparam = tf.placeholder(
GLOBAL_TF_FLOAT_PRECISION, [None, None], name="i_aparam"
)
extras["aparam"] = t_aparam
feed_dict[t_aparam] = aparam
t_out = obj.build(
t_inputs,
t_natoms,
Expand Down
22 changes: 22 additions & 0 deletions source/tests/consistent/fitting/test_dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
("float64", "float32"), # precision
(True, False), # mixed_types
(0, 1), # numb_fparam
(0, 1), # numb_aparam
(10, 20), # numb_dos
)
class TestDOS(CommonTest, FittingTest, unittest.TestCase):
Expand All @@ -68,13 +69,15 @@
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
return {
"neuron": [5, 5, 5],
"resnet_dt": resnet_dt,
"precision": precision,
"numb_fparam": numb_fparam,
"numb_aparam": numb_aparam,
"seed": 20240217,
"numb_dos": numb_dos,
}
Expand All @@ -86,6 +89,7 @@
precision,
mixed_types,
numb_fparam,
numb_aparam,

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable numb_aparam is not used.
numb_dos,
) = self.param
return CommonTest.skip_pt
Expand Down Expand Up @@ -115,6 +119,9 @@
# inconsistent if not sorted
self.atype.sort()
self.fparam = -np.ones((1,), dtype=GLOBAL_NP_FLOAT_PRECISION)
self.aparam = np.zeros_like(
self.atype, dtype=GLOBAL_NP_FLOAT_PRECISION
).reshape(-1, 1)

@property
def addtional_data(self) -> dict:
Expand All @@ -123,6 +130,7 @@
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
return {
Expand All @@ -137,6 +145,7 @@
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
return self.build_tf_fitting(
Expand All @@ -145,6 +154,7 @@
self.natoms,
self.atype,
self.fparam if numb_fparam else None,
self.aparam if numb_aparam else None,
suffix,
)

Expand All @@ -154,6 +164,7 @@
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
return (
Expand All @@ -163,6 +174,9 @@
fparam=torch.from_numpy(self.fparam).to(device=PT_DEVICE)
if numb_fparam
else None,
aparam=torch.from_numpy(self.aparam).to(device=PT_DEVICE)
if numb_aparam
else None,
)["dos"]
.detach()
.cpu()
Expand All @@ -175,12 +189,14 @@
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
return dp_obj(
self.inputs,
self.atype.reshape(1, -1),
fparam=self.fparam if numb_fparam else None,
aparam=self.aparam if numb_aparam else None,
)["dos"]

def eval_jax(self, jax_obj: Any) -> Any:
Expand All @@ -189,13 +205,15 @@
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
return np.asarray(
jax_obj(
jnp.asarray(self.inputs),
jnp.asarray(self.atype.reshape(1, -1)),
fparam=jnp.asarray(self.fparam) if numb_fparam else None,
aparam=jnp.asarray(self.aparam) if numb_aparam else None,
)["dos"]
)

Expand All @@ -206,13 +224,15 @@
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
return np.asarray(
array_api_strict_obj(
array_api_strict.asarray(self.inputs),
array_api_strict.asarray(self.atype.reshape(1, -1)),
fparam=array_api_strict.asarray(self.fparam) if numb_fparam else None,
aparam=array_api_strict.asarray(self.aparam) if numb_aparam else None,
)["dos"]
)

Expand All @@ -230,6 +250,7 @@
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
if precision == "float64":
Expand All @@ -247,6 +268,7 @@
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
if precision == "float64":
Expand Down
Loading
Loading