Skip to content

Commit 34091f9

Browse files
Merge branch 'devel' into add_cpp_infer
2 parents 9dbc547 + 65ca05a commit 34091f9

File tree

5 files changed

+175
-1
lines changed

5 files changed

+175
-1
lines changed

deepmd/tf/descriptor/descriptor.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,8 @@ def get_dim_rot_mat_1(self) -> int:
105105
int
106106
the first dimension of the rotation matrix
107107
"""
108-
raise NotImplementedError
108+
# by default, no rotation matrix
109+
return 0
109110

110111
def get_nlist(self) -> tuple[tf.Tensor, tf.Tensor, list[int], list[int]]:
111112
"""Returns neighbor information.
@@ -534,3 +535,9 @@ def serialize(self, suffix: str = "") -> dict:
534535
def input_requirement(self) -> list[DataRequirementItem]:
535536
"""Return data requirements needed for the model input."""
536537
return []
538+
539+
def get_rot_mat(self) -> tf.Tensor:
540+
"""Get rotational matrix."""
541+
nframes = tf.shape(self.dout)[0]
542+
natoms = tf.shape(self.dout)[1]
543+
return tf.zeros([nframes, natoms, 0], dtype=GLOBAL_TF_FLOAT_PRECISION)

deepmd/tf/descriptor/hybrid.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,3 +492,21 @@ def deserialize(cls, data: dict, suffix: str = "") -> "DescrptHybrid":
492492
if hasattr(ii, "type_embedding"):
493493
raise NotImplementedError("hybrid + type embedding is not supported")
494494
return obj
495+
496+
def get_dim_rot_mat_1(self) -> int:
497+
"""Returns the first dimension of the rotation matrix. The rotation is of shape
498+
dim_1 x 3.
499+
500+
Returns
501+
-------
502+
int
503+
the first dimension of the rotation matrix
504+
"""
505+
return sum([ii.get_dim_rot_mat_1() for ii in self.descrpt_list])
506+
507+
def get_rot_mat(self) -> tf.Tensor:
508+
"""Get rotational matrix."""
509+
all_rot_mat = []
510+
for ii in self.descrpt_list:
511+
all_rot_mat.append(ii.get_rot_mat())
512+
return tf.concat(all_rot_mat, axis=2)

