Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 55 additions & 28 deletions deepmd/pt/model/descriptor/se_t_tebd.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,8 @@ def enable_compression(
assert not self.se_ttebd.resnet_dt, (
"Model compression error: descriptor resnet_dt must be false!"
)
if self.tebd_input_mode != "strip":
raise RuntimeError("Cannot compress model when tebd_input_mode != 'strip'")
for tt in self.se_ttebd.exclude_types:
if (tt[0] not in range(self.se_ttebd.ntypes)) or (
tt[1] not in range(self.se_ttebd.ntypes)
Expand All @@ -573,9 +575,6 @@ def enable_compression(
"Empty embedding-nets are not supported in model compression!"
)

if self.tebd_input_mode != "strip":
raise RuntimeError("Cannot compress model when tebd_input_mode == 'concat'")

data = self.serialize()
self.table = DPTabulate(
self,
Expand All @@ -597,7 +596,11 @@ def enable_compression(
)

self.se_ttebd.enable_compression(
self.table.data, self.table_config, self.lower, self.upper
self.type_embedding,
self.table.data,
self.table_config,
self.lower,
self.upper,
)
self.compress = True

Expand Down Expand Up @@ -694,12 +697,17 @@ def __init__(
self.stats = None
# compression related variables
self.compress = False
# For geometric compression
self.compress_info = nn.ParameterList(
[nn.Parameter(torch.zeros(0, dtype=self.prec, device="cpu"))]
)
self.compress_data = nn.ParameterList(
[nn.Parameter(torch.zeros(0, dtype=self.prec, device=env.DEVICE))]
)
# For type embedding compression
self.register_buffer(
"type_embd_data", torch.zeros(0, dtype=self.prec, device=env.DEVICE)
)

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
Expand Down Expand Up @@ -986,31 +994,24 @@ def forward(
nei_type_j = nei_type.unsqueeze(1).expand([-1, nnei, -1])
idx_i = nei_type_i * ntypes_with_padding
idx_j = nei_type_j
# (nf x nl x nt_i x nt_j) x ng
idx = (
(idx_i + idx_j)
.view(-1, 1)
.expand(-1, ng)
.type(torch.long)
.to(torch.long)
)
# ntypes * (ntypes) * nt
type_embedding_i = torch.tile(
type_embedding.view(ntypes_with_padding, 1, nt),
[1, ntypes_with_padding, 1],
)
# (ntypes) * ntypes * nt
type_embedding_j = torch.tile(
type_embedding.view(1, ntypes_with_padding, nt),
[ntypes_with_padding, 1, 1],
)
# (ntypes * ntypes) * (nt+nt)
two_side_type_embedding = torch.cat(
[type_embedding_i, type_embedding_j], -1
).reshape(-1, nt * 2)
tt_full = self.filter_layers_strip.networks[0](two_side_type_embedding)
idx = (idx_i + idx_j).reshape(-1).to(torch.long)
if self.compress:
tt_full = self.type_embd_data
else:
type_embedding_i = torch.tile(
type_embedding.view(ntypes_with_padding, 1, nt),
[1, ntypes_with_padding, 1],
)
type_embedding_j = torch.tile(
type_embedding.view(1, ntypes_with_padding, nt),
[ntypes_with_padding, 1, 1],
)
two_side_type_embedding = torch.cat(
[type_embedding_i, type_embedding_j], -1
).reshape(-1, nt * 2)
tt_full = self.filter_layers_strip.networks[0](two_side_type_embedding)
# (nfnl x nt_i x nt_j) x ng
gg_t = torch.gather(tt_full, dim=0, index=idx)
gg_t = tt_full[idx]
# (nfnl x nt_i x nt_j) x ng
gg_t = gg_t.reshape(nfnl, nnei, nnei, ng)
if self.smooth:
Expand Down Expand Up @@ -1042,6 +1043,7 @@ def forward(

def enable_compression(
self,
type_embedding_net: TypeEmbedNet,
table_data: dict,
table_config: dict,
lower: dict,
Expand All @@ -1051,6 +1053,8 @@ def enable_compression(

Parameters
----------
type_embedding_net : TypeEmbedNet
The type embedding network
table_data : dict
The tabulated data from DPTabulate
table_config : dict
Expand All @@ -1060,6 +1064,13 @@ def enable_compression(
upper : dict
Upper bounds for compression
"""
if self.tebd_input_mode != "strip":
raise RuntimeError("Type embedding compression only works in strip mode")
if self.filter_layers_strip is None:
raise RuntimeError(
"filter_layers_strip must exist for type embedding compression"
)

# Compress the main geometric embedding network (self.filter_layers)
net_key = "filter_net"
self.compress_info[0] = torch.as_tensor(
Expand All @@ -1078,6 +1089,22 @@ def enable_compression(
device=env.DEVICE, dtype=self.prec
)

# Compress the type embedding network (self.filter_layers_strip)
with torch.no_grad():
full_embd = type_embedding_net.get_full_embedding(env.DEVICE)
nt, t_dim = full_embd.shape
type_embedding_i = full_embd.view(nt, 1, t_dim).expand(nt, nt, t_dim)
type_embedding_j = full_embd.view(1, nt, t_dim).expand(nt, nt, t_dim)
two_side_type_embedding = torch.cat(
[type_embedding_i, type_embedding_j], dim=-1
).reshape(-1, t_dim * 2)
embd_tensor = self.filter_layers_strip.networks[0](
two_side_type_embedding
).detach()
if hasattr(self, "type_embd_data"):
del self.type_embd_data
self.register_buffer("type_embd_data", embd_tensor)

self.compress = True

def has_message_passing(self) -> bool:
Expand Down
Loading