Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions deepmd/dpmodel/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ class RepFlowArgs:
optim_update : bool, optional
Whether to enable the optimized update method.
Uses a more efficient process when enabled. Defaults to True
smooth_edge_update : bool, optional
Whether to make edge update smooth.
If True, the edge update from angle message will not use self as padding.
"""

def __init__(
Expand All @@ -147,6 +150,7 @@ def __init__(
fix_stat_std: float = 0.3,
skip_stat: bool = False,
optim_update: bool = True,
smooth_edge_update: bool = False,
) -> None:
self.n_dim = n_dim
self.e_dim = e_dim
Expand All @@ -172,6 +176,7 @@ def __init__(
self.a_compress_e_rate = a_compress_e_rate
self.a_compress_use_split = a_compress_use_split
self.optim_update = optim_update
self.smooth_edge_update = smooth_edge_update

def __getitem__(self, key):
if hasattr(self, key):
Expand Down Expand Up @@ -202,6 +207,7 @@ def serialize(self) -> dict:
"update_residual_init": self.update_residual_init,
"fix_stat_std": self.fix_stat_std,
"optim_update": self.optim_update,
"smooth_edge_update": self.smooth_edge_update,
}

@classmethod
Expand Down Expand Up @@ -297,6 +303,7 @@ def init_subclass_params(sub_data, sub_class):
update_residual_init=self.repflow_args.update_residual_init,
fix_stat_std=self.repflow_args.fix_stat_std,
optim_update=self.repflow_args.optim_update,
smooth_edge_update=self.repflow_args.smooth_edge_update,
exclude_types=exclude_types,
env_protection=env_protection,
precision=precision,
Expand Down
41 changes: 27 additions & 14 deletions deepmd/dpmodel/descriptor/repflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ class DescrptBlockRepflows(NativeOP, DescriptorBlock):
optim_update : bool, optional
Whether to enable the optimized update method.
Uses a more efficient process when enabled. Defaults to True
smooth_edge_update : bool, optional
Whether to make edge update smooth.
If True, the edge update from angle message will not use self as padding.
ntypes : int
Number of element types
activation_function : str, optional
Expand Down Expand Up @@ -161,6 +164,7 @@ def __init__(
precision: str = "float64",
fix_stat_std: float = 0.3,
optim_update: bool = True,
smooth_edge_update: bool = False,
seed: Optional[Union[int, list[int]]] = None,
) -> None:
super().__init__()
Expand Down Expand Up @@ -191,6 +195,7 @@ def __init__(
self.set_stddev_constant = fix_stat_std != 0.0
self.a_compress_use_split = a_compress_use_split
self.optim_update = optim_update
self.smooth_edge_update = smooth_edge_update

self.n_dim = n_dim
self.e_dim = e_dim
Expand Down Expand Up @@ -243,6 +248,7 @@ def __init__(
update_residual_init=self.update_residual_init,
precision=precision,
optim_update=self.optim_update,
smooth_edge_update=self.smooth_edge_update,
seed=child_seed(child_seed(seed, 1), ii),
)
)
Expand Down Expand Up @@ -563,6 +569,7 @@ def __init__(
axis_neuron: int = 4,
update_angle: bool = True,
optim_update: bool = True,
smooth_edge_update: bool = False,
activation_function: str = "silu",
update_style: str = "res_residual",
update_residual: float = 0.1,
Expand Down Expand Up @@ -607,6 +614,7 @@ def __init__(
self.seed = seed
self.prec = PRECISION_DICT[precision]
self.optim_update = optim_update
self.smooth_edge_update = smooth_edge_update

assert update_residual_init in [
"norm",
Expand Down Expand Up @@ -1136,19 +1144,23 @@ def call(
],
axis=2,
)
full_mask = xp.concat(
[
a_nlist_mask,
xp.zeros(
(nb, nloc, self.nnei - self.a_sel),
dtype=a_nlist_mask.dtype,
),
],
axis=-1,
)
padding_edge_angle_update = xp.where(
xp.expand_dims(full_mask, axis=-1), padding_edge_angle_update, edge_ebd
)
if not self.smooth_edge_update:
# will be deprecated in the future
full_mask = xp.concat(
[
a_nlist_mask,
xp.zeros(
(nb, nloc, self.nnei - self.a_sel),
dtype=a_nlist_mask.dtype,
),
],
axis=-1,
)
padding_edge_angle_update = xp.where(
xp.expand_dims(full_mask, axis=-1),
padding_edge_angle_update,
edge_ebd,
)
e_update_list.append(
self.act(self.edge_angle_linear2(padding_edge_angle_update))
)
Expand Down Expand Up @@ -1235,7 +1247,7 @@ def serialize(self) -> dict:
The serialized networks.
"""
data = {
"@class": "RepformerLayer",
"@class": "RepFlowLayer",
"@version": 1,
"e_rcut": self.e_rcut,
"e_rcut_smth": self.e_rcut_smth,
Expand All @@ -1259,6 +1271,7 @@ def serialize(self) -> dict:
"update_residual_init": self.update_residual_init,
"precision": self.precision,
"optim_update": self.optim_update,
"smooth_edge_update": self.smooth_edge_update,
"node_self_mlp": self.node_self_mlp.serialize(),
"node_sym_linear": self.node_sym_linear.serialize(),
"node_edge_linear": self.node_edge_linear.serialize(),
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def init_subclass_params(sub_data, sub_class):
update_residual_init=self.repflow_args.update_residual_init,
fix_stat_std=self.repflow_args.fix_stat_std,
optim_update=self.repflow_args.optim_update,
smooth_edge_update=self.repflow_args.smooth_edge_update,
exclude_types=exclude_types,
env_protection=env_protection,
precision=precision,
Expand Down
35 changes: 20 additions & 15 deletions deepmd/pt/model/descriptor/repflow_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
axis_neuron: int = 4,
update_angle: bool = True,
optim_update: bool = True,
smooth_edge_update: bool = False,
activation_function: str = "silu",
update_style: str = "res_residual",
update_residual: float = 0.1,
Expand Down Expand Up @@ -96,6 +97,7 @@ def __init__(
self.seed = seed
self.prec = PRECISION_DICT[precision]
self.optim_update = optim_update
self.smooth_edge_update = smooth_edge_update

assert update_residual_init in [
"norm",
Expand Down Expand Up @@ -718,20 +720,22 @@ def forward(
],
dim=2,
)
full_mask = torch.concat(
[
a_nlist_mask,
torch.zeros(
[nb, nloc, self.nnei - self.a_sel],
dtype=a_nlist_mask.dtype,
device=a_nlist_mask.device,
),
],
dim=-1,
)
padding_edge_angle_update = torch.where(
full_mask.unsqueeze(-1), padding_edge_angle_update, edge_ebd
)
if not self.smooth_edge_update:
# will be deprecated in the future
full_mask = torch.concat(
[
a_nlist_mask,
torch.zeros(
[nb, nloc, self.nnei - self.a_sel],
dtype=a_nlist_mask.dtype,
device=a_nlist_mask.device,
),
],
dim=-1,
)
padding_edge_angle_update = torch.where(
full_mask.unsqueeze(-1), padding_edge_angle_update, edge_ebd
)
e_update_list.append(
self.act(self.edge_angle_linear2(padding_edge_angle_update))
)
Expand Down Expand Up @@ -823,7 +827,7 @@ def serialize(self) -> dict:
The serialized networks.
"""
data = {
"@class": "RepformerLayer",
"@class": "RepFlowLayer",
"@version": 1,
"e_rcut": self.e_rcut,
"e_rcut_smth": self.e_rcut_smth,
Expand All @@ -847,6 +851,7 @@ def serialize(self) -> dict:
"update_residual_init": self.update_residual_init,
"precision": self.precision,
"optim_update": self.optim_update,
"smooth_edge_update": self.smooth_edge_update,
"node_self_mlp": self.node_self_mlp.serialize(),
"node_sym_linear": self.node_sym_linear.serialize(),
"node_edge_linear": self.node_edge_linear.serialize(),
Expand Down
6 changes: 6 additions & 0 deletions deepmd/pt/model/descriptor/repflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ class DescrptBlockRepflows(DescriptorBlock):
fix_stat_std : float, optional
If non-zero (default is 0.3), use this constant as the normalization standard deviation
instead of computing it from data statistics.
smooth_edge_update : bool, optional
Whether to make edge update smooth.
If True, the edge update from angle message will not use self as padding.
optim_update : bool, optional
Whether to enable the optimized update method.
Uses a more efficient process when enabled. Defaults to True
Expand Down Expand Up @@ -179,6 +182,7 @@ def __init__(
env_protection: float = 0.0,
precision: str = "float64",
fix_stat_std: float = 0.3,
smooth_edge_update: bool = False,
optim_update: bool = True,
seed: Optional[Union[int, list[int]]] = None,
) -> None:
Expand Down Expand Up @@ -210,6 +214,7 @@ def __init__(
self.set_stddev_constant = fix_stat_std != 0.0
self.a_compress_use_split = a_compress_use_split
self.optim_update = optim_update
self.smooth_edge_update = smooth_edge_update

self.n_dim = n_dim
self.e_dim = e_dim
Expand Down Expand Up @@ -262,6 +267,7 @@ def __init__(
update_residual_init=self.update_residual_init,
precision=precision,
optim_update=self.optim_update,
smooth_edge_update=self.smooth_edge_update,
seed=child_seed(child_seed(seed, 1), ii),
)
)
Expand Down
11 changes: 11 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -1493,6 +1493,10 @@ def dpa3_repflow_args():
"Whether to enable the optimized update method. "
"Uses a more efficient process when enabled. Defaults to True"
)
doc_smooth_edge_update = (
"Whether to make edge update smooth."
"If True, the edge update from angle message will not use self as padding."
)

return [
# repflow args
Expand Down Expand Up @@ -1586,6 +1590,13 @@ def dpa3_repflow_args():
default=True,
doc=doc_optim_update,
),
Argument(
"smooth_edge_update",
bool,
optional=True,
default=False, # For compatability. This will be True in the future
doc=doc_smooth_edge_update,
),
]


Expand Down
2 changes: 2 additions & 0 deletions source/tests/pt/model/test_dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def test_consistency(
update_angle=ua,
update_style=rus,
update_residual_init=ruri,
smooth_edge_update=True,
)

# dpa3 new impl
Expand Down Expand Up @@ -190,6 +191,7 @@ def test_jit(
update_angle=ua,
update_style=rus,
update_residual_init=ruri,
smooth_edge_update=True,
)

# dpa3 new impl
Expand Down
9 changes: 5 additions & 4 deletions source/tests/pt/model/test_permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,10 @@
"a_dim": 8,
"nlayers": 6,
"e_rcut": 6.0,
"e_rcut_smth": 5.0,
"e_rcut_smth": 3.0,
"e_sel": 20,
"a_rcut": 4.0,
"a_rcut_smth": 3.5,
"a_rcut_smth": 2.0,
"a_sel": 10,
"axis_neuron": 4,
"a_compress_rate": 1,
Expand All @@ -190,8 +190,9 @@
"update_style": "res_residual",
"update_residual": 0.1,
"update_residual_init": "const",
"smooth_edge_update": True,
},
"activation_function": "silu",
"activation_function": "silut:10.0",
"use_tebd_bias": False,
"precision": "float32",
"concat_output_tebd": False,
Expand All @@ -200,7 +201,7 @@
"neuron": [24, 24],
"resnet_dt": True,
"precision": "float32",
"activation_function": "silu",
"activation_function": "silut:10.0",
"seed": 1,
},
}
Expand Down
24 changes: 23 additions & 1 deletion source/tests/pt/model/test_smooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
model_dos,
model_dpa1,
model_dpa2,
model_dpa3,
model_hybrid,
model_se_e2_a,
model_spin,
Expand Down Expand Up @@ -59,10 +60,16 @@ def test(
0.0,
4.0 - 0.5 * epsilon,
0.0,
6.0 - 0.5 * epsilon,
0.0,
0.0,
0.0,
6.0 - 0.5 * epsilon,
0.0,
],
dtype=dtype,
device=env.DEVICE,
).view([-1, 3])
).view([-1, 3]) # to test descriptors with two rcuts, e.g. DPA2/3
coord1 = torch.rand(
[natoms - coord0.shape[0], 3],
dtype=dtype,
Expand All @@ -77,11 +84,15 @@ def test(
coord0 = torch.clone(coord)
coord1 = torch.clone(coord)
coord1[1][0] += epsilon
coord1[3][0] += epsilon
coord2 = torch.clone(coord)
coord2[2][1] += epsilon
coord2[4][1] += epsilon
coord3 = torch.clone(coord)
coord3[1][0] += epsilon
coord1[3][0] += epsilon
coord3[2][1] += epsilon
coord2[4][1] += epsilon
test_spin = getattr(self, "test_spin", False)
if not test_spin:
test_keys = ["energy", "force", "virial"]
Expand Down Expand Up @@ -226,6 +237,17 @@ def setUp(self) -> None:
self.epsilon, self.aprec = None, None


class TestEnergyModelDPA3(unittest.TestCase, SmoothTest):
def setUp(self) -> None:
model_params = copy.deepcopy(model_dpa3)
self.type_split = True
self.model = get_model(model_params).to(env.DEVICE)
# less degree of smoothness,
# error can be systematically removed by reducing epsilon
self.epsilon = 1e-5
self.aprec = 1e-5


class TestEnergyModelHybrid(unittest.TestCase, SmoothTest):
def setUp(self) -> None:
model_params = copy.deepcopy(model_hybrid)
Expand Down
Loading