Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions deepmd/dpmodel/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ class RepFlowArgs:
optim_update : bool, optional
Whether to enable the optimized update method.
Uses a more efficient process when enabled. Defaults to True
smooth_edge_update : bool, optional
Whether to make edge update smooth.
If True, the edge update from angle message will not use self as padding.
"""

def __init__(
Expand All @@ -147,6 +150,7 @@ def __init__(
fix_stat_std: float = 0.3,
skip_stat: bool = False,
optim_update: bool = True,
smooth_edge_update: bool = False,
) -> None:
self.n_dim = n_dim
self.e_dim = e_dim
Expand All @@ -172,6 +176,7 @@ def __init__(
self.a_compress_e_rate = a_compress_e_rate
self.a_compress_use_split = a_compress_use_split
self.optim_update = optim_update
self.smooth_edge_update = smooth_edge_update

def __getitem__(self, key):
if hasattr(self, key):
Expand Down Expand Up @@ -297,6 +302,7 @@ def init_subclass_params(sub_data, sub_class):
update_residual_init=self.repflow_args.update_residual_init,
fix_stat_std=self.repflow_args.fix_stat_std,
optim_update=self.repflow_args.optim_update,
smooth_edge_update=self.repflow_args.smooth_edge_update,
exclude_types=exclude_types,
env_protection=env_protection,
precision=precision,
Expand Down
38 changes: 25 additions & 13 deletions deepmd/dpmodel/descriptor/repflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ class DescrptBlockRepflows(NativeOP, DescriptorBlock):
optim_update : bool, optional
Whether to enable the optimized update method.
Uses a more efficient process when enabled. Defaults to True
smooth_edge_update : bool, optional
Whether to make edge update smooth.
If True, the edge update from angle message will not use self as padding.
ntypes : int
Number of element types
activation_function : str, optional
Expand Down Expand Up @@ -161,6 +164,7 @@ def __init__(
precision: str = "float64",
fix_stat_std: float = 0.3,
optim_update: bool = True,
smooth_edge_update: bool = False,
seed: Optional[Union[int, list[int]]] = None,
) -> None:
super().__init__()
Expand Down Expand Up @@ -191,6 +195,7 @@ def __init__(
self.set_stddev_constant = fix_stat_std != 0.0
self.a_compress_use_split = a_compress_use_split
self.optim_update = optim_update
self.smooth_edge_update = smooth_edge_update

self.n_dim = n_dim
self.e_dim = e_dim
Expand Down Expand Up @@ -243,6 +248,7 @@ def __init__(
update_residual_init=self.update_residual_init,
precision=precision,
optim_update=self.optim_update,
smooth_edge_update=self.smooth_edge_update,
seed=child_seed(child_seed(seed, 1), ii),
)
)
Expand Down Expand Up @@ -563,6 +569,7 @@ def __init__(
axis_neuron: int = 4,
update_angle: bool = True,
optim_update: bool = True,
smooth_edge_update: bool = False,
activation_function: str = "silu",
update_style: str = "res_residual",
update_residual: float = 0.1,
Expand Down Expand Up @@ -607,6 +614,7 @@ def __init__(
self.seed = seed
self.prec = PRECISION_DICT[precision]
self.optim_update = optim_update
self.smooth_edge_update = smooth_edge_update

assert update_residual_init in [
"norm",
Expand Down Expand Up @@ -1136,19 +1144,23 @@ def call(
],
axis=2,
)
full_mask = xp.concat(
[
a_nlist_mask,
xp.zeros(
(nb, nloc, self.nnei - self.a_sel),
dtype=a_nlist_mask.dtype,
),
],
axis=-1,
)
padding_edge_angle_update = xp.where(
xp.expand_dims(full_mask, axis=-1), padding_edge_angle_update, edge_ebd
)
if not self.smooth_edge_update:
# will be deprecated in the future
full_mask = xp.concat(
[
a_nlist_mask,
xp.zeros(
(nb, nloc, self.nnei - self.a_sel),
dtype=a_nlist_mask.dtype,
),
],
axis=-1,
)
padding_edge_angle_update = xp.where(
xp.expand_dims(full_mask, axis=-1),
padding_edge_angle_update,
edge_ebd,
)
e_update_list.append(
self.act(self.edge_angle_linear2(padding_edge_angle_update))
)
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def init_subclass_params(sub_data, sub_class):
update_residual_init=self.repflow_args.update_residual_init,
fix_stat_std=self.repflow_args.fix_stat_std,
optim_update=self.repflow_args.optim_update,
smooth_edge_update=self.repflow_args.smooth_edge_update,
exclude_types=exclude_types,
env_protection=env_protection,
precision=precision,
Expand Down
32 changes: 18 additions & 14 deletions deepmd/pt/model/descriptor/repflow_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
axis_neuron: int = 4,
update_angle: bool = True,
optim_update: bool = True,
smooth_edge_update: bool = False,
activation_function: str = "silu",
update_style: str = "res_residual",
update_residual: float = 0.1,
Expand Down Expand Up @@ -96,6 +97,7 @@ def __init__(
self.seed = seed
self.prec = PRECISION_DICT[precision]
self.optim_update = optim_update
self.smooth_edge_update = smooth_edge_update

assert update_residual_init in [
"norm",
Expand Down Expand Up @@ -718,20 +720,22 @@ def forward(
],
dim=2,
)
full_mask = torch.concat(
[
a_nlist_mask,
torch.zeros(
[nb, nloc, self.nnei - self.a_sel],
dtype=a_nlist_mask.dtype,
device=a_nlist_mask.device,
),
],
dim=-1,
)
padding_edge_angle_update = torch.where(
full_mask.unsqueeze(-1), padding_edge_angle_update, edge_ebd
)
if not self.smooth_edge_update:
# will be deprecated in the future
full_mask = torch.concat(
[
a_nlist_mask,
torch.zeros(
[nb, nloc, self.nnei - self.a_sel],
dtype=a_nlist_mask.dtype,
device=a_nlist_mask.device,
),
],
dim=-1,
)
padding_edge_angle_update = torch.where(
full_mask.unsqueeze(-1), padding_edge_angle_update, edge_ebd
)
e_update_list.append(
self.act(self.edge_angle_linear2(padding_edge_angle_update))
)
Expand Down
6 changes: 6 additions & 0 deletions deepmd/pt/model/descriptor/repflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ class DescrptBlockRepflows(DescriptorBlock):
fix_stat_std : float, optional
If non-zero (default is 0.3), use this constant as the normalization standard deviation
instead of computing it from data statistics.
smooth_edge_update : bool, optional
Whether to make edge update smooth.
If True, the edge update from angle message will not use self as padding.
optim_update : bool, optional
Whether to enable the optimized update method.
Uses a more efficient process when enabled. Defaults to True
Expand Down Expand Up @@ -179,6 +182,7 @@ def __init__(
env_protection: float = 0.0,
precision: str = "float64",
fix_stat_std: float = 0.3,
smooth_edge_update: bool = False,
optim_update: bool = True,
seed: Optional[Union[int, list[int]]] = None,
) -> None:
Expand Down Expand Up @@ -210,6 +214,7 @@ def __init__(
self.set_stddev_constant = fix_stat_std != 0.0
self.a_compress_use_split = a_compress_use_split
self.optim_update = optim_update
self.smooth_edge_update = smooth_edge_update

self.n_dim = n_dim
self.e_dim = e_dim
Expand Down Expand Up @@ -262,6 +267,7 @@ def __init__(
update_residual_init=self.update_residual_init,
precision=precision,
optim_update=self.optim_update,
smooth_edge_update=self.smooth_edge_update,
seed=child_seed(child_seed(seed, 1), ii),
)
)
Expand Down
11 changes: 11 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -1493,6 +1493,10 @@ def dpa3_repflow_args():
"Whether to enable the optimized update method. "
"Uses a more efficient process when enabled. Defaults to True"
)
doc_smooth_edge_update = (
"Whether to make edge update smooth."
"If True, the edge update from angle message will not use self as padding."
)

return [
# repflow args
Expand Down Expand Up @@ -1586,6 +1590,13 @@ def dpa3_repflow_args():
default=True,
doc=doc_optim_update,
),
Argument(
"smooth_edge_update",
bool,
optional=True,
default=False, # for compatability, will be True in the future
doc=doc_smooth_edge_update,
),
]


Expand Down
9 changes: 5 additions & 4 deletions source/tests/pt/model/test_permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,10 @@
"a_dim": 8,
"nlayers": 6,
"e_rcut": 6.0,
"e_rcut_smth": 5.0,
"e_rcut_smth": 3.0,
"e_sel": 20,
"a_rcut": 4.0,
"a_rcut_smth": 3.5,
"a_rcut_smth": 2.0,
"a_sel": 10,
"axis_neuron": 4,
"a_compress_rate": 1,
Expand All @@ -190,8 +190,9 @@
"update_style": "res_residual",
"update_residual": 0.1,
"update_residual_init": "const",
"smooth_edge_update": True,
},
"activation_function": "silu",
"activation_function": "silut:10.0",
"use_tebd_bias": False,
"precision": "float32",
"concat_output_tebd": False,
Expand All @@ -200,7 +201,7 @@
"neuron": [24, 24],
"resnet_dt": True,
"precision": "float32",
"activation_function": "silu",
"activation_function": "silut:10.0",
"seed": 1,
},
}
Expand Down
24 changes: 23 additions & 1 deletion source/tests/pt/model/test_smooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
model_dos,
model_dpa1,
model_dpa2,
model_dpa3,
model_hybrid,
model_se_e2_a,
model_spin,
Expand Down Expand Up @@ -59,10 +60,16 @@ def test(
0.0,
4.0 - 0.5 * epsilon,
0.0,
6.0 - 0.5 * epsilon,
0.0,
0.0,
0.0,
6.0 - 0.5 * epsilon,
0.0,
],
dtype=dtype,
device=env.DEVICE,
).view([-1, 3])
).view([-1, 3]) # to test descriptors with two rcuts, e.g. DPA2/3
coord1 = torch.rand(
[natoms - coord0.shape[0], 3],
dtype=dtype,
Expand All @@ -77,11 +84,15 @@ def test(
coord0 = torch.clone(coord)
coord1 = torch.clone(coord)
coord1[1][0] += epsilon
coord1[3][0] += epsilon
coord2 = torch.clone(coord)
coord2[2][1] += epsilon
coord2[4][1] += epsilon
coord3 = torch.clone(coord)
coord3[1][0] += epsilon
coord1[3][0] += epsilon
coord3[2][1] += epsilon
coord2[4][1] += epsilon
test_spin = getattr(self, "test_spin", False)
if not test_spin:
test_keys = ["energy", "force", "virial"]
Expand Down Expand Up @@ -226,6 +237,17 @@ def setUp(self) -> None:
self.epsilon, self.aprec = None, None


class TestEnergyModelDPA3(unittest.TestCase, SmoothTest):
def setUp(self) -> None:
model_params = copy.deepcopy(model_dpa3)
self.type_split = True
self.model = get_model(model_params).to(env.DEVICE)
# less degree of smoothness,
# error can be systematically removed by reducing epsilon
self.epsilon = 1e-5
self.aprec = 1e-5


class TestEnergyModelHybrid(unittest.TestCase, SmoothTest):
def setUp(self) -> None:
model_params = copy.deepcopy(model_hybrid)
Expand Down
12 changes: 11 additions & 1 deletion source/tests/universal/common/cases/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,20 +723,30 @@ def test_smooth(self) -> None:
0.0,
self.expected_rcut - 0.5 * epsilon,
0.0,
self.expected_rcut / 2 - 0.5 * epsilon,
0.0,
0.0,
0.0,
self.expected_rcut / 2 - 0.5 * epsilon,
0.0,
]
).reshape(-1, 3)
).reshape(-1, 3) # to test descriptors with two rcuts, e.g. DPA2/3
coord1 = rng.random([natoms - coord0.shape[0], 3])
coord1 = np.matmul(coord1, cell)
coord = np.concatenate([coord0, coord1], axis=0)

coord0 = deepcopy(coord)
coord1 = deepcopy(coord)
coord1[1][0] += epsilon
coord1[3][0] += epsilon
coord2 = deepcopy(coord)
coord2[2][1] += epsilon
coord2[4][1] += epsilon
coord3 = deepcopy(coord)
coord3[1][0] += epsilon
coord1[3][0] += epsilon
coord3[2][1] += epsilon
coord2[4][1] += epsilon

# reshape for input
coord0 = coord0.reshape([nf, -1])
Expand Down
3 changes: 3 additions & 0 deletions source/tests/universal/dpmodel/descriptor/test_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@ def DescriptorParamDPA3(
a_compress_e_rate=1,
a_compress_use_split=False,
optim_update=True,
smooth_edge_update=False,
fix_stat_std=0.3,
precision="float64",
):
Expand All @@ -502,6 +503,7 @@ def DescriptorParamDPA3(
"a_compress_e_rate": a_compress_e_rate,
"a_compress_use_split": a_compress_use_split,
"optim_update": optim_update,
"smooth_edge_update": smooth_edge_update,
"fix_stat_std": fix_stat_std,
"n_multi_edge_message": n_multi_edge_message,
"axis_neuron": 2,
Expand Down Expand Up @@ -537,6 +539,7 @@ def DescriptorParamDPA3(
"a_compress_e_rate": (2,),
"a_compress_use_split": (True, False),
"optim_update": (True, False),
"smooth_edge_update": (True,),
"fix_stat_std": (0.3,),
"n_multi_edge_message": (1, 2),
"env_protection": (0.0, 1e-8),
Expand Down
Loading