From 22d01927b32e5518d7ee7a22c20fcd2b0ce8e538 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 8 Nov 2024 19:12:27 -0500 Subject: [PATCH 1/3] feat(pt): DPA-2 repinit compress Signed-off-by: Jinzhe Zeng --- deepmd/pt/model/descriptor/dpa2.py | 86 ++++++++++ deepmd/pt/utils/tabulate.py | 19 ++- .../model/test_compressed_descriptor_dpa2.py | 149 ++++++++++++++++++ 3 files changed, 249 insertions(+), 5 deletions(-) create mode 100644 source/tests/pt/model/test_compressed_descriptor_dpa2.py diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index ad5167c572..40fffda835 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -31,10 +31,14 @@ build_multiple_neighbor_list, get_multiple_nlist_key, ) +from deepmd.pt.utils.tabulate import ( + DPTabulate, +) from deepmd.pt.utils.update_sel import ( UpdateSel, ) from deepmd.pt.utils.utils import ( + ActivationFn, to_numpy_array, ) from deepmd.utils.data_system import ( @@ -859,3 +863,85 @@ def update_sel( ) local_jdata_cpy["repformer"]["nsel"] = repformer_sel[0] return local_jdata_cpy, min_nbor_dist + + def enable_compression( + self, + min_nbor_dist: float, + table_extrapolate: float = 5, + table_stride_1: float = 0.01, + table_stride_2: float = 0.1, + check_frequency: int = -1, + ) -> None: + """Receive the statisitcs (distance, max_nbor_size and env_mat_range) of the training data. + + Parameters + ---------- + min_nbor_dist + The nearest distance between atoms + table_extrapolate + The scale of model extrapolation + table_stride_1 + The uniform stride of the first table + table_stride_2 + The uniform stride of the second table + check_frequency + The overflow check frequency + """ + # do some checks before the mocel compression process + if self.repinit.compress: + raise ValueError("Compression is already enabled.") + assert ( + not self.repinit.resnet_dt + ), "Model compression error: repinit resnet_dt must be false!" + for tt in self.repinit.exclude_types: + if (tt[0] not in range(self.repinit.ntypes)) or ( + tt[1] not in range(self.repinit.ntypes) + ): + raise RuntimeError( + "Repinit exclude types" + + str(tt) + + " must within the number of atomic types " + + str(self.repinit.ntypes) + + "!" + ) + if ( + self.repinit.ntypes * self.repinit.ntypes - len(self.repinit.exclude_types) + == 0 + ): + raise RuntimeError( + "Repinit empty embedding-nets are not supported in model compression!" + ) + + if self.repinit.attn_layer != 0: + raise RuntimeError( + "Cannot compress model when repinit attention layer is not 0." + ) + + if self.repinit.tebd_input_mode != "strip": + raise RuntimeError( + "Cannot compress model when repinit tebd_input_mode == 'concat'" + ) + + # repinit doesn't have a serialize method + data = self.serialize() + self.table = DPTabulate( + self, + data["repinit_args"]["neuron"], + data["repinit_args"]["type_one_side"], + data["exclude_types"], + ActivationFn(data["repinit_args"]["activation_function"]), + ) + self.table_config = [ + table_extrapolate, + table_stride_1, + table_stride_2, + check_frequency, + ] + self.lower, self.upper = self.table.build( + min_nbor_dist, table_extrapolate, table_stride_1, table_stride_2 + ) + + self.repinit.enable_compression( + self.table.data, self.table_config, self.lower, self.upper + ) + self.compress = True diff --git a/deepmd/pt/utils/tabulate.py b/deepmd/pt/utils/tabulate.py index 796f7dcd52..e21d2ec9a6 100644 --- a/deepmd/pt/utils/tabulate.py +++ b/deepmd/pt/utils/tabulate.py @@ -95,11 +95,14 @@ def __init__( raise RuntimeError("Unknown activation function type!") self.activation_fn = activation_fn - self.davg = self.descrpt.serialize()["@variables"]["davg"] - self.dstd = self.descrpt.serialize()["@variables"]["dstd"] - self.ntypes = self.descrpt.get_ntypes() + serialized = self.descrpt.serialize() + if isinstance(self.descrpt, deepmd.pt.model.descriptor.DescrptDPA2): + serialized = serialized["repinit_variable"] + self.davg = serialized["@variables"]["davg"] + self.dstd = serialized["@variables"]["dstd"] + self.embedding_net_nodes = serialized["embeddings"]["networks"] - self.embedding_net_nodes = self.descrpt.serialize()["embeddings"]["networks"] + self.ntypes = self.descrpt.get_ntypes() self.layer_size = self._get_layer_size() self.table_size = self._get_table_size() @@ -291,7 +294,13 @@ def _layer_1(self, x, w, b): return t, self.activation_fn(torch.matmul(x, w) + b) + t def _get_descrpt_type(self): - if isinstance(self.descrpt, deepmd.pt.model.descriptor.DescrptDPA1): + if isinstance( + self.descrpt, + ( + deepmd.pt.model.descriptor.DescrptDPA1, + deepmd.pt.model.descriptor.DescrptDPA2, + ), + ): return "Atten" elif isinstance(self.descrpt, deepmd.pt.model.descriptor.DescrptSeA): return "A" diff --git a/source/tests/pt/model/test_compressed_descriptor_dpa2.py b/source/tests/pt/model/test_compressed_descriptor_dpa2.py new file mode 100644 index 0000000000..05b1143eb1 --- /dev/null +++ b/source/tests/pt/model/test_compressed_descriptor_dpa2.py @@ -0,0 +1,149 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from typing import ( + Any, +) + +import numpy as np +import torch + +from deepmd.dpmodel.descriptor.dpa2 import ( + RepformerArgs, + RepinitArgs, +) +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) +from deepmd.pt.model.descriptor.dpa2 import ( + DescrptDPA2, +) +from deepmd.pt.utils.env import DEVICE as PT_DEVICE +from deepmd.pt.utils.nlist import build_neighbor_list as build_neighbor_list_pt +from deepmd.pt.utils.nlist import ( + extend_coord_with_ghosts as extend_coord_with_ghosts_pt, +) + +from ...consistent.common import ( + parameterized, +) + + +def eval_pt_descriptor( + pt_obj: Any, natoms, coords, atype, box, mixed_types: bool = False +) -> Any: + ext_coords, ext_atype, mapping = extend_coord_with_ghosts_pt( + torch.from_numpy(coords).to(PT_DEVICE).reshape(1, -1, 3), + torch.from_numpy(atype).to(PT_DEVICE).reshape(1, -1), + torch.from_numpy(box).to(PT_DEVICE).reshape(1, 3, 3), + pt_obj.get_rcut(), + ) + nlist = build_neighbor_list_pt( + ext_coords, + ext_atype, + natoms[0], + pt_obj.get_rcut(), + pt_obj.get_sel(), + distinguish_types=(not mixed_types), + ) + result, _, _, _, _ = pt_obj(ext_coords, ext_atype, nlist, mapping=mapping) + return result + + +@parameterized(("float32", "float64"), (True, False)) +class TestDescriptorDPA2(unittest.TestCase): + def setUp(self): + (self.dtype, self.type_one_side) = self.param + if self.dtype == "float32": + self.skipTest("FP32 has bugs:") + # ../../../../deepmd/pt/model/descriptor/repformer_layer.py:521: in forward + # torch.matmul(attnw.unsqueeze(-2), gg1v).squeeze(-2).view(nb, nloc, nh * ni) + # E RuntimeError: expected scalar type Float but found Double + if self.dtype == "float32": + self.atol = 1e-5 + elif self.dtype == "float64": + self.atol = 1e-10 + self.seed = 21 + self.sel = [10] + self.rcut_smth = 5.80 + self.rcut = 6.00 + self.neuron = [6, 12, 24] + self.axis_neuron = 3 + self.ntypes = 2 + self.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ) + self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32) + self.box = np.array( + [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ) + self.natoms = np.array([6, 6, 2, 4], dtype=np.int32) + + repinit = RepinitArgs( + rcut=self.rcut, + rcut_smth=self.rcut_smth, + nsel=10, + tebd_input_mode="strip", + type_one_side=self.type_one_side, + ) + repformer = RepformerArgs( + rcut=self.rcut - 1, + rcut_smth=self.rcut_smth - 1, + nsel=9, + ) + + self.descriptor = DescrptDPA2( + ntypes=self.ntypes, + repinit=repinit, + repformer=repformer, + precision=self.dtype, + ) + + def test_compressed_forward(self): + result_pt = eval_pt_descriptor( + self.descriptor, + self.natoms, + self.coords, + self.atype, + self.box, + ) + self.descriptor.enable_compression(0.5) + result_pt_compressed = eval_pt_descriptor( + self.descriptor, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + self.assertEqual(result_pt.shape, result_pt_compressed.shape) + torch.testing.assert_close( + result_pt, + result_pt_compressed, + atol=self.atol, + rtol=self.atol, + ) + + +if __name__ == "__main__": + unittest.main() From fe284290846aa382f8b1558b8d81c7311f2b7d27 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 8 Nov 2024 20:13:00 -0500 Subject: [PATCH 2/3] set compress attr Signed-off-by: Jinzhe Zeng --- deepmd/pt/model/descriptor/dpa2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index 40fffda835..d25cd45462 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -310,6 +310,7 @@ def init_subclass_params(sub_data, sub_class): # set trainable for param in self.parameters(): param.requires_grad = trainable + self.compress = False def get_rcut(self) -> float: """Returns the cut-off radius.""" @@ -888,7 +889,7 @@ def enable_compression( The overflow check frequency """ # do some checks before the mocel compression process - if self.repinit.compress: + if self.compress: raise ValueError("Compression is already enabled.") assert ( not self.repinit.resnet_dt From 51e86964c8c104aa18b4a3918d6871076a871f3c Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 8 Nov 2024 21:42:18 -0500 Subject: [PATCH 3/3] fix typo in docstring Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Jinzhe Zeng --- deepmd/pt/model/descriptor/dpa2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index d25cd45462..77e9f1d936 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -873,7 +873,7 @@ def enable_compression( table_stride_2: float = 0.1, check_frequency: int = -1, ) -> None: - """Receive the statisitcs (distance, max_nbor_size and env_mat_range) of the training data. + """Receive the statistics (distance, max_nbor_size and env_mat_range) of the training data. Parameters ----------