Skip to content

Commit 0ad1ab6

Browse files
authored
feat(pt): type embedding can still be compress even if attn_layer != 0 (#5066)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Split compression into independent type-embedding (TEBD) and geometric modes, enabling partial compression when attention layer ≠ 0. * **Documentation** * Expanded backend-specific guidance describing full vs partial compression rules and prerequisites (e.g., TEBD input mode). * **Tests** * Added tests covering non-zero attention-layer scenarios to validate partial compression behavior. * **Bug Fixes** * Improved eligibility checks and clearer runtime warnings when geometric compression is skipped. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent a72b3af commit 0ad1ab6

File tree

7 files changed

+314
-71
lines changed

7 files changed

+314
-71
lines changed

deepmd/pt/model/descriptor/dpa1.py

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import warnings
23
from typing import (
34
Any,
45
Callable,
@@ -304,7 +305,8 @@ def __init__(
304305
self.use_econf_tebd = use_econf_tebd
305306
self.use_tebd_bias = use_tebd_bias
306307
self.type_map = type_map
307-
self.compress = False
308+
self.tebd_compress = False
309+
self.geo_compress = False
308310
self.type_embedding = TypeEmbedNet(
309311
ntypes,
310312
tebd_dim,
@@ -592,12 +594,17 @@ def enable_compression(
592594
check_frequency
593595
The overflow check frequency
594596
"""
595-
# do some checks before the mocel compression process
596-
if self.compress:
597+
# do some checks before the model compression process
598+
if self.tebd_compress or self.geo_compress:
597599
raise ValueError("Compression is already enabled.")
600+
601+
if self.tebd_input_mode != "strip":
602+
raise RuntimeError("Type embedding compression only works in strip mode")
603+
598604
assert not self.se_atten.resnet_dt, (
599605
"Model compression error: descriptor resnet_dt must be false!"
600606
)
607+
601608
for tt in self.se_atten.exclude_types:
602609
if (tt[0] not in range(self.se_atten.ntypes)) or (
603610
tt[1] not in range(self.se_atten.ntypes)
@@ -609,6 +616,7 @@ def enable_compression(
609616
+ str(self.se_atten.ntypes)
610617
+ "!"
611618
)
619+
612620
if (
613621
self.se_atten.ntypes * self.se_atten.ntypes
614622
- len(self.se_atten.exclude_types)
@@ -618,38 +626,38 @@ def enable_compression(
618626
"Empty embedding-nets are not supported in model compression!"
619627
)
620628

621-
if self.se_atten.attn_layer != 0:
622-
raise RuntimeError("Cannot compress model when attention layer is not 0.")
623-
624-
if self.tebd_input_mode != "strip":
625-
raise RuntimeError("Cannot compress model when tebd_input_mode == 'concat'")
626-
627-
data = self.serialize()
628-
self.table = DPTabulate(
629-
self,
630-
data["neuron"],
631-
data["type_one_side"],
632-
data["exclude_types"],
633-
ActivationFn(data["activation_function"]),
634-
)
635-
self.table_config = [
636-
table_extrapolate,
637-
table_stride_1,
638-
table_stride_2,
639-
check_frequency,
640-
]
641-
self.lower, self.upper = self.table.build(
642-
min_nbor_dist, table_extrapolate, table_stride_1, table_stride_2
643-
)
644-
645-
self.se_atten.enable_compression(
646-
self.table.data, self.table_config, self.lower, self.upper
647-
)
648-
649629
# Enable type embedding compression
650630
self.se_atten.type_embedding_compression(self.type_embedding)
631+
self.tebd_compress = True
632+
633+
if self.se_atten.attn_layer == 0:
634+
data = self.serialize()
635+
self.table = DPTabulate(
636+
self,
637+
data["neuron"],
638+
data["type_one_side"],
639+
data["exclude_types"],
640+
ActivationFn(data["activation_function"]),
641+
)
642+
self.table_config = [
643+
table_extrapolate,
644+
table_stride_1,
645+
table_stride_2,
646+
check_frequency,
647+
]
648+
self.lower, self.upper = self.table.build(
649+
min_nbor_dist, table_extrapolate, table_stride_1, table_stride_2
650+
)
651651

652-
self.compress = True
652+
self.se_atten.enable_compression(
653+
self.table.data, self.table_config, self.lower, self.upper
654+
)
655+
self.geo_compress = True
656+
else:
657+
warnings.warn(
658+
"Attention layer is not 0, only type embedding is compressed. Geometric part is not compressed.",
659+
UserWarning,
660+
)
653661

654662
def forward(
655663
self,

deepmd/pt/model/descriptor/dpa2.py

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import warnings
23
from typing import (
34
Any,
45
Callable,
@@ -938,38 +939,39 @@ def enable_compression(
938939
"Repinit empty embedding-nets are not supported in model compression!"
939940
)
940941

941-
if self.repinit.attn_layer != 0:
942-
raise RuntimeError(
943-
"Cannot compress model when repinit attention layer is not 0."
944-
)
945-
946942
if self.repinit.tebd_input_mode != "strip":
947943
raise RuntimeError(
948-
"Cannot compress model when repinit tebd_input_mode == 'concat'"
944+
"Cannot compress model when repinit tebd_input_mode != 'strip'"
949945
)
950946

951-
# repinit doesn't have a serialize method
952-
data = self.serialize()
953-
self.table = DPTabulate(
954-
self,
955-
data["repinit_args"]["neuron"],
956-
data["repinit_args"]["type_one_side"],
957-
data["exclude_types"],
958-
ActivationFn(data["repinit_args"]["activation_function"]),
959-
)
960-
self.table_config = [
961-
table_extrapolate,
962-
table_stride_1,
963-
table_stride_2,
964-
check_frequency,
965-
]
966-
self.lower, self.upper = self.table.build(
967-
min_nbor_dist, table_extrapolate, table_stride_1, table_stride_2
968-
)
947+
if self.repinit.attn_layer == 0:
948+
# repinit doesn't have a serialize method
949+
data = self.serialize()
950+
self.table = DPTabulate(
951+
self,
952+
data["repinit_args"]["neuron"],
953+
data["repinit_args"]["type_one_side"],
954+
data["exclude_types"],
955+
ActivationFn(data["repinit_args"]["activation_function"]),
956+
)
957+
self.table_config = [
958+
table_extrapolate,
959+
table_stride_1,
960+
table_stride_2,
961+
check_frequency,
962+
]
963+
self.lower, self.upper = self.table.build(
964+
min_nbor_dist, table_extrapolate, table_stride_1, table_stride_2
965+
)
969966

970-
self.repinit.enable_compression(
971-
self.table.data, self.table_config, self.lower, self.upper
972-
)
967+
self.repinit.enable_compression(
968+
self.table.data, self.table_config, self.lower, self.upper
969+
)
970+
else:
971+
warnings.warn(
972+
"Attention layer is not 0, only type embedding is compressed. Geometric part is not compressed.",
973+
UserWarning,
974+
)
973975

974976
# Enable type embedding compression for repinit (se_atten)
975977
self.repinit.type_embedding_compression(self.type_embedding)

deepmd/pt/model/descriptor/se_atten.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -275,9 +275,10 @@ def __init__(
275275
self.filter_layers_strip = filter_layers_strip
276276
self.stats = None
277277

278-
# For geometric compression
279-
self.compress = False
278+
self.tebd_compress = False
279+
self.geo_compress = False
280280
self.is_sorted = False
281+
# For geometric compression
281282
self.compress_info = nn.ParameterList(
282283
[nn.Parameter(torch.zeros(0, dtype=self.prec, device="cpu"))]
283284
)
@@ -452,7 +453,7 @@ def enable_compression(
452453
device="cpu",
453454
)
454455
self.compress_data[0] = table_data[net].to(device=env.DEVICE, dtype=self.prec)
455-
self.compress = True
456+
self.geo_compress = True
456457

457458
def type_embedding_compression(self, type_embedding_net: TypeEmbedNet) -> None:
458459
"""Enable type embedding compression for strip mode.
@@ -504,6 +505,8 @@ def type_embedding_compression(self, type_embedding_net: TypeEmbedNet) -> None:
504505
del self.type_embd_data
505506
self.register_buffer("type_embd_data", embd_tensor)
506507

508+
self.tebd_compress = True
509+
507510
def forward(
508511
self,
509512
nlist: torch.Tensor,
@@ -630,7 +633,7 @@ def forward(
630633
# nf x (nl x nnei)
631634
nei_type = torch.gather(extended_atype, dim=1, index=nlist_index)
632635
if self.type_one_side:
633-
if self.compress:
636+
if self.tebd_compress:
634637
tt_full = self.type_embd_data
635638
else:
636639
# (ntypes+1, tebd_dim) -> (ntypes+1, ng)
@@ -644,7 +647,7 @@ def forward(
644647
idx_j = nei_type.view(-1)
645648
# (nf x nl x nnei)
646649
idx = (idx_i + idx_j).to(torch.long)
647-
if self.compress:
650+
if self.tebd_compress:
648651
# ((ntypes+1)^2, ng)
649652
tt_full = self.type_embd_data
650653
else:
@@ -671,7 +674,7 @@ def forward(
671674
gg_t = gg_t.reshape(nfnl, nnei, ng)
672675
if self.smooth:
673676
gg_t = gg_t * sw.reshape(-1, self.nnei, 1)
674-
if self.compress:
677+
if self.geo_compress:
675678
ss = ss.reshape(-1, 1)
676679
gg_t = gg_t.reshape(-1, gg_t.size(-1))
677680
xyz_scatter = torch.ops.deepmd.tabulate_fusion_se_atten(
@@ -719,7 +722,7 @@ def forward(
719722
return (
720723
result.view(nframes, nloc, self.filter_neuron[-1] * self.axis_neuron),
721724
gg.view(nframes, nloc, self.nnei, self.filter_neuron[-1])
722-
if not self.compress
725+
if not self.geo_compress
723726
else None,
724727
dmatrix.view(nframes, nloc, self.nnei, 4)[..., 1:],
725728
rot_mat.view(nframes, nloc, self.filter_neuron[-1], 3),

deepmd/utils/argcheck.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,7 @@ def descrpt_se_atten_common_args() -> list[Argument]:
505505
doc_exclude_types = "The excluded pairs of types which have no interaction with each other. For example, `[[0, 1]]` means no interaction between type 0 and type 1."
506506
doc_env_protection = "Protection parameter to prevent division by zero errors during environment matrix calculations. For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection."
507507
doc_attn = "The length of hidden vectors in attention layers"
508-
doc_attn_layer = "The number of attention layers. Note that model compression of `se_atten` is only enabled when attn_layer==0 and tebd_input_mode=='strip'"
508+
doc_attn_layer = "The number of attention layers. Note that model compression of `se_atten` works for any attn_layer value (for pytorch backend only, for other backends, attn_layer=0 is still needed to compress) when tebd_input_mode=='strip'. When attn_layer!=0, only type embedding is compressed, geometric parts are not compressed."
509509
doc_attn_dotr = "Whether to do dot product with the normalized relative coordinates"
510510
doc_attn_mask = "Whether to do mask on the diagonal in the attention matrix"
511511

doc/model/dpa2.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ Type embedding is within this descriptor with the {ref}`tebd_dim <model[standard
3838

3939
## Model compression
4040

41-
Model compression is supported when {ref}`repinit/tebd_input_mode <model[standard]/descriptor[dpa2]/repinit/tebd_input_mode>` is `strip`, but only the `repinit` part is compressed.
41+
Model compression is supported when {ref}`repinit/tebd_input_mode <model[standard]/descriptor[dpa2]/repinit/tebd_input_mode>` is `strip`.
42+
43+
- If {ref}`repinit/attn_layer <model[standard]/descriptor[dpa2]/repinit/attn_layer>` is `0`, both the type embedding and geometric parts inside `repinit` are compressed.
44+
- If `repinit/attn_layer` is not `0`, only the type embedding tables are compressed and the geometric attention layers remain as neural networks.
45+
4246
An example is given in `examples/water/dpa2/input_torch_compressible.json`.
4347
The performance improvement will be limited if other parts are more expensive.

doc/model/train-se-atten.md

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,9 @@ You can use descriptor `"se_atten_v2"` and is not allowed to set `tebd_input_mod
134134

135135
Practical evidence demonstrates that `"se_atten_v2"` offers better and more stable performance compared to `"se_atten"`.
136136

137-
Notice: Model compression for the `se_atten_v2` descriptor is exclusively designed for models with the training parameter {ref}`attn_layer <model[standard]/descriptor[se_atten_v2]/attn_layer>` set to 0.
137+
:::{note}
138+
Model compression support differs across backends. See [Model compression](#model-compression) for backend-specific requirements.
139+
:::
138140

139141
## Type embedding
140142

@@ -182,7 +184,13 @@ DPA-1 supports both the [standard data format](../data/system.md) and the [mixed
182184

183185
## Model compression
184186

185-
Model compression is supported only when there is no attention layer (`attn_layer` is 0) and `tebd_input_mode` is `strip`.
187+
### TensorFlow {{ tensorflow_icon }}
188+
189+
Model compression is supported only when the descriptor attention depth {ref}`attn_layer <model[standard]/descriptor[se_atten]/attn_layer>` is 0 and {ref}`tebd_input_mode <model[standard]/descriptor[se_atten]/tebd_input_mode>` is `"strip"`. Attention layers higher than 0 cannot be compressed in the TensorFlow implementation because the geometric part is tabulated from the static computation graph.
190+
191+
### PyTorch {{ pytorch_icon }}
192+
193+
Model compression is supported for any {ref}`attn_layer <model[standard]/descriptor[se_atten_v2]/attn_layer>` value when {ref}`tebd_input_mode <model[standard]/descriptor[se_atten_v2]/tebd_input_mode>` is `"strip"`. When `attn_layer` is 0, both the type embedding and geometric parts are compressed. When `attn_layer` is not 0, only the type embedding is compressed while the geometric part keeps the neural network implementation (a warning is emitted during compression).
186194

187195
## Training example
188196

0 commit comments

Comments
 (0)