Skip to content

Commit 57acd99

Browse files
committed
add compress
1 parent 389287a commit 57acd99

File tree

7 files changed

+98
-9
lines changed

7 files changed

+98
-9
lines changed

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def __init__(
1414
a_rcut: float = 4.0,
1515
a_rcut_smth: float = 3.5,
1616
a_sel: int = 20,
17+
a_compress_rate: int = 0,
1718
axis_neuron: int = 4,
1819
update_angle: bool = True,
1920
update_style: str = "res_residual",
@@ -45,6 +46,10 @@ def __init__(
4546
Where to start smoothing for angle. For example the 1/r term is smoothed from rcut to rcut_smth.
4647
a_sel : int, optional
4748
Maximally possible number of selected angle neighbors.
49+
a_compress_rate : int, optional
50+
The compression rate for angular messages. The default value is 0, indicating no compression.
51+
If a non-zero integer c is provided, the node and edge dimensions will be compressed
52+
to n_dim/c and e_dim/2c, respectively, within the angular message.
4853
axis_neuron : int, optional
4954
The number of dimension of submatrix in the symmetrization ops.
5055
update_angle : bool, optional
@@ -72,6 +77,7 @@ def __init__(
7277
self.a_rcut = a_rcut
7378
self.a_rcut_smth = a_rcut_smth
7479
self.a_sel = a_sel
80+
self.a_compress_rate = a_compress_rate
7581
self.axis_neuron = axis_neuron
7682
self.update_angle = update_angle
7783
self.update_style = update_style
@@ -97,6 +103,7 @@ def serialize(self) -> dict:
97103
"a_rcut": self.a_rcut,
98104
"a_rcut_smth": self.a_rcut_smth,
99105
"a_sel": self.a_sel,
106+
"a_compress_rate": self.a_compress_rate,
100107
"axis_neuron": self.axis_neuron,
101108
"update_angle": self.update_angle,
102109
"update_style": self.update_style,

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def init_subclass_params(sub_data, sub_class):
151151
n_dim=self.repflow_args.n_dim,
152152
e_dim=self.repflow_args.e_dim,
153153
a_dim=self.repflow_args.a_dim,
154+
a_compress_rate=self.repflow_args.a_compress_rate,
154155
axis_neuron=self.repflow_args.axis_neuron,
155156
update_angle=self.repflow_args.update_angle,
156157
activation_function=self.activation_function,

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 66 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def __init__(
4545
n_dim: int = 128,
4646
e_dim: int = 16,
4747
a_dim: int = 64,
48+
a_compress_rate: int = 0,
4849
axis_neuron: int = 4,
4950
update_angle: bool = True, # angle
5051
activation_function: str = "silu",
@@ -70,6 +71,12 @@ def __init__(
7071
self.n_dim = n_dim
7172
self.e_dim = e_dim
7273
self.a_dim = a_dim
74+
self.a_compress_rate = a_compress_rate
75+
if a_compress_rate != 0:
76+
assert a_dim % (2 * a_compress_rate) == 0, (
77+
f"For a_compress_rate of {a_compress_rate}, a_dim must be divisible by {2 * a_compress_rate}. "
78+
f"Currently, a_dim={a_dim} is not valid."
79+
)
7380
self.axis_neuron = axis_neuron
7481
self.update_angle = update_angle
7582
self.activation_function = activation_function
@@ -167,20 +174,42 @@ def __init__(
167174
)
168175

169176
if self.update_angle:
170-
self.angle_dim = self.a_dim + self.n_dim + 2 * self.e_dim
177+
self.angle_dim = self.a_dim
178+
if self.a_compress_rate == 0:
179+
# angle + node + edge * 2
180+
self.angle_dim += self.n_dim + 2 * self.e_dim
181+
self.a_compress_n_linear = None
182+
self.a_compress_e_linear = None
183+
else:
184+
# angle + node/c + edge/2c * 2
185+
self.angle_dim += 2 * (self.a_dim // self.a_compress_rate)
186+
self.a_compress_n_linear = MLPLayer(
187+
self.n_dim,
188+
self.a_dim // self.a_compress_rate,
189+
precision=precision,
190+
bias=False,
191+
seed=child_seed(seed, 8),
192+
)
193+
self.a_compress_e_linear = MLPLayer(
194+
self.e_dim,
195+
self.a_dim // (2 * self.a_compress_rate),
196+
precision=precision,
197+
bias=False,
198+
seed=child_seed(seed, 9),
199+
)
171200

172201
# edge angle message
173202
self.edge_angle_linear1 = MLPLayer(
174203
self.angle_dim,
175204
self.e_dim,
176205
precision=precision,
177-
seed=child_seed(seed, 8),
206+
seed=child_seed(seed, 10),
178207
)
179208
self.edge_angle_linear2 = MLPLayer(
180209
self.e_dim,
181210
self.e_dim,
182211
precision=precision,
183-
seed=child_seed(seed, 9),
212+
seed=child_seed(seed, 11),
184213
)
185214
if self.update_style == "res_residual":
186215
self.e_residual.append(
@@ -189,7 +218,7 @@ def __init__(
189218
self.update_residual,
190219
self.update_residual_init,
191220
precision=precision,
192-
seed=child_seed(seed, 10),
221+
seed=child_seed(seed, 12),
193222
)
194223
)
195224

@@ -198,7 +227,7 @@ def __init__(
198227
self.angle_dim,
199228
self.a_dim,
200229
precision=precision,
201-
seed=child_seed(seed, 11),
230+
seed=child_seed(seed, 13),
202231
)
203232
if self.update_style == "res_residual":
204233
self.a_residual.append(
@@ -207,13 +236,15 @@ def __init__(
207236
self.update_residual,
208237
self.update_residual_init,
209238
precision=precision,
210-
seed=child_seed(seed, 12),
239+
seed=child_seed(seed, 14),
211240
)
212241
)
213242
else:
214243
self.angle_self_linear = None
215244
self.edge_angle_linear1 = None
216245
self.edge_angle_linear2 = None
246+
self.a_compress_n_linear = None
247+
self.a_compress_e_linear = None
217248
self.angle_dim = 0
218249

219250
self.n_residual = nn.ParameterList(self.n_residual)
@@ -448,12 +479,22 @@ def forward(
448479
assert self.edge_angle_linear1 is not None
449480
assert self.edge_angle_linear2 is not None
450481
# get angle info
482+
if self.a_compress_rate != 0:
483+
assert self.a_compress_n_linear is not None
484+
assert self.a_compress_e_linear is not None
485+
node_ebd_for_angle = self.a_compress_n_linear(node_ebd)
486+
edge_ebd_for_angle = self.a_compress_e_linear(edge_ebd)
487+
else:
488+
node_ebd_for_angle = node_ebd
489+
edge_ebd_for_angle = edge_ebd
490+
451491
# nb x nloc x a_nnei x a_nnei x n_dim
452492
node_for_angle_info = torch.tile(
453-
node_ebd.unsqueeze(2).unsqueeze(2), (1, 1, self.a_sel, self.a_sel, 1)
493+
node_ebd_for_angle.unsqueeze(2).unsqueeze(2),
494+
(1, 1, self.a_sel, self.a_sel, 1),
454495
)
455496
# nb x nloc x a_nnei x e_dim
456-
edge_for_angle = edge_ebd[:, :, : self.a_sel, :]
497+
edge_for_angle = edge_ebd_for_angle[:, :, : self.a_sel, :]
457498
# nb x nloc x a_nnei x e_dim
458499
edge_for_angle = torch.where(
459500
a_nlist_mask.unsqueeze(-1), edge_for_angle, 0.0
@@ -471,7 +512,7 @@ def forward(
471512
[edge_for_angle_i, edge_for_angle_j], dim=-1
472513
)
473514
angle_info_list = [angle_ebd, node_for_angle_info, edge_for_angle_info]
474-
# nb x nloc x a_nnei x a_nnei x (a + n_dim + e_dim*2)
515+
# nb x nloc x a_nnei x a_nnei x (a + n_dim + e_dim*2) or (a + a/c + a/c)
475516
angle_info = torch.cat(angle_info_list, dim=-1)
476517

477518
# edge angle message
@@ -605,6 +646,7 @@ def serialize(self) -> dict:
605646
"n_dim": self.n_dim,
606647
"e_dim": self.e_dim,
607648
"a_dim": self.a_dim,
649+
"a_compress_rate": self.a_compress_rate,
608650
"axis_neuron": self.axis_neuron,
609651
"activation_function": self.activation_function,
610652
"update_angle": self.update_angle,
@@ -625,6 +667,13 @@ def serialize(self) -> dict:
625667
"angle_self_linear": self.angle_self_linear.serialize(),
626668
}
627669
)
670+
if self.a_compress_rate != 0:
671+
data.update(
672+
{
673+
"a_compress_n_linear": self.a_compress_n_linear.serialize(),
674+
"a_compress_e_linear": self.a_compress_e_linear.serialize(),
675+
}
676+
)
628677
if self.update_style == "res_residual":
629678
data.update(
630679
{
@@ -650,13 +699,16 @@ def deserialize(cls, data: dict) -> "RepFlowLayer":
650699
check_version_compatibility(data.pop("@version"), 1, 1)
651700
data.pop("@class")
652701
update_angle = data["update_angle"]
702+
a_compress_rate = data["a_compress_rate"]
653703
node_self_mlp = data.pop("node_self_mlp")
654704
node_sym_linear = data.pop("node_sym_linear")
655705
node_edge_linear = data.pop("node_edge_linear")
656706
edge_self_linear = data.pop("edge_self_linear")
657707
edge_angle_linear1 = data.pop("edge_angle_linear1", None)
658708
edge_angle_linear2 = data.pop("edge_angle_linear2", None)
659709
angle_self_linear = data.pop("angle_self_linear", None)
710+
a_compress_n_linear = data.pop("a_compress_n_linear", None)
711+
a_compress_e_linear = data.pop("a_compress_e_linear", None)
660712
update_style = data["update_style"]
661713
variables = data.pop("@variables", {})
662714
n_residual = variables.get("n_residual", data.pop("n_residual", []))
@@ -676,6 +728,11 @@ def deserialize(cls, data: dict) -> "RepFlowLayer":
676728
obj.edge_angle_linear1 = MLPLayer.deserialize(edge_angle_linear1)
677729
obj.edge_angle_linear2 = MLPLayer.deserialize(edge_angle_linear2)
678730
obj.angle_self_linear = MLPLayer.deserialize(angle_self_linear)
731+
if a_compress_rate != 0:
732+
assert isinstance(a_compress_n_linear, dict)
733+
assert isinstance(a_compress_e_linear, dict)
734+
obj.a_compress_n_linear = MLPLayer.deserialize(a_compress_n_linear)
735+
obj.a_compress_e_linear = MLPLayer.deserialize(a_compress_e_linear)
679736

680737
if update_style == "res_residual":
681738
for ii, t in enumerate(obj.n_residual):

deepmd/pt/model/descriptor/repflows.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def __init__(
8585
n_dim: int = 128,
8686
e_dim: int = 64,
8787
a_dim: int = 64,
88+
a_compress_rate: int = 0,
8889
axis_neuron: int = 4,
8990
update_angle: bool = True,
9091
activation_function: str = "silu",
@@ -123,6 +124,10 @@ def __init__(
123124
Where to start smoothing for angle. For example the 1/r term is smoothed from rcut to rcut_smth.
124125
a_sel : int, optional
125126
Maximally possible number of selected angle neighbors.
127+
a_compress_rate : int, optional
128+
The compression rate for angular messages. The default value is 0, indicating no compression.
129+
If a non-zero integer c is provided, the node and edge dimensions will be compressed
130+
to n_dim/c and e_dim/2c, respectively, within the angular message.
126131
axis_neuron : int, optional
127132
The number of dimension of submatrix in the symmetrization ops.
128133
update_angle : bool, optional
@@ -175,6 +180,7 @@ def __init__(
175180
self.rcut_smth = e_rcut_smth
176181
self.sec = self.sel
177182
self.split_sel = self.sel
183+
self.a_compress_rate = a_compress_rate
178184
self.axis_neuron = axis_neuron
179185
self.set_davg_zero = set_davg_zero
180186
self.skip_stat = skip_stat
@@ -218,6 +224,7 @@ def __init__(
218224
n_dim=self.n_dim,
219225
e_dim=self.e_dim,
220226
a_dim=self.a_dim,
227+
a_compress_rate=self.a_compress_rate,
221228
axis_neuron=self.axis_neuron,
222229
update_angle=self.update_angle,
223230
activation_function=self.activation_function,

deepmd/utils/argcheck.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1440,6 +1440,11 @@ def dpa3_repflow_args():
14401440
doc_a_sel = 'Maximally possible number of selected angle neighbors. It can be:\n\n\
14411441
- `int`. The maximum number of neighbor atoms to be considered. We recommend it to be less than 200. \n\n\
14421442
- `str`. Can be "auto:factor" or "auto". "factor" is a float number larger than 1. This option will automatically determine the `sel`. In detail it counts the maximal number of neighbors with in the cutoff radius for each type of neighbor, then multiply the maximum by the "factor". Finally the number is wrapped up to 4 divisible. The option "auto" is equivalent to "auto:1.1".'
1443+
doc_a_compress_rate = (
1444+
"The compression rate for angular messages. The default value is 0, indicating no compression. "
1445+
" If a non-zero integer c is provided, the node and edge dimensions will be compressed "
1446+
"to n_dim/c and e_dim/2c, respectively, within the angular message."
1447+
)
14431448
doc_axis_neuron = "The number of dimension of submatrix in the symmetrization ops."
14441449
doc_update_angle = (
14451450
"Where to update the angle rep. If not, only node and edge rep will be used."
@@ -1475,6 +1480,9 @@ def dpa3_repflow_args():
14751480
Argument("a_rcut", float, doc=doc_a_rcut),
14761481
Argument("a_rcut_smth", float, doc=doc_a_rcut_smth),
14771482
Argument("a_sel", [int, str], doc=doc_a_sel),
1483+
Argument(
1484+
"a_compress_rate", int, optional=True, default=0, doc=doc_a_compress_rate
1485+
),
14781486
Argument(
14791487
"axis_neuron",
14801488
int,

source/tests/pt/model/test_dpa3.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,14 @@ def test_consistency(
4848
ua,
4949
rus,
5050
ruri,
51+
acr,
5152
prec,
5253
ect,
5354
) in itertools.product(
5455
[True, False], # update_angle
5556
["res_residual"], # update_style
5657
["norm", "const"], # update_residual_init
58+
[0, 1], # a_compress_rate
5759
["float64"], # precision
5860
[False], # use_econf_tebd
5961
):
@@ -73,6 +75,7 @@ def test_consistency(
7375
a_rcut=self.rcut - 0.1,
7476
a_rcut_smth=self.rcut_smth,
7577
a_sel=nnei - 1,
78+
a_compress_rate=acr,
7679
axis_neuron=4,
7780
update_angle=ua,
7881
update_style=rus,
@@ -127,12 +130,14 @@ def test_jit(
127130
ua,
128131
rus,
129132
ruri,
133+
acr,
130134
prec,
131135
ect,
132136
) in itertools.product(
133137
[True, False], # update_angle
134138
["res_residual"], # update_style
135139
["norm", "const"], # update_residual_init
140+
[0, 1], # a_compress_rate
136141
["float64"], # precision
137142
[False], # use_econf_tebd
138143
):
@@ -150,6 +155,7 @@ def test_jit(
150155
a_rcut=self.rcut - 0.1,
151156
a_rcut_smth=self.rcut_smth,
152157
a_sel=nnei - 1,
158+
a_compress_rate=acr,
153159
axis_neuron=4,
154160
update_angle=ua,
155161
update_style=rus,

source/tests/universal/dpmodel/descriptor/test_descriptor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,7 @@ def DescriptorParamDPA3(
475475
update_residual=0.1,
476476
update_residual_init="const",
477477
update_angle=True,
478+
a_compress_rate=0,
478479
precision="float64",
479480
):
480481
input_dict = {
@@ -491,6 +492,7 @@ def DescriptorParamDPA3(
491492
"a_rcut": rcut / 2,
492493
"a_rcut_smth": rcut_smth / 2,
493494
"a_sel": sum(sel) // 4,
495+
"a_compress_rate": a_compress_rate,
494496
"axis_neuron": 4,
495497
"update_angle": update_angle,
496498
"update_style": update_style,
@@ -520,6 +522,7 @@ def DescriptorParamDPA3(
520522
"update_residual_init": ("const",),
521523
"exclude_types": ([], [[0, 1]]),
522524
"update_angle": (True, False),
525+
"a_compress_rate": (0, 1),
523526
"env_protection": (0.0, 1e-8),
524527
"precision": ("float64",),
525528
}

0 commit comments

Comments
 (0)