Skip to content

Commit 13e247e

Browse files
fix(tf): fix compress suffix in DescrptDPA1Compat (#4243)
Fix #4114 . <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Enhanced compression capabilities in descriptor models with new optional parameters for improved flexibility. - Improved serialization processes for attention layers, allowing for better handling of scaling factors and normalization. - Dynamic tensor name construction in utility functions to accommodate varying suffixes. - **Bug Fixes** - Adjusted method parameters to ensure compatibility and functionality with new suffix options. - **Tests** - Introduced a new test suite to validate the functionality of the TensorFlow-based descriptor model, ensuring consistent output with the updated features. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 5394854 commit 13e247e

File tree

4 files changed

+212
-4
lines changed

4 files changed

+212
-4
lines changed

deepmd/tf/descriptor/se_atten.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,7 @@ def enable_compression(
423423
table_stride_2: float = 0.1,
424424
check_frequency: int = -1,
425425
suffix: str = "",
426+
tebd_suffix: str = "",
426427
) -> None:
427428
"""Reveive the statisitcs (distance, max_nbor_size and env_mat_range) of the training data.
428429
@@ -444,6 +445,8 @@ def enable_compression(
444445
The overflow check frequency
445446
suffix : str, optional
446447
The suffix of the scope
448+
tebd_suffix : str, optional
449+
The suffix of the type embedding scope, only for DescrptDPA1Compat
447450
"""
448451
# do some checks before the mocel compression process
449452
assert (
@@ -496,7 +499,9 @@ def enable_compression(
496499
min_nbor_dist, table_extrapolate, table_stride_1, table_stride_2
497500
)
498501

499-
self.final_type_embedding = get_two_side_type_embedding(self, graph)
502+
self.final_type_embedding = get_two_side_type_embedding(
503+
self, graph, suffix=tebd_suffix
504+
)
500505
type_side_suffix = get_extra_embedding_net_suffix(type_one_side=False)
501506
self.matrix = get_extra_side_embedding_net_variable(
502507
self, graph_def, type_side_suffix, "matrix", suffix
@@ -2248,6 +2253,56 @@ def build(
22482253
self.dout = tf.concat([self.dout, atom_embed], axis=-1)
22492254
return self.dout
22502255

2256+
def enable_compression(
2257+
self,
2258+
min_nbor_dist: float,
2259+
graph: tf.Graph,
2260+
graph_def: tf.GraphDef,
2261+
table_extrapolate: float = 5,
2262+
table_stride_1: float = 0.01,
2263+
table_stride_2: float = 0.1,
2264+
check_frequency: int = -1,
2265+
suffix: str = "",
2266+
tebd_suffix: str = "",
2267+
) -> None:
2268+
"""Reveive the statisitcs (distance, max_nbor_size and env_mat_range) of the training data.
2269+
2270+
Parameters
2271+
----------
2272+
min_nbor_dist
2273+
The nearest distance between atoms
2274+
graph : tf.Graph
2275+
The graph of the model
2276+
graph_def : tf.GraphDef
2277+
The graph_def of the model
2278+
table_extrapolate
2279+
The scale of model extrapolation
2280+
table_stride_1
2281+
The uniform stride of the first table
2282+
table_stride_2
2283+
The uniform stride of the second table
2284+
check_frequency
2285+
The overflow check frequency
2286+
suffix : str, optional
2287+
The suffix of the scope
2288+
tebd_suffix : str, optional
2289+
Same as suffix.
2290+
"""
2291+
assert (
2292+
tebd_suffix == ""
2293+
), "DescrptDPA1Compat must use the same tebd_suffix as suffix!"
2294+
super().enable_compression(
2295+
min_nbor_dist,
2296+
graph,
2297+
graph_def,
2298+
table_extrapolate=table_extrapolate,
2299+
table_stride_1=table_stride_1,
2300+
table_stride_2=table_stride_2,
2301+
check_frequency=check_frequency,
2302+
suffix=suffix,
2303+
tebd_suffix=suffix,
2304+
)
2305+
22512306
def init_variables(
22522307
self,
22532308
graph: tf.Graph,

deepmd/tf/utils/compress.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ def get_type_embedding(self, graph):
2020
return type_embedding
2121

2222

23-
def get_two_side_type_embedding(self, graph):
24-
type_embedding = get_tensor_by_name_from_graph(graph, "t_typeebd")
23+
def get_two_side_type_embedding(self, graph, suffix=""):
24+
type_embedding = get_tensor_by_name_from_graph(graph, f"t_typeebd{suffix}")
2525
type_embedding = type_embedding.astype(self.filter_np_precision)
2626
type_embedding_shape = type_embedding.shape
2727

deepmd/tf/utils/tabulate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def __init__(
126126
self.dstd = get_tensor_by_name_from_graph(
127127
self.graph, f"descrpt_attr{self.suffix}/t_std"
128128
)
129-
self.ntypes = get_tensor_by_name_from_graph(self.graph, "descrpt_attr/ntypes")
129+
self.ntypes = self.descrpt.get_ntypes()
130130

131131
self.embedding_net_nodes = get_embedding_net_nodes_from_graph_def(
132132
self.graph_def, suffix=self.suffix
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import unittest
3+
4+
import numpy as np
5+
6+
from deepmd.common import (
7+
make_default_mesh,
8+
)
9+
from deepmd.env import (
10+
GLOBAL_NP_FLOAT_PRECISION,
11+
)
12+
from deepmd.tf.descriptor.se_atten import DescrptDPA1Compat as tf_SeAtten
13+
from deepmd.tf.env import (
14+
GLOBAL_TF_FLOAT_PRECISION,
15+
default_tf_session_config,
16+
tf,
17+
)
18+
from deepmd.tf.utils.sess import (
19+
run_sess,
20+
)
21+
22+
23+
def build_tf_descriptor(obj, natoms, coords, atype, box, suffix):
24+
t_coord = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="i_coord")
25+
t_type = tf.placeholder(tf.int32, [None], name="i_type")
26+
t_natoms = tf.placeholder(tf.int32, natoms.shape, name="i_natoms")
27+
t_box = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [9], name="i_box")
28+
t_mesh = tf.placeholder(tf.int32, [None], name="i_mesh")
29+
t_des = obj.build(
30+
t_coord,
31+
t_type,
32+
t_natoms,
33+
t_box,
34+
t_mesh,
35+
{},
36+
suffix=suffix,
37+
)
38+
return [t_des], {
39+
t_coord: coords,
40+
t_type: atype,
41+
t_natoms: natoms,
42+
t_box: box,
43+
t_mesh: make_default_mesh(True, False),
44+
}
45+
46+
47+
def build_eval_tf(sess, obj, natoms, coords, atype, box, suffix):
48+
t_out, feed_dict = build_tf_descriptor(obj, natoms, coords, atype, box, suffix)
49+
50+
t_out_indentity = [
51+
tf.identity(tt, name=f"o_{ii}_{suffix}") for ii, tt in enumerate(t_out)
52+
]
53+
run_sess(sess, tf.global_variables_initializer())
54+
return run_sess(
55+
sess,
56+
t_out_indentity,
57+
feed_dict=feed_dict,
58+
)
59+
60+
61+
class TestDescriptorSeA(unittest.TestCase):
62+
def setUp(self):
63+
self.device = "cpu"
64+
self.seed = 21
65+
self.sel = [9, 10]
66+
self.rcut_smth = 5.80
67+
self.rcut = 6.00
68+
self.neuron = [6, 12, 24]
69+
self.axis_neuron = 3
70+
self.ntypes = 2
71+
self.coords = np.array(
72+
[
73+
12.83,
74+
2.56,
75+
2.18,
76+
12.09,
77+
2.87,
78+
2.74,
79+
00.25,
80+
3.32,
81+
1.68,
82+
3.36,
83+
3.00,
84+
1.81,
85+
3.51,
86+
2.51,
87+
2.60,
88+
4.27,
89+
3.22,
90+
1.56,
91+
],
92+
dtype=GLOBAL_NP_FLOAT_PRECISION,
93+
)
94+
self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32)
95+
# self.atype = np.array([0, 0, 1, 1, 1, 1], dtype=np.int32)
96+
self.box = np.array(
97+
[13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0],
98+
dtype=GLOBAL_NP_FLOAT_PRECISION,
99+
)
100+
self.natoms = np.array([6, 6, 2, 4], dtype=np.int32)
101+
self.suffix = "test"
102+
self.type_one_side = False
103+
self.se_a_tf = tf_SeAtten(
104+
self.rcut,
105+
self.rcut_smth,
106+
self.sel,
107+
self.ntypes,
108+
self.neuron,
109+
self.axis_neuron,
110+
type_one_side=self.type_one_side,
111+
seed=21,
112+
precision="float32",
113+
tebd_input_mode="strip",
114+
temperature=1.0,
115+
attn_layer=0,
116+
)
117+
118+
def test_tf_pt_consistent(
119+
self,
120+
):
121+
with tf.Session(config=default_tf_session_config) as sess:
122+
graph = tf.get_default_graph()
123+
ret = build_eval_tf(
124+
sess,
125+
self.se_a_tf,
126+
self.natoms,
127+
self.coords,
128+
self.atype,
129+
self.box,
130+
self.suffix,
131+
)
132+
output_graph_def = tf.graph_util.convert_variables_to_constants(
133+
sess,
134+
graph.as_graph_def(),
135+
[f"o_{ii}_{self.suffix}" for ii, _ in enumerate(ret)],
136+
)
137+
with tf.Graph().as_default() as new_graph:
138+
tf.import_graph_def(output_graph_def, name="")
139+
self.se_a_tf.init_variables(
140+
new_graph,
141+
output_graph_def,
142+
suffix=self.suffix,
143+
)
144+
self.se_a_tf.enable_compression(
145+
1.0,
146+
new_graph,
147+
output_graph_def,
148+
suffix=self.suffix,
149+
)
150+
151+
152+
if __name__ == "__main__":
153+
unittest.main()

0 commit comments

Comments
 (0)