Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions deepmd/pt/modifier/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from .base_modifier import (
BaseModifier,
)

__all__ = [
"BaseModifier",
]
56 changes: 56 additions & 0 deletions deepmd/pt/modifier/base_modifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import torch

from deepmd.dpmodel.modifier.base_modifier import (
make_base_modifier,
)


class BaseModifier(torch.nn.Module, make_base_modifier()):
def __init__(self, *args, **kwargs) -> None:
"""Construct a basic model for different tasks."""
torch.nn.Module.__init__(self)

Check warning on line 12 in deepmd/pt/modifier/base_modifier.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/modifier/base_modifier.py#L12

Added line #L12 was not covered by tests

def modify_data(self, data: dict) -> None:
"""Modify data.
Parameters
----------
data
Internal data of DeepmdData.
Be a dict, has the following keys
- coord coordinates
- box simulation box
- atype atom types
- find_energy tells if data has energy
- find_force tells if data has force
- find_virial tells if data has virial
- energy energy
- force force
- virial virial
"""
if (

Check warning on line 32 in deepmd/pt/modifier/base_modifier.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/modifier/base_modifier.py#L32

Added line #L32 was not covered by tests
"find_energy" not in data
and "find_force" not in data
and "find_virial" not in data
):
return

Check warning on line 37 in deepmd/pt/modifier/base_modifier.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/modifier/base_modifier.py#L37

Added line #L37 was not covered by tests

get_nframes = None
coord = data["coord"][:get_nframes, :]
if data["box"] is None:
box = None

Check warning on line 42 in deepmd/pt/modifier/base_modifier.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/modifier/base_modifier.py#L39-L42

Added lines #L39 - L42 were not covered by tests
else:
box = data["box"][:get_nframes, :]
atype = data["atype"][:get_nframes, :]
atype = atype[0]
nframes = coord.shape[0]

Check warning on line 47 in deepmd/pt/modifier/base_modifier.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/modifier/base_modifier.py#L44-L47

Added lines #L44 - L47 were not covered by tests

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable nframes is not used.

tot_e, tot_f, tot_v = self.forward(coord, atype, box, False, None, None)

Check warning on line 49 in deepmd/pt/modifier/base_modifier.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/modifier/base_modifier.py#L49

Added line #L49 was not covered by tests

if "find_energy" in data and data["find_energy"] == 1.0:
data["energy"] -= tot_e.reshape(data["energy"].shape)
if "find_force" in data and data["find_force"] == 1.0:
data["force"] -= tot_f.reshape(data["force"].shape)
if "find_virial" in data and data["find_virial"] == 1.0:
data["virial"] -= tot_v.reshape(data["virial"].shape)

Check warning on line 56 in deepmd/pt/modifier/base_modifier.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/modifier/base_modifier.py#L51-L56

Added lines #L51 - L56 were not covered by tests
53 changes: 49 additions & 4 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
get_model,
get_zbl_model,
)
from deepmd.pt.modifier import (
BaseModifier,
)
from deepmd.pt.optimizer import (
KFOptimizerWrapper,
LKFOptimizer,
Expand Down Expand Up @@ -136,6 +139,16 @@
)
self.num_model = len(self.model_keys)

# modifier for the training data
modifier_params = model_params.get("modifier", None)
if modifier_params is not None:
assert self.multi_task is False, (

Check warning on line 145 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L145

Added line #L145 was not covered by tests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NotImplementedError is preferred to assert.

"Modifier is not supported for multi-task training!"
)
self.modifier = get_data_modifier(modifier_params)

Check warning on line 148 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L148

Added line #L148 was not covered by tests
else:
self.modifier = None

# Iteration config
self.num_steps = training_params["numb_steps"]
self.disp_file = training_params.get("disp_file", "lcurve.out")
Expand Down Expand Up @@ -233,12 +246,26 @@
_stat_file_path,
_data_requirement,
finetune_has_new_type=False,
modifier=None,
):
_data_requirement += get_additional_data_requirement(_model)
_training_data.add_data_requirement(_data_requirement)
if _validation_data is not None:
_validation_data.add_data_requirement(_data_requirement)

# modify data
if modifier is not None:
log.info(f"Using {modifier.modifier_type} as data modifier")
for _data in [_training_data, _validation_data]:
if _data is not None:
all_sampled = make_stat_input(

Check warning on line 261 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L258-L261

Added lines #L258 - L261 were not covered by tests
_data.systems,
_data.dataloaders,
-1,
)
for sampled in all_sampled:
modifier.modify_data(sampled)

Check warning on line 267 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L266-L267

Added lines #L266 - L267 were not covered by tests
Comment on lines +256 to +267
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that this modification only affects the data statistics. It still uses the original data in get_data for actual training.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. modifier.modify_data(sampled) in line 267 has directly modify the training data sampled.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sampled data is only used for data statistics, while it uses the original data from get_data function in Line 1092.


@functools.lru_cache
def get_sample():
sampled = make_stat_input(
Expand Down Expand Up @@ -333,6 +360,7 @@
finetune_has_new_type=self.finetune_links["Default"].get_has_new_type()
if self.finetune_links is not None
else False,
modifier=self.modifier,
)
(
self.training_dataloader,
Expand Down Expand Up @@ -371,6 +399,7 @@
].get_has_new_type()
if self.finetune_links is not None
else False,
modifier=self.modifier,
)
(
self.training_dataloader[model_key],
Expand Down Expand Up @@ -1079,10 +1108,13 @@
optim_state_dict = deepcopy(self.optimizer.state_dict())
for item in optim_state_dict["param_groups"]:
item["lr"] = float(item["lr"])
torch.save(
{"model": module.state_dict(), "optimizer": optim_state_dict},
save_path,
)
save_dict = {
"model": module.state_dict(),
"optimizer": optim_state_dict,
}
if self.modifier is not None:
save_dict["data_modifier"] = self.modifier.state_dict()

Check warning on line 1116 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L1116

Added line #L1116 was not covered by tests
Comment on lines +1115 to +1116
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see this is saved, but how is it recovered?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I plan to restore it by adapting wrapper. Do you have any other idea for this?

torch.save(save_dict, save_path)
checkpoint_dir = save_path.parent
checkpoint_files = [
f
Expand Down Expand Up @@ -1355,3 +1387,16 @@
f"to {to_numpy_array(new_bias).reshape(-1)!s}."
)
return _model


def get_data_modifier(_modifier_params: dict[str, Any]):
modifier_params = deepcopy(_modifier_params)
try:
modifier_type = modifier_params.pop("type")
except KeyError:
raise ValueError("Data modifier type not specified!") from None
return (

Check warning on line 1398 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L1393-L1398

Added lines #L1393 - L1398 were not covered by tests
BaseModifier.get_class_by_type(modifier_type)
.get_modifier(modifier_params)
.to(DEVICE)
)
8 changes: 6 additions & 2 deletions deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,16 @@
- a list of dicts, each of which contains data from a system
"""
lst = []
log.info(f"Packing data for statistics from {len(datasets)} systems")
if nbatches > 0:
log.info(f"Packing data for statistics from {len(datasets)} systems")
for i in range(len(datasets)):
sys_stat = {}
with torch.device("cpu"):
iterator = iter(dataloaders[i])
numb_batches = min(nbatches, len(dataloaders[i]))
if nbatches == -1:
numb_batches = len(dataloaders[i])

Check warning on line 57 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L57

Added line #L57 was not covered by tests
else:
numb_batches = min(nbatches, len(dataloaders[i]))
for _ in range(numb_batches):
try:
stat_data = next(iterator)
Expand Down
Loading