deepmd/tf/model/model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,11 @@ def __init__(
668668
else:
669669
if fitting_net["type"] in ["dipole", "polar"]:
670670
fitting_net["embedding_width"] = self.descrpt.get_dim_rot_mat_1()
671+
if fitting_net["embedding_width"] == 0:
672+
raise ValueError(
673+
"This descriptor cannot provide a rotation matrix "
674+
"for a tensorial fitting."
675+
)
671676
self.fitting = Fitting(
672677
**fitting_net,
673678
descrpt=self.descrpt,

source/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ if(ENABLE_PADDLE)
131131
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}glog/include")
132132
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}gflags/include")
133133
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/include")
134+
list(APPEND BACKEND_INCLUDE_DIRS "${PADDLE_INFERENCE_DIR}/paddle/include")
134135
list(APPEND BACKEND_INCLUDE_DIRS
135136
"${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/include")
136137
list(APPEND BACKEND_INCLUDE_DIRS "${PADDLE_LIB_THIRD_PARTY_PATH}glog/include")
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import numpy as np
3+
4+
from deepmd.tf.descriptor.hybrid import (
5+
DescrptHybrid,
6+
)
7+
from deepmd.tf.env import (
8+
tf,
9+
)
10+
from deepmd.tf.fit import (
11+
DipoleFittingSeA,
12+
)
13+
from deepmd.tf.model import (
14+
DipoleModel,
15+
)
16+
17+
from .common import (
18+
DataSystem,
19+
gen_data,
20+
j_loader,
21+
)
22+
23+
GLOBAL_ENER_FLOAT_PRECISION = tf.float64
24+
GLOBAL_TF_FLOAT_PRECISION = tf.float64
25+
GLOBAL_NP_FLOAT_PRECISION = np.float64
26+
27+
28+
class TestModel(tf.test.TestCase):
29+
def setUp(self) -> None:
30+
gen_data()
31+
32+
def test_model(self) -> None:
33+
jfile = "polar_se_a.json"
34+
jdata = j_loader(jfile)
35+
36+
systems = jdata["systems"]
37+
set_pfx = "set"
38+
batch_size = 1
39+
test_size = 1
40+
rcut = jdata["model"]["descriptor"]["rcut"]
41+
42+
data = DataSystem(systems, set_pfx, batch_size, test_size, rcut, run_opt=None)
43+
44+
test_data = data.get_test()
45+
numb_test = 1
46+
47+
descrpt = DescrptHybrid(
48+
list=[
49+
{
50+
"type": "se_e2_a",
51+
"sel": [20, 20],
52+
"rcut_smth": 1.8,
53+
"rcut": 6.0,
54+
"neuron": [2, 4, 8],
55+
"resnet_dt": False,
56+
"axis_neuron": 8,
57+
"precision": "float64",
58+
"type_one_side": True,
59+
"seed": 1,
60+
},
61+
{
62+
"type": "se_e2_a",
63+
"sel": [20, 20],
64+
"rcut_smth": 1.8,
65+
"rcut": 6.0,
66+
"neuron": [2, 4, 8],
67+
"resnet_dt": False,
68+
"axis_neuron": 8,
69+
"precision": "float64",
70+
"type_one_side": True,
71+
"seed": 1,
72+
},
73+
{
74+
"type": "se_e3",
75+
"sel": [5, 5],
76+
"rcut_smth": 1.8,
77+
"rcut": 2.0,
78+
"neuron": [2],
79+
"resnet_dt": False,
80+
"precision": "float64",
81+
"seed": 1,
82+
},
83+
]
84+
)
85+
jdata["model"]["fitting_net"].pop("type", None)
86+
jdata["model"]["fitting_net"].pop("fit_diag", None)
87+
jdata["model"]["fitting_net"]["ntypes"] = descrpt.get_ntypes()
88+
jdata["model"]["fitting_net"]["dim_descrpt"] = descrpt.get_dim_out()
89+
jdata["model"]["fitting_net"]["embedding_width"] = descrpt.get_dim_rot_mat_1()
90+
fitting = DipoleFittingSeA(**jdata["model"]["fitting_net"], uniform_seed=True)
91+
model = DipoleModel(descrpt, fitting)
92+
93+
# model._compute_dstats([test_data['coord']], [test_data['box']], [test_data['type']], [test_data['natoms_vec']], [test_data['default_mesh']])
94+
input_data = {
95+
"coord": [test_data["coord"]],
96+
"box": [test_data["box"]],
97+
"type": [test_data["type"]],
98+
"natoms_vec": [test_data["natoms_vec"]],
99+
"default_mesh": [test_data["default_mesh"]],
100+
"fparam": [test_data["fparam"]],
101+
}
102+
model._compute_input_stat(input_data)
103+
104+
t_prop_c = tf.placeholder(tf.float32, [5], name="t_prop_c")
105+
t_coord = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="i_coord")
106+
t_type = tf.placeholder(tf.int32, [None], name="i_type")
107+
t_natoms = tf.placeholder(tf.int32, [model.ntypes + 2], name="i_natoms")
108+
t_box = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None, 9], name="i_box")
109+
t_mesh = tf.placeholder(tf.int32, [None], name="i_mesh")
110+
is_training = tf.placeholder(tf.bool)
111+
t_fparam = None
112+
113+
model_pred = model.build(
114+
t_coord,
115+
t_type,
116+
t_natoms,
117+
t_box,
118+
t_mesh,
119+
t_fparam,
120+
suffix="dipole_hybrid",
121+
reuse=False,
122+
)
123+
dipole = model_pred["dipole"]
124+
gdipole = model_pred["global_dipole"]
125+
force = model_pred["force"]
126+
virial = model_pred["virial"]
127+
atom_virial = model_pred["atom_virial"]
128+
129+
feed_dict_test = {
130+
t_prop_c: test_data["prop_c"],
131+
t_coord: np.reshape(test_data["coord"][:numb_test, :], [-1]),
132+
t_box: test_data["box"][:numb_test, :],
133+
t_type: np.reshape(test_data["type"][:numb_test, :], [-1]),
134+
t_natoms: test_data["natoms_vec"],
135+
t_mesh: test_data["default_mesh"],
136+
is_training: False,
137+
}
138+
139+
sess = self.cached_session().__enter__()
140+
sess.run(tf.global_variables_initializer())
141+
[p, gp, f, v, av] = sess.run(
142+
[dipole, gdipole, force, virial, atom_virial], feed_dict=feed_dict_test
143+
)

0 commit comments

Comments
 (0)