|
6 | 6 |
|
7 | 7 | import numpy as np |
8 | 8 |
|
| 9 | +from deepmd.dpmodel.common import ( |
| 10 | + to_numpy_array, |
| 11 | +) |
9 | 12 | from deepmd.dpmodel.model.ener_model import EnergyModel as EnergyModelDP |
10 | 13 | from deepmd.dpmodel.model.model import get_model as get_model_dp |
| 14 | +from deepmd.dpmodel.utils.nlist import ( |
| 15 | + build_neighbor_list, |
| 16 | + extend_coord_with_ghosts, |
| 17 | +) |
| 18 | +from deepmd.dpmodel.utils.region import ( |
| 19 | + normalize_coord, |
| 20 | +) |
11 | 21 | from deepmd.env import ( |
12 | 22 | GLOBAL_NP_FLOAT_PRECISION, |
13 | 23 | ) |
|
27 | 37 | if INSTALLED_PT: |
28 | 38 | from deepmd.pt.model.model import get_model as get_model_pt |
29 | 39 | from deepmd.pt.model.model.ener_model import EnergyModel as EnergyModelPT |
30 | | - |
| 40 | + from deepmd.pt.utils.utils import to_numpy_array as torch_to_numpy |
| 41 | + from deepmd.pt.utils.utils import to_torch_tensor as numpy_to_torch |
31 | 42 | else: |
32 | 43 | EnergyModelPT = None |
33 | 44 | if INSTALLED_TF: |
|
39 | 50 | ) |
40 | 51 |
|
41 | 52 | if INSTALLED_JAX: |
| 53 | + from deepmd.jax.common import ( |
| 54 | + to_jax_array, |
| 55 | + ) |
42 | 56 | from deepmd.jax.model.ener_model import EnergyModel as EnergyModelJAX |
43 | 57 | from deepmd.jax.model.model import get_model as get_model_jax |
44 | 58 | else: |
@@ -243,3 +257,207 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: |
243 | 257 | ret["energy_derv_c"].ravel(), |
244 | 258 | ) |
245 | 259 | raise ValueError(f"Unknown backend: {backend}") |
| 260 | + |
| 261 | + |
| 262 | +@parameterized( |
| 263 | + ( |
| 264 | + [], |
| 265 | + [[0, 1]], |
| 266 | + ), |
| 267 | + ( |
| 268 | + [], |
| 269 | + [1], |
| 270 | + ), |
| 271 | +) |
| 272 | +class TestEnerLower(CommonTest, ModelTest, unittest.TestCase): |
| 273 | + @property |
| 274 | + def data(self) -> dict: |
| 275 | + pair_exclude_types, atom_exclude_types = self.param |
| 276 | + return { |
| 277 | + "type_map": ["O", "H"], |
| 278 | + "pair_exclude_types": pair_exclude_types, |
| 279 | + "atom_exclude_types": atom_exclude_types, |
| 280 | + "descriptor": { |
| 281 | + "type": "se_e2_a", |
| 282 | + "sel": [20, 20], |
| 283 | + "rcut_smth": 0.50, |
| 284 | + "rcut": 6.00, |
| 285 | + "neuron": [ |
| 286 | + 3, |
| 287 | + 6, |
| 288 | + ], |
| 289 | + "resnet_dt": False, |
| 290 | + "axis_neuron": 2, |
| 291 | + "precision": "float64", |
| 292 | + "type_one_side": True, |
| 293 | + "seed": 1, |
| 294 | + }, |
| 295 | + "fitting_net": { |
| 296 | + "neuron": [ |
| 297 | + 5, |
| 298 | + 5, |
| 299 | + ], |
| 300 | + "resnet_dt": True, |
| 301 | + "precision": "float64", |
| 302 | + "seed": 1, |
| 303 | + }, |
| 304 | + } |
| 305 | + |
| 306 | + tf_class = EnergyModelTF |
| 307 | + dp_class = EnergyModelDP |
| 308 | + pt_class = EnergyModelPT |
| 309 | + jax_class = EnergyModelJAX |
| 310 | + args = model_args() |
| 311 | + |
| 312 | + def get_reference_backend(self): |
| 313 | + """Get the reference backend. |
| 314 | +
|
| 315 | + We need a reference backend that can reproduce forces. |
| 316 | + """ |
| 317 | + if not self.skip_pt: |
| 318 | + return self.RefBackend.PT |
| 319 | + if not self.skip_jax: |
| 320 | + return self.RefBackend.JAX |
| 321 | + if not self.skip_dp: |
| 322 | + return self.RefBackend.DP |
| 323 | + raise ValueError("No available reference") |
| 324 | + |
| 325 | + @property |
| 326 | + def skip_tf(self): |
| 327 | + # TF does not have lower interface |
| 328 | + return True |
| 329 | + |
| 330 | + @property |
| 331 | + def skip_jax(self): |
| 332 | + return not INSTALLED_JAX |
| 333 | + |
| 334 | + def pass_data_to_cls(self, cls, data) -> Any: |
| 335 | + """Pass data to the class.""" |
| 336 | + data = data.copy() |
| 337 | + if cls is EnergyModelDP: |
| 338 | + return get_model_dp(data) |
| 339 | + elif cls is EnergyModelPT: |
| 340 | + return get_model_pt(data) |
| 341 | + elif cls is EnergyModelJAX: |
| 342 | + return get_model_jax(data) |
| 343 | + return cls(**data, **self.additional_data) |
| 344 | + |
| 345 | + def setUp(self): |
| 346 | + CommonTest.setUp(self) |
| 347 | + |
| 348 | + self.ntypes = 2 |
| 349 | + coords = np.array( |
| 350 | + [ |
| 351 | + 12.83, |
| 352 | + 2.56, |
| 353 | + 2.18, |
| 354 | + 12.09, |
| 355 | + 2.87, |
| 356 | + 2.74, |
| 357 | + 00.25, |
| 358 | + 3.32, |
| 359 | + 1.68, |
| 360 | + 3.36, |
| 361 | + 3.00, |
| 362 | + 1.81, |
| 363 | + 3.51, |
| 364 | + 2.51, |
| 365 | + 2.60, |
| 366 | + 4.27, |
| 367 | + 3.22, |
| 368 | + 1.56, |
| 369 | + ], |
| 370 | + dtype=GLOBAL_NP_FLOAT_PRECISION, |
| 371 | + ).reshape(1, -1, 3) |
| 372 | + atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32).reshape(1, -1) |
| 373 | + box = np.array( |
| 374 | + [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], |
| 375 | + dtype=GLOBAL_NP_FLOAT_PRECISION, |
| 376 | + ).reshape(1, 9) |
| 377 | + |
| 378 | + rcut = 6.0 |
| 379 | + nframes, nloc = atype.shape[:2] |
| 380 | + coord_normalized = normalize_coord( |
| 381 | + coords.reshape(nframes, nloc, 3), |
| 382 | + box.reshape(nframes, 3, 3), |
| 383 | + ) |
| 384 | + extended_coord, extended_atype, mapping = extend_coord_with_ghosts( |
| 385 | + coord_normalized, atype, box, rcut |
| 386 | + ) |
| 387 | + nlist = build_neighbor_list( |
| 388 | + extended_coord, |
| 389 | + extended_atype, |
| 390 | + nloc, |
| 391 | + 6.0, |
| 392 | + [20, 20], |
| 393 | + distinguish_types=True, |
| 394 | + ) |
| 395 | + extended_coord = extended_coord.reshape(nframes, -1, 3) |
| 396 | + self.nlist = nlist |
| 397 | + self.extended_coord = extended_coord |
| 398 | + self.extended_atype = extended_atype |
| 399 | + self.mapping = mapping |
| 400 | + |
| 401 | + def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]: |
| 402 | + raise NotImplementedError("no TF in this test") |
| 403 | + |
| 404 | + def eval_dp(self, dp_obj: Any) -> Any: |
| 405 | + return dp_obj.call_lower( |
| 406 | + self.extended_coord, |
| 407 | + self.extended_atype, |
| 408 | + self.nlist, |
| 409 | + self.mapping, |
| 410 | + do_atomic_virial=True, |
| 411 | + ) |
| 412 | + |
| 413 | + def eval_pt(self, pt_obj: Any) -> Any: |
| 414 | + return { |
| 415 | + kk: torch_to_numpy(vv) |
| 416 | + for kk, vv in pt_obj.forward_lower( |
| 417 | + numpy_to_torch(self.extended_coord), |
| 418 | + numpy_to_torch(self.extended_atype), |
| 419 | + numpy_to_torch(self.nlist), |
| 420 | + numpy_to_torch(self.mapping), |
| 421 | + do_atomic_virial=True, |
| 422 | + ).items() |
| 423 | + } |
| 424 | + |
| 425 | + def eval_jax(self, jax_obj: Any) -> Any: |
| 426 | + return { |
| 427 | + kk: to_numpy_array(vv) |
| 428 | + for kk, vv in jax_obj.call_lower( |
| 429 | + to_jax_array(self.extended_coord), |
| 430 | + to_jax_array(self.extended_atype), |
| 431 | + to_jax_array(self.nlist), |
| 432 | + to_jax_array(self.mapping), |
| 433 | + do_atomic_virial=True, |
| 434 | + ).items() |
| 435 | + } |
| 436 | + |
| 437 | + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: |
| 438 | + # shape not matched. ravel... |
| 439 | + if backend is self.RefBackend.DP: |
| 440 | + return ( |
| 441 | + ret["energy_redu"].ravel(), |
| 442 | + ret["energy"].ravel(), |
| 443 | + SKIP_FLAG, |
| 444 | + SKIP_FLAG, |
| 445 | + SKIP_FLAG, |
| 446 | + ) |
| 447 | + elif backend is self.RefBackend.PT: |
| 448 | + return ( |
| 449 | + ret["energy"].ravel(), |
| 450 | + ret["atom_energy"].ravel(), |
| 451 | + ret["extended_force"].ravel(), |
| 452 | + ret["virial"].ravel(), |
| 453 | + ret["extended_virial"].ravel(), |
| 454 | + ) |
| 455 | + elif backend is self.RefBackend.JAX: |
| 456 | + return ( |
| 457 | + ret["energy_redu"].ravel(), |
| 458 | + ret["energy"].ravel(), |
| 459 | + ret["energy_derv_r"].ravel(), |
| 460 | + ret["energy_derv_c_redu"].ravel(), |
| 461 | + ret["energy_derv_c"].ravel(), |
| 462 | + ) |
| 463 | + raise ValueError(f"Unknown backend: {backend}") |
0 commit comments