-
Notifications
You must be signed in to change notification settings - Fork 579
feat(pt): add plugin for data modifier #4661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: devel
Are you sure you want to change the base?
Changes from all commits
134b36a
b60f3a8
ce81c22
a42c7a4
41606ae
8c86290
f9c5d53
a85908e
148a7a8
502b204
3487689
af66d73
2ea5e48
e6bca31
5369e80
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| # SPDX-License-Identifier: LGPL-3.0-or-later | ||
| from .base_modifier import ( | ||
| BaseModifier, | ||
| ) | ||
|
|
||
| __all__ = [ | ||
| "BaseModifier", | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| # SPDX-License-Identifier: LGPL-3.0-or-later | ||
| import torch | ||
|
|
||
| from deepmd.dpmodel.modifier.base_modifier import ( | ||
| make_base_modifier, | ||
| ) | ||
|
|
||
|
|
||
| class BaseModifier(torch.nn.Module, make_base_modifier()): | ||
| def __init__(self, *args, **kwargs) -> None: | ||
| """Construct a basic model for different tasks.""" | ||
| torch.nn.Module.__init__(self) | ||
|
|
||
| def modify_data(self, data: dict) -> None: | ||
| """Modify data. | ||
| Parameters | ||
| ---------- | ||
| data | ||
| Internal data of DeepmdData. | ||
| Be a dict, has the following keys | ||
| - coord coordinates | ||
| - box simulation box | ||
| - atype atom types | ||
| - find_energy tells if data has energy | ||
| - find_force tells if data has force | ||
| - find_virial tells if data has virial | ||
| - energy energy | ||
| - force force | ||
| - virial virial | ||
| """ | ||
| if ( | ||
| "find_energy" not in data | ||
| and "find_force" not in data | ||
| and "find_virial" not in data | ||
| ): | ||
| return | ||
|
|
||
| get_nframes = None | ||
| coord = data["coord"][:get_nframes, :] | ||
| if data["box"] is None: | ||
| box = None | ||
| else: | ||
| box = data["box"][:get_nframes, :] | ||
| atype = data["atype"][:get_nframes, :] | ||
| atype = atype[0] | ||
| nframes = coord.shape[0] | ||
|
|
||
| tot_e, tot_f, tot_v = self.forward(coord, atype, box, False, None, None) | ||
|
|
||
| if "find_energy" in data and data["find_energy"] == 1.0: | ||
| data["energy"] -= tot_e.reshape(data["energy"].shape) | ||
| if "find_force" in data and data["find_force"] == 1.0: | ||
| data["force"] -= tot_f.reshape(data["force"].shape) | ||
| if "find_virial" in data and data["find_virial"] == 1.0: | ||
| data["virial"] -= tot_v.reshape(data["virial"].shape) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -39,6 +39,9 @@ | |
| get_model, | ||
| get_zbl_model, | ||
| ) | ||
| from deepmd.pt.modifier import ( | ||
| BaseModifier, | ||
| ) | ||
| from deepmd.pt.optimizer import ( | ||
| KFOptimizerWrapper, | ||
| LKFOptimizer, | ||
|
|
@@ -136,6 +139,16 @@ | |
| ) | ||
| self.num_model = len(self.model_keys) | ||
|
|
||
| # modifier for the training data | ||
| modifier_params = model_params.get("modifier", None) | ||
| if modifier_params is not None: | ||
| assert self.multi_task is False, ( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| "Modifier is not supported for multi-task training!" | ||
| ) | ||
| self.modifier = get_data_modifier(modifier_params) | ||
| else: | ||
| self.modifier = None | ||
|
|
||
| # Iteration config | ||
| self.num_steps = training_params["numb_steps"] | ||
| self.disp_file = training_params.get("disp_file", "lcurve.out") | ||
|
|
@@ -233,12 +246,26 @@ | |
| _stat_file_path, | ||
| _data_requirement, | ||
| finetune_has_new_type=False, | ||
| modifier=None, | ||
| ): | ||
| _data_requirement += get_additional_data_requirement(_model) | ||
| _training_data.add_data_requirement(_data_requirement) | ||
| if _validation_data is not None: | ||
| _validation_data.add_data_requirement(_data_requirement) | ||
|
|
||
| # modify data | ||
| if modifier is not None: | ||
| log.info(f"Using {modifier.modifier_type} as data modifier") | ||
| for _data in [_training_data, _validation_data]: | ||
| if _data is not None: | ||
| all_sampled = make_stat_input( | ||
| _data.systems, | ||
| _data.dataloaders, | ||
| -1, | ||
| ) | ||
| for sampled in all_sampled: | ||
| modifier.modify_data(sampled) | ||
|
Comment on lines
+256
to
+267
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I noticed that this modification only affects the data statistics. It still uses the original data in
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think so.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
|
|
||
| @functools.lru_cache | ||
| def get_sample(): | ||
| sampled = make_stat_input( | ||
|
|
@@ -333,6 +360,7 @@ | |
| finetune_has_new_type=self.finetune_links["Default"].get_has_new_type() | ||
| if self.finetune_links is not None | ||
| else False, | ||
| modifier=self.modifier, | ||
| ) | ||
| ( | ||
| self.training_dataloader, | ||
|
|
@@ -371,6 +399,7 @@ | |
| ].get_has_new_type() | ||
| if self.finetune_links is not None | ||
| else False, | ||
| modifier=self.modifier, | ||
| ) | ||
| ( | ||
| self.training_dataloader[model_key], | ||
|
|
@@ -1079,10 +1108,13 @@ | |
| optim_state_dict = deepcopy(self.optimizer.state_dict()) | ||
| for item in optim_state_dict["param_groups"]: | ||
| item["lr"] = float(item["lr"]) | ||
| torch.save( | ||
| {"model": module.state_dict(), "optimizer": optim_state_dict}, | ||
| save_path, | ||
| ) | ||
| save_dict = { | ||
| "model": module.state_dict(), | ||
| "optimizer": optim_state_dict, | ||
| } | ||
| if self.modifier is not None: | ||
| save_dict["data_modifier"] = self.modifier.state_dict() | ||
|
Comment on lines
+1115
to
+1116
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see this is saved, but how is it recovered?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I plan to restore it by adapting wrapper. Do you have any other idea for this? |
||
| torch.save(save_dict, save_path) | ||
| checkpoint_dir = save_path.parent | ||
| checkpoint_files = [ | ||
| f | ||
|
|
@@ -1355,3 +1387,16 @@ | |
| f"to {to_numpy_array(new_bias).reshape(-1)!s}." | ||
| ) | ||
| return _model | ||
|
|
||
|
|
||
| def get_data_modifier(_modifier_params: dict[str, Any]): | ||
| modifier_params = deepcopy(_modifier_params) | ||
| try: | ||
| modifier_type = modifier_params.pop("type") | ||
| except KeyError: | ||
| raise ValueError("Data modifier type not specified!") from None | ||
| return ( | ||
| BaseModifier.get_class_by_type(modifier_type) | ||
| .get_modifier(modifier_params) | ||
| .to(DEVICE) | ||
| ) | ||
Check notice
Code scanning / CodeQL
Unused local variable Note