Skip to content

Commit 023bb9c

Browse files
feat(pt): DPA-2 repinit compress (#4329)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced model compression functionality in the descriptor class, allowing users to enable compression with specific parameters. - Enhanced handling of serialized data for different descriptor types, improving flexibility and efficiency. - **Tests** - Added a new test suite for the descriptor class, ensuring robust testing of functionality with various configurations and floating-point precisions. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
1 parent c12bc01 commit 023bb9c

File tree

3 files changed

+250
-5
lines changed

3 files changed

+250
-5
lines changed

deepmd/pt/model/descriptor/dpa2.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,14 @@
3131
build_multiple_neighbor_list,
3232
get_multiple_nlist_key,
3333
)
34+
from deepmd.pt.utils.tabulate import (
35+
DPTabulate,
36+
)
3437
from deepmd.pt.utils.update_sel import (
3538
UpdateSel,
3639
)
3740
from deepmd.pt.utils.utils import (
41+
ActivationFn,
3842
to_numpy_array,
3943
)
4044
from deepmd.utils.data_system import (
@@ -306,6 +310,7 @@ def init_subclass_params(sub_data, sub_class):
306310
# set trainable
307311
for param in self.parameters():
308312
param.requires_grad = trainable
313+
self.compress = False
309314

310315
def get_rcut(self) -> float:
311316
"""Returns the cut-off radius."""
@@ -859,3 +864,85 @@ def update_sel(
859864
)
860865
local_jdata_cpy["repformer"]["nsel"] = repformer_sel[0]
861866
return local_jdata_cpy, min_nbor_dist
867+
868+
def enable_compression(
869+
self,
870+
min_nbor_dist: float,
871+
table_extrapolate: float = 5,
872+
table_stride_1: float = 0.01,
873+
table_stride_2: float = 0.1,
874+
check_frequency: int = -1,
875+
) -> None:
876+
"""Receive the statistics (distance, max_nbor_size and env_mat_range) of the training data.
877+
878+
Parameters
879+
----------
880+
min_nbor_dist
881+
The nearest distance between atoms
882+
table_extrapolate
883+
The scale of model extrapolation
884+
table_stride_1
885+
The uniform stride of the first table
886+
table_stride_2
887+
The uniform stride of the second table
888+
check_frequency
889+
The overflow check frequency
890+
"""
891+
# do some checks before the mocel compression process
892+
if self.compress:
893+
raise ValueError("Compression is already enabled.")
894+
assert (
895+
not self.repinit.resnet_dt
896+
), "Model compression error: repinit resnet_dt must be false!"
897+
for tt in self.repinit.exclude_types:
898+
if (tt[0] not in range(self.repinit.ntypes)) or (
899+
tt[1] not in range(self.repinit.ntypes)
900+
):
901+
raise RuntimeError(
902+
"Repinit exclude types"
903+
+ str(tt)
904+
+ " must within the number of atomic types "
905+
+ str(self.repinit.ntypes)
906+
+ "!"
907+
)
908+
if (
909+
self.repinit.ntypes * self.repinit.ntypes - len(self.repinit.exclude_types)
910+
== 0
911+
):
912+
raise RuntimeError(
913+
"Repinit empty embedding-nets are not supported in model compression!"
914+
)
915+
916+
if self.repinit.attn_layer != 0:
917+
raise RuntimeError(
918+
"Cannot compress model when repinit attention layer is not 0."
919+
)
920+
921+
if self.repinit.tebd_input_mode != "strip":
922+
raise RuntimeError(
923+
"Cannot compress model when repinit tebd_input_mode == 'concat'"
924+
)
925+
926+
# repinit doesn't have a serialize method
927+
data = self.serialize()
928+
self.table = DPTabulate(
929+
self,
930+
data["repinit_args"]["neuron"],
931+
data["repinit_args"]["type_one_side"],
932+
data["exclude_types"],
933+
ActivationFn(data["repinit_args"]["activation_function"]),
934+
)
935+
self.table_config = [
936+
table_extrapolate,
937+
table_stride_1,
938+
table_stride_2,
939+
check_frequency,
940+
]
941+
self.lower, self.upper = self.table.build(
942+
min_nbor_dist, table_extrapolate, table_stride_1, table_stride_2
943+
)
944+
945+
self.repinit.enable_compression(
946+
self.table.data, self.table_config, self.lower, self.upper
947+
)
948+
self.compress = True

deepmd/pt/utils/tabulate.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,14 @@ def __init__(
9595
raise RuntimeError("Unknown activation function type!")
9696

9797
self.activation_fn = activation_fn
98-
self.davg = self.descrpt.serialize()["@variables"]["davg"]
99-
self.dstd = self.descrpt.serialize()["@variables"]["dstd"]
100-
self.ntypes = self.descrpt.get_ntypes()
98+
serialized = self.descrpt.serialize()
99+
if isinstance(self.descrpt, deepmd.pt.model.descriptor.DescrptDPA2):
100+
serialized = serialized["repinit_variable"]
101+
self.davg = serialized["@variables"]["davg"]
102+
self.dstd = serialized["@variables"]["dstd"]
103+
self.embedding_net_nodes = serialized["embeddings"]["networks"]
101104

102-
self.embedding_net_nodes = self.descrpt.serialize()["embeddings"]["networks"]
105+
self.ntypes = self.descrpt.get_ntypes()
103106

104107
self.layer_size = self._get_layer_size()
105108
self.table_size = self._get_table_size()
@@ -291,7 +294,13 @@ def _layer_1(self, x, w, b):
291294
return t, self.activation_fn(torch.matmul(x, w) + b) + t
292295

293296
def _get_descrpt_type(self):
294-
if isinstance(self.descrpt, deepmd.pt.model.descriptor.DescrptDPA1):
297+
if isinstance(
298+
self.descrpt,
299+
(
300+
deepmd.pt.model.descriptor.DescrptDPA1,
301+
deepmd.pt.model.descriptor.DescrptDPA2,
302+
),
303+
):
295304
return "Atten"
296305
elif isinstance(self.descrpt, deepmd.pt.model.descriptor.DescrptSeA):
297306
return "A"
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import unittest
3+
from typing import (
4+
Any,
5+
)
6+
7+
import numpy as np
8+
import torch
9+
10+
from deepmd.dpmodel.descriptor.dpa2 import (
11+
RepformerArgs,
12+
RepinitArgs,
13+
)
14+
from deepmd.env import (
15+
GLOBAL_NP_FLOAT_PRECISION,
16+
)
17+
from deepmd.pt.model.descriptor.dpa2 import (
18+
DescrptDPA2,
19+
)
20+
from deepmd.pt.utils.env import DEVICE as PT_DEVICE
21+
from deepmd.pt.utils.nlist import build_neighbor_list as build_neighbor_list_pt
22+
from deepmd.pt.utils.nlist import (
23+
extend_coord_with_ghosts as extend_coord_with_ghosts_pt,
24+
)
25+
26+
from ...consistent.common import (
27+
parameterized,
28+
)
29+
30+
31+
def eval_pt_descriptor(
32+
pt_obj: Any, natoms, coords, atype, box, mixed_types: bool = False
33+
) -> Any:
34+
ext_coords, ext_atype, mapping = extend_coord_with_ghosts_pt(
35+
torch.from_numpy(coords).to(PT_DEVICE).reshape(1, -1, 3),
36+
torch.from_numpy(atype).to(PT_DEVICE).reshape(1, -1),
37+
torch.from_numpy(box).to(PT_DEVICE).reshape(1, 3, 3),
38+
pt_obj.get_rcut(),
39+
)
40+
nlist = build_neighbor_list_pt(
41+
ext_coords,
42+
ext_atype,
43+
natoms[0],
44+
pt_obj.get_rcut(),
45+
pt_obj.get_sel(),
46+
distinguish_types=(not mixed_types),
47+
)
48+
result, _, _, _, _ = pt_obj(ext_coords, ext_atype, nlist, mapping=mapping)
49+
return result
50+
51+
52+
@parameterized(("float32", "float64"), (True, False))
53+
class TestDescriptorDPA2(unittest.TestCase):
54+
def setUp(self):
55+
(self.dtype, self.type_one_side) = self.param
56+
if self.dtype == "float32":
57+
self.skipTest("FP32 has bugs:")
58+
# ../../../../deepmd/pt/model/descriptor/repformer_layer.py:521: in forward
59+
# torch.matmul(attnw.unsqueeze(-2), gg1v).squeeze(-2).view(nb, nloc, nh * ni)
60+
# E RuntimeError: expected scalar type Float but found Double
61+
if self.dtype == "float32":
62+
self.atol = 1e-5
63+
elif self.dtype == "float64":
64+
self.atol = 1e-10
65+
self.seed = 21
66+
self.sel = [10]
67+
self.rcut_smth = 5.80
68+
self.rcut = 6.00
69+
self.neuron = [6, 12, 24]
70+
self.axis_neuron = 3
71+
self.ntypes = 2
72+
self.coords = np.array(
73+
[
74+
12.83,
75+
2.56,
76+
2.18,
77+
12.09,
78+
2.87,
79+
2.74,
80+
00.25,
81+
3.32,
82+
1.68,
83+
3.36,
84+
3.00,
85+
1.81,
86+
3.51,
87+
2.51,
88+
2.60,
89+
4.27,
90+
3.22,
91+
1.56,
92+
],
93+
dtype=GLOBAL_NP_FLOAT_PRECISION,
94+
)
95+
self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32)
96+
self.box = np.array(
97+
[13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0],
98+
dtype=GLOBAL_NP_FLOAT_PRECISION,
99+
)
100+
self.natoms = np.array([6, 6, 2, 4], dtype=np.int32)
101+
102+
repinit = RepinitArgs(
103+
rcut=self.rcut,
104+
rcut_smth=self.rcut_smth,
105+
nsel=10,
106+
tebd_input_mode="strip",
107+
type_one_side=self.type_one_side,
108+
)
109+
repformer = RepformerArgs(
110+
rcut=self.rcut - 1,
111+
rcut_smth=self.rcut_smth - 1,
112+
nsel=9,
113+
)
114+
115+
self.descriptor = DescrptDPA2(
116+
ntypes=self.ntypes,
117+
repinit=repinit,
118+
repformer=repformer,
119+
precision=self.dtype,
120+
)
121+
122+
def test_compressed_forward(self):
123+
result_pt = eval_pt_descriptor(
124+
self.descriptor,
125+
self.natoms,
126+
self.coords,
127+
self.atype,
128+
self.box,
129+
)
130+
self.descriptor.enable_compression(0.5)
131+
result_pt_compressed = eval_pt_descriptor(
132+
self.descriptor,
133+
self.natoms,
134+
self.coords,
135+
self.atype,
136+
self.box,
137+
)
138+
139+
self.assertEqual(result_pt.shape, result_pt_compressed.shape)
140+
torch.testing.assert_close(
141+
result_pt,
142+
result_pt_compressed,
143+
atol=self.atol,
144+
rtol=self.atol,
145+
)
146+
147+
148+
if __name__ == "__main__":
149+
unittest.main()

0 commit comments

Comments
 (0)