diff --git a/deepmd/pt/modifier/__init__.py b/deepmd/pt/modifier/__init__.py new file mode 100644 index 0000000000..bfa1540ce9 --- /dev/null +++ b/deepmd/pt/modifier/__init__.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from .base_modifier import ( + BaseModifier, +) + +__all__ = [ + "BaseModifier", +] diff --git a/deepmd/pt/modifier/base_modifier.py b/deepmd/pt/modifier/base_modifier.py new file mode 100644 index 0000000000..be9246a2ad --- /dev/null +++ b/deepmd/pt/modifier/base_modifier.py @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import torch + +from deepmd.dpmodel.modifier.base_modifier import ( + make_base_modifier, +) + + +class BaseModifier(torch.nn.Module, make_base_modifier()): + def __init__(self, *args, **kwargs) -> None: + """Construct a basic model for different tasks.""" + torch.nn.Module.__init__(self) + + def modify_data(self, data: dict) -> None: + """Modify data. + + Parameters + ---------- + data + Internal data of DeepmdData. + Be a dict, has the following keys + - coord coordinates + - box simulation box + - atype atom types + - find_energy tells if data has energy + - find_force tells if data has force + - find_virial tells if data has virial + - energy energy + - force force + - virial virial + """ + if ( + "find_energy" not in data + and "find_force" not in data + and "find_virial" not in data + ): + return + + get_nframes = None + coord = data["coord"][:get_nframes, :] + if data["box"] is None: + box = None + else: + box = data["box"][:get_nframes, :] + atype = data["atype"][:get_nframes, :] + atype = atype[0] + nframes = coord.shape[0] + + tot_e, tot_f, tot_v = self.forward(coord, atype, box, False, None, None) + + if "find_energy" in data and data["find_energy"] == 1.0: + data["energy"] -= tot_e.reshape(data["energy"].shape) + if "find_force" in data and data["find_force"] == 1.0: + data["force"] -= tot_f.reshape(data["force"].shape) + if "find_virial" in data and data["find_virial"] == 1.0: + data["virial"] -= tot_v.reshape(data["virial"].shape) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 7a6ff0ebde..485a4a2079 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -39,6 +39,9 @@ get_model, get_zbl_model, ) +from deepmd.pt.modifier import ( + BaseModifier, +) from deepmd.pt.optimizer import ( KFOptimizerWrapper, LKFOptimizer, @@ -136,6 +139,16 @@ def __init__( ) self.num_model = len(self.model_keys) + # modifier for the training data + modifier_params = model_params.get("modifier", None) + if modifier_params is not None: + assert self.multi_task is False, ( + "Modifier is not supported for multi-task training!" + ) + self.modifier = get_data_modifier(modifier_params) + else: + self.modifier = None + # Iteration config self.num_steps = training_params["numb_steps"] self.disp_file = training_params.get("disp_file", "lcurve.out") @@ -233,12 +246,26 @@ def single_model_stat( _stat_file_path, _data_requirement, finetune_has_new_type=False, + modifier=None, ): _data_requirement += get_additional_data_requirement(_model) _training_data.add_data_requirement(_data_requirement) if _validation_data is not None: _validation_data.add_data_requirement(_data_requirement) + # modify data + if modifier is not None: + log.info(f"Using {modifier.modifier_type} as data modifier") + for _data in [_training_data, _validation_data]: + if _data is not None: + all_sampled = make_stat_input( + _data.systems, + _data.dataloaders, + -1, + ) + for sampled in all_sampled: + modifier.modify_data(sampled) + @functools.lru_cache def get_sample(): sampled = make_stat_input( @@ -333,6 +360,7 @@ def get_lr(lr_params): finetune_has_new_type=self.finetune_links["Default"].get_has_new_type() if self.finetune_links is not None else False, + modifier=self.modifier, ) ( self.training_dataloader, @@ -371,6 +399,7 @@ def get_lr(lr_params): ].get_has_new_type() if self.finetune_links is not None else False, + modifier=self.modifier, ) ( self.training_dataloader[model_key], @@ -1079,10 +1108,13 @@ def save_model(self, save_path, lr=0.0, step=0) -> None: optim_state_dict = deepcopy(self.optimizer.state_dict()) for item in optim_state_dict["param_groups"]: item["lr"] = float(item["lr"]) - torch.save( - {"model": module.state_dict(), "optimizer": optim_state_dict}, - save_path, - ) + save_dict = { + "model": module.state_dict(), + "optimizer": optim_state_dict, + } + if self.modifier is not None: + save_dict["data_modifier"] = self.modifier.state_dict() + torch.save(save_dict, save_path) checkpoint_dir = save_path.parent checkpoint_files = [ f @@ -1355,3 +1387,16 @@ def model_change_out_bias( f"to {to_numpy_array(new_bias).reshape(-1)!s}." ) return _model + + +def get_data_modifier(_modifier_params: dict[str, Any]): + modifier_params = deepcopy(_modifier_params) + try: + modifier_type = modifier_params.pop("type") + except KeyError: + raise ValueError("Data modifier type not specified!") from None + return ( + BaseModifier.get_class_by_type(modifier_type) + .get_modifier(modifier_params) + .to(DEVICE) + ) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index cf6892b49d..bb265a6b0d 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -47,12 +47,16 @@ def make_stat_input(datasets, dataloaders, nbatches): - a list of dicts, each of which contains data from a system """ lst = [] - log.info(f"Packing data for statistics from {len(datasets)} systems") + if nbatches > 0: + log.info(f"Packing data for statistics from {len(datasets)} systems") for i in range(len(datasets)): sys_stat = {} with torch.device("cpu"): iterator = iter(dataloaders[i]) - numb_batches = min(nbatches, len(dataloaders[i])) + if nbatches == -1: + numb_batches = len(dataloaders[i]) + else: + numb_batches = min(nbatches, len(dataloaders[i])) for _ in range(numb_batches): try: stat_data = next(iterator)