Skip to content

Commit 1309e26

Browse files
committed
add compress
1 parent 20a60c6 commit 1309e26

File tree

7 files changed

+98
-9
lines changed

7 files changed

+98
-9
lines changed

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def __init__(
1414
a_rcut: float = 4.0,
1515
a_rcut_smth: float = 3.5,
1616
a_sel: int = 20,
17+
a_compress_rate: int = 0,
1718
axis_neuron: int = 4,
1819
update_angle: bool = True,
1920
update_style: str = "res_residual",
@@ -44,6 +45,10 @@ def __init__(
4445
Where to start smoothing for angle. For example the 1/r term is smoothed from rcut to rcut_smth.
4546
a_sel : int, optional
4647
Maximally possible number of selected angle neighbors.
48+
a_compress_rate : int, optional
49+
The compression rate for angular messages. The default value is 0, indicating no compression.
50+
If a non-zero integer c is provided, the node and edge dimensions will be compressed
51+
to n_dim/c and e_dim/2c, respectively, within the angular message.
4752
axis_neuron : int, optional
4853
The number of dimension of submatrix in the symmetrization ops.
4954
update_angle : bool, optional
@@ -71,6 +76,7 @@ def __init__(
7176
self.a_rcut = a_rcut
7277
self.a_rcut_smth = a_rcut_smth
7378
self.a_sel = a_sel
79+
self.a_compress_rate = a_compress_rate
7480
self.axis_neuron = axis_neuron
7581
self.update_angle = update_angle
7682
self.update_style = update_style
@@ -95,6 +101,7 @@ def serialize(self) -> dict:
95101
"a_rcut": self.a_rcut,
96102
"a_rcut_smth": self.a_rcut_smth,
97103
"a_sel": self.a_sel,
104+
"a_compress_rate": self.a_compress_rate,
98105
"axis_neuron": self.axis_neuron,
99106
"update_angle": self.update_angle,
100107
"update_style": self.update_style,

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def init_subclass_params(sub_data, sub_class):
151151
n_dim=self.repflow_args.n_dim,
152152
e_dim=self.repflow_args.e_dim,
153153
a_dim=self.repflow_args.a_dim,
154+
a_compress_rate=self.repflow_args.a_compress_rate,
154155
axis_neuron=self.repflow_args.axis_neuron,
155156
update_angle=self.repflow_args.update_angle,
156157
activation_function=self.activation_function,

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 66 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def __init__(
4545
n_dim: int = 128,
4646
e_dim: int = 16,
4747
a_dim: int = 64,
48+
a_compress_rate: int = 0,
4849
axis_neuron: int = 4,
4950
update_angle: bool = True, # angle
5051
activation_function: str = "silu",
@@ -70,6 +71,12 @@ def __init__(
7071
self.n_dim = n_dim
7172
self.e_dim = e_dim
7273
self.a_dim = a_dim
74+
self.a_compress_rate = a_compress_rate
75+
if a_compress_rate != 0:
76+
assert a_dim % (2 * a_compress_rate) == 0, (
77+
f"For a_compress_rate of {a_compress_rate}, a_dim must be divisible by {2 * a_compress_rate}. "
78+
f"Currently, a_dim={a_dim} is not valid."
79+
)
7380
self.axis_neuron = axis_neuron
7481
self.update_angle = update_angle
7582
self.activation_function = activation_function
@@ -167,20 +174,42 @@ def __init__(
167174
)
168175

169176
if self.update_angle:
170-
self.angle_dim = self.a_dim + self.n_dim + 2 * self.e_dim
177+
self.angle_dim = self.a_dim
178+
if self.a_compress_rate == 0:
179+
# angle + node + edge * 2
180+
self.angle_dim += self.n_dim + 2 * self.e_dim
181+
self.a_compress_n_linear = None
182+
self.a_compress_e_linear = None
183+
else:
184+
# angle + node/c + edge/2c * 2
185+
self.angle_dim += 2 * (self.a_dim // self.a_compress_rate)
186+
self.a_compress_n_linear = MLPLayer(
187+
self.n_dim,
188+
self.a_dim // self.a_compress_rate,
189+
precision=precision,
190+
bias=False,
191+
seed=child_seed(seed, 8),
192+
)
193+
self.a_compress_e_linear = MLPLayer(
194+
self.e_dim,
195+
self.a_dim // (2 * self.a_compress_rate),
196+
precision=precision,
197+
bias=False,
198+
seed=child_seed(seed, 9),
199+
)
171200

172201
# edge angle message
173202
self.edge_angle_linear1 = MLPLayer(
174203
self.angle_dim,
175204
self.e_dim,
176205
precision=precision,
177-
seed=child_seed(seed, 8),
206+
seed=child_seed(seed, 10),
178207
)
179208
self.edge_angle_linear2 = MLPLayer(
180209
self.e_dim,
181210
self.e_dim,
182211
precision=precision,
183-
seed=child_seed(seed, 9),
212+
seed=child_seed(seed, 11),
184213
)
185214
if self.update_style == "res_residual":
186215
self.e_residual.append(
@@ -189,7 +218,7 @@ def __init__(
189218
self.update_residual,
190219
self.update_residual_init,
191220
precision=precision,
192-
seed=child_seed(seed, 10),
221+
seed=child_seed(seed, 12),
193222
)
194223
)
195224

@@ -198,7 +227,7 @@ def __init__(
198227
self.angle_dim,
199228
self.a_dim,
200229
precision=precision,
201-
seed=child_seed(seed, 11),
230+
seed=child_seed(seed, 13),
202231
)
203232
if self.update_style == "res_residual":
204233
self.a_residual.append(
@@ -207,13 +236,15 @@ def __init__(
207236
self.update_residual,
208237
self.update_residual_init,
209238
precision=precision,
210-
seed=child_seed(seed, 12),
239+
seed=child_seed(seed, 14),
211240
)
212241
)
213242
else:
214243
self.angle_self_linear = None
215244
self.edge_angle_linear1 = None
216245
self.edge_angle_linear2 = None
246+
self.a_compress_n_linear = None
247+
self.a_compress_e_linear = None
217248
self.angle_dim = 0
218249

219250
self.n_residual = nn.ParameterList(self.n_residual)
@@ -448,12 +479,22 @@ def forward(
448479
assert self.edge_angle_linear1 is not None
449480
assert self.edge_angle_linear2 is not None
450481
# get angle info
482+
if self.a_compress_rate != 0:
483+
assert self.a_compress_n_linear is not None
484+
assert self.a_compress_e_linear is not None
485+
node_ebd_for_angle = self.a_compress_n_linear(node_ebd)
486+
edge_ebd_for_angle = self.a_compress_e_linear(edge_ebd)
487+
else:
488+
node_ebd_for_angle = node_ebd
489+
edge_ebd_for_angle = edge_ebd
490+
451491
# nb x nloc x a_nnei x a_nnei x n_dim
452492
node_for_angle_info = torch.tile(
453-
node_ebd.unsqueeze(2).unsqueeze(2), (1, 1, self.a_sel, self.a_sel, 1)
493+
node_ebd_for_angle.unsqueeze(2).unsqueeze(2),
494+
(1, 1, self.a_sel, self.a_sel, 1),
454495
)
455496
# nb x nloc x a_nnei x e_dim
456-
edge_for_angle = edge_ebd[:, :, : self.a_sel, :]
497+
edge_for_angle = edge_ebd_for_angle[:, :, : self.a_sel, :]
457498
# nb x nloc x a_nnei x e_dim
458499
edge_for_angle = torch.where(
459500
a_nlist_mask.unsqueeze(-1), edge_for_angle, 0.0
@@ -471,7 +512,7 @@ def forward(
471512
[edge_for_angle_i, edge_for_angle_j], dim=-1
472513
)
473514
angle_info_list = [angle_ebd, node_for_angle_info, edge_for_angle_info]
474-
# nb x nloc x a_nnei x a_nnei x (a + n_dim + e_dim*2)
515+
# nb x nloc x a_nnei x a_nnei x (a + n_dim + e_dim*2) or (a + a/c + a/c)
475516
angle_info = torch.cat(angle_info_list, dim=-1)
476517

477518
# edge angle message
@@ -605,6 +646,7 @@ def serialize(self) -> dict:
605646
"n_dim": self.n_dim,
606647
"e_dim": self.e_dim,
607648
"a_dim": self.a_dim,
649+
"a_compress_rate": self.a_compress_rate,
608650
"axis_neuron": self.axis_neuron,
609651
"activation_function": self.activation_function,
610652
"update_angle": self.update_angle,
@@ -625,6 +667,13 @@ def serialize(self) -> dict:
625667
"angle_self_linear": self.angle_self_linear.serialize(),
626668
}
627669
)
670+
if self.a_compress_rate != 0:
671+
data.update(
672+
{
673+
"a_compress_n_linear": self.a_compress_n_linear.serialize(),
674+
"a_compress_e_linear": self.a_compress_e_linear.serialize(),
675+
}
676+
)
628677
if self.update_style == "res_residual":
629678
data.update(
630679
{
@@ -650,13 +699,16 @@ def deserialize(cls, data: dict) -> "RepFlowLayer":
650699
check_version_compatibility(data.pop("@version"), 1, 1)
651700
data.pop("@class")
652701
update_angle = data["update_angle"]
702+
a_compress_rate = data["a_compress_rate"]
653703
node_self_mlp = data.pop("node_self_mlp")
654704
node_sym_linear = data.pop("node_sym_linear")
655705
node_edge_linear = data.pop("node_edge_linear")
656706
edge_self_linear = data.pop("edge_self_linear")
657707
edge_angle_linear1 = data.pop("edge_angle_linear1", None)
658708
edge_angle_linear2 = data.pop("edge_angle_linear2", None)
659709
angle_self_linear = data.pop("angle_self_linear", None)
710+
a_compress_n_linear = data.pop("a_compress_n_linear", None)
711+
a_compress_e_linear = data.pop("a_compress_e_linear", None)
660712
update_style = data["update_style"]
661713
variables = data.pop("@variables", {})
662714
n_residual = variables.get("n_residual", data.pop("n_residual", []))
@@ -676,6 +728,11 @@ def deserialize(cls, data: dict) -> "RepFlowLayer":
676728
obj.edge_angle_linear1 = MLPLayer.deserialize(edge_angle_linear1)
677729
obj.edge_angle_linear2 = MLPLayer.deserialize(edge_angle_linear2)
678730
obj.angle_self_linear = MLPLayer.deserialize(angle_self_linear)
731+
if a_compress_rate != 0:
732+
assert isinstance(a_compress_n_linear, dict)
733+
assert isinstance(a_compress_e_linear, dict)
734+
obj.a_compress_n_linear = MLPLayer.deserialize(a_compress_n_linear)
735+
obj.a_compress_e_linear = MLPLayer.deserialize(a_compress_e_linear)
679736

680737
if update_style == "res_residual":
681738
for ii, t in enumerate(obj.n_residual):

deepmd/pt/model/descriptor/repflows.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def __init__(
8585
n_dim: int = 128,
8686
e_dim: int = 64,
8787
a_dim: int = 64,
88+
a_compress_rate: int = 0,
8889
axis_neuron: int = 4,
8990
update_angle: bool = True,
9091
activation_function: str = "silu",
@@ -122,6 +123,10 @@ def __init__(
122123
Where to start smoothing for angle. For example the 1/r term is smoothed from rcut to rcut_smth.
123124
a_sel : int, optional
124125
Maximally possible number of selected angle neighbors.
126+
a_compress_rate : int, optional
127+
The compression rate for angular messages. The default value is 0, indicating no compression.
128+
If a non-zero integer c is provided, the node and edge dimensions will be compressed
129+
to n_dim/c and e_dim/2c, respectively, within the angular message.
125130
axis_neuron : int, optional
126131
The number of dimension of submatrix in the symmetrization ops.
127132
update_angle : bool, optional
@@ -174,6 +179,7 @@ def __init__(
174179
self.rcut_smth = e_rcut_smth
175180
self.sec = self.sel
176181
self.split_sel = self.sel
182+
self.a_compress_rate = a_compress_rate
177183
self.axis_neuron = axis_neuron
178184
self.set_davg_zero = set_davg_zero
179185

@@ -216,6 +222,7 @@ def __init__(
216222
n_dim=self.n_dim,
217223
e_dim=self.e_dim,
218224
a_dim=self.a_dim,
225+
a_compress_rate=self.a_compress_rate,
219226
axis_neuron=self.axis_neuron,
220227
update_angle=self.update_angle,
221228
activation_function=self.activation_function,

deepmd/utils/argcheck.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1440,6 +1440,11 @@ def dpa3_repflow_args():
14401440
doc_a_sel = 'Maximally possible number of selected angle neighbors. It can be:\n\n\
14411441
- `int`. The maximum number of neighbor atoms to be considered. We recommend it to be less than 200. \n\n\
14421442
- `str`. Can be "auto:factor" or "auto". "factor" is a float number larger than 1. This option will automatically determine the `sel`. In detail it counts the maximal number of neighbors with in the cutoff radius for each type of neighbor, then multiply the maximum by the "factor". Finally the number is wrapped up to 4 divisible. The option "auto" is equivalent to "auto:1.1".'
1443+
doc_a_compress_rate = (
1444+
"The compression rate for angular messages. The default value is 0, indicating no compression. "
1445+
" If a non-zero integer c is provided, the node and edge dimensions will be compressed "
1446+
"to n_dim/c and e_dim/2c, respectively, within the angular message."
1447+
)
14431448
doc_axis_neuron = "The number of dimension of submatrix in the symmetrization ops."
14441449
doc_update_angle = (
14451450
"Where to update the angle rep. If not, only node and edge rep will be used."
@@ -1475,6 +1480,9 @@ def dpa3_repflow_args():
14751480
Argument("a_rcut", float, doc=doc_a_rcut),
14761481
Argument("a_rcut_smth", float, doc=doc_a_rcut_smth),
14771482
Argument("a_sel", [int, str], doc=doc_a_sel),
1483+
Argument(
1484+
"a_compress_rate", int, optional=True, default=0, doc=doc_a_compress_rate
1485+
),
14781486
Argument(
14791487
"axis_neuron",
14801488
int,

source/tests/pt/model/test_dpa3.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,14 @@ def test_consistency(
4848
ua,
4949
rus,
5050
ruri,
51+
acr,
5152
prec,
5253
ect,
5354
) in itertools.product(
5455
[True, False], # update_angle
5556
["res_residual"], # update_style
5657
["norm", "const"], # update_residual_init
58+
[0, 1], # a_compress_rate
5759
["float64"], # precision
5860
[False], # use_econf_tebd
5961
):
@@ -73,6 +75,7 @@ def test_consistency(
7375
a_rcut=self.rcut - 0.1,
7476
a_rcut_smth=self.rcut_smth,
7577
a_sel=nnei - 1,
78+
a_compress_rate=acr,
7679
axis_neuron=4,
7780
update_angle=ua,
7881
update_style=rus,
@@ -127,12 +130,14 @@ def test_jit(
127130
ua,
128131
rus,
129132
ruri,
133+
acr,
130134
prec,
131135
ect,
132136
) in itertools.product(
133137
[True, False], # update_angle
134138
["res_residual"], # update_style
135139
["norm", "const"], # update_residual_init
140+
[0, 1], # a_compress_rate
136141
["float64"], # precision
137142
[False], # use_econf_tebd
138143
):
@@ -150,6 +155,7 @@ def test_jit(
150155
a_rcut=self.rcut - 0.1,
151156
a_rcut_smth=self.rcut_smth,
152157
a_sel=nnei - 1,
158+
a_compress_rate=acr,
153159
axis_neuron=4,
154160
update_angle=ua,
155161
update_style=rus,

source/tests/universal/dpmodel/descriptor/test_descriptor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,7 @@ def DescriptorParamDPA3(
475475
update_residual=0.1,
476476
update_residual_init="const",
477477
update_angle=True,
478+
a_compress_rate=0,
478479
precision="float64",
479480
):
480481
input_dict = {
@@ -491,6 +492,7 @@ def DescriptorParamDPA3(
491492
"a_rcut": rcut / 2,
492493
"a_rcut_smth": rcut_smth / 2,
493494
"a_sel": sum(sel) // 4,
495+
"a_compress_rate": a_compress_rate,
494496
"axis_neuron": 4,
495497
"update_angle": update_angle,
496498
"update_style": update_style,
@@ -520,6 +522,7 @@ def DescriptorParamDPA3(
520522
"update_residual_init": ("const",),
521523
"exclude_types": ([], [[0, 1]]),
522524
"update_angle": (True, False),
525+
"a_compress_rate": (0, 1),
523526
"env_protection": (0.0, 1e-8),
524527
"precision": ("float64",),
525528
}

0 commit comments

Comments
 (0)