|
27 | 27 | MLPLayer, |
28 | 28 | NetworkCollection, |
29 | 29 | ) |
| 30 | +from deepmd.pt.model.network.network import ( |
| 31 | + TypeEmbedNet, |
| 32 | +) |
30 | 33 | from deepmd.pt.utils import ( |
31 | 34 | env, |
32 | 35 | ) |
@@ -281,6 +284,9 @@ def __init__( |
281 | 284 | self.compress_data = nn.ParameterList( |
282 | 285 | [nn.Parameter(torch.zeros(0, dtype=self.prec, device=env.DEVICE))] |
283 | 286 | ) |
| 287 | + # For type embedding compression (strip mode, two-side only) |
| 288 | + self.compress_type_embd = False |
| 289 | + self.two_side_embd_data = None |
284 | 290 |
|
285 | 291 | def get_rcut(self) -> float: |
286 | 292 | """Returns the cut-off radius.""" |
@@ -447,6 +453,50 @@ def enable_compression( |
447 | 453 | self.compress_data[0] = table_data[net].to(device=env.DEVICE, dtype=self.prec) |
448 | 454 | self.compress = True |
449 | 455 |
|
| 456 | + def enable_type_embedding_compression( |
| 457 | + self, type_embedding_net: TypeEmbedNet |
| 458 | + ) -> None: |
| 459 | + """Enable type embedding compression for strip mode (two-side only). |
| 460 | +
|
| 461 | + This method precomputes the type embedding network outputs for all possible |
| 462 | + type pairs, following the same approach as TF backend's compression: |
| 463 | +
|
| 464 | + TF approach: |
| 465 | + 1. get_two_side_type_embedding(): creates (ntypes+1)^2 type pair combinations |
| 466 | + 2. make_data(): applies embedding network to get precomputed outputs |
| 467 | + 3. In forward: lookup precomputed values instead of real-time computation |
| 468 | +
|
| 469 | + PyTorch implementation: |
| 470 | + - Precomputes all (ntypes+1)^2 type pair embedding network outputs |
| 471 | + - Stores in buffer for proper serialization and device management |
| 472 | + - Uses lookup during inference to avoid redundant computations |
| 473 | +
|
| 474 | + Parameters |
| 475 | + ---------- |
| 476 | + type_embedding_net : TypeEmbedNet |
| 477 | + The type embedding network that provides get_full_embedding() method |
| 478 | + """ |
| 479 | + with torch.no_grad(): |
| 480 | + # Get full type embedding: (ntypes+1) x t_dim |
| 481 | + full_embd = type_embedding_net.get_full_embedding(env.DEVICE) |
| 482 | + nt, t_dim = full_embd.shape |
| 483 | + |
| 484 | + # Create all type pair combinations [neighbor, center] |
| 485 | + # for a fixed row i, all columns j have different neighbor types |
| 486 | + embd_nei = full_embd.view(1, nt, t_dim).expand(nt, nt, t_dim) |
| 487 | + # for a fixed row i, all columns j share the same center type i |
| 488 | + embd_center = full_embd.view(nt, 1, t_dim).expand(nt, nt, t_dim) |
| 489 | + two_side_embd = torch.cat([embd_nei, embd_center], dim=-1).reshape( |
| 490 | + -1, t_dim * 2 |
| 491 | + ) |
| 492 | + |
| 493 | + # Apply strip embedding network and store |
| 494 | + # index logic: index = center_type * nt + neighbor_type |
| 495 | + self.two_side_embd_data = self.filter_layers_strip.networks[0]( |
| 496 | + two_side_embd |
| 497 | + ).detach() |
| 498 | + self.compress_type_embd = True |
| 499 | + |
450 | 500 | def forward( |
451 | 501 | self, |
452 | 502 | nlist: torch.Tensor, |
@@ -605,7 +655,12 @@ def forward( |
605 | 655 | two_side_type_embedding = torch.cat( |
606 | 656 | [type_embedding_nei, type_embedding_center], -1 |
607 | 657 | ).reshape(-1, nt * 2) |
608 | | - tt_full = self.filter_layers_strip.networks[0](two_side_type_embedding) |
| 658 | + if self.compress_type_embd and self.two_side_embd_data is not None: |
| 659 | + tt_full = self.two_side_embd_data |
| 660 | + else: |
| 661 | + tt_full = self.filter_layers_strip.networks[0]( |
| 662 | + two_side_type_embedding |
| 663 | + ) |
609 | 664 | # (nf x nl x nnei) x ng |
610 | 665 | gg_t = torch.gather(tt_full, dim=0, index=idx) |
611 | 666 | # (nf x nl) x nnei x ng |
|
0 commit comments