|
| 1 | +# SPDX-License-Identifier: LGPL-3.0-or-later |
| 2 | +import numpy as np |
| 3 | + |
| 4 | +from deepmd.tf.descriptor.hybrid import ( |
| 5 | + DescrptHybrid, |
| 6 | +) |
| 7 | +from deepmd.tf.env import ( |
| 8 | + tf, |
| 9 | +) |
| 10 | +from deepmd.tf.fit import ( |
| 11 | + DipoleFittingSeA, |
| 12 | +) |
| 13 | +from deepmd.tf.model import ( |
| 14 | + DipoleModel, |
| 15 | +) |
| 16 | + |
| 17 | +from .common import ( |
| 18 | + DataSystem, |
| 19 | + gen_data, |
| 20 | + j_loader, |
| 21 | +) |
| 22 | + |
| 23 | +GLOBAL_ENER_FLOAT_PRECISION = tf.float64 |
| 24 | +GLOBAL_TF_FLOAT_PRECISION = tf.float64 |
| 25 | +GLOBAL_NP_FLOAT_PRECISION = np.float64 |
| 26 | + |
| 27 | + |
| 28 | +class TestModel(tf.test.TestCase): |
| 29 | + def setUp(self) -> None: |
| 30 | + gen_data() |
| 31 | + |
| 32 | + def test_model(self) -> None: |
| 33 | + jfile = "polar_se_a.json" |
| 34 | + jdata = j_loader(jfile) |
| 35 | + |
| 36 | + systems = jdata["systems"] |
| 37 | + set_pfx = "set" |
| 38 | + batch_size = 1 |
| 39 | + test_size = 1 |
| 40 | + rcut = jdata["model"]["descriptor"]["rcut"] |
| 41 | + |
| 42 | + data = DataSystem(systems, set_pfx, batch_size, test_size, rcut, run_opt=None) |
| 43 | + |
| 44 | + test_data = data.get_test() |
| 45 | + numb_test = 1 |
| 46 | + |
| 47 | + descrpt = DescrptHybrid( |
| 48 | + list=[ |
| 49 | + { |
| 50 | + "type": "se_e2_a", |
| 51 | + "sel": [20, 20], |
| 52 | + "rcut_smth": 1.8, |
| 53 | + "rcut": 6.0, |
| 54 | + "neuron": [2, 4, 8], |
| 55 | + "resnet_dt": False, |
| 56 | + "axis_neuron": 8, |
| 57 | + "precision": "float64", |
| 58 | + "type_one_side": True, |
| 59 | + "seed": 1, |
| 60 | + }, |
| 61 | + { |
| 62 | + "type": "se_e2_a", |
| 63 | + "sel": [20, 20], |
| 64 | + "rcut_smth": 1.8, |
| 65 | + "rcut": 6.0, |
| 66 | + "neuron": [2, 4, 8], |
| 67 | + "resnet_dt": False, |
| 68 | + "axis_neuron": 8, |
| 69 | + "precision": "float64", |
| 70 | + "type_one_side": True, |
| 71 | + "seed": 1, |
| 72 | + }, |
| 73 | + { |
| 74 | + "type": "se_e3", |
| 75 | + "sel": [5, 5], |
| 76 | + "rcut_smth": 1.8, |
| 77 | + "rcut": 2.0, |
| 78 | + "neuron": [2], |
| 79 | + "resnet_dt": False, |
| 80 | + "precision": "float64", |
| 81 | + "seed": 1, |
| 82 | + }, |
| 83 | + ] |
| 84 | + ) |
| 85 | + jdata["model"]["fitting_net"].pop("type", None) |
| 86 | + jdata["model"]["fitting_net"].pop("fit_diag", None) |
| 87 | + jdata["model"]["fitting_net"]["ntypes"] = descrpt.get_ntypes() |
| 88 | + jdata["model"]["fitting_net"]["dim_descrpt"] = descrpt.get_dim_out() |
| 89 | + jdata["model"]["fitting_net"]["embedding_width"] = descrpt.get_dim_rot_mat_1() |
| 90 | + fitting = DipoleFittingSeA(**jdata["model"]["fitting_net"], uniform_seed=True) |
| 91 | + model = DipoleModel(descrpt, fitting) |
| 92 | + |
| 93 | + # model._compute_dstats([test_data['coord']], [test_data['box']], [test_data['type']], [test_data['natoms_vec']], [test_data['default_mesh']]) |
| 94 | + input_data = { |
| 95 | + "coord": [test_data["coord"]], |
| 96 | + "box": [test_data["box"]], |
| 97 | + "type": [test_data["type"]], |
| 98 | + "natoms_vec": [test_data["natoms_vec"]], |
| 99 | + "default_mesh": [test_data["default_mesh"]], |
| 100 | + "fparam": [test_data["fparam"]], |
| 101 | + } |
| 102 | + model._compute_input_stat(input_data) |
| 103 | + |
| 104 | + t_prop_c = tf.placeholder(tf.float32, [5], name="t_prop_c") |
| 105 | + t_coord = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="i_coord") |
| 106 | + t_type = tf.placeholder(tf.int32, [None], name="i_type") |
| 107 | + t_natoms = tf.placeholder(tf.int32, [model.ntypes + 2], name="i_natoms") |
| 108 | + t_box = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None, 9], name="i_box") |
| 109 | + t_mesh = tf.placeholder(tf.int32, [None], name="i_mesh") |
| 110 | + is_training = tf.placeholder(tf.bool) |
| 111 | + t_fparam = None |
| 112 | + |
| 113 | + model_pred = model.build( |
| 114 | + t_coord, |
| 115 | + t_type, |
| 116 | + t_natoms, |
| 117 | + t_box, |
| 118 | + t_mesh, |
| 119 | + t_fparam, |
| 120 | + suffix="dipole_hybrid", |
| 121 | + reuse=False, |
| 122 | + ) |
| 123 | + dipole = model_pred["dipole"] |
| 124 | + gdipole = model_pred["global_dipole"] |
| 125 | + force = model_pred["force"] |
| 126 | + virial = model_pred["virial"] |
| 127 | + atom_virial = model_pred["atom_virial"] |
| 128 | + |
| 129 | + feed_dict_test = { |
| 130 | + t_prop_c: test_data["prop_c"], |
| 131 | + t_coord: np.reshape(test_data["coord"][:numb_test, :], [-1]), |
| 132 | + t_box: test_data["box"][:numb_test, :], |
| 133 | + t_type: np.reshape(test_data["type"][:numb_test, :], [-1]), |
| 134 | + t_natoms: test_data["natoms_vec"], |
| 135 | + t_mesh: test_data["default_mesh"], |
| 136 | + is_training: False, |
| 137 | + } |
| 138 | + |
| 139 | + sess = self.cached_session().__enter__() |
| 140 | + sess.run(tf.global_variables_initializer()) |
| 141 | + [p, gp, f, v, av] = sess.run( |
| 142 | + [dipole, gdipole, force, virial, atom_virial], feed_dict=feed_dict_test |
| 143 | + ) |
0 commit comments