From c3e9ce97d693b3def7ff27377085125d3bbda144 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 25 May 2025 12:53:37 +0800 Subject: [PATCH 01/32] checkpoint Signed-off-by: Jinzhe Zeng --- deepmd/backend/jax.py | 6 +- deepmd/dpmodel/loss/ener.py | 95 +++++++------ deepmd/jax/entrypoints/__init__.py | 1 + deepmd/jax/entrypoints/main.py | 54 ++++++++ deepmd/jax/entrypoints/train.py | 163 +++++++++++++++++++++++ deepmd/jax/train/__init__.py | 1 + deepmd/jax/train/trainer.py | 207 +++++++++++++++++++++++++++++ 7 files changed, 476 insertions(+), 51 deletions(-) create mode 100644 deepmd/jax/entrypoints/__init__.py create mode 100644 deepmd/jax/entrypoints/main.py create mode 100644 deepmd/jax/entrypoints/train.py create mode 100644 deepmd/jax/train/__init__.py create mode 100644 deepmd/jax/train/trainer.py diff --git a/deepmd/backend/jax.py b/deepmd/backend/jax.py index 7a714c2090..761b4f552f 100644 --- a/deepmd/backend/jax.py +++ b/deepmd/backend/jax.py @@ -60,7 +60,11 @@ def entry_point_hook(self) -> Callable[["Namespace"], None]: Callable[[Namespace], None] The entry point hook of the backend. """ - raise NotImplementedError + from deepmd.jax.entrypoints.main import ( + main, + ) + + return main @property def deep_eval(self) -> type["DeepEvalBackend"]: diff --git a/deepmd/dpmodel/loss/ener.py b/deepmd/dpmodel/loss/ener.py index f119fd6050..44470dfa73 100644 --- a/deepmd/dpmodel/loss/ener.py +++ b/deepmd/dpmodel/loss/ener.py @@ -93,10 +93,10 @@ def call( label_dict: dict[str, np.ndarray], ) -> dict[str, np.ndarray]: """Calculate loss from model results and labeled results.""" - energy = model_dict["energy"] - force = model_dict["force"] - virial = model_dict["virial"] - atom_ener = model_dict["atom_ener"] + energy = model_dict["energy_redu"] + force = model_dict["energy_derv_r"] + virial = model_dict["energy_derv_c_redu"] + atom_ener = model_dict["energy"] energy_hat = label_dict["energy"] force_hat = label_dict["force"] virial_hat = label_dict["virial"] @@ -268,57 +268,52 @@ def call( def label_requirement(self) -> list[DataRequirementItem]: """Return data label requirements needed for this loss calculation.""" label_requirement = [] - if self.has_e: - label_requirement.append( - DataRequirementItem( - "energy", - ndof=1, - atomic=False, - must=False, - high_prec=True, - ) + label_requirement.append( + DataRequirementItem( + "energy", + ndof=1, + atomic=False, + must=False, + high_prec=True, ) - if self.has_f: - label_requirement.append( - DataRequirementItem( - "force", - ndof=3, - atomic=True, - must=False, - high_prec=False, - ) + ) + label_requirement.append( + DataRequirementItem( + "force", + ndof=3, + atomic=True, + must=False, + high_prec=False, ) - if self.has_v: - label_requirement.append( - DataRequirementItem( - "virial", - ndof=9, - atomic=False, - must=False, - high_prec=False, - ) + ) + label_requirement.append( + DataRequirementItem( + "virial", + ndof=9, + atomic=False, + must=False, + high_prec=False, ) - if self.has_ae: - label_requirement.append( - DataRequirementItem( - "atom_ener", - ndof=1, - atomic=True, - must=False, - high_prec=False, - ) + ) + label_requirement.append( + DataRequirementItem( + "atom_ener", + ndof=1, + atomic=True, + must=False, + high_prec=False, ) - if self.has_pf: - label_requirement.append( - DataRequirementItem( - "atom_pref", - ndof=1, - atomic=True, - must=False, - high_prec=False, - repeat=3, - ) + ) + label_requirement.append( + DataRequirementItem( + "atom_pref", + ndof=1, + atomic=True, + must=False, + high_prec=False, + repeat=3, ) + ) if self.has_gf > 0: label_requirement.append( DataRequirementItem( diff --git a/deepmd/jax/entrypoints/__init__.py b/deepmd/jax/entrypoints/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/deepmd/jax/entrypoints/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/deepmd/jax/entrypoints/main.py b/deepmd/jax/entrypoints/main.py new file mode 100644 index 0000000000..acaf30e664 --- /dev/null +++ b/deepmd/jax/entrypoints/main.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""DeePMD-Kit entry point module.""" + +import argparse +from typing import ( + Optional, + Union, +) + +from deepmd.backend.suffix import ( + format_model_suffix, +) +from deepmd.jax.entrypoints.train import ( + train, +) +from deepmd.main import ( + parse_args, +) + +__all__ = ["main"] + + +def main(args: Optional[Union[list[str], argparse.Namespace]] = None) -> None: + """DeePMD-Kit entry point. + + Parameters + ---------- + args : list[str] or argparse.Namespace, optional + list of command line arguments, used to avoid calling from the subprocess, + as it is quite slow to import tensorflow; if Namespace is given, it will + be used directly + + Raises + ------ + RuntimeError + if no command was input + """ + if not isinstance(args, argparse.Namespace): + args = parse_args(args=args) + + dict_args = vars(args) + + if args.command == "train": + train(**dict_args) + elif args.command == "freeze": + raise + dict_args["output"] = format_model_suffix( + dict_args["output"], preferred_backend=args.backend, strict_prefer=True + ) + # freeze(**dict_args) + elif args.command is None: + pass + else: + raise RuntimeError(f"unknown command {args.command}") diff --git a/deepmd/jax/entrypoints/train.py b/deepmd/jax/entrypoints/train.py new file mode 100644 index 0000000000..d1d47ef99a --- /dev/null +++ b/deepmd/jax/entrypoints/train.py @@ -0,0 +1,163 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""DeePMD training entrypoint script. + +Can handle local or distributed training. +""" + +import json +import logging +import time +from typing import ( + Optional, +) + +from deepmd.common import ( + j_loader, +) +from deepmd.jax.train.trainer import ( + DPTrainer, +) +from deepmd.utils.argcheck import ( + normalize, +) +from deepmd.utils.compat import ( + update_deepmd_input, +) +from deepmd.utils.data_system import ( + get_data, +) + +__all__ = ["train"] + +log = logging.getLogger(__name__) + + +def train( + *, + INPUT: str, + init_model: Optional[str], + restart: Optional[str], + output: str, + init_frz_model: str, + mpi_log: str, + log_level: int, + log_path: Optional[str], + skip_neighbor_stat: bool = False, + finetune: Optional[str] = None, + use_pretrain_script: bool = False, + **kwargs, +) -> None: + """Run DeePMD model training. + + Parameters + ---------- + INPUT : str + json/yaml control file + init_model : Optional[str] + path prefix of checkpoint files or None + restart : Optional[str] + path prefix of checkpoint files or None + output : str + path for dump file with arguments + init_frz_model : str + path to frozen model or None + mpi_log : str + mpi logging mode + log_level : int + logging level defined by int 0-3 + log_path : Optional[str] + logging file path or None if logs are to be output only to stdout + skip_neighbor_stat : bool, default=False + skip checking neighbor statistics + finetune : Optional[str] + path to pretrained model or None + use_pretrain_script : bool + Whether to use model script in pretrained model when doing init-model or init-frz-model. + Note that this option is true and unchangeable for fine-tuning. + **kwargs + additional arguments + + Raises + ------ + RuntimeError + if distributed training job name is wrong + """ + # load json database + jdata = j_loader(INPUT) + + origin_type_map = None + + jdata = update_deepmd_input(jdata, warning=True, dump="input_v2_compat.json") + + jdata = normalize(jdata) + jdata = update_sel(jdata) + + with open(output, "w") as fp: + json.dump(jdata, fp, indent=4) + # print_resource_summary() + + # make necessary checks + assert "training" in jdata + + # init the model + + model = DPTrainer(jdata) + rcut = model.model.get_rcut() + type_map = model.model.get_type_map() + if len(type_map) == 0: + ipt_type_map = None + else: + ipt_type_map = type_map + + # init random seed of data systems + seed = jdata["training"].get("seed", None) + + # init data + train_data = get_data(jdata["training"]["training_data"], rcut, ipt_type_map, None) + train_data.add_data_requirements(model.data_requirements) + train_data.print_summary("training") + if jdata["training"].get("validation_data", None) is not None: + valid_data = get_data( + jdata["training"]["validation_data"], + rcut, + train_data.type_map, + None, + ) + valid_data.add_data_requirements(model.data_requirements) + valid_data.print_summary("validation") + + # get training info + stop_batch = jdata["training"]["numb_steps"] + origin_type_map = jdata["model"].get("origin_type_map", None) + if ( + origin_type_map is not None and not origin_type_map + ): # get the type_map from data if not provided + origin_type_map = get_data( + jdata["training"]["training_data"], rcut, None, None + ).get_type_map() + + # train the model with the provided systems in a cyclic way + start_time = time.time() + model.train(train_data, valid_data) + end_time = time.time() + log.info("finished training") + log.info(f"wall time: {(end_time - start_time):.3f} s") + + +def update_sel(jdata): + log.info( + "Calculate neighbor statistics... (add --skip-neighbor-stat to skip this step)" + ) + jdata_cpy = jdata.copy() + type_map = jdata["model"].get("type_map") + train_data = get_data( + jdata["training"]["training_data"], + 0, # not used + type_map, + None, # not used + ) + # TODO: OOM, need debug + # jdata_cpy["model"], min_nbor_dist = BaseModel.update_sel( + # train_data, type_map, jdata["model"] + # ) + return jdata_cpy diff --git a/deepmd/jax/train/__init__.py b/deepmd/jax/train/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/deepmd/jax/train/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/deepmd/jax/train/trainer.py b/deepmd/jax/train/trainer.py new file mode 100644 index 0000000000..31ff266371 --- /dev/null +++ b/deepmd/jax/train/trainer.py @@ -0,0 +1,207 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: LGPL-3.0-or-later +import logging +from typing import ( + Optional, +) + +import numpy as np +import optax + +from deepmd.dpmodel.loss.ener import ( + EnergyLoss, +) +from deepmd.dpmodel.model.transform_output import ( + communicate_extended_output, +) +from deepmd.dpmodel.utils.learning_rate import ( + LearningRateExp, +) +from deepmd.dpmodel.utils.nlist import ( + build_neighbor_list, + extend_coord_with_ghosts, +) +from deepmd.dpmodel.utils.region import ( + normalize_coord, +) +from deepmd.jax.env import ( + jax, + jnp, + nnx, +) +from deepmd.jax.model.model import ( + get_model, +) +from deepmd.utils.data import ( + DataRequirementItem, +) + +log = logging.getLogger(__name__) + + +class DPTrainer: + def __init__(self, jdata) -> None: + self.model = get_model(jdata["model"]) + self.training_param = jdata["training"] + self.num_steps = self.training_param["numb_steps"] + + def get_lr_and_coef(lr_param): + lr_type = lr_param.get("type", "exp") + if lr_type == "exp": + lr = LearningRateExp( + lr_param["start_lr"], + lr_param["stop_lr"], + lr_param["decay_steps"], + self.num_steps, + ) + else: + raise RuntimeError("unknown learning_rate type " + lr_type) + return lr + + learning_rate_param = jdata["learning_rate"] + self.lr = get_lr_and_coef(learning_rate_param) + loss_param = jdata.get("loss", {}) + loss_param["starter_learning_rate"] = learning_rate_param["start_lr"] + self.loss = EnergyLoss.get_loss(loss_param) + + # training + tr_data = jdata["training"] + self.disp_file = tr_data.get("disp_file", "lcurve.out") + self.disp_freq = tr_data.get("disp_freq", 1000) + self.save_freq = tr_data.get("save_freq", 1000) + self.save_ckpt = tr_data.get("save_ckpt", "model.ckpt") + self.max_ckpt_keep = tr_data.get("max_ckpt_keep", 5) + self.display_in_training = tr_data.get("disp_training", True) + self.timing_in_training = tr_data.get("time_training", True) + self.profiling = tr_data.get("profiling", False) + self.profiling_file = tr_data.get("profiling_file", "timeline.json") + self.enable_profiler = tr_data.get("enable_profiler", False) + self.tensorboard = tr_data.get("tensorboard", False) + self.tensorboard_log_dir = tr_data.get("tensorboard_log_dir", "log") + self.tensorboard_freq = tr_data.get("tensorboard_freq", 1) + self.mixed_prec = tr_data.get("mixed_precision", None) + self.change_bias_after_training = tr_data.get( + "change_bias_after_training", False + ) + self.numb_fparam = self.model.get_dim_fparam() + + if tr_data.get("validation_data", None) is not None: + self.valid_numb_batch = tr_data["validation_data"].get("numb_btch", 1) + else: + self.valid_numb_batch = 1 + + # if init the graph with the frozen model + self.frz_model = None + self.ckpt_meta = None + self.model_type = None + + @property + def data_requirements(self) -> list[DataRequirementItem]: + return self.loss.label_requirement + + def train(self, train_data, valid_data=None) -> None: + optimizer = nnx.Optimizer(self.model, optax.adam(1e-3)) # reference sharing + + def train_step( + optimizer, + lr, + label_dict, + extended_coord, + extended_atype, + nlist, + mapping, + fp, + ap, + ): + def loss_fn(model): + model_dict_lower = jax.jit(model.call_lower)( + extended_coord, + extended_atype, + nlist, + mapping, + fp, + ap, + ) + model_dict = communicate_extended_output( + model_dict_lower, + model.model_output_def(), + mapping, + do_atomic_virial=False, + ) + loss, more_loss = self.loss( + learning_rate=lr, + natoms=label_dict["coord"].shape[1], + model_dict=model_dict, + label_dict=label_dict, + ) + return loss + + loss, grads = nnx.value_and_grad(loss_fn)(optimizer.model) + optimizer.update(grads) + return loss + + for step in range(self.num_steps): + batch_data = train_data.get_batch() + # numpy to jax + jax_data = { + kk: jnp.asarray(vv) if not kk.startswith("find_") else bool(vv.item()) + for kk, vv in batch_data.items() + } + extended_coord, extended_atype, nlist, mapping, fp, ap = prepare_input( + rcut=self.model.get_rcut(), + sel=self.model.get_sel(), + coord=jax_data["coord"], + atype=jax_data["type"], + box=jax_data["box"] if jax_data["find_box"] else None, + fparam=jax_data.get("fparam", None), + aparam=jax_data.get("aparam", None), + ) + loss = train_step( + optimizer, + self.lr.value(step), + jax_data, + extended_coord, + extended_atype, + nlist, + mapping, + fp, + ap, + ) + # print(step, jnp.sqrt(loss)) + + +def prepare_input( + *, # enforce keyword-only arguments + rcut: float, + sel: list[int], + coord: np.ndarray, + atype: np.ndarray, + box: Optional[np.ndarray] = None, + fparam: Optional[np.ndarray] = None, + aparam: Optional[np.ndarray] = None, +): + nframes, nloc = atype.shape[:2] + cc, bb, fp, ap = coord, box, fparam, aparam + del coord, box, fparam, aparam + if bb is not None: + coord_normalized = normalize_coord( + cc.reshape(nframes, nloc, 3), + bb.reshape(nframes, 3, 3), + ) + else: + coord_normalized = cc.copy() + extended_coord, extended_atype, mapping = extend_coord_with_ghosts( + coord_normalized, atype, bb, rcut + ) + nlist = build_neighbor_list( + extended_coord, + extended_atype, + nloc, + rcut, + sel, + # types will be distinguished in the lower interface, + # so it doesn't need to be distinguished here + distinguish_types=False, + ) + extended_coord = extended_coord.reshape(nframes, -1, 3) + return extended_coord, extended_atype, nlist, mapping, fp, ap From ac5a7a31f7b383d3a56a2f7dc643b8cd34d56a31 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 26 May 2025 21:27:58 +0800 Subject: [PATCH 02/32] checkpoint --- deepmd/dpmodel/loss/loss.py | 2 +- deepmd/jax/train/trainer.py | 73 ++++++++++++++++++++++++------------- 2 files changed, 48 insertions(+), 27 deletions(-) diff --git a/deepmd/dpmodel/loss/loss.py b/deepmd/dpmodel/loss/loss.py index a297380cce..f5314ba5b0 100644 --- a/deepmd/dpmodel/loss/loss.py +++ b/deepmd/dpmodel/loss/loss.py @@ -51,7 +51,7 @@ def display_if_exist(loss: np.ndarray, find_property: float) -> np.ndarray: the loss scalar or NaN """ xp = array_api_compat.array_namespace(loss) - return loss if bool(find_property) else xp.nan + return loss #if bool(find_property) else xp.nan @classmethod def get_loss(cls, loss_params: dict) -> "Loss": diff --git a/deepmd/jax/train/trainer.py b/deepmd/jax/train/trainer.py index 31ff266371..b4f0eaa915 100644 --- a/deepmd/jax/train/trainer.py +++ b/deepmd/jax/train/trainer.py @@ -7,6 +7,7 @@ import numpy as np import optax +from tqdm import trange from deepmd.dpmodel.loss.ener import ( EnergyLoss, @@ -102,6 +103,40 @@ def data_requirements(self) -> list[DataRequirementItem]: def train(self, train_data, valid_data=None) -> None: optimizer = nnx.Optimizer(self.model, optax.adam(1e-3)) # reference sharing + def loss_fn( + model, + lr, + label_dict, + extended_coord, + extended_atype, + nlist, + mapping, + fp, + ap, + ): + model_dict_lower = self.model.call_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fp, + ap, + ) + model_dict = communicate_extended_output( + model_dict_lower, + model.model_output_def(), + mapping, + do_atomic_virial=False, + ) + loss, more_loss = self.loss( + learning_rate=lr, + natoms=label_dict["coord"].shape[1], + model_dict=model_dict, + label_dict=label_dict, + ) + return loss + + @nnx.jit def train_step( optimizer, lr, @@ -113,34 +148,20 @@ def train_step( fp, ap, ): - def loss_fn(model): - model_dict_lower = jax.jit(model.call_lower)( - extended_coord, - extended_atype, - nlist, - mapping, - fp, - ap, - ) - model_dict = communicate_extended_output( - model_dict_lower, - model.model_output_def(), - mapping, - do_atomic_virial=False, - ) - loss, more_loss = self.loss( - learning_rate=lr, - natoms=label_dict["coord"].shape[1], - model_dict=model_dict, - label_dict=label_dict, - ) - return loss - - loss, grads = nnx.value_and_grad(loss_fn)(optimizer.model) + grads = nnx.grad(loss_fn)( + optimizer.model, + lr, + label_dict, + extended_coord, + extended_atype, + nlist, + mapping, + fp, + ap, + ) optimizer.update(grads) - return loss - for step in range(self.num_steps): + for step in trange(self.num_steps): batch_data = train_data.get_batch() # numpy to jax jax_data = { From ad20de945b46b17c8f0a93a06d4ba953df671bc3 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 26 May 2025 21:49:29 +0800 Subject: [PATCH 03/32] fix(jax): make display_if_exist jit-able Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/loss/loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/dpmodel/loss/loss.py b/deepmd/dpmodel/loss/loss.py index f5314ba5b0..94f4a8fd36 100644 --- a/deepmd/dpmodel/loss/loss.py +++ b/deepmd/dpmodel/loss/loss.py @@ -51,7 +51,7 @@ def display_if_exist(loss: np.ndarray, find_property: float) -> np.ndarray: the loss scalar or NaN """ xp = array_api_compat.array_namespace(loss) - return loss #if bool(find_property) else xp.nan + return xp.where(find_property, loss, xp.nan) @classmethod def get_loss(cls, loss_params: dict) -> "Loss": From c62c3565be22bd857e7592227b71868b92a7d045 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 28 May 2025 00:30:57 +0800 Subject: [PATCH 04/32] fix(jax): workaround for "xxTracer is not a valid JAX type" Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/atomic_model/pairtab_atomic_model.py | 2 +- deepmd/dpmodel/descriptor/dpa1.py | 6 +++++- deepmd/dpmodel/descriptor/repflows.py | 6 +++++- deepmd/dpmodel/descriptor/repformers.py | 6 +++++- deepmd/dpmodel/descriptor/se_e2_a.py | 6 +++++- deepmd/dpmodel/descriptor/se_r.py | 7 ++++++- deepmd/dpmodel/descriptor/se_t.py | 6 +++++- deepmd/dpmodel/descriptor/se_t_tebd.py | 6 +++++- deepmd/dpmodel/fitting/general_fitting.py | 12 ++++++++---- deepmd/dpmodel/utils/exclude_mask.py | 6 ++++-- deepmd/dpmodel/utils/network.py | 4 ++-- 11 files changed, 51 insertions(+), 16 deletions(-) diff --git a/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py b/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py index 0c35320e7f..9d7739d5c8 100644 --- a/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py @@ -293,7 +293,7 @@ def _pair_tabulated_inter( uu -= idx table_coef = self._extract_spline_coefficient( - i_type, j_type, idx, self.tab_data, nspline + i_type, j_type, idx, self.tab_data[...], nspline ) table_coef = xp.reshape(table_coef, (nframes, nloc, nnei, 4)) ener = self._calculate_ener(table_coef, uu) diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 3b25e63fb5..51c56e9681 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -951,7 +951,11 @@ def call( xp = array_api_compat.array_namespace(nlist, coord_ext, atype_ext) # nf x nloc x nnei x 4 dmatrix, diff, sw = self.env_mat.call( - coord_ext, atype_ext, nlist, self.mean, self.stddev + coord_ext, + atype_ext, + nlist, + self.mean[...], + self.stddev[...], ) nf, nloc, nnei, _ = dmatrix.shape atype = atype_ext[:, :nloc] diff --git a/deepmd/dpmodel/descriptor/repflows.py b/deepmd/dpmodel/descriptor/repflows.py index df0b81d9d2..f9f81d35cb 100644 --- a/deepmd/dpmodel/descriptor/repflows.py +++ b/deepmd/dpmodel/descriptor/repflows.py @@ -467,7 +467,11 @@ def call( nlist = xp.where(exclude_mask, nlist, xp.full_like(nlist, -1)) # nb x nloc x nnei x 4, nb x nloc x nnei x 3, nb x nloc x nnei x 1 dmatrix, diff, sw = self.env_mat_edge.call( - coord_ext, atype_ext, nlist, self.mean, self.stddev + coord_ext, + atype_ext, + nlist, + self.mean[...], + self.stddev[...], ) # nb x nloc x nnei nlist_mask = nlist != -1 diff --git a/deepmd/dpmodel/descriptor/repformers.py b/deepmd/dpmodel/descriptor/repformers.py index a917f7a227..3d02054350 100644 --- a/deepmd/dpmodel/descriptor/repformers.py +++ b/deepmd/dpmodel/descriptor/repformers.py @@ -441,7 +441,11 @@ def call( nlist = xp.where(exclude_mask, nlist, xp.full_like(nlist, -1)) # nf x nloc x nnei x 4 dmatrix, diff, sw = self.env_mat.call( - coord_ext, atype_ext, nlist, self.mean, self.stddev + coord_ext, + atype_ext, + nlist, + self.mean[...], + self.stddev[...], ) nf, nloc, nnei, _ = dmatrix.shape # nf x nloc x nnei diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index d38955b98f..bd72d936e3 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -591,7 +591,11 @@ def call( input_dtype = coord_ext.dtype # nf x nloc x nnei x 4 rr, diff, ww = self.env_mat.call( - coord_ext, atype_ext, nlist, self.davg, self.dstd + coord_ext, + atype_ext, + nlist, + self.davg[...], + self.dstd[...], ) nf, nloc, nnei, _ = rr.shape sec = self.sel_cumsum diff --git a/deepmd/dpmodel/descriptor/se_r.py b/deepmd/dpmodel/descriptor/se_r.py index 01df14394e..5b2931b23f 100644 --- a/deepmd/dpmodel/descriptor/se_r.py +++ b/deepmd/dpmodel/descriptor/se_r.py @@ -373,7 +373,12 @@ def call( del mapping # nf x nloc x nnei x 1 rr, diff, ww = self.env_mat.call( - coord_ext, atype_ext, nlist, self.davg, self.dstd, True + coord_ext, + atype_ext, + nlist, + self.davg[...], + self.dstd[...], + True, ) nf, nloc, nnei, _ = rr.shape sec = self.sel_cumsum diff --git a/deepmd/dpmodel/descriptor/se_t.py b/deepmd/dpmodel/descriptor/se_t.py index 3540eae53c..fb30f04961 100644 --- a/deepmd/dpmodel/descriptor/se_t.py +++ b/deepmd/dpmodel/descriptor/se_t.py @@ -349,7 +349,11 @@ def call( xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist) # nf x nloc x nnei x 4 rr, diff, ww = self.env_mat.call( - coord_ext, atype_ext, nlist, self.davg, self.dstd + coord_ext, + atype_ext, + nlist, + self.davg[...], + self.dstd[...], ) nf, nloc, nnei, _ = rr.shape sec = self.sel_cumsum diff --git a/deepmd/dpmodel/descriptor/se_t_tebd.py b/deepmd/dpmodel/descriptor/se_t_tebd.py index c3ee41fe00..ff26024aad 100644 --- a/deepmd/dpmodel/descriptor/se_t_tebd.py +++ b/deepmd/dpmodel/descriptor/se_t_tebd.py @@ -733,7 +733,11 @@ def call( xp = array_api_compat.array_namespace(nlist, coord_ext, atype_ext) # nf x nloc x nnei x 4 dmatrix, diff, sw = self.env_mat.call( - coord_ext, atype_ext, nlist, self.mean, self.stddev + coord_ext, + atype_ext, + nlist, + self.mean[...], + self.stddev[...], ) nf, nloc, nnei, _ = dmatrix.shape exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index 7342663141..cd0d4e72d4 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -410,7 +410,7 @@ def _call_common( f"get an input fparam of dim {fparam.shape[-1]}, " f"which is not consistent with {self.numb_fparam}." ) - fparam = (fparam - self.fparam_avg) * self.fparam_inv_std + fparam = (fparam - self.fparam_avg[...]) * self.fparam_inv_std[...] fparam = xp.tile( xp.reshape(fparam, [nf, 1, self.numb_fparam]), (1, nloc, 1) ) @@ -432,7 +432,7 @@ def _call_common( f"which is not consistent with {self.numb_aparam}." ) aparam = xp.reshape(aparam, [nf, nloc, self.numb_aparam]) - aparam = (aparam - self.aparam_avg) * self.aparam_inv_std + aparam = (aparam - self.aparam_avg[...]) * self.aparam_inv_std[...] xx = xp.concat( [xx, aparam], axis=-1, @@ -445,7 +445,9 @@ def _call_common( if self.dim_case_embd > 0: assert self.case_embd is not None - case_embd = xp.tile(xp.reshape(self.case_embd, [1, 1, -1]), [nf, nloc, 1]) + case_embd = xp.tile( + xp.reshape(self.case_embd[...], [1, 1, -1]), [nf, nloc, 1] + ) xx = xp.concat( [xx, case_embd], axis=-1, @@ -482,7 +484,9 @@ def _call_common( outs -= self.nets[()](xx_zeros) outs += xp.reshape( xp.take( - xp.astype(self.bias_atom_e, outs.dtype), xp.reshape(atype, [-1]), axis=0 + xp.astype(self.bias_atom_e[...], outs.dtype), + xp.reshape(atype, [-1]), + axis=0, ), [nf, nloc, net_dim_out], ) diff --git a/deepmd/dpmodel/utils/exclude_mask.py b/deepmd/dpmodel/utils/exclude_mask.py index 372eba133e..f390bbc7c1 100644 --- a/deepmd/dpmodel/utils/exclude_mask.py +++ b/deepmd/dpmodel/utils/exclude_mask.py @@ -53,7 +53,8 @@ def build_type_exclude_mask( xp = array_api_compat.array_namespace(atype) nf, natom = atype.shape return xp.reshape( - xp.take(self.type_mask, xp.reshape(atype, [-1]), axis=0), (nf, natom) + xp.take(self.type_mask[...], xp.reshape(atype, [-1]), axis=0), + (nf, natom), ) @@ -131,7 +132,8 @@ def build_type_exclude_mask( # nf x (nloc x nnei) type_ij = xp.reshape(type_ij, (nf, nloc * nnei)) mask = xp.reshape( - xp.take(self.type_mask, xp.reshape(type_ij, (-1,))), (nf, nloc, nnei) + xp.take(self.type_mask[...], xp.reshape(type_ij, (-1,))), + (nf, nloc, nnei), ) return mask diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index 6a3c6d8081..ed5e90b81b 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -259,9 +259,9 @@ def call(self, x: np.ndarray) -> np.ndarray: xp = array_api_compat.array_namespace(x) fn = get_activation_fn(self.activation_function) y = ( - xp.matmul(x, self.w) + self.b + xp.matmul(x, self.w[...]) + self.b[...] if self.b is not None - else xp.matmul(x, self.w) + else xp.matmul(x, self.w[...]) ) if y.dtype != x.dtype: # workaround for bfloat16 From d0a1ce799c2b7b3d049f69c6acd43fef92340a23 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 28 May 2025 01:19:51 +0800 Subject: [PATCH 05/32] checkpoint Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/descriptor/dpa1.py | 1 + deepmd/dpmodel/descriptor/repflows.py | 4 +- deepmd/dpmodel/fitting/ener_fitting.py | 72 ++++++++++ deepmd/dpmodel/loss/ener.py | 13 +- deepmd/dpmodel/utils/env_mat_stat.py | 7 +- deepmd/jax/entrypoints/main.py | 11 ++ deepmd/jax/train/trainer.py | 185 +++++++++++++++++++++++-- 7 files changed, 275 insertions(+), 18 deletions(-) diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 51c56e9681..703cc887a3 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -781,6 +781,7 @@ def __init__( self.mean = np.zeros(wanted_shape, dtype=PRECISION_DICT[self.precision]) self.stddev = np.ones(wanted_shape, dtype=PRECISION_DICT[self.precision]) self.orig_sel = self.sel + self.ndescrpt = self.nnei * 4 def get_rcut(self) -> float: """Returns the cut-off radius.""" diff --git a/deepmd/dpmodel/descriptor/repflows.py b/deepmd/dpmodel/descriptor/repflows.py index f9f81d35cb..5029a6b2ec 100644 --- a/deepmd/dpmodel/descriptor/repflows.py +++ b/deepmd/dpmodel/descriptor/repflows.py @@ -1315,7 +1315,9 @@ def call( ) nb, nloc, nnei = nlist.shape nall = node_ebd_ext.shape[1] - n_edge = int(xp.sum(xp.astype(nlist_mask, xp.int32))) + n_edge = ( + int(xp.sum(xp.astype(nlist_mask, xp.int32))) if self.use_dynamic_sel else 0 + ) node_ebd = node_ebd_ext[:, :nloc, :] assert (nb, nloc) == node_ebd.shape[:2] if not self.use_dynamic_sel: diff --git a/deepmd/dpmodel/fitting/ener_fitting.py b/deepmd/dpmodel/fitting/ener_fitting.py index 6435b6468f..3f7684d1f9 100644 --- a/deepmd/dpmodel/fitting/ener_fitting.py +++ b/deepmd/dpmodel/fitting/ener_fitting.py @@ -6,6 +6,8 @@ Union, ) +import numpy as np + from deepmd.dpmodel.common import ( DEFAULT_PRECISION, ) @@ -17,6 +19,10 @@ from deepmd.dpmodel.fitting.general_fitting import ( GeneralFitting, ) + +from deepmd.utils.out_stat import ( + compute_stats_from_redu, +) from deepmd.utils.version import ( check_version_compatibility, ) @@ -86,3 +92,69 @@ def serialize(self) -> dict: **super().serialize(), "type": "ener", } + + def compute_output_stats(self, all_stat: dict, mixed_type: bool = False) -> None: + """Compute the output statistics. + + Parameters + ---------- + all_stat + must have the following components: + all_stat['energy'] of shape n_sys x n_batch x n_frame + can be prepared by model.make_stat_input + mixed_type + Whether to perform the mixed_type mode. + If True, the input data has the mixed_type format (see doc/model/train_se_atten.md), + in which frames in a system may have different natoms_vec(s), with the same nloc. + """ + self.bias_atom_e = self._compute_output_stats( + all_stat, rcond=self.rcond, mixed_type=mixed_type + ) + + def _compute_output_stats(self, all_stat, rcond=1e-3, mixed_type=False): + data = all_stat["energy"] + # data[sys_idx][batch_idx][frame_idx] + sys_ener = [] + for ss in range(len(data)): + sys_data = [] + for ii in range(len(data[ss])): + for jj in range(len(data[ss][ii])): + sys_data.append(data[ss][ii][jj]) + sys_data = np.concatenate(sys_data) + sys_ener.append(np.average(sys_data)) + sys_ener = np.array(sys_ener) + sys_tynatom = [] + if mixed_type: + data = all_stat["real_natoms_vec"] + nsys = len(data) + for ss in range(len(data)): + tmp_tynatom = [] + for ii in range(len(data[ss])): + for jj in range(len(data[ss][ii])): + tmp_tynatom.append(data[ss][ii][jj].astype(np.float64)) + tmp_tynatom = np.average(np.array(tmp_tynatom), axis=0) + sys_tynatom.append(tmp_tynatom) + else: + data = all_stat["natoms_vec"] + nsys = len(data) + for ss in range(len(data)): + sys_tynatom.append(data[ss][0].astype(np.float64)) + sys_tynatom = np.array(sys_tynatom) + sys_tynatom = np.reshape(sys_tynatom, [nsys, -1]) + sys_tynatom = sys_tynatom[:, 2:] + if len(self.atom_ener) > 0: + # Atomic energies stats are incorrect if atomic energies are assigned. + # In this situation, we directly use these assigned energies instead of computing stats. + # This will make the loss decrease quickly + assigned_atom_ener = np.array( + [ee if ee is not None else np.nan for ee in self.atom_ener_v] + ) + else: + assigned_atom_ener = None + energy_shift, _ = compute_stats_from_redu( + sys_ener.reshape(-1, 1), + sys_tynatom, + assigned_bias=assigned_atom_ener, + rcond=rcond, + ) + return energy_shift.ravel() diff --git a/deepmd/dpmodel/loss/ener.py b/deepmd/dpmodel/loss/ener.py index 44470dfa73..87257dc584 100644 --- a/deepmd/dpmodel/loss/ener.py +++ b/deepmd/dpmodel/loss/ener.py @@ -177,7 +177,9 @@ def call( delta=self.huber_delta, ) loss += pref_e * l_huber_loss - more_loss["l2_ener_loss"] = self.display_if_exist(l2_ener_loss, find_energy) + more_loss["rmse_e"] = self.display_if_exist( + xp.sqrt(l2_ener_loss) * atom_norm_ener, find_energy + ) if self.has_f: l2_force_loss = xp.mean(xp.square(diff_f)) if not self.use_huber: @@ -189,8 +191,8 @@ def call( delta=self.huber_delta, ) loss += pref_f * l_huber_loss - more_loss["l2_force_loss"] = self.display_if_exist( - l2_force_loss, find_force + more_loss["rmse_f"] = self.display_if_exist( + xp.sqrt(l2_force_loss), find_force ) if self.has_v: virial_reshape = xp.reshape(virial, [-1]) @@ -207,9 +209,7 @@ def call( delta=self.huber_delta, ) loss += pref_v * l_huber_loss - more_loss["l2_virial_loss"] = self.display_if_exist( - l2_virial_loss, find_virial - ) + more_loss["rmse_v"] = self.display_if_exist(l2_virial_loss, find_virial) if self.has_ae: atom_ener_reshape = xp.reshape(atom_ener, [-1]) atom_ener_hat_reshape = xp.reshape(atom_ener_hat, [-1]) @@ -261,6 +261,7 @@ def call( ) self.l2_l = loss + more_loss["rmse"] = xp.sqrt(loss) self.l2_more = more_loss return loss, more_loss diff --git a/deepmd/dpmodel/utils/env_mat_stat.py b/deepmd/dpmodel/utils/env_mat_stat.py index e25739fa56..278f565a3a 100644 --- a/deepmd/dpmodel/utils/env_mat_stat.py +++ b/deepmd/dpmodel/utils/env_mat_stat.py @@ -119,12 +119,15 @@ def iter( "last_dim should be 1 for raial-only or 4 for full descriptor." ) for system in data: - coord, atype, box, natoms = ( + coord, atype, box = ( system["coord"], system["atype"], system["box"], - system["natoms"], ) + coord = xp.reshape(coord, (coord.shape[0], -1, 3)) # (nframes, nloc, 3) + atype = xp.reshape(atype, (coord.shape[0], -1)) # (nframes, nloc) + if box is not None: + box = xp.reshape(box, (coord.shape[0], 3, 3)) ( extended_coord, extended_atype, diff --git a/deepmd/jax/entrypoints/main.py b/deepmd/jax/entrypoints/main.py index acaf30e664..e7ad32b56c 100644 --- a/deepmd/jax/entrypoints/main.py +++ b/deepmd/jax/entrypoints/main.py @@ -2,6 +2,9 @@ """DeePMD-Kit entry point module.""" import argparse +from pathlib import ( + Path, +) from typing import ( Optional, Union, @@ -13,6 +16,9 @@ from deepmd.jax.entrypoints.train import ( train, ) +from deepmd.loggers.loggers import ( + set_log_handles, +) from deepmd.main import ( parse_args, ) @@ -39,6 +45,11 @@ def main(args: Optional[Union[list[str], argparse.Namespace]] = None) -> None: args = parse_args(args=args) dict_args = vars(args) + set_log_handles( + args.log_level, + Path(args.log_path) if args.log_path else None, + mpi_log=None, + ) if args.command == "train": train(**dict_args) diff --git a/deepmd/jax/train/trainer.py b/deepmd/jax/train/trainer.py index b4f0eaa915..191e9d8466 100644 --- a/deepmd/jax/train/trainer.py +++ b/deepmd/jax/train/trainer.py @@ -1,13 +1,16 @@ #!/usr/bin/env python3 # SPDX-License-Identifier: LGPL-3.0-or-later import logging +import time from typing import ( Optional, ) import numpy as np import optax -from tqdm import trange +from tqdm import ( + trange, +) from deepmd.dpmodel.loss.ener import ( EnergyLoss, @@ -26,16 +29,22 @@ normalize_coord, ) from deepmd.jax.env import ( - jax, jnp, nnx, ) from deepmd.jax.model.model import ( get_model, ) +from deepmd.loggers.training import ( + format_training_message, + format_training_message_per_task, +) from deepmd.utils.data import ( DataRequirementItem, ) +from deepmd.utils.model_stat import ( + make_stat_input, +) log = logging.getLogger(__name__) @@ -101,7 +110,28 @@ def data_requirements(self) -> list[DataRequirementItem]: return self.loss.label_requirement def train(self, train_data, valid_data=None) -> None: - optimizer = nnx.Optimizer(self.model, optax.adam(1e-3)) # reference sharing + model = self.model + tx = optax.adam( + learning_rate=1e-3 # TODO + ) + optimizer = nnx.Optimizer(model, tx) + + # data stat + data_stat_nbatch = 10 # TODO + all_stat = make_stat_input(train_data, data_stat_nbatch, merge_sys=False) + all_stat["atype"] = all_stat.pop("type") + + # swap dict key and list idx + all_stat_sys = [ + { + kk: jnp.asarray(np.concatenate(vv[ii], axis=0)) + for kk, vv in all_stat.items() + if not kk.startswith("find_") + } + for ii in range(train_data.get_nsystems()) + ] + model.atomic_model.descriptor.compute_input_stats(all_stat_sys) + model.atomic_model.fitting.compute_output_stats(all_stat) def loss_fn( model, @@ -114,7 +144,7 @@ def loss_fn( fp, ap, ): - model_dict_lower = self.model.call_lower( + model_dict_lower = model.call_lower( extended_coord, extended_atype, nlist, @@ -136,8 +166,43 @@ def loss_fn( ) return loss + @nnx.jit + def loss_fn_more_loss( + model, + lr, + label_dict, + extended_coord, + extended_atype, + nlist, + mapping, + fp, + ap, + ): + model_dict_lower = model.call_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fp, + ap, + ) + model_dict = communicate_extended_output( + model_dict_lower, + model.model_output_def(), + mapping, + do_atomic_virial=False, + ) + loss, more_loss = self.loss( + learning_rate=lr, + natoms=label_dict["coord"].shape[1], + model_dict=model_dict, + label_dict=label_dict, + ) + return more_loss + @nnx.jit def train_step( + model, optimizer, lr, label_dict, @@ -149,7 +214,7 @@ def train_step( ap, ): grads = nnx.grad(loss_fn)( - optimizer.model, + model, lr, label_dict, extended_coord, @@ -161,6 +226,8 @@ def train_step( ) optimizer.update(grads) + start_time = time.time() + disp_file_fp = open(self.disp_file, "w") for step in trange(self.num_steps): batch_data = train_data.get_batch() # numpy to jax @@ -169,15 +236,16 @@ def train_step( for kk, vv in batch_data.items() } extended_coord, extended_atype, nlist, mapping, fp, ap = prepare_input( - rcut=self.model.get_rcut(), - sel=self.model.get_sel(), + rcut=model.get_rcut(), + sel=model.get_sel(), coord=jax_data["coord"], atype=jax_data["type"], box=jax_data["box"] if jax_data["find_box"] else None, fparam=jax_data.get("fparam", None), aparam=jax_data.get("aparam", None), ) - loss = train_step( + train_step( + model, optimizer, self.lr.value(step), jax_data, @@ -188,7 +256,106 @@ def train_step( fp, ap, ) - # print(step, jnp.sqrt(loss)) + if self.display_in_training and ( + step == 0 or (step + 1) % self.disp_freq == 0 + ): + wall_time = time.time() - start_time + log.info( + format_training_message( + batch=step + 1, + wall_time=wall_time, + ) + ) + more_loss = loss_fn_more_loss( + optimizer.model, + self.lr.value(step), + jax_data, + extended_coord, + extended_atype, + nlist, + mapping, + fp, + ap, + ) + if valid_data is not None: + valid_batch_data = valid_data.get_batch() + jax_valid_data = { + kk: jnp.asarray(vv) for kk, vv in valid_batch_data.items() + } + extended_coord, extended_atype, nlist, mapping, fp, ap = ( + prepare_input( + rcut=model.get_rcut(), + sel=model.get_sel(), + coord=jax_valid_data["coord"], + atype=jax_valid_data["type"], + box=jax_valid_data["box"] + if jax_valid_data["find_box"] + else None, + fparam=jax_valid_data.get("fparam", None), + aparam=jax_valid_data.get("aparam", None), + ) + ) + valid_more_loss = loss_fn_more_loss( + optimizer.model, + self.lr.value(step), + jax_valid_data, + extended_coord, + extended_atype, + nlist, + mapping, + fp, + ap, + ) + self.print_on_training( + disp_file_fp, + train_results=more_loss, + valid_results=valid_more_loss, + cur_batch=step + 1, + cur_lr=self.lr.value(step), + ) + start_time = time.time() + + disp_file_fp.close() + + @staticmethod + def print_on_training( + fp, + train_results, + valid_results, + cur_batch, + cur_lr, + ) -> None: + print_str = "" + print_str += f"{cur_batch:7d}" + if valid_results is not None: + prop_fmt = " %11.2e %11.2e" + for k in valid_results.keys(): + # assert k in train_results.keys() + print_str += prop_fmt % (valid_results[k], train_results[k]) + else: + prop_fmt = " %11.2e" + for k in train_results.keys(): + print_str += prop_fmt % (train_results[k]) + print_str += f" {cur_lr:8.1e}\n" + log.info( + format_training_message_per_task( + batch=cur_batch, + task_name="trn", + rmse=train_results, + learning_rate=cur_lr, + ) + ) + if valid_results is not None: + log.info( + format_training_message_per_task( + batch=cur_batch, + task_name="val", + rmse=valid_results, + learning_rate=None, + ) + ) + fp.write(print_str) + fp.flush() def prepare_input( From e88838bec28b5ba624359e40b81568b1ec18a69d Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 28 May 2025 02:13:19 +0800 Subject: [PATCH 06/32] fix scale of initial parameters Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/utils/network.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index ed5e90b81b..d4d325a863 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -105,9 +105,23 @@ def __init__( # only use_timestep when skip connection is established. use_timestep = use_timestep and (num_out == num_in or num_out == num_in * 2) rng = np.random.default_rng(seed) - self.w = rng.normal(size=(num_in, num_out)).astype(prec) - self.b = rng.normal(size=(num_out,)).astype(prec) if bias else None - self.idt = rng.normal(size=(num_out,)).astype(prec) if use_timestep else None + self.w = rng.normal( + size=(num_in, num_out), scale=1.0 / np.sqrt(num_out + num_in) + ).astype(prec) + self.b = ( + rng.normal(size=(num_out,), scale=1.0 / np.sqrt(num_out + num_in)).astype( + prec + ) + if bias + else None + ) + self.idt = ( + rng.normal(size=(num_out,), scale=1.0 / np.sqrt(num_out + num_in)).astype( + prec + ) + if use_timestep + else None + ) self.activation_function = ( activation_function if activation_function is not None else "none" ) From 29922d70fe58380ad0aaf006f8b4b2e0476388b4 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 28 May 2025 04:39:39 +0800 Subject: [PATCH 07/32] set up lr Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/utils/learning_rate.py | 7 +++---- deepmd/jax/train/trainer.py | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/deepmd/dpmodel/utils/learning_rate.py b/deepmd/dpmodel/utils/learning_rate.py index 90c18fca22..9bfc6d2a16 100644 --- a/deepmd/dpmodel/utils/learning_rate.py +++ b/deepmd/dpmodel/utils/learning_rate.py @@ -45,9 +45,8 @@ def __init__( self.decay_rate = decay_rate self.min_lr = stop_lr - def value(self, step) -> np.float64: + def value(self, step, xp=np) -> np.float64: """Get the learning rate at the given step.""" - step_lr = self.start_lr * np.power(self.decay_rate, step // self.decay_steps) - if step_lr < self.min_lr: - step_lr = self.min_lr + step_lr = self.start_lr * xp.power(self.decay_rate, step // self.decay_steps) + step_lr = xp.clip(step_lr, self.min_lr, None) return step_lr diff --git a/deepmd/jax/train/trainer.py b/deepmd/jax/train/trainer.py index 191e9d8466..1e1bf4c594 100644 --- a/deepmd/jax/train/trainer.py +++ b/deepmd/jax/train/trainer.py @@ -112,7 +112,7 @@ def data_requirements(self) -> list[DataRequirementItem]: def train(self, train_data, valid_data=None) -> None: model = self.model tx = optax.adam( - learning_rate=1e-3 # TODO + learning_rate=lambda step: self.lr.value(step, xp=jnp), ) optimizer = nnx.Optimizer(model, tx) From d5d5f06140dfc2b57f8b4f961f6e286c4192fdc5 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 28 May 2025 04:40:53 +0800 Subject: [PATCH 08/32] clean up tqdm Signed-off-by: Jinzhe Zeng --- deepmd/jax/train/trainer.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/deepmd/jax/train/trainer.py b/deepmd/jax/train/trainer.py index 1e1bf4c594..db4c50d4ff 100644 --- a/deepmd/jax/train/trainer.py +++ b/deepmd/jax/train/trainer.py @@ -8,9 +8,6 @@ import numpy as np import optax -from tqdm import ( - trange, -) from deepmd.dpmodel.loss.ener import ( EnergyLoss, @@ -228,7 +225,7 @@ def train_step( start_time = time.time() disp_file_fp = open(self.disp_file, "w") - for step in trange(self.num_steps): + for step in range(self.num_steps): batch_data = train_data.get_batch() # numpy to jax jax_data = { From 6947ee69c67f3910ec7824496f14d31ef46f63c0 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 28 May 2025 05:02:13 +0800 Subject: [PATCH 09/32] freeze Signed-off-by: Jinzhe Zeng --- deepmd/jax/entrypoints/freeze.py | 36 ++++++++++++++++++++++++++++++++ deepmd/jax/entrypoints/main.py | 6 ++++-- deepmd/jax/train/trainer.py | 21 +++++++++++++++++++ 3 files changed, 61 insertions(+), 2 deletions(-) create mode 100644 deepmd/jax/entrypoints/freeze.py diff --git a/deepmd/jax/entrypoints/freeze.py b/deepmd/jax/entrypoints/freeze.py new file mode 100644 index 0000000000..bd283e8681 --- /dev/null +++ b/deepmd/jax/entrypoints/freeze.py @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from pathlib import ( + Path, +) + +from deepmd.jax.utils.serialization import ( + deserialize_to_file, + serialize_from_file, +) + + +def freeze( + *, + checkpoint_folder: str, + output: str, + **kwargs, +) -> None: + """Freeze the graph in supplied folder. + + Parameters + ---------- + checkpoint_folder : str + location of either the folder with checkpoint or the checkpoint prefix + output : str + output file name + **kwargs + other arguments + """ + if (Path(checkpoint_folder) / "checkpoint").is_file(): + checkpoint_meta = Path(checkpoint_folder) / "checkpoint" + checkpoint_folder = checkpoint_meta.read_text().strip() + if Path(checkpoint_folder).is_dir(): + data = serialize_from_file(checkpoint_folder) + deserialize_to_file(output, data) + else: + raise FileNotFoundError(f"Checkpoint {checkpoint_folder} does not exist.") diff --git a/deepmd/jax/entrypoints/main.py b/deepmd/jax/entrypoints/main.py index e7ad32b56c..6bbb9f08f7 100644 --- a/deepmd/jax/entrypoints/main.py +++ b/deepmd/jax/entrypoints/main.py @@ -13,6 +13,9 @@ from deepmd.backend.suffix import ( format_model_suffix, ) +from deepmd.jax.entrypoints.freeze import ( + freeze, +) from deepmd.jax.entrypoints.train import ( train, ) @@ -54,11 +57,10 @@ def main(args: Optional[Union[list[str], argparse.Namespace]] = None) -> None: if args.command == "train": train(**dict_args) elif args.command == "freeze": - raise dict_args["output"] = format_model_suffix( dict_args["output"], preferred_backend=args.backend, strict_prefer=True ) - # freeze(**dict_args) + freeze(**dict_args) elif args.command is None: pass else: diff --git a/deepmd/jax/train/trainer.py b/deepmd/jax/train/trainer.py index db4c50d4ff..f01116d2ed 100644 --- a/deepmd/jax/train/trainer.py +++ b/deepmd/jax/train/trainer.py @@ -2,12 +2,16 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging import time +from pathlib import ( + Path, +) from typing import ( Optional, ) import numpy as np import optax +import orbax.checkpoint as ocp from deepmd.dpmodel.loss.ener import ( EnergyLoss, @@ -48,6 +52,7 @@ class DPTrainer: def __init__(self, jdata) -> None: + self.model_def_script = jdata["model"] self.model = get_model(jdata["model"]) self.training_param = jdata["training"] self.num_steps = self.training_param["numb_steps"] @@ -311,6 +316,22 @@ def train_step( cur_lr=self.lr.value(step), ) start_time = time.time() + if step % self.save_freq == 0: + # save model + _, state = nnx.split(model) + with ocp.Checkpointer( + ocp.CompositeCheckpointHandler("state", "model_def_script") + ) as checkpointer: + checkpointer.save( + Path(f"{self.save_ckpt}.jax").absolute(), + ocp.args.Composite( + state=ocp.args.StandardSave(state.to_pure_dict()), + model_def_script=ocp.args.JsonSave(self.model_def_script), + ), + ) + log.info(f"Trained model has been saved to: {self.save_ckpt}.jax") + with open("checkpoint", "w") as fp: + fp.write(f"{self.save_ckpt}.jax") disp_file_fp.close() From 9218157c0cefa9dc5360d2953115eb3dee59a76d Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 29 May 2025 01:23:48 +0800 Subject: [PATCH 10/32] improve checkpoint Signed-off-by: Jinzhe Zeng --- deepmd/jax/train/trainer.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/deepmd/jax/train/trainer.py b/deepmd/jax/train/trainer.py index f01116d2ed..9c1dabfca4 100644 --- a/deepmd/jax/train/trainer.py +++ b/deepmd/jax/train/trainer.py @@ -13,6 +13,9 @@ import optax import orbax.checkpoint as ocp +from deepmd.common import ( + symlink_prefix_files, +) from deepmd.dpmodel.loss.ener import ( EnergyLoss, ) @@ -316,20 +319,25 @@ def train_step( cur_lr=self.lr.value(step), ) start_time = time.time() - if step % self.save_freq == 0: + if (step + 1) % self.save_freq == 0: # save model _, state = nnx.split(model) + ckpt_path = Path(f"{self.save_ckpt}-{step + 1}.jax") + if ckpt_path.exists(): + # remove old checkpoint if it exists + ckpt_path.unlink() with ocp.Checkpointer( ocp.CompositeCheckpointHandler("state", "model_def_script") ) as checkpointer: checkpointer.save( - Path(f"{self.save_ckpt}.jax").absolute(), + ckpt_path.absolute(), ocp.args.Composite( state=ocp.args.StandardSave(state.to_pure_dict()), model_def_script=ocp.args.JsonSave(self.model_def_script), ), ) - log.info(f"Trained model has been saved to: {self.save_ckpt}.jax") + log.info(f"Trained model has been saved to: {ckpt_path!s}") + symlink_prefix_files(f"{self.save_ckpt}-{step + 1}", self.save_ckpt) with open("checkpoint", "w") as fp: fp.write(f"{self.save_ckpt}.jax") From e3dca7a2171517eaa5e0373d6588f7b496ada19c Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 29 May 2025 02:26:27 +0800 Subject: [PATCH 11/32] fix unreference variable Signed-off-by: Jinzhe Zeng --- deepmd/jax/entrypoints/train.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/deepmd/jax/entrypoints/train.py b/deepmd/jax/entrypoints/train.py index d1d47ef99a..c1de799c7b 100644 --- a/deepmd/jax/entrypoints/train.py +++ b/deepmd/jax/entrypoints/train.py @@ -125,6 +125,8 @@ def train( ) valid_data.add_data_requirements(model.data_requirements) valid_data.print_summary("validation") + else: + valid_data = None # get training info stop_batch = jdata["training"]["numb_steps"] From 1fdb40cf9a728c7b5b4cae7f57530c899e40a46b Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 29 May 2025 03:32:45 +0800 Subject: [PATCH 12/32] hessian loss Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/loss/ener.py | 76 +++++++++++++++++++++++++++++++++++++ deepmd/jax/train/trainer.py | 11 +++++- 2 files changed, 86 insertions(+), 1 deletion(-) diff --git a/deepmd/dpmodel/loss/ener.py b/deepmd/dpmodel/loss/ener.py index 87257dc584..ec34ffd275 100644 --- a/deepmd/dpmodel/loss/ener.py +++ b/deepmd/dpmodel/loss/ener.py @@ -387,3 +387,79 @@ def deserialize(cls, data: dict) -> "Loss": check_version_compatibility(data.pop("@version"), 2, 1) data.pop("@class") return cls(**data) + + +class EnergyHessianLoss(EnergyLoss): + def __init__( + self, + start_pref_h=0.0, + limit_pref_h=0.0, + **kwargs, + ): + r"""Enable the layer to compute loss on hessian. + + Parameters + ---------- + start_pref_h : float + The prefactor of hessian loss at the start of the training. + limit_pref_h : float + The prefactor of hessian loss at the end of the training. + **kwargs + Other keyword arguments. + """ + EnergyLoss.__init__(self, **kwargs) + self.has_h = start_pref_h != 0.0 and limit_pref_h != 0.0 + + self.start_pref_h = start_pref_h + self.limit_pref_h = limit_pref_h + + def call( + self, + learning_rate: float, + natoms: int, + model_dict: dict[str, np.ndarray], + label_dict: dict[str, np.ndarray], + ) -> dict[str, np.ndarray]: + """Calculate loss from model results and labeled results.""" + loss, more_loss = EnergyLoss.call( + self, learning_rate, natoms, model_dict, label_dict + ) + xp = array_api_compat.array_namespace(model_dict["energy"]) + coef = learning_rate / self.starter_learning_rate + pref_h = self.limit_pref_h + (self.start_pref_h - self.limit_pref_h) * coef + + if ( + self.has_h + and "energy_derv_r_derv_r" in model_dict + and "hessian" in label_dict + ): + find_hessian = label_dict.get("find_hessian", 0.0) + pref_h = pref_h * find_hessian + diff_h = label_dict["hessian"].reshape( + -1, + ) - model_dict["energy_derv_r_derv_r"].reshape( + -1, + ) + l2_hessian_loss = xp.mean(xp.square(diff_h)) + loss += pref_h * l2_hessian_loss + rmse_h = xp.sqrt(l2_hessian_loss) + more_loss["rmse_h"] = self.display_if_exist(rmse_h, find_hessian) + + more_loss["rmse"] = xp.sqrt(loss) + return loss, more_loss + + @property + def label_requirement(self) -> list[DataRequirementItem]: + """Add hessian label requirement needed for this loss calculation.""" + label_requirement = super().label_requirement + if self.has_h: + label_requirement.append( + DataRequirementItem( + "hessian", + ndof=1, # 9=3*3 --> 3N*3N=ndof*natoms*natoms + atomic=True, + must=False, + high_prec=False, + ) + ) + return label_requirement diff --git a/deepmd/jax/train/trainer.py b/deepmd/jax/train/trainer.py index 9c1dabfca4..8e0b77ebf5 100644 --- a/deepmd/jax/train/trainer.py +++ b/deepmd/jax/train/trainer.py @@ -17,6 +17,7 @@ symlink_prefix_files, ) from deepmd.dpmodel.loss.ener import ( + EnergyHessianLoss, EnergyLoss, ) from deepmd.dpmodel.model.transform_output import ( @@ -77,7 +78,15 @@ def get_lr_and_coef(lr_param): self.lr = get_lr_and_coef(learning_rate_param) loss_param = jdata.get("loss", {}) loss_param["starter_learning_rate"] = learning_rate_param["start_lr"] - self.loss = EnergyLoss.get_loss(loss_param) + + loss_type = loss_param.get("type", "ener") + if loss_type == "ener" and loss_param.get("start_pref_h", 0.0) > 0.0: + self.loss = EnergyHessianLoss.get_loss(loss_param) + self.model.enable_hessian() + elif loss_type == "ener": + self.loss = EnergyLoss.get_loss(loss_param) + else: + raise RuntimeError("unknown loss type " + loss_type) # training tr_data = jdata["training"] From 1474327ad5ef5f4a88c67fab9800bf4280f18183 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 29 May 2025 03:33:06 +0800 Subject: [PATCH 13/32] valid_more_loss Signed-off-by: Jinzhe Zeng --- deepmd/jax/train/trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/deepmd/jax/train/trainer.py b/deepmd/jax/train/trainer.py index 8e0b77ebf5..663f97ec63 100644 --- a/deepmd/jax/train/trainer.py +++ b/deepmd/jax/train/trainer.py @@ -320,6 +320,8 @@ def train_step( fp, ap, ) + else: + valid_more_loss = None self.print_on_training( disp_file_fp, train_results=more_loss, From 68b37274bac3277c9381820a9ad637a14cf28ca5 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 2 Jun 2025 02:16:37 +0800 Subject: [PATCH 14/32] print summary Signed-off-by: Jinzhe Zeng --- deepmd/jax/entrypoints/train.py | 34 ++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/deepmd/jax/entrypoints/train.py b/deepmd/jax/entrypoints/train.py index d1d47ef99a..b6ba8fbab8 100644 --- a/deepmd/jax/entrypoints/train.py +++ b/deepmd/jax/entrypoints/train.py @@ -14,6 +14,10 @@ from deepmd.common import ( j_loader, ) +from deepmd.jax.env import ( + jax, + jax_export, +) from deepmd.jax.train.trainer import ( DPTrainer, ) @@ -26,12 +30,40 @@ from deepmd.utils.data_system import ( get_data, ) +from deepmd.utils.summary import SummaryPrinter as BaseSummaryPrinter __all__ = ["train"] log = logging.getLogger(__name__) +class SummaryPrinter(BaseSummaryPrinter): + """Summary printer for JAX.""" + + def is_built_with_cuda(self) -> bool: + """Check if the backend is built with CUDA.""" + return jax_export.default_export_platform() == "cuda" + + def is_built_with_rocm(self) -> bool: + """Check if the backend is built with ROCm.""" + return jax_export.default_export_platform() == "rocm" + + def get_compute_device(self) -> str: + """Get Compute device.""" + return jax.default_backend() + + def get_ngpus(self) -> int: + """Get the number of GPUs.""" + return jax.device_count() + + def get_backend_info(self) -> dict: + """Get backend information.""" + return { + "Backend": "JAX", + "JAX ver": jax.__version__, + } + + def train( *, INPUT: str, @@ -94,7 +126,7 @@ def train( with open(output, "w") as fp: json.dump(jdata, fp, indent=4) - # print_resource_summary() + SummaryPrinter()() # make necessary checks assert "training" in jdata From a0cd67a347436adee31dbff1cb54e71d8a18bfb2 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 2 Jun 2025 02:22:38 +0800 Subject: [PATCH 15/32] seed Signed-off-by: Jinzhe Zeng --- deepmd/jax/entrypoints/train.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/deepmd/jax/entrypoints/train.py b/deepmd/jax/entrypoints/train.py index 67f8dba901..ec3303b6e4 100644 --- a/deepmd/jax/entrypoints/train.py +++ b/deepmd/jax/entrypoints/train.py @@ -21,6 +21,7 @@ from deepmd.jax.train.trainer import ( DPTrainer, ) +from deepmd.utils import random as dp_random from deepmd.utils.argcheck import ( normalize, ) @@ -143,6 +144,9 @@ def train( # init random seed of data systems seed = jdata["training"].get("seed", None) + if seed is not None: + seed = seed % (2**32) + dp_random.seed(seed) # init data train_data = get_data(jdata["training"]["training_data"], rcut, ipt_type_map, None) From d21c39c7153378f7c09dfee348786f3e4ece5f07 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 3 Jun 2025 23:11:01 +0800 Subject: [PATCH 16/32] restart. bug to be fixed Signed-off-by: Jinzhe Zeng --- deepmd/jax/entrypoints/train.py | 6 ++- deepmd/jax/train/trainer.py | 68 ++++++++++++++++++++++--------- deepmd/jax/utils/serialization.py | 5 ++- 3 files changed, 57 insertions(+), 22 deletions(-) diff --git a/deepmd/jax/entrypoints/train.py b/deepmd/jax/entrypoints/train.py index ec3303b6e4..27b3e54e55 100644 --- a/deepmd/jax/entrypoints/train.py +++ b/deepmd/jax/entrypoints/train.py @@ -134,7 +134,11 @@ def train( # init the model - model = DPTrainer(jdata) + model = DPTrainer( + jdata, + init_model=init_model, + restart=restart, + ) rcut = model.model.get_rcut() type_map = model.model.get_type_map() if len(type_map) == 0: diff --git a/deepmd/jax/train/trainer.py b/deepmd/jax/train/trainer.py index 663f97ec63..95631f3be7 100644 --- a/deepmd/jax/train/trainer.py +++ b/deepmd/jax/train/trainer.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # SPDX-License-Identifier: LGPL-3.0-or-later import logging +import shutil import time from pathlib import ( Path, @@ -37,9 +38,15 @@ jnp, nnx, ) +from deepmd.jax.model.base_model import ( + BaseModel, +) from deepmd.jax.model.model import ( get_model, ) +from deepmd.jax.utils.serialization import ( + serialize_from_file, +) from deepmd.loggers.training import ( format_training_message, format_training_message_per_task, @@ -55,9 +62,26 @@ class DPTrainer: - def __init__(self, jdata) -> None: + def __init__( + self, + jdata, + init_model: Optional[str] = None, + restart: Optional[str] = None, + ) -> None: + self.init_model = init_model + self.restart = restart self.model_def_script = jdata["model"] - self.model = get_model(jdata["model"]) + self.start_step = 0 + if self.init_model is not None: + model_dict = serialize_from_file(self.init_model) + self.model = BaseModel.deserialize(model_dict["model"]) + elif self.restart is not None: + model_dict = serialize_from_file(self.restart) + self.model = BaseModel.deserialize(model_dict["model"]) + self.start_step = model_dict["@variables"].get("current_step", 0) + else: + # from scratch + self.model = get_model(jdata["model"]) self.training_param = jdata["training"] self.num_steps = self.training_param["numb_steps"] @@ -129,23 +153,25 @@ def train(self, train_data, valid_data=None) -> None: learning_rate=lambda step: self.lr.value(step, xp=jnp), ) optimizer = nnx.Optimizer(model, tx) + optimizer.step += self.start_step # data stat - data_stat_nbatch = 10 # TODO - all_stat = make_stat_input(train_data, data_stat_nbatch, merge_sys=False) - all_stat["atype"] = all_stat.pop("type") + if self.init_model is None and self.restart is None: + data_stat_nbatch = 10 # TODO + all_stat = make_stat_input(train_data, data_stat_nbatch, merge_sys=False) + all_stat["atype"] = all_stat.pop("type") - # swap dict key and list idx - all_stat_sys = [ - { - kk: jnp.asarray(np.concatenate(vv[ii], axis=0)) - for kk, vv in all_stat.items() - if not kk.startswith("find_") - } - for ii in range(train_data.get_nsystems()) - ] - model.atomic_model.descriptor.compute_input_stats(all_stat_sys) - model.atomic_model.fitting.compute_output_stats(all_stat) + # swap dict key and list idx + all_stat_sys = [ + { + kk: jnp.asarray(np.concatenate(vv[ii], axis=0)) + for kk, vv in all_stat.items() + if not kk.startswith("find_") + } + for ii in range(train_data.get_nsystems()) + ] + model.atomic_model.descriptor.compute_input_stats(all_stat_sys) + model.atomic_model.fitting.compute_output_stats(all_stat) def loss_fn( model, @@ -242,7 +268,7 @@ def train_step( start_time = time.time() disp_file_fp = open(self.disp_file, "w") - for step in range(self.num_steps): + for step in range(self.start_step, self.num_steps): batch_data = train_data.get_batch() # numpy to jax jax_data = { @@ -334,9 +360,11 @@ def train_step( # save model _, state = nnx.split(model) ckpt_path = Path(f"{self.save_ckpt}-{step + 1}.jax") - if ckpt_path.exists(): + if ckpt_path.is_dir(): # remove old checkpoint if it exists - ckpt_path.unlink() + shutil.rmtree(ckpt_path) + model_def_script_cpy = self.model_def_script.copy() + model_def_script_cpy["current_step"] = step + 1 with ocp.Checkpointer( ocp.CompositeCheckpointHandler("state", "model_def_script") ) as checkpointer: @@ -344,7 +372,7 @@ def train_step( ckpt_path.absolute(), ocp.args.Composite( state=ocp.args.StandardSave(state.to_pure_dict()), - model_def_script=ocp.args.JsonSave(self.model_def_script), + model_def_script=ocp.args.JsonSave(model_def_script_cpy), ), ) log.info(f"Trained model has been saved to: {ckpt_path!s}") diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index 5d4da49e08..454affba31 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -177,6 +177,7 @@ def convert_str_to_int_key(item: dict) -> None: convert_str_to_int_key(state) model_def_script = data.model_def_script + current_step = model_def_script.pop("current_step", 0) abstract_model = get_model(model_def_script) graphdef, abstract_state = nnx.split(abstract_model) abstract_state.replace_by_pure_dict(state) @@ -187,7 +188,9 @@ def convert_str_to_int_key(item: dict) -> None: "jax_version": jax.__version__, "model": model_dict, "model_def_script": model_def_script, - "@variables": {}, + "@variables": { + "current_step": current_step, + }, } return data elif model_file.endswith(".hlo"): From 15bb506071ee44ff1788616e1d5125a299eeb3b2 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 3 Jun 2025 23:42:40 +0800 Subject: [PATCH 17/32] fix lr Signed-off-by: Jinzhe Zeng --- deepmd/jax/train/trainer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/deepmd/jax/train/trainer.py b/deepmd/jax/train/trainer.py index 95631f3be7..408f67da9f 100644 --- a/deepmd/jax/train/trainer.py +++ b/deepmd/jax/train/trainer.py @@ -150,10 +150,9 @@ def data_requirements(self) -> list[DataRequirementItem]: def train(self, train_data, valid_data=None) -> None: model = self.model tx = optax.adam( - learning_rate=lambda step: self.lr.value(step, xp=jnp), + learning_rate=lambda step: self.lr.value(self.start_step + step, xp=jnp), ) optimizer = nnx.Optimizer(model, tx) - optimizer.step += self.start_step # data stat if self.init_model is None and self.restart is None: From b9111c8582a4065ca199c456d73cad29598cbd98 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 10 Jun 2025 19:46:14 +0800 Subject: [PATCH 18/32] fix(dpmodel/pt/pd/jax): pass trainable to layer & support JAX trainable 1. For dpmodel, pt, and pd, pass the trainable parameter to the layer (not actually used in this PR). 2. For JAX, support the `trainable` parameter in the layer. Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/descriptor/dpa1.py | 13 +++++ deepmd/dpmodel/descriptor/dpa2.py | 6 ++ deepmd/dpmodel/descriptor/dpa3.py | 2 + deepmd/dpmodel/descriptor/repflows.py | 33 ++++++++++- deepmd/dpmodel/descriptor/repformers.py | 56 +++++++++++++++++-- deepmd/dpmodel/descriptor/se_e2_a.py | 1 + deepmd/dpmodel/descriptor/se_r.py | 1 + deepmd/dpmodel/descriptor/se_t.py | 1 + deepmd/dpmodel/descriptor/se_t_tebd.py | 5 ++ deepmd/dpmodel/fitting/general_fitting.py | 1 + deepmd/dpmodel/utils/network.py | 14 +++++ deepmd/dpmodel/utils/type_embed.py | 1 + deepmd/jax/utils/network.py | 6 +- deepmd/pd/model/descriptor/dpa1.py | 2 + deepmd/pd/model/descriptor/dpa2.py | 6 ++ deepmd/pd/model/descriptor/dpa3.py | 2 + deepmd/pd/model/descriptor/repflow_layer.py | 16 ++++++ deepmd/pd/model/descriptor/repflows.py | 15 ++++- deepmd/pd/model/descriptor/repformer_layer.py | 46 ++++++++++++++- deepmd/pd/model/descriptor/repformers.py | 8 ++- deepmd/pd/model/descriptor/se_a.py | 1 + deepmd/pd/model/descriptor/se_atten.py | 11 ++++ deepmd/pd/model/descriptor/se_t_tebd.py | 5 ++ deepmd/pd/model/network/mlp.py | 2 + deepmd/pd/model/network/network.py | 2 + deepmd/pt/model/descriptor/dpa1.py | 2 + deepmd/pt/model/descriptor/dpa2.py | 6 ++ deepmd/pt/model/descriptor/dpa3.py | 2 + deepmd/pt/model/descriptor/repflow_layer.py | 16 ++++++ deepmd/pt/model/descriptor/repflows.py | 15 ++++- deepmd/pt/model/descriptor/repformer_layer.py | 46 ++++++++++++++- deepmd/pt/model/descriptor/repformers.py | 10 +++- deepmd/pt/model/descriptor/se_a.py | 1 + deepmd/pt/model/descriptor/se_atten.py | 11 ++++ deepmd/pt/model/descriptor/se_r.py | 1 + deepmd/pt/model/descriptor/se_t.py | 1 + deepmd/pt/model/descriptor/se_t_tebd.py | 5 ++ deepmd/pt/model/network/mlp.py | 4 ++ deepmd/pt/model/network/network.py | 2 + deepmd/pt/model/task/fitting.py | 1 + deepmd/tf/descriptor/se.py | 6 ++ deepmd/tf/descriptor/se_atten.py | 6 ++ deepmd/tf/descriptor/se_t.py | 5 ++ deepmd/tf/fit/fitting.py | 2 + .../tests/consistent/descriptor/test_dpa1.py | 1 + .../tests/consistent/descriptor/test_dpa2.py | 2 +- .../tests/consistent/descriptor/test_dpa3.py | 2 +- 47 files changed, 382 insertions(+), 21 deletions(-) diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 51c56e9681..9e9da46aec 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -319,6 +319,7 @@ def __init__( trainable_ln=trainable_ln, ln_eps=ln_eps, seed=child_seed(seed, 0), + trainable=trainable, ) self.use_econf_tebd = use_econf_tebd self.use_tebd_bias = use_tebd_bias @@ -333,6 +334,7 @@ def __init__( use_tebd_bias=use_tebd_bias, type_map=type_map, seed=child_seed(seed, 1), + trainable=trainable, ) self.tebd_dim = tebd_dim self.concat_output_tebd = concat_output_tebd @@ -691,6 +693,7 @@ def __init__( ln_eps: Optional[float] = 1e-5, smooth: bool = True, seed: Optional[Union[int, list[int]]] = None, + trainable: bool = True, ) -> None: self.rcut = rcut self.rcut_smth = rcut_smth @@ -741,6 +744,7 @@ def __init__( self.resnet_dt, self.precision, seed=child_seed(seed, 0), + trainable=trainable, ) self.embeddings = embeddings if self.tebd_input_mode in ["strip"]: @@ -756,6 +760,7 @@ def __init__( self.resnet_dt, self.precision, seed=child_seed(seed, 1), + trainable=trainable, ) self.embeddings_strip = embeddings_strip else: @@ -774,6 +779,7 @@ def __init__( smooth=self.smooth, precision=self.precision, seed=child_seed(seed, 2), + trainable=trainable, ) wanted_shape = (self.ntypes, self.nnei, 4) @@ -1186,6 +1192,7 @@ def __init__( smooth: bool = True, precision: str = DEFAULT_PRECISION, seed: Optional[Union[int, list[int]]] = None, + trainable: bool = True, ) -> None: """Construct a neighbor-wise attention net.""" super().__init__() @@ -1219,6 +1226,7 @@ def __init__( smooth=smooth, precision=precision, seed=child_seed(seed, ii), + trainable=trainable, ) for ii in range(layer_num) ] @@ -1314,6 +1322,7 @@ def __init__( smooth: bool = True, precision: str = DEFAULT_PRECISION, seed: Optional[Union[int, list[int]]] = None, + trainable: bool = True, ) -> None: """Construct a neighbor-wise attention layer.""" super().__init__() @@ -1340,6 +1349,7 @@ def __init__( smooth=smooth, precision=precision, seed=child_seed(seed, 0), + trainable=trainable, ) self.attn_layer_norm = LayerNorm( self.embed_dim, @@ -1420,6 +1430,7 @@ def __init__( smooth: bool = True, precision: str = DEFAULT_PRECISION, seed: Optional[Union[int, list[int]]] = None, + trainable: bool = True, ) -> None: """Construct a multi-head neighbor-wise attention net.""" super().__init__() @@ -1449,6 +1460,7 @@ def __init__( use_timestep=False, precision=precision, seed=child_seed(seed, 0), + trainable=trainable, ) self.out_proj = NativeLayer( hidden_dim, @@ -1457,6 +1469,7 @@ def __init__( use_timestep=False, precision=precision, seed=child_seed(seed, 1), + trainable=trainable, ) def call(self, query, nei_mask, input_r=None, sw=None, attnw_shift=20.0): diff --git a/deepmd/dpmodel/descriptor/dpa2.py b/deepmd/dpmodel/descriptor/dpa2.py index 70accefa30..da39afdc23 100644 --- a/deepmd/dpmodel/descriptor/dpa2.py +++ b/deepmd/dpmodel/descriptor/dpa2.py @@ -474,6 +474,7 @@ def init_subclass_params(sub_data, sub_class): smooth=smooth, type_one_side=self.repinit_args.type_one_side, seed=child_seed(seed, 0), + trainable=trainable, ) self.use_three_body = self.repinit_args.use_three_body if self.use_three_body: @@ -493,6 +494,7 @@ def init_subclass_params(sub_data, sub_class): resnet_dt=self.repinit_args.resnet_dt, smooth=smooth, seed=child_seed(seed, 5), + trainable=trainable, ) else: self.repinit_three_body = None @@ -533,6 +535,7 @@ def init_subclass_params(sub_data, sub_class): g1_out_mlp=self.repformer_args.g1_out_mlp, ln_eps=self.repformer_args.ln_eps, seed=child_seed(seed, 1), + trainable=trainable, ) self.rcsl_list = [ (self.repformers.get_rcut(), self.repformers.get_nsel()), @@ -562,6 +565,7 @@ def init_subclass_params(sub_data, sub_class): use_tebd_bias=use_tebd_bias, type_map=type_map, seed=child_seed(seed, 2), + trainable=trainable, ) self.concat_output_tebd = concat_output_tebd self.precision = precision @@ -585,6 +589,7 @@ def init_subclass_params(sub_data, sub_class): bias=False, precision=precision, seed=child_seed(seed, 3), + trainable=trainable, ) self.tebd_transform = None if self.add_tebd_to_repinit_out: @@ -594,6 +599,7 @@ def init_subclass_params(sub_data, sub_class): bias=False, precision=precision, seed=child_seed(seed, 4), + trainable=trainable, ) assert self.repinit.rcut > self.repformers.rcut assert self.repinit.sel[0] > self.repformers.sel[0] diff --git a/deepmd/dpmodel/descriptor/dpa3.py b/deepmd/dpmodel/descriptor/dpa3.py index f9210b0574..3a03b2a9ad 100644 --- a/deepmd/dpmodel/descriptor/dpa3.py +++ b/deepmd/dpmodel/descriptor/dpa3.py @@ -357,6 +357,7 @@ def init_subclass_params(sub_data, sub_class): env_protection=env_protection, precision=precision, seed=child_seed(seed, 1), + trainable=trainable, ) self.use_econf_tebd = use_econf_tebd @@ -374,6 +375,7 @@ def init_subclass_params(sub_data, sub_class): use_tebd_bias=use_tebd_bias, type_map=type_map, seed=child_seed(seed, 2), + trainable=trainable, ) self.concat_output_tebd = concat_output_tebd self.precision = precision diff --git a/deepmd/dpmodel/descriptor/repflows.py b/deepmd/dpmodel/descriptor/repflows.py index f8c329b515..43fe844262 100644 --- a/deepmd/dpmodel/descriptor/repflows.py +++ b/deepmd/dpmodel/descriptor/repflows.py @@ -167,6 +167,8 @@ class DescrptBlockRepflows(NativeOP, DescriptorBlock): For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection. seed : int, optional Random seed for parameter initialization. + trainable : bool, default: True + Whether the block is trainable """ def __init__( @@ -205,6 +207,7 @@ def __init__( sel_reduce_factor: float = 10.0, use_loc_mapping: bool = True, seed: Optional[Union[int, list[int]]] = None, + trainable: bool = True, ) -> None: super().__init__() self.e_rcut = float(e_rcut) @@ -269,10 +272,19 @@ def __init__( self.seed = seed self.edge_embd = NativeLayer( - 1, self.e_dim, precision=precision, seed=child_seed(seed, 0) + 1, + self.e_dim, + precision=precision, + seed=child_seed(seed, 0), + trainable=trainable, ) self.angle_embd = NativeLayer( - 1, self.a_dim, precision=precision, bias=False, seed=child_seed(seed, 1) + 1, + self.a_dim, + precision=precision, + bias=False, + seed=child_seed(seed, 1), + trainable=trainable, ) layers = [] for ii in range(nlayers): @@ -304,6 +316,7 @@ def __init__( sel_reduce_factor=self.sel_reduce_factor, smooth_edge_update=self.smooth_edge_update, seed=child_seed(child_seed(seed, 1), ii), + trainable=trainable, ) ) self.layers = layers @@ -860,6 +873,7 @@ def __init__( update_residual_init: str = "const", precision: str = "float64", seed: Optional[Union[int, list[int]]] = None, + trainable: bool = True, ) -> None: super().__init__() self.epsilon = 1e-4 # protection of 1./nnei @@ -922,6 +936,7 @@ def __init__( n_dim, precision=precision, seed=child_seed(seed, 0), + trainable=trainable, ) if self.update_style == "res_residual": self.n_residual.append( @@ -931,6 +946,7 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(seed, 1), + trainable=trainable, ) ) @@ -941,6 +957,7 @@ def __init__( n_dim, precision=precision, seed=child_seed(seed, 2), + trainable=trainable, ) if self.update_style == "res_residual": self.n_residual.append( @@ -950,6 +967,7 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(seed, 3), + trainable=trainable, ) ) @@ -959,6 +977,7 @@ def __init__( self.n_multi_edge_message * n_dim, precision=precision, seed=child_seed(seed, 4), + trainable=trainable, ) if self.update_style == "res_residual": for head_index in range(self.n_multi_edge_message): @@ -969,6 +988,7 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(child_seed(seed, 5), head_index), + trainable=trainable, ) ) @@ -978,6 +998,7 @@ def __init__( e_dim, precision=precision, seed=child_seed(seed, 6), + trainable=trainable, ) if self.update_style == "res_residual": self.e_residual.append( @@ -987,6 +1008,7 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(seed, 7), + trainable=trainable, ) ) @@ -1015,6 +1037,7 @@ def __init__( precision=precision, bias=False, seed=child_seed(seed, 8), + trainable=trainable, ) self.a_compress_e_linear = NativeLayer( self.e_dim, @@ -1022,6 +1045,7 @@ def __init__( precision=precision, bias=False, seed=child_seed(seed, 9), + trainable=trainable, ) else: self.a_compress_n_linear = None @@ -1033,12 +1057,14 @@ def __init__( self.e_dim, precision=precision, seed=child_seed(seed, 10), + trainable=trainable, ) self.edge_angle_linear2 = NativeLayer( self.e_dim, self.e_dim, precision=precision, seed=child_seed(seed, 11), + trainable=trainable, ) if self.update_style == "res_residual": self.e_residual.append( @@ -1048,6 +1074,7 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(seed, 12), + trainable=trainable, ) ) @@ -1057,6 +1084,7 @@ def __init__( self.a_dim, precision=precision, seed=child_seed(seed, 13), + trainable=trainable, ) if self.update_style == "res_residual": self.a_residual.append( @@ -1066,6 +1094,7 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(seed, 14), + trainable=trainable, ) ) else: diff --git a/deepmd/dpmodel/descriptor/repformers.py b/deepmd/dpmodel/descriptor/repformers.py index 3d02054350..6ac9675d28 100644 --- a/deepmd/dpmodel/descriptor/repformers.py +++ b/deepmd/dpmodel/descriptor/repformers.py @@ -164,6 +164,8 @@ class DescrptBlockRepformers(NativeOP, DescriptorBlock): The epsilon value for layer normalization. seed : int, optional The random seed for initialization. + trainable : bool, default: True + Whether the block is trainable """ def __init__( @@ -204,6 +206,7 @@ def __init__( g1_out_mlp: bool = True, ln_eps: Optional[float] = 1e-5, seed: Optional[Union[int, list[int]]] = None, + trainable: bool = True, ) -> None: super().__init__() self.rcut = rcut @@ -252,7 +255,11 @@ def __init__( self.epsilon = 1e-4 self.g2_embd = NativeLayer( - 1, self.g2_dim, precision=precision, seed=child_seed(seed, 0) + 1, + self.g2_dim, + precision=precision, + seed=child_seed(seed, 0), + trainable=trainable, ) layers = [] for ii in range(nlayers): @@ -290,6 +297,7 @@ def __init__( g1_out_conv=self.g1_out_conv, g1_out_mlp=self.g1_out_mlp, seed=child_seed(child_seed(seed, 1), ii), + trainable=trainable, ) ) self.layers = layers @@ -847,6 +855,7 @@ def __init__( attnw_shift: float = 20.0, precision: str = "float64", seed: Optional[Union[int, list[int]]] = None, + trainable: bool = True, ) -> None: """Return neighbor-wise multi-head self-attention maps, with gate mechanism.""" super().__init__() @@ -859,6 +868,7 @@ def __init__( bias=False, precision=precision, seed=seed, + trainable=trainable, ) self.has_gate = has_gate self.smooth = smooth @@ -970,6 +980,7 @@ def __init__( head_num: int, precision: str = "float64", seed: Optional[Union[int, list[int]]] = None, + trainable: bool = True, ) -> None: super().__init__() self.input_dim = input_dim @@ -980,12 +991,14 @@ def __init__( bias=False, precision=precision, seed=child_seed(seed, 0), + trainable=trainable, ) self.head_map = NativeLayer( input_dim * head_num, input_dim, precision=precision, seed=child_seed(seed, 1), + trainable=trainable, ) self.precision = precision @@ -1058,12 +1071,18 @@ def __init__( head_num: int, precision: str = "float64", seed: Optional[Union[int, list[int]]] = None, + trainable: bool = True, ) -> None: super().__init__() self.input_dim = input_dim self.head_num = head_num self.head_map = NativeLayer( - head_num, 1, bias=False, precision=precision, seed=seed + head_num, + 1, + bias=False, + precision=precision, + seed=seed, + trainable=trainable, ) self.precision = precision @@ -1133,6 +1152,7 @@ def __init__( attnw_shift: float = 20.0, precision: str = "float64", seed: Optional[Union[int, list[int]]] = None, + trainable: bool = True, ) -> None: super().__init__() self.input_dim = input_dim @@ -1144,6 +1164,7 @@ def __init__( bias=False, precision=precision, seed=child_seed(seed, 0), + trainable=trainable, ) self.mapkv = NativeLayer( input_dim, @@ -1151,12 +1172,14 @@ def __init__( bias=False, precision=precision, seed=child_seed(seed, 1), + trainable=trainable, ) self.head_map = NativeLayer( input_dim * head_num, input_dim, precision=precision, seed=child_seed(seed, 2), + trainable=trainable, ) self.smooth = smooth self.attnw_shift = attnw_shift @@ -1295,6 +1318,7 @@ def __init__( g1_out_mlp: bool = True, ln_eps: Optional[float] = 1e-5, seed: Optional[Union[int, list[int]]] = None, + trainable: bool = True, ) -> None: super().__init__() self.epsilon = 1e-4 # protection of 1./nnei @@ -1354,6 +1378,7 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(seed, 0), + trainable=trainable, ) ) @@ -1363,6 +1388,7 @@ def __init__( g1_dim, precision=precision, seed=child_seed(seed, 1), + trainable=trainable, ) self.linear2 = None self.proj_g1g2 = None @@ -1379,6 +1405,7 @@ def __init__( g2_dim, precision=precision, seed=child_seed(seed, 2), + trainable=trainable, ) if self.update_style == "res_residual": g2_residual.append( @@ -1388,6 +1415,7 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(seed, 3), + trainable=trainable, ) ) if self.g1_out_mlp: @@ -1396,6 +1424,7 @@ def __init__( g1_dim, precision=precision, seed=child_seed(seed, 15), + trainable=trainable, ) if self.update_style == "res_residual": g1_residual.append( @@ -1405,6 +1434,7 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(seed, 16), + trainable=trainable, ) ) else: @@ -1417,6 +1447,7 @@ def __init__( bias=False, precision=precision, seed=child_seed(seed, 4), + trainable=trainable, ) else: self.proj_g1g2 = NativeLayer( @@ -1425,6 +1456,7 @@ def __init__( bias=False, precision=precision, seed=child_seed(seed, 4), + trainable=trainable, ) if self.update_style == "res_residual": g1_residual.append( @@ -1434,6 +1466,7 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(seed, 17), + trainable=trainable, ) ) if self.update_g2_has_g1g1: @@ -1443,6 +1476,7 @@ def __init__( bias=False, precision=precision, seed=child_seed(seed, 5), + trainable=trainable, ) if self.update_style == "res_residual": g2_residual.append( @@ -1452,6 +1486,7 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(seed, 6), + trainable=trainable, ) ) if self.update_g2_has_attn or self.update_h2: @@ -1463,10 +1498,15 @@ def __init__( self.smooth, precision=precision, seed=child_seed(seed, 7), + trainable=trainable, ) if self.update_g2_has_attn: self.attn2_mh_apply = Atten2MultiHeadApply( - g2_dim, attn2_nhead, precision=precision, seed=child_seed(seed, 8) + g2_dim, + attn2_nhead, + precision=precision, + seed=child_seed(seed, 8), + trainable=trainable, ) self.attn2_lm = LayerNorm( g2_dim, @@ -1483,12 +1523,17 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(seed, 10), + trainable=trainable, ) ) if self.update_h2: self.attn2_ev_apply = Atten2EquiVarApply( - g2_dim, attn2_nhead, precision=precision, seed=child_seed(seed, 11) + g2_dim, + attn2_nhead, + precision=precision, + seed=child_seed(seed, 11), + trainable=trainable, ) if self.update_style == "res_residual": h2_residual.append( @@ -1498,6 +1543,7 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(seed, 12), + trainable=trainable, ) ) if self.update_g1_has_attn: @@ -1508,6 +1554,7 @@ def __init__( self.smooth, precision=precision, seed=child_seed(seed, 13), + trainable=trainable, ) if self.update_style == "res_residual": g1_residual.append( @@ -1517,6 +1564,7 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(seed, 14), + trainable=trainable, ) ) diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index bd72d936e3..5bcffc6c53 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -207,6 +207,7 @@ def __init__( self.resnet_dt, self.precision, seed=child_seed(seed, ii), + trainable=trainable, ) self.embeddings = embeddings self.env_mat = EnvMat(self.rcut, self.rcut_smth, protection=self.env_protection) diff --git a/deepmd/dpmodel/descriptor/se_r.py b/deepmd/dpmodel/descriptor/se_r.py index 5b2931b23f..9d485b15a9 100644 --- a/deepmd/dpmodel/descriptor/se_r.py +++ b/deepmd/dpmodel/descriptor/se_r.py @@ -166,6 +166,7 @@ def __init__( self.resnet_dt, self.precision, seed=child_seed(seed, ii), + trainable=trainable, ) self.embeddings = embeddings self.env_mat = EnvMat(self.rcut, self.rcut_smth, protection=self.env_protection) diff --git a/deepmd/dpmodel/descriptor/se_t.py b/deepmd/dpmodel/descriptor/se_t.py index fb30f04961..496dd3e090 100644 --- a/deepmd/dpmodel/descriptor/se_t.py +++ b/deepmd/dpmodel/descriptor/se_t.py @@ -147,6 +147,7 @@ def __init__( self.resnet_dt, self.precision, seed=child_seed(self.seed, ii), + trainable=trainable, ) self.embeddings = embeddings self.env_mat = EnvMat(self.rcut, self.rcut_smth, protection=self.env_protection) diff --git a/deepmd/dpmodel/descriptor/se_t_tebd.py b/deepmd/dpmodel/descriptor/se_t_tebd.py index ff26024aad..ae8f1280d2 100644 --- a/deepmd/dpmodel/descriptor/se_t_tebd.py +++ b/deepmd/dpmodel/descriptor/se_t_tebd.py @@ -157,6 +157,7 @@ def __init__( env_protection=env_protection, smooth=smooth, seed=child_seed(seed, 0), + trainable=trainable, ) self.use_econf_tebd = use_econf_tebd self.type_map = type_map @@ -171,6 +172,7 @@ def __init__( use_tebd_bias=use_tebd_bias, type_map=type_map, seed=child_seed(seed, 1), + trainable=trainable, ) self.tebd_dim = tebd_dim self.concat_output_tebd = concat_output_tebd @@ -497,6 +499,7 @@ def __init__( env_protection: float = 0.0, smooth: bool = True, seed: Optional[Union[int, list[int]]] = None, + trainable: bool = True, ) -> None: self.rcut = rcut self.rcut_smth = rcut_smth @@ -542,6 +545,7 @@ def __init__( self.resnet_dt, self.precision, seed=child_seed(seed, 0), + trainable=trainable, ) self.embeddings = embeddings if self.tebd_input_mode in ["strip"]: @@ -557,6 +561,7 @@ def __init__( self.resnet_dt, self.precision, seed=child_seed(seed, 1), + trainable=trainable, ) self.embeddings_strip = embeddings_strip else: diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index cd0d4e72d4..db94580243 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -198,6 +198,7 @@ def __init__( self.precision, bias_out=True, seed=child_seed(seed, ii), + trainable=trainable, ) for ii in range(self.ntypes if not self.mixed_types else 1) ], diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index bf28b66b7b..4d37e2ee5d 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -87,6 +87,8 @@ class NativeLayer(NativeOP): The precision of the layer. seed : int, optional Random seed. + trainable : bool, default=True + Whether the layer is trainable. """ def __init__( @@ -99,7 +101,10 @@ def __init__( resnet: bool = False, precision: str = DEFAULT_PRECISION, seed: Optional[Union[int, list[int]]] = None, + trainable: bool = True, ) -> None: + # trainable must be set before any array attribute is set + self.trainable = trainable prec = PRECISION_DICT[precision.lower()] self.precision = precision # only use_timestep when skip connection is established. @@ -146,6 +151,7 @@ def serialize(self) -> dict: "resnet": self.resnet, # make deterministic "precision": np.dtype(PRECISION_DICT[self.precision]).name, + "trainable": self.trainable, "@variables": data, } @@ -240,6 +246,8 @@ def __getitem__(self, key): return self.resnet elif key == "precision": return self.precision + elif key == "trainable": + return self.trainable else: raise KeyError(key) @@ -429,6 +437,7 @@ def __init__( resnet=False, precision=precision, seed=seed, + trainable=trainable, ) xp = array_api_compat.array_namespace(self.w, self.b) self.w = xp.squeeze(self.w, 0) # keep the weight shape to be [num_in] @@ -681,6 +690,7 @@ def __init__( precision: str = DEFAULT_PRECISION, seed: Optional[Union[int, list[int]]] = None, bias: bool = True, + trainable: bool = True, ) -> None: layers = [] i_in = in_dim @@ -696,6 +706,7 @@ def __init__( resnet=True, precision=precision, seed=child_seed(seed, idx), + trainable=trainable, ).serialize() ) i_in = i_ot @@ -786,6 +797,7 @@ def __init__( precision: str = DEFAULT_PRECISION, bias_out: bool = True, seed: Optional[Union[int, list[int]]] = None, + trainable: bool = True, ) -> None: super().__init__( in_dim, @@ -794,6 +806,7 @@ def __init__( resnet_dt=resnet_dt, precision=precision, seed=seed, + trainable=trainable, ) i_in = neuron[-1] if len(neuron) > 0 else in_dim i_ot = out_dim @@ -807,6 +820,7 @@ def __init__( resnet=False, precision=precision, seed=child_seed(seed, len(neuron)), + trainable=trainable, ) ) self.out_dim = out_dim diff --git a/deepmd/dpmodel/utils/type_embed.py b/deepmd/dpmodel/utils/type_embed.py index 17e40f3592..d533d71ee9 100644 --- a/deepmd/dpmodel/utils/type_embed.py +++ b/deepmd/dpmodel/utils/type_embed.py @@ -93,6 +93,7 @@ def __init__( self.precision, seed=self.seed, bias=self.use_tebd_bias, + trainable=trainable, ) @support_array_api(version="2022.12") diff --git a/deepmd/jax/utils/network.py b/deepmd/jax/utils/network.py index 2c406095cd..78da4c96f5 100644 --- a/deepmd/jax/utils/network.py +++ b/deepmd/jax/utils/network.py @@ -16,6 +16,7 @@ make_multilayer_network, ) from deepmd.jax.common import ( + ArrayAPIVariable, flax_module, to_jax_array, ) @@ -44,7 +45,10 @@ def __setattr__(self, name: str, value: Any) -> None: if name in {"w", "b", "idt"}: value = to_jax_array(value) if value is not None: - value = ArrayAPIParam(value) + if self.trainable: + value = ArrayAPIParam(value) + else: + value = ArrayAPIVariable(value) return super().__setattr__(name, value) diff --git a/deepmd/pd/model/descriptor/dpa1.py b/deepmd/pd/model/descriptor/dpa1.py index 6942b096c9..ad45c13d1d 100644 --- a/deepmd/pd/model/descriptor/dpa1.py +++ b/deepmd/pd/model/descriptor/dpa1.py @@ -292,6 +292,7 @@ def __init__( trainable_ln=trainable_ln, ln_eps=ln_eps, seed=child_seed(seed, 1), + trainable=trainable, ) self.use_econf_tebd = use_econf_tebd self.use_tebd_bias = use_tebd_bias @@ -305,6 +306,7 @@ def __init__( use_econf_tebd=use_econf_tebd, use_tebd_bias=use_tebd_bias, type_map=type_map, + trainable=trainable, ) self.prec = PRECISION_DICT[precision] self.tebd_dim = tebd_dim diff --git a/deepmd/pd/model/descriptor/dpa2.py b/deepmd/pd/model/descriptor/dpa2.py index 0e3b24397f..44b3229f66 100644 --- a/deepmd/pd/model/descriptor/dpa2.py +++ b/deepmd/pd/model/descriptor/dpa2.py @@ -184,6 +184,7 @@ def init_subclass_params(sub_data, sub_class): smooth=smooth, type_one_side=self.repinit_args.type_one_side, seed=child_seed(seed, 0), + trainable=trainable, ) self.use_three_body = self.repinit_args.use_three_body if self.use_three_body: @@ -203,6 +204,7 @@ def init_subclass_params(sub_data, sub_class): resnet_dt=self.repinit_args.resnet_dt, smooth=smooth, seed=child_seed(seed, 5), + trainable=trainable, ) else: self.repinit_three_body = None @@ -243,6 +245,7 @@ def init_subclass_params(sub_data, sub_class): g1_out_conv=self.repformer_args.g1_out_conv, g1_out_mlp=self.repformer_args.g1_out_mlp, seed=child_seed(seed, 1), + trainable=trainable, ) self.rcsl_list = [ (self.repformers.get_rcut(), self.repformers.get_nsel()), @@ -270,6 +273,7 @@ def init_subclass_params(sub_data, sub_class): use_econf_tebd=self.use_econf_tebd, use_tebd_bias=use_tebd_bias, type_map=type_map, + trainable=trainable, ) self.concat_output_tebd = concat_output_tebd self.precision = precision @@ -295,6 +299,7 @@ def init_subclass_params(sub_data, sub_class): precision=precision, init="glorot", seed=child_seed(seed, 3), + trainable=trainable, ) self.tebd_transform = None if self.add_tebd_to_repinit_out: @@ -304,6 +309,7 @@ def init_subclass_params(sub_data, sub_class): bias=False, precision=precision, seed=child_seed(seed, 4), + trainable=trainable, ) assert self.repinit.rcut > self.repformers.rcut assert self.repinit.sel[0] > self.repformers.sel[0] diff --git a/deepmd/pd/model/descriptor/dpa3.py b/deepmd/pd/model/descriptor/dpa3.py index 0f1a8f4c2f..e022169930 100644 --- a/deepmd/pd/model/descriptor/dpa3.py +++ b/deepmd/pd/model/descriptor/dpa3.py @@ -167,6 +167,7 @@ def init_subclass_params(sub_data, sub_class): env_protection=env_protection, precision=precision, seed=child_seed(seed, 1), + trainable=trainable, ) self.use_econf_tebd = use_econf_tebd @@ -182,6 +183,7 @@ def init_subclass_params(sub_data, sub_class): use_econf_tebd=self.use_econf_tebd, use_tebd_bias=use_tebd_bias, type_map=type_map, + trainable=trainable, ) self.concat_output_tebd = concat_output_tebd self.precision = precision diff --git a/deepmd/pd/model/descriptor/repflow_layer.py b/deepmd/pd/model/descriptor/repflow_layer.py index f1bdd0439d..d059e13775 100644 --- a/deepmd/pd/model/descriptor/repflow_layer.py +++ b/deepmd/pd/model/descriptor/repflow_layer.py @@ -61,6 +61,7 @@ def __init__( update_residual_init: str = "const", precision: str = "float64", seed: Optional[Union[int, list[int]]] = None, + trainable: bool = True, ) -> None: super().__init__() self.epsilon = 1e-4 # protection of 1./nnei @@ -123,6 +124,7 @@ def __init__( n_dim, precision=precision, seed=child_seed(seed, 0), + trainable=trainable, ) if self.update_style == "res_residual": self.n_residual.append( @@ -132,6 +134,7 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(seed, 1), + trainable=trainable, ) ) @@ -142,6 +145,7 @@ def __init__( n_dim, precision=precision, seed=child_seed(seed, 2), + trainable=trainable, ) if self.update_style == "res_residual": self.n_residual.append( @@ -151,6 +155,7 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(seed, 3), + trainable=trainable, ) ) @@ -160,6 +165,7 @@ def __init__( self.n_multi_edge_message * n_dim, precision=precision, seed=child_seed(seed, 4), + trainable=trainable, ) if self.update_style == "res_residual": for head_index in range(self.n_multi_edge_message): @@ -170,6 +176,7 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(child_seed(seed, 5), head_index), + trainable=trainable, ) ) @@ -179,6 +186,7 @@ def __init__( e_dim, precision=precision, seed=child_seed(seed, 6), + trainable=trainable, ) if self.update_style == "res_residual": self.e_residual.append( @@ -188,6 +196,7 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(seed, 7), + trainable=trainable, ) ) @@ -216,6 +225,7 @@ def __init__( precision=precision, bias=False, seed=child_seed(seed, 8), + trainable=trainable, ) self.a_compress_e_linear = MLPLayer( self.e_dim, @@ -223,6 +233,7 @@ def __init__( precision=precision, bias=False, seed=child_seed(seed, 9), + trainable=trainable, ) else: self.a_compress_n_linear = None @@ -234,12 +245,14 @@ def __init__( self.e_dim, precision=precision, seed=child_seed(seed, 10), + trainable=trainable, ) self.edge_angle_linear2 = MLPLayer( self.e_dim, self.e_dim, precision=precision, seed=child_seed(seed, 11), + trainable=trainable, ) if self.update_style == "res_residual": self.e_residual.append( @@ -249,6 +262,7 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(seed, 12), + trainable=trainable, ) ) @@ -258,6 +272,7 @@ def __init__( self.a_dim, precision=precision, seed=child_seed(seed, 13), + trainable=trainable, ) if self.update_style == "res_residual": self.a_residual.append( @@ -267,6 +282,7 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(seed, 14), + trainable=trainable, ) ) else: diff --git a/deepmd/pd/model/descriptor/repflows.py b/deepmd/pd/model/descriptor/repflows.py index 3200c26dba..04553253a1 100644 --- a/deepmd/pd/model/descriptor/repflows.py +++ b/deepmd/pd/model/descriptor/repflows.py @@ -167,6 +167,7 @@ def __init__( use_loc_mapping: bool = False, optim_update: bool = True, seed: Optional[Union[int, list[int]]] = None, + trainable: bool = True, ) -> None: super().__init__() self.e_rcut = float(e_rcut) @@ -223,10 +224,19 @@ def __init__( self.seed = seed self.edge_embd = MLPLayer( - 1, self.e_dim, precision=precision, seed=child_seed(seed, 0) + 1, + self.e_dim, + precision=precision, + seed=child_seed(seed, 0), + trainable=trainable, ) self.angle_embd = MLPLayer( - 1, self.a_dim, precision=precision, bias=False, seed=child_seed(seed, 1) + 1, + self.a_dim, + precision=precision, + bias=False, + seed=child_seed(seed, 1), + trainable=trainable, ) layers = [] for ii in range(nlayers): @@ -258,6 +268,7 @@ def __init__( sel_reduce_factor=self.sel_reduce_factor, smooth_edge_update=self.smooth_edge_update, seed=child_seed(child_seed(seed, 1), ii), + trainable=trainable, ) ) self.layers = paddle.nn.LayerList(layers) diff --git a/deepmd/pd/model/descriptor/repformer_layer.py b/deepmd/pd/model/descriptor/repformer_layer.py index b4d93d8301..4dad08fff8 100644 --- a/deepmd/pd/model/descriptor/repformer_layer.py +++ b/deepmd/pd/model/descriptor/repformer_layer.py @@ -163,6 +163,7 @@ def __init__( attnw_shift: float = 20.0, precision: str = "float64", seed: Optional[Union[int, list[int]]] = None, + trainable: bool = True, ) -> None: """Return neighbor-wise multi-head self-attention maps, with gate mechanism.""" super().__init__() @@ -175,6 +176,7 @@ def __init__( bias=False, precision=precision, seed=seed, + trainable=trainable, ) self.has_gate = has_gate self.smooth = smooth @@ -288,6 +290,7 @@ def __init__( head_num: int, precision: str = "float64", seed: Optional[Union[int, list[int]]] = None, + trainable: bool = True, ) -> None: super().__init__() self.input_dim = input_dim @@ -298,12 +301,14 @@ def __init__( bias=False, precision=precision, seed=child_seed(seed, 0), + trainable=trainable, ) self.head_map = MLPLayer( input_dim * head_num, input_dim, precision=precision, seed=child_seed(seed, 1), + trainable=trainable, ) self.precision = precision @@ -375,12 +380,18 @@ def __init__( head_num: int, precision: str = "float64", seed: Optional[Union[int, list[int]]] = None, + trainable: bool = True, ) -> None: super().__init__() self.input_dim = input_dim self.head_num = head_num self.head_map = MLPLayer( - head_num, 1, bias=False, precision=precision, seed=seed + head_num, + 1, + bias=False, + precision=precision, + seed=seed, + trainable=trainable, ) self.precision = precision @@ -448,6 +459,7 @@ def __init__( attnw_shift: float = 20.0, precision: str = "float64", seed: Optional[Union[int, list[int]]] = None, + trainable: bool = True, ) -> None: super().__init__() self.input_dim = input_dim @@ -459,6 +471,7 @@ def __init__( bias=False, precision=precision, seed=child_seed(seed, 0), + trainable=trainable, ) self.mapkv = MLPLayer( input_dim, @@ -466,12 +479,14 @@ def __init__( bias=False, precision=precision, seed=child_seed(seed, 1), + trainable=trainable, ) self.head_map = MLPLayer( input_dim * head_num, input_dim, precision=precision, seed=child_seed(seed, 2), + trainable=trainable, ) self.smooth = smooth self.attnw_shift = attnw_shift @@ -612,6 +627,7 @@ def __init__( g1_out_conv: bool = True, g1_out_mlp: bool = True, seed: Optional[Union[int, list[int]]] = None, + trainable: bool = True, ) -> None: super().__init__() self.epsilon = 1e-4 # protection of 1./nnei @@ -672,6 +688,7 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(seed, 0), + trainable=trainable, ) ) @@ -681,6 +698,7 @@ def __init__( g1_dim, precision=precision, seed=child_seed(seed, 1), + trainable=trainable, ) self.linear2 = None self.proj_g1g2 = None @@ -697,6 +715,7 @@ def __init__( g2_dim, precision=precision, seed=child_seed(seed, 2), + trainable=trainable, ) if self.update_style == "res_residual": self.g2_residual.append( @@ -706,6 +725,7 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(seed, 3), + trainable=trainable, ) ) if self.g1_out_mlp: @@ -714,6 +734,7 @@ def __init__( g1_dim, precision=precision, seed=child_seed(seed, 15), + trainable=trainable, ) if self.update_style == "res_residual": self.g1_residual.append( @@ -723,6 +744,7 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(seed, 16), + trainable=trainable, ) ) else: @@ -735,6 +757,7 @@ def __init__( bias=False, precision=precision, seed=child_seed(seed, 4), + trainable=trainable, ) else: self.proj_g1g2 = MLPLayer( @@ -743,6 +766,7 @@ def __init__( bias=False, precision=precision, seed=child_seed(seed, 4), + trainable=trainable, ) if self.update_style == "res_residual": self.g1_residual.append( @@ -752,6 +776,7 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(seed, 17), + trainable=trainable, ) ) if self.update_g2_has_g1g1: @@ -761,6 +786,7 @@ def __init__( bias=False, precision=precision, seed=child_seed(seed, 5), + trainable=trainable, ) if self.update_style == "res_residual": self.g2_residual.append( @@ -770,6 +796,7 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(seed, 6), + trainable=trainable, ) ) if self.update_g2_has_attn or self.update_h2: @@ -781,10 +808,15 @@ def __init__( self.smooth, precision=precision, seed=child_seed(seed, 7), + trainable=trainable, ) if self.update_g2_has_attn: self.attn2_mh_apply = Atten2MultiHeadApply( - g2_dim, attn2_nhead, precision=precision, seed=child_seed(seed, 8) + g2_dim, + attn2_nhead, + precision=precision, + seed=child_seed(seed, 8), + trainable=trainable, ) self.attn2_lm = LayerNorm( g2_dim, @@ -801,12 +833,17 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(seed, 10), + trainable=trainable, ) ) if self.update_h2: self.attn2_ev_apply = Atten2EquiVarApply( - g2_dim, attn2_nhead, precision=precision, seed=child_seed(seed, 11) + g2_dim, + attn2_nhead, + precision=precision, + seed=child_seed(seed, 11), + trainable=trainable, ) if self.update_style == "res_residual": self.h2_residual.append( @@ -816,6 +853,7 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(seed, 12), + trainable=trainable, ) ) if self.update_g1_has_attn: @@ -826,6 +864,7 @@ def __init__( self.smooth, precision=precision, seed=child_seed(seed, 13), + trainable=trainable, ) if self.update_style == "res_residual": self.g1_residual.append( @@ -835,6 +874,7 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(seed, 14), + trainable=trainable, ) ) diff --git a/deepmd/pd/model/descriptor/repformers.py b/deepmd/pd/model/descriptor/repformers.py index 32f88dd1d3..12778e3b1f 100644 --- a/deepmd/pd/model/descriptor/repformers.py +++ b/deepmd/pd/model/descriptor/repformers.py @@ -87,6 +87,7 @@ def __init__( use_sqrt_nnei: bool = True, g1_out_conv: bool = True, g1_out_mlp: bool = True, + trainable: bool = True, ) -> None: r""" The repformer descriptor block. @@ -223,7 +224,11 @@ def __init__( self.seed = seed self.g2_embd = MLPLayer( - 1, self.g2_dim, precision=precision, seed=child_seed(seed, 0) + 1, + self.g2_dim, + precision=precision, + seed=child_seed(seed, 0), + trainable=trainable, ) layers = [] for ii in range(nlayers): @@ -261,6 +266,7 @@ def __init__( g1_out_conv=self.g1_out_conv, g1_out_mlp=self.g1_out_mlp, seed=child_seed(child_seed(seed, 1), ii), + trainable=trainable, ) ) self.layers = paddle.nn.LayerList(layers) diff --git a/deepmd/pd/model/descriptor/se_a.py b/deepmd/pd/model/descriptor/se_a.py index 7b70a742ce..9cd9f7b0b7 100644 --- a/deepmd/pd/model/descriptor/se_a.py +++ b/deepmd/pd/model/descriptor/se_a.py @@ -481,6 +481,7 @@ def __init__( precision=self.precision, resnet_dt=self.resnet_dt, seed=child_seed(self.seed, ii), + trainable=trainable, ) self.filter_layers = filter_layers self.stats = None diff --git a/deepmd/pd/model/descriptor/se_atten.py b/deepmd/pd/model/descriptor/se_atten.py index 6bec47b12e..304ab72e3e 100644 --- a/deepmd/pd/model/descriptor/se_atten.py +++ b/deepmd/pd/model/descriptor/se_atten.py @@ -81,6 +81,7 @@ def __init__( ln_eps: Optional[float] = 1e-5, seed: Optional[Union[int, list[int]]] = None, type: Optional[str] = None, + trainable: bool = True, ) -> None: r"""Construct an embedding net of type `se_atten`. @@ -205,6 +206,7 @@ def __init__( smooth=self.smooth, precision=self.precision, seed=child_seed(self.seed, 0), + trainable=trainable, ) wanted_shape = (self.ntypes, self.nnei, 4) @@ -229,6 +231,7 @@ def __init__( precision=self.precision, resnet_dt=self.resnet_dt, seed=child_seed(self.seed, 1), + trainable=trainable, ) self.filter_layers = filter_layers if self.tebd_input_mode in ["strip"]: @@ -242,6 +245,7 @@ def __init__( precision=self.precision, resnet_dt=self.resnet_dt, seed=child_seed(self.seed, 2), + trainable=trainable, ) self.filter_layers_strip = filter_layers_strip self.stats = None @@ -655,6 +659,7 @@ def __init__( smooth: bool = True, precision: str = DEFAULT_PRECISION, seed: Optional[Union[int, list[int]]] = None, + trainable: bool = True, ) -> None: """Construct a neighbor-wise attention net.""" super().__init__() @@ -690,6 +695,7 @@ def __init__( smooth=smooth, precision=precision, seed=child_seed(seed, i), + trainable=trainable, ) ) self.attention_layers = nn.LayerList(attention_layers) @@ -797,6 +803,7 @@ def __init__( ln_eps: float = 1e-5, precision: str = DEFAULT_PRECISION, seed: Optional[Union[int, list[int]]] = None, + trainable: bool = True, ) -> None: """Construct a neighbor-wise attention layer.""" super().__init__() @@ -824,6 +831,7 @@ def __init__( smooth=smooth, precision=precision, seed=child_seed(seed, 0), + trainable=trainable, ) self.attn_layer_norm = LayerNorm( self.embed_dim, @@ -904,6 +912,7 @@ def __init__( smooth: bool = True, precision: str = DEFAULT_PRECISION, seed: Optional[Union[int, list[int]]] = None, + trainable: bool = True, ) -> None: """Construct a multi-head neighbor-wise attention net.""" super().__init__() @@ -936,6 +945,7 @@ def __init__( stddev=1.0, precision=precision, seed=child_seed(seed, 0), + trainable=trainable, ) self.out_proj = MLPLayer( hidden_dim, @@ -946,6 +956,7 @@ def __init__( stddev=1.0, precision=precision, seed=child_seed(seed, 1), + trainable=trainable, ) def forward( diff --git a/deepmd/pd/model/descriptor/se_t_tebd.py b/deepmd/pd/model/descriptor/se_t_tebd.py index 2898283f0c..3ebf62d7a5 100644 --- a/deepmd/pd/model/descriptor/se_t_tebd.py +++ b/deepmd/pd/model/descriptor/se_t_tebd.py @@ -160,6 +160,7 @@ def __init__( env_protection=env_protection, smooth=smooth, seed=child_seed(seed, 1), + trainable=trainable, ) self.prec = PRECISION_DICT[precision] self.use_econf_tebd = use_econf_tebd @@ -173,6 +174,7 @@ def __init__( use_econf_tebd=use_econf_tebd, type_map=type_map, use_tebd_bias=use_tebd_bias, + trainable=trainable, ) self.tebd_dim = tebd_dim self.tebd_input_mode = tebd_input_mode @@ -529,6 +531,7 @@ def __init__( env_protection: float = 0.0, smooth: bool = True, seed: Optional[Union[int, list[int]]] = None, + trainable: bool = True, ) -> None: super().__init__() self.rcut = float(rcut) @@ -585,6 +588,7 @@ def __init__( precision=self.precision, resnet_dt=self.resnet_dt, seed=child_seed(self.seed, 1), + trainable=trainable, ) self.filter_layers = filter_layers if self.tebd_input_mode in ["strip"]: @@ -598,6 +602,7 @@ def __init__( precision=self.precision, resnet_dt=self.resnet_dt, seed=child_seed(self.seed, 2), + trainable=trainable, ) self.filter_layers_strip = filter_layers_strip self.stats = None diff --git a/deepmd/pd/model/network/mlp.py b/deepmd/pd/model/network/mlp.py index 41286fbbae..0b413b9faf 100644 --- a/deepmd/pd/model/network/mlp.py +++ b/deepmd/pd/model/network/mlp.py @@ -85,6 +85,7 @@ def __init__( precision: str = DEFAULT_PRECISION, init: str = "default", seed: int | list[int] | None = None, + trainable: bool = True, ): super().__init__() # only use_timestep when skip connection is established. @@ -277,6 +278,7 @@ def deserialize(cls, data: dict) -> MLPLayer: activation_function=nl["activation_function"], resnet=nl["resnet"], precision=nl["precision"], + trainable=nl["trainable"], ) prec = PRECISION_DICT[obj.precision] diff --git a/deepmd/pd/model/network/network.py b/deepmd/pd/model/network/network.py index 9cdb7b3adc..0c97045ba4 100644 --- a/deepmd/pd/model/network/network.py +++ b/deepmd/pd/model/network/network.py @@ -45,6 +45,7 @@ def __init__( use_econf_tebd=False, use_tebd_bias: bool = False, type_map=None, + trainable: bool = True, ) -> None: """Construct a type embedding net.""" super().__init__() @@ -65,6 +66,7 @@ def __init__( type_map=type_map, precision=precision, seed=seed, + trainable=trainable, ) # init.normal_(self.embedding.weight[:-1], mean=bavg, std=stddev) diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index 9c1e144f48..16603dc75d 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -298,6 +298,7 @@ def __init__( trainable_ln=trainable_ln, ln_eps=ln_eps, seed=child_seed(seed, 1), + trainable=trainable, ) self.use_econf_tebd = use_econf_tebd self.use_tebd_bias = use_tebd_bias @@ -311,6 +312,7 @@ def __init__( use_econf_tebd=use_econf_tebd, use_tebd_bias=use_tebd_bias, type_map=type_map, + trainable=trainable, ) self.prec = PRECISION_DICT[precision] self.tebd_dim = tebd_dim diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index 969fdca5fc..0d6fbd84e5 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -188,6 +188,7 @@ def init_subclass_params(sub_data, sub_class): smooth=smooth, type_one_side=self.repinit_args.type_one_side, seed=child_seed(seed, 0), + trainable=trainable, ) self.use_three_body = self.repinit_args.use_three_body if self.use_three_body: @@ -207,6 +208,7 @@ def init_subclass_params(sub_data, sub_class): resnet_dt=self.repinit_args.resnet_dt, smooth=smooth, seed=child_seed(seed, 5), + trainable=trainable, ) else: self.repinit_three_body = None @@ -247,6 +249,7 @@ def init_subclass_params(sub_data, sub_class): g1_out_conv=self.repformer_args.g1_out_conv, g1_out_mlp=self.repformer_args.g1_out_mlp, seed=child_seed(seed, 1), + trainable=trainable, ) self.rcsl_list = [ (self.repformers.get_rcut(), self.repformers.get_nsel()), @@ -274,6 +277,7 @@ def init_subclass_params(sub_data, sub_class): use_econf_tebd=self.use_econf_tebd, use_tebd_bias=use_tebd_bias, type_map=type_map, + trainable=trainable, ) self.concat_output_tebd = concat_output_tebd self.precision = precision @@ -299,6 +303,7 @@ def init_subclass_params(sub_data, sub_class): precision=precision, init="glorot", seed=child_seed(seed, 3), + trainable=trainable, ) self.tebd_transform = None if self.add_tebd_to_repinit_out: @@ -308,6 +313,7 @@ def init_subclass_params(sub_data, sub_class): bias=False, precision=precision, seed=child_seed(seed, 4), + trainable=trainable, ) assert self.repinit.rcut > self.repformers.rcut assert self.repinit.sel[0] > self.repformers.sel[0] diff --git a/deepmd/pt/model/descriptor/dpa3.py b/deepmd/pt/model/descriptor/dpa3.py index 5d45c0633a..36b09230de 100644 --- a/deepmd/pt/model/descriptor/dpa3.py +++ b/deepmd/pt/model/descriptor/dpa3.py @@ -169,6 +169,7 @@ def init_subclass_params(sub_data, sub_class): env_protection=env_protection, precision=precision, seed=child_seed(seed, 1), + trainable=trainable, ) self.use_econf_tebd = use_econf_tebd @@ -184,6 +185,7 @@ def init_subclass_params(sub_data, sub_class): use_econf_tebd=self.use_econf_tebd, use_tebd_bias=use_tebd_bias, type_map=type_map, + trainable=trainable, ) self.concat_output_tebd = concat_output_tebd self.precision = precision diff --git a/deepmd/pt/model/descriptor/repflow_layer.py b/deepmd/pt/model/descriptor/repflow_layer.py index 37d4f07bb4..a52e5eba30 100644 --- a/deepmd/pt/model/descriptor/repflow_layer.py +++ b/deepmd/pt/model/descriptor/repflow_layer.py @@ -64,6 +64,7 @@ def __init__( update_residual_init: str = "const", precision: str = "float64", seed: Optional[Union[int, list[int]]] = None, + trainable: bool = True, ) -> None: super().__init__() self.epsilon = 1e-4 # protection of 1./nnei @@ -126,6 +127,7 @@ def __init__( n_dim, precision=precision, seed=child_seed(seed, 0), + trainable=trainable, ) if self.update_style == "res_residual": self.n_residual.append( @@ -135,6 +137,7 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(seed, 1), + trainable=trainable, ) ) @@ -145,6 +148,7 @@ def __init__( n_dim, precision=precision, seed=child_seed(seed, 2), + trainable=trainable, ) if self.update_style == "res_residual": self.n_residual.append( @@ -154,6 +158,7 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(seed, 3), + trainable=trainable, ) ) @@ -163,6 +168,7 @@ def __init__( self.n_multi_edge_message * n_dim, precision=precision, seed=child_seed(seed, 4), + trainable=trainable, ) if self.update_style == "res_residual": for head_index in range(self.n_multi_edge_message): @@ -173,6 +179,7 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(child_seed(seed, 5), head_index), + trainable=trainable, ) ) @@ -182,6 +189,7 @@ def __init__( e_dim, precision=precision, seed=child_seed(seed, 6), + trainable=trainable, ) if self.update_style == "res_residual": self.e_residual.append( @@ -191,6 +199,7 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(seed, 7), + trainable=trainable, ) ) @@ -219,6 +228,7 @@ def __init__( precision=precision, bias=False, seed=child_seed(seed, 8), + trainable=trainable, ) self.a_compress_e_linear = MLPLayer( self.e_dim, @@ -226,6 +236,7 @@ def __init__( precision=precision, bias=False, seed=child_seed(seed, 9), + trainable=trainable, ) else: self.a_compress_n_linear = None @@ -237,12 +248,14 @@ def __init__( self.e_dim, precision=precision, seed=child_seed(seed, 10), + trainable=trainable, ) self.edge_angle_linear2 = MLPLayer( self.e_dim, self.e_dim, precision=precision, seed=child_seed(seed, 11), + trainable=trainable, ) if self.update_style == "res_residual": self.e_residual.append( @@ -252,6 +265,7 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(seed, 12), + trainable=trainable, ) ) @@ -261,6 +275,7 @@ def __init__( self.a_dim, precision=precision, seed=child_seed(seed, 13), + trainable=trainable, ) if self.update_style == "res_residual": self.a_residual.append( @@ -270,6 +285,7 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(seed, 14), + trainable=trainable, ) ) else: diff --git a/deepmd/pt/model/descriptor/repflows.py b/deepmd/pt/model/descriptor/repflows.py index 5889b0a819..4d4cb4e748 100644 --- a/deepmd/pt/model/descriptor/repflows.py +++ b/deepmd/pt/model/descriptor/repflows.py @@ -219,6 +219,7 @@ def __init__( use_loc_mapping: bool = True, optim_update: bool = True, seed: Optional[Union[int, list[int]]] = None, + trainable: bool = True, ) -> None: super().__init__() self.e_rcut = float(e_rcut) @@ -283,10 +284,19 @@ def __init__( self.seed = seed self.edge_embd = MLPLayer( - 1, self.e_dim, precision=precision, seed=child_seed(seed, 0) + 1, + self.e_dim, + precision=precision, + seed=child_seed(seed, 0), + trainable=trainable, ) self.angle_embd = MLPLayer( - 1, self.a_dim, precision=precision, bias=False, seed=child_seed(seed, 1) + 1, + self.a_dim, + precision=precision, + bias=False, + seed=child_seed(seed, 1), + trainable=trainable, ) layers = [] for ii in range(nlayers): @@ -318,6 +328,7 @@ def __init__( sel_reduce_factor=self.sel_reduce_factor, smooth_edge_update=self.smooth_edge_update, seed=child_seed(child_seed(seed, 1), ii), + trainable=trainable, ) ) self.layers = torch.nn.ModuleList(layers) diff --git a/deepmd/pt/model/descriptor/repformer_layer.py b/deepmd/pt/model/descriptor/repformer_layer.py index 1e2cba66d6..9715b7479b 100644 --- a/deepmd/pt/model/descriptor/repformer_layer.py +++ b/deepmd/pt/model/descriptor/repformer_layer.py @@ -160,6 +160,7 @@ def __init__( attnw_shift: float = 20.0, precision: str = "float64", seed: Optional[Union[int, list[int]]] = None, + trainable: bool = True, ) -> None: """Return neighbor-wise multi-head self-attention maps, with gate mechanism.""" super().__init__() @@ -172,6 +173,7 @@ def __init__( bias=False, precision=precision, seed=seed, + trainable=trainable, ) self.has_gate = has_gate self.smooth = smooth @@ -285,6 +287,7 @@ def __init__( head_num: int, precision: str = "float64", seed: Optional[Union[int, list[int]]] = None, + trainable: bool = True, ) -> None: super().__init__() self.input_dim = input_dim @@ -295,12 +298,14 @@ def __init__( bias=False, precision=precision, seed=child_seed(seed, 0), + trainable=trainable, ) self.head_map = MLPLayer( input_dim * head_num, input_dim, precision=precision, seed=child_seed(seed, 1), + trainable=trainable, ) self.precision = precision @@ -370,12 +375,18 @@ def __init__( head_num: int, precision: str = "float64", seed: Optional[Union[int, list[int]]] = None, + trainable: bool = True, ) -> None: super().__init__() self.input_dim = input_dim self.head_num = head_num self.head_map = MLPLayer( - head_num, 1, bias=False, precision=precision, seed=seed + head_num, + 1, + bias=False, + precision=precision, + seed=seed, + trainable=trainable, ) self.precision = precision @@ -443,6 +454,7 @@ def __init__( attnw_shift: float = 20.0, precision: str = "float64", seed: Optional[Union[int, list[int]]] = None, + trainable: bool = True, ) -> None: super().__init__() self.input_dim = input_dim @@ -454,6 +466,7 @@ def __init__( bias=False, precision=precision, seed=child_seed(seed, 0), + trainable=trainable, ) self.mapkv = MLPLayer( input_dim, @@ -461,12 +474,14 @@ def __init__( bias=False, precision=precision, seed=child_seed(seed, 1), + trainable=trainable, ) self.head_map = MLPLayer( input_dim * head_num, input_dim, precision=precision, seed=child_seed(seed, 2), + trainable=trainable, ) self.smooth = smooth self.attnw_shift = attnw_shift @@ -602,6 +617,7 @@ def __init__( g1_out_conv: bool = True, g1_out_mlp: bool = True, seed: Optional[Union[int, list[int]]] = None, + trainable: bool = True, ) -> None: super().__init__() self.epsilon = 1e-4 # protection of 1./nnei @@ -662,6 +678,7 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(seed, 0), + trainable=trainable, ) ) @@ -671,6 +688,7 @@ def __init__( g1_dim, precision=precision, seed=child_seed(seed, 1), + trainable=trainable, ) self.linear2 = None self.proj_g1g2 = None @@ -687,6 +705,7 @@ def __init__( g2_dim, precision=precision, seed=child_seed(seed, 2), + trainable=trainable, ) if self.update_style == "res_residual": self.g2_residual.append( @@ -696,6 +715,7 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(seed, 3), + trainable=trainable, ) ) if self.g1_out_mlp: @@ -704,6 +724,7 @@ def __init__( g1_dim, precision=precision, seed=child_seed(seed, 15), + trainable=trainable, ) if self.update_style == "res_residual": self.g1_residual.append( @@ -713,6 +734,7 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(seed, 16), + trainable=trainable, ) ) else: @@ -725,6 +747,7 @@ def __init__( bias=False, precision=precision, seed=child_seed(seed, 4), + trainable=trainable, ) else: self.proj_g1g2 = MLPLayer( @@ -733,6 +756,7 @@ def __init__( bias=False, precision=precision, seed=child_seed(seed, 4), + trainable=trainable, ) if self.update_style == "res_residual": self.g1_residual.append( @@ -742,6 +766,7 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(seed, 17), + trainable=trainable, ) ) if self.update_g2_has_g1g1: @@ -751,6 +776,7 @@ def __init__( bias=False, precision=precision, seed=child_seed(seed, 5), + trainable=trainable, ) if self.update_style == "res_residual": self.g2_residual.append( @@ -760,6 +786,7 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(seed, 6), + trainable=trainable, ) ) if self.update_g2_has_attn or self.update_h2: @@ -771,10 +798,15 @@ def __init__( self.smooth, precision=precision, seed=child_seed(seed, 7), + trainable=trainable, ) if self.update_g2_has_attn: self.attn2_mh_apply = Atten2MultiHeadApply( - g2_dim, attn2_nhead, precision=precision, seed=child_seed(seed, 8) + g2_dim, + attn2_nhead, + precision=precision, + seed=child_seed(seed, 8), + trainable=trainable, ) self.attn2_lm = LayerNorm( g2_dim, @@ -791,12 +823,17 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(seed, 10), + trainable=trainable, ) ) if self.update_h2: self.attn2_ev_apply = Atten2EquiVarApply( - g2_dim, attn2_nhead, precision=precision, seed=child_seed(seed, 11) + g2_dim, + attn2_nhead, + precision=precision, + seed=child_seed(seed, 11), + trainable=trainable, ) if self.update_style == "res_residual": self.h2_residual.append( @@ -806,6 +843,7 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(seed, 12), + trainable=trainable, ) ) if self.update_g1_has_attn: @@ -816,6 +854,7 @@ def __init__( self.smooth, precision=precision, seed=child_seed(seed, 13), + trainable=trainable, ) if self.update_style == "res_residual": self.g1_residual.append( @@ -825,6 +864,7 @@ def __init__( self.update_residual_init, precision=precision, seed=child_seed(seed, 14), + trainable=trainable, ) ) diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index 82773d1a78..022c7510df 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -111,6 +111,7 @@ def __init__( use_sqrt_nnei: bool = True, g1_out_conv: bool = True, g1_out_mlp: bool = True, + trainable: bool = True, ) -> None: r""" The repformer descriptor block. @@ -197,6 +198,8 @@ def __init__( The epsilon value for layer normalization. seed : int, optional Random seed for parameter initialization. + trainable : bool + Whether the block is trainable """ super().__init__() self.rcut = float(rcut) @@ -247,7 +250,11 @@ def __init__( self.seed = seed self.g2_embd = MLPLayer( - 1, self.g2_dim, precision=precision, seed=child_seed(seed, 0) + 1, + self.g2_dim, + precision=precision, + seed=child_seed(seed, 0), + trainable=trainable, ) layers = [] for ii in range(nlayers): @@ -285,6 +292,7 @@ def __init__( g1_out_conv=self.g1_out_conv, g1_out_mlp=self.g1_out_mlp, seed=child_seed(child_seed(seed, 1), ii), + trainable=trainable, ) ) self.layers = torch.nn.ModuleList(layers) diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index fc3e14bd25..f49b5a1276 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -525,6 +525,7 @@ def __init__( precision=self.precision, resnet_dt=self.resnet_dt, seed=child_seed(self.seed, ii), + trainable=trainable, ) self.filter_layers = filter_layers self.stats = None diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 1ce6ad4583..0c18bcc8a8 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -100,6 +100,7 @@ def __init__( ln_eps: Optional[float] = 1e-5, seed: Optional[Union[int, list[int]]] = None, type: Optional[str] = None, + trainable: bool = True, ) -> None: r"""Construct an embedding net of type `se_atten`. @@ -224,6 +225,7 @@ def __init__( smooth=self.smooth, precision=self.precision, seed=child_seed(self.seed, 0), + trainable=trainable, ) wanted_shape = (self.ntypes, self.nnei, 4) @@ -248,6 +250,7 @@ def __init__( precision=self.precision, resnet_dt=self.resnet_dt, seed=child_seed(self.seed, 1), + trainable=trainable, ) self.filter_layers = filter_layers if self.tebd_input_mode in ["strip"]: @@ -261,6 +264,7 @@ def __init__( precision=self.precision, resnet_dt=self.resnet_dt, seed=child_seed(self.seed, 2), + trainable=trainable, ) self.filter_layers_strip = filter_layers_strip self.stats = None @@ -680,6 +684,7 @@ def __init__( smooth: bool = True, precision: str = DEFAULT_PRECISION, seed: Optional[Union[int, list[int]]] = None, + trainable: bool = True, ) -> None: """Construct a neighbor-wise attention net.""" super().__init__() @@ -715,6 +720,7 @@ def __init__( smooth=smooth, precision=precision, seed=child_seed(seed, i), + trainable=trainable, ) ) self.attention_layers = nn.ModuleList(attention_layers) @@ -823,6 +829,7 @@ def __init__( ln_eps: float = 1e-5, precision: str = DEFAULT_PRECISION, seed: Optional[Union[int, list[int]]] = None, + trainable: bool = True, ) -> None: """Construct a neighbor-wise attention layer.""" super().__init__() @@ -850,6 +857,7 @@ def __init__( smooth=smooth, precision=precision, seed=child_seed(seed, 0), + trainable=trainable, ) self.attn_layer_norm = LayerNorm( self.embed_dim, @@ -930,6 +938,7 @@ def __init__( smooth: bool = True, precision: str = DEFAULT_PRECISION, seed: Optional[Union[int, list[int]]] = None, + trainable: bool = True, ) -> None: """Construct a multi-head neighbor-wise attention net.""" super().__init__() @@ -962,6 +971,7 @@ def __init__( stddev=1.0, precision=precision, seed=child_seed(seed, 0), + trainable=trainable, ) self.out_proj = MLPLayer( hidden_dim, @@ -972,6 +982,7 @@ def __init__( stddev=1.0, precision=precision, seed=child_seed(seed, 1), + trainable=trainable, ) def forward( diff --git a/deepmd/pt/model/descriptor/se_r.py b/deepmd/pt/model/descriptor/se_r.py index a91757460c..9ce92fb8b4 100644 --- a/deepmd/pt/model/descriptor/se_r.py +++ b/deepmd/pt/model/descriptor/se_r.py @@ -142,6 +142,7 @@ def __init__( precision=self.precision, resnet_dt=self.resnet_dt, seed=child_seed(self.seed, ii), + trainable=trainable, ) self.filter_layers = filter_layers self.stats = None diff --git a/deepmd/pt/model/descriptor/se_t.py b/deepmd/pt/model/descriptor/se_t.py index 6e075a04e4..f3bd0f65ef 100644 --- a/deepmd/pt/model/descriptor/se_t.py +++ b/deepmd/pt/model/descriptor/se_t.py @@ -575,6 +575,7 @@ def __init__( precision=self.precision, resnet_dt=self.resnet_dt, seed=child_seed(self.seed, ii), + trainable=trainable, ) self.filter_layers = filter_layers self.stats = None diff --git a/deepmd/pt/model/descriptor/se_t_tebd.py b/deepmd/pt/model/descriptor/se_t_tebd.py index 7e27805bd5..3ee7929151 100644 --- a/deepmd/pt/model/descriptor/se_t_tebd.py +++ b/deepmd/pt/model/descriptor/se_t_tebd.py @@ -160,6 +160,7 @@ def __init__( env_protection=env_protection, smooth=smooth, seed=child_seed(seed, 1), + trainable=trainable, ) self.prec = PRECISION_DICT[precision] self.use_econf_tebd = use_econf_tebd @@ -170,6 +171,7 @@ def __init__( tebd_dim, precision=precision, seed=child_seed(seed, 2), + trainable=trainable, use_econf_tebd=use_econf_tebd, type_map=type_map, use_tebd_bias=use_tebd_bias, @@ -525,6 +527,7 @@ def __init__( env_protection: float = 0.0, smooth: bool = True, seed: Optional[Union[int, list[int]]] = None, + trainable: bool = True, ) -> None: super().__init__() self.rcut = float(rcut) @@ -577,6 +580,7 @@ def __init__( precision=self.precision, resnet_dt=self.resnet_dt, seed=child_seed(self.seed, 1), + trainable=trainable, ) self.filter_layers = filter_layers if self.tebd_input_mode in ["strip"]: @@ -590,6 +594,7 @@ def __init__( precision=self.precision, resnet_dt=self.resnet_dt, seed=child_seed(self.seed, 2), + trainable=trainable, ) self.filter_layers_strip = filter_layers_strip self.stats = None diff --git a/deepmd/pt/model/network/mlp.py b/deepmd/pt/model/network/mlp.py index 22675d6163..ea07f617d4 100644 --- a/deepmd/pt/model/network/mlp.py +++ b/deepmd/pt/model/network/mlp.py @@ -83,8 +83,10 @@ def __init__( precision: str = DEFAULT_PRECISION, init: str = "default", seed: Optional[Union[int, list[int]]] = None, + trainable: bool = True, ) -> None: super().__init__() + self.trainable = trainable # only use_timestep when skip connection is established. self.use_timestep = use_timestep and ( num_out == num_in or num_out == num_in * 2 @@ -233,6 +235,7 @@ def serialize(self) -> dict: activation_function=self.activate_name, resnet=self.resnet, precision=self.precision, + trainable=self.trainable, ) nl.w, nl.b, nl.idt = ( to_numpy_array(self.matrix), @@ -259,6 +262,7 @@ def deserialize(cls, data: dict) -> "MLPLayer": activation_function=nl["activation_function"], resnet=nl["resnet"], precision=nl["precision"], + trainable=nl["trainable"], ) prec = PRECISION_DICT[obj.precision] diff --git a/deepmd/pt/model/network/network.py b/deepmd/pt/model/network/network.py index ab01a90774..71f335e446 100644 --- a/deepmd/pt/model/network/network.py +++ b/deepmd/pt/model/network/network.py @@ -253,6 +253,7 @@ def __init__( use_econf_tebd=False, use_tebd_bias: bool = False, type_map=None, + trainable: bool = True, ) -> None: """Construct a type embedding net.""" super().__init__() @@ -273,6 +274,7 @@ def __init__( type_map=type_map, precision=precision, seed=seed, + trainable=trainable, ) # nn.init.normal_(self.embedding.weight[:-1], mean=bavg, std=stddev) diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index 0865b61f52..6d0f3041dc 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -320,6 +320,7 @@ def __init__( self.precision, bias_out=True, seed=child_seed(self.seed, ii), + trainable=trainable, ) for ii in range(self.ntypes if not self.mixed_types else 1) ], diff --git a/deepmd/tf/descriptor/se.py b/deepmd/tf/descriptor/se.py index 2863704143..5b04c5ba00 100644 --- a/deepmd/tf/descriptor/se.py +++ b/deepmd/tf/descriptor/se.py @@ -192,6 +192,7 @@ def serialize_network( resnet_dt: bool, variables: dict, excluded_types: set[tuple[int, int]] = set(), + trainable: bool = True, suffix: str = "", ) -> dict: """Serialize network. @@ -214,6 +215,8 @@ def serialize_network( The input variables excluded_types : set[tuple[int, int]], optional The excluded types + trainable : bool + Whether the network is trainable suffix : str, optional The suffix of the scope @@ -236,6 +239,7 @@ def serialize_network( activation_function=activation_function, resnet_dt=resnet_dt, precision=self.precision.name, + trainable=trainable, ) embeddings[(type_j, type_i)] = EmbeddingNet( in_dim=in_dim, @@ -243,6 +247,7 @@ def serialize_network( activation_function=activation_function, resnet_dt=resnet_dt, precision=self.precision.name, + trainable=trainable, ) embeddings[(type_i, type_j)].clear() embeddings[(type_j, type_i)].clear() @@ -278,6 +283,7 @@ def serialize_network( activation_function=activation_function, resnet_dt=resnet_dt, precision=self.precision.name, + trainable=trainable, ) assert embeddings[network_idx] is not None if weight_name == "idt": diff --git a/deepmd/tf/descriptor/se_atten.py b/deepmd/tf/descriptor/se_atten.py index 3a9b86a0d6..6049512caf 100644 --- a/deepmd/tf/descriptor/se_atten.py +++ b/deepmd/tf/descriptor/se_atten.py @@ -1593,6 +1593,7 @@ def serialize_attention_layers( bias=bias, use_timestep=False, precision=self.precision.name, + trainable=self.trainable, ) matrix_list = [ attention_layer_params[layer_idx][key]["matrix"] @@ -1611,6 +1612,7 @@ def serialize_attention_layers( bias=bias, use_timestep=False, precision=self.precision.name, + trainable=self.trainable, ) out_proj["matrix"] = attention_layer_params[layer_idx]["c_out"]["matrix"] if bias: @@ -1654,6 +1656,7 @@ def serialize_network_strip( variables: dict, suffix: str = "", type_one_side: bool = False, + trainable: bool = True, ) -> dict: """Serialize network. @@ -1679,6 +1682,8 @@ def serialize_network_strip( If 'False', type embeddings of both neighbor and central atoms are considered. If 'True', only type embeddings of neighbor atoms are considered. Default is 'False'. + trainable : bool + Whether the network is trainable Returns ------- @@ -1719,6 +1724,7 @@ def serialize_network_strip( activation_function=activation_function, resnet_dt=resnet_dt, precision=self.precision.name, + trainable=trainable, ) assert embeddings[network_idx] is not None if weight_name == "idt": diff --git a/deepmd/tf/descriptor/se_t.py b/deepmd/tf/descriptor/se_t.py index c5d50744af..60ce025902 100644 --- a/deepmd/tf/descriptor/se_t.py +++ b/deepmd/tf/descriptor/se_t.py @@ -726,6 +726,7 @@ def serialize_network( resnet_dt: bool, variables: dict, excluded_types: set[tuple[int, int]] = set(), + trainable: bool = True, suffix: str = "", ) -> dict: """Serialize network. @@ -748,6 +749,8 @@ def serialize_network( The input variables excluded_types : set[tuple[int, int]], optional The excluded types + trainable : bool, optional + Whether the network is trainable suffix : str, optional The suffix of the scope @@ -771,6 +774,7 @@ def clear_ij(type_i, type_j) -> None: activation_function=activation_function, resnet_dt=resnet_dt, precision=self.precision.name, + trainable=trainable, ) embeddings[(type_i, type_j)].clear() @@ -805,6 +809,7 @@ def clear_ij(type_i, type_j) -> None: activation_function=activation_function, resnet_dt=resnet_dt, precision=self.precision.name, + trainable=trainable, ) assert embeddings[network_idx] is not None if weight_name == "idt": diff --git a/deepmd/tf/fit/fitting.py b/deepmd/tf/fit/fitting.py index f159de1628..878e4e93bc 100644 --- a/deepmd/tf/fit/fitting.py +++ b/deepmd/tf/fit/fitting.py @@ -135,6 +135,7 @@ def serialize_network( resnet_dt: bool, variables: dict, out_dim: Optional[int] = 1, + trainable: bool = True, suffix: str = "", ) -> dict: """Serialize network. @@ -199,6 +200,7 @@ def serialize_network( resnet_dt=resnet_dt, precision=self.precision.name, bias_out=True, + trainable=trainable, ) assert fittings[network_idx] is not None if weight_name == "idt": diff --git a/source/tests/consistent/descriptor/test_dpa1.py b/source/tests/consistent/descriptor/test_dpa1.py index db5fe4dae0..d31cf289b9 100644 --- a/source/tests/consistent/descriptor/test_dpa1.py +++ b/source/tests/consistent/descriptor/test_dpa1.py @@ -127,6 +127,7 @@ def data(self) -> dict: "use_tebd_bias": use_tebd_bias, "type_map": ["O", "H"] if use_econf_tebd else None, "seed": 1145141919810, + "trainable": False, } def is_meaningless_zero_attention_layer_tests( diff --git a/source/tests/consistent/descriptor/test_dpa2.py b/source/tests/consistent/descriptor/test_dpa2.py index ef840bf9d7..6864d91f26 100644 --- a/source/tests/consistent/descriptor/test_dpa2.py +++ b/source/tests/consistent/descriptor/test_dpa2.py @@ -181,7 +181,7 @@ def data(self) -> dict: "smooth": smooth, "exclude_types": exclude_types, "env_protection": 0.0, - "trainable": True, + "trainable": False, "use_econf_tebd": use_econf_tebd, "use_tebd_bias": use_tebd_bias, "type_map": ["O", "H"] if use_econf_tebd else None, diff --git a/source/tests/consistent/descriptor/test_dpa3.py b/source/tests/consistent/descriptor/test_dpa3.py index b99117b9e7..47cc4e1004 100644 --- a/source/tests/consistent/descriptor/test_dpa3.py +++ b/source/tests/consistent/descriptor/test_dpa3.py @@ -130,7 +130,7 @@ def data(self) -> dict: "exclude_types": exclude_types, "env_protection": 0.0, "use_loc_mapping": use_loc_mapping, - "trainable": True, + "trainable": False, } @property From 7d7e043f00550854a04919af1f05b2c9cf2555af Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 10 Jun 2025 19:49:40 +0800 Subject: [PATCH 19/32] bump the version of Layer data Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/utils/network.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index 4d37e2ee5d..6eb651fdf5 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -144,7 +144,7 @@ def serialize(self) -> dict: } return { "@class": "Layer", - "@version": 1, + "@version": 2, "bias": self.b is not None, "use_timestep": self.idt is not None, "activation_function": self.activation_function, @@ -165,7 +165,7 @@ def deserialize(cls, data: dict) -> "NativeLayer": The dict to deserialize from. """ data = data.copy() - check_version_compatibility(data.pop("@version", 1), 1, 1) + check_version_compatibility(data.pop("@version", 1), 2, 1) data.pop("@class", None) variables = data.pop("@variables") assert variables["w"] is not None and len(variables["w"].shape) == 2 From 77cc091bf711c9186fef1d2a513fe576a4d07a47 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 10 Jun 2025 20:49:56 +0800 Subject: [PATCH 20/32] fix pd trainable Signed-off-by: Jinzhe Zeng --- deepmd/pd/model/descriptor/repflows.py | 2 ++ deepmd/pd/model/descriptor/repformers.py | 2 ++ deepmd/pd/model/descriptor/se_atten.py | 2 ++ deepmd/pd/model/network/mlp.py | 2 ++ deepmd/pd/model/network/network.py | 1 + deepmd/pt/model/descriptor/repflows.py | 2 ++ deepmd/pt/model/descriptor/se_atten.py | 2 ++ 7 files changed, 13 insertions(+) diff --git a/deepmd/pd/model/descriptor/repflows.py b/deepmd/pd/model/descriptor/repflows.py index 04553253a1..f00bdd2cb5 100644 --- a/deepmd/pd/model/descriptor/repflows.py +++ b/deepmd/pd/model/descriptor/repflows.py @@ -131,6 +131,8 @@ class DescrptBlockRepflows(DescriptorBlock): For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection. seed : int, optional Random seed for parameter initialization. + trainable : bool, default: True + Whether this block is trainable """ def __init__( diff --git a/deepmd/pd/model/descriptor/repformers.py b/deepmd/pd/model/descriptor/repformers.py index 12778e3b1f..0c197b3092 100644 --- a/deepmd/pd/model/descriptor/repformers.py +++ b/deepmd/pd/model/descriptor/repformers.py @@ -174,6 +174,8 @@ def __init__( The epsilon value for layer normalization. seed : int, optional Random seed for parameter initialization. + trainable : bool, default: True + Whether this block is trainable """ super().__init__() self.rcut = float(rcut) diff --git a/deepmd/pd/model/descriptor/se_atten.py b/deepmd/pd/model/descriptor/se_atten.py index 304ab72e3e..788ab211a7 100644 --- a/deepmd/pd/model/descriptor/se_atten.py +++ b/deepmd/pd/model/descriptor/se_atten.py @@ -147,6 +147,8 @@ def __init__( If not None, the scaling of attention weights is `temperature` itself. seed : int, Optional Random seed for parameter initialization. + trainable : bool, default: True + Whether this block is trainable """ super().__init__() del type diff --git a/deepmd/pd/model/network/mlp.py b/deepmd/pd/model/network/mlp.py index 0b413b9faf..ee408b8719 100644 --- a/deepmd/pd/model/network/mlp.py +++ b/deepmd/pd/model/network/mlp.py @@ -88,6 +88,7 @@ def __init__( trainable: bool = True, ): super().__init__() + self.trainable = trainable # only use_timestep when skip connection is established. self.use_timestep = use_timestep and ( num_out == num_in or num_out == num_in * 2 @@ -252,6 +253,7 @@ def serialize(self) -> dict: activation_function=self.activate_name, resnet=self.resnet, precision=self.precision, + trainable=self.trainable, ) nl.w, nl.b, nl.idt = ( to_numpy_array(self.matrix), diff --git a/deepmd/pd/model/network/network.py b/deepmd/pd/model/network/network.py index 0c97045ba4..320fc55eed 100644 --- a/deepmd/pd/model/network/network.py +++ b/deepmd/pd/model/network/network.py @@ -197,6 +197,7 @@ def __init__( self.precision, self.seed, bias=self.use_tebd_bias, + trainable=trainable, ) for param in self.parameters(): param.stop_gradient = not trainable diff --git a/deepmd/pt/model/descriptor/repflows.py b/deepmd/pt/model/descriptor/repflows.py index 4d4cb4e748..d6b38a7f20 100644 --- a/deepmd/pt/model/descriptor/repflows.py +++ b/deepmd/pt/model/descriptor/repflows.py @@ -181,6 +181,8 @@ class DescrptBlockRepflows(DescriptorBlock): For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection. seed : int, optional Random seed for parameter initialization. + trainable : bool, default: True + Whether this block is trainable """ def __init__( diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 0c18bcc8a8..27c5716919 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -166,6 +166,8 @@ def __init__( If not None, the scaling of attention weights is `temperature` itself. seed : int, Optional Random seed for parameter initialization. + trainable : bool, default: True + Whether this block is trainable """ super().__init__() del type From 6bd237d3b4efd8f7aeca960473789c2e7554d55f Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 10 Jun 2025 21:31:31 +0800 Subject: [PATCH 21/32] fix(jax): fix DPA3 force NaN with edge_init_use_dist Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/descriptor/repflows.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/dpmodel/descriptor/repflows.py b/deepmd/dpmodel/descriptor/repflows.py index f8c329b515..0a39e4c596 100644 --- a/deepmd/dpmodel/descriptor/repflows.py +++ b/deepmd/dpmodel/descriptor/repflows.py @@ -526,7 +526,7 @@ def call( # edge_input, h2 = xp.split(dmatrix, [1], axis=-1) # nb x nloc x nnei x 1 if self.edge_init_use_dist: - edge_input = xp.linalg.vector_norm(diff, axis=-1, keepdims=True) + edge_input = safe_for_vector_norm(diff, axis=-1, keepdims=True) else: edge_input = dmatrix[:, :, :, :1] h2 = dmatrix[:, :, :, 1:] From 243356601a1b77ce50956b13338feb000b16485d Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 17 Jun 2025 22:28:56 +0800 Subject: [PATCH 22/32] fix(jax): use more safe_for_vector_norm Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/descriptor/repflows.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/dpmodel/descriptor/repflows.py b/deepmd/dpmodel/descriptor/repflows.py index 28fa4a0549..6fee687c43 100644 --- a/deepmd/dpmodel/descriptor/repflows.py +++ b/deepmd/dpmodel/descriptor/repflows.py @@ -503,7 +503,7 @@ def call( sw = xp.where(nlist_mask, sw, xp.zeros_like(sw)) # get angle nlist (maybe smaller) - a_dist_mask = (xp.linalg.vector_norm(diff, axis=-1) < self.a_rcut)[ + a_dist_mask = (safe_for_vector_norm(diff, axis=-1) < self.a_rcut)[ :, :, : self.a_sel ] a_nlist = nlist[:, :, : self.a_sel] From dfaf6fbfcbf14f99f0298f4c537804d89f8ab145 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 20 Jun 2025 11:01:11 +0800 Subject: [PATCH 23/32] fix nopbc behavior Signed-off-by: Jinzhe Zeng --- deepmd/jax/train/trainer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/deepmd/jax/train/trainer.py b/deepmd/jax/train/trainer.py index 408f67da9f..fd6d700011 100644 --- a/deepmd/jax/train/trainer.py +++ b/deepmd/jax/train/trainer.py @@ -169,6 +169,9 @@ def train(self, train_data, valid_data=None) -> None: } for ii in range(train_data.get_nsystems()) ] + for ii, single_data in enumerate(all_stat_sys): + if not train_data.data_systems[ii].pbc: + single_data["box"] = None model.atomic_model.descriptor.compute_input_stats(all_stat_sys) model.atomic_model.fitting.compute_output_stats(all_stat) @@ -279,7 +282,7 @@ def train_step( sel=model.get_sel(), coord=jax_data["coord"], atype=jax_data["type"], - box=jax_data["box"] if jax_data["find_box"] else None, + box=jax_data["box"] if jax_data["default_mesh"].size > 0 else None, fparam=jax_data.get("fparam", None), aparam=jax_data.get("aparam", None), ) From b76fb837e058c085290853677abb2e36ef88feaa Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 20 Jun 2025 11:52:30 +0800 Subject: [PATCH 24/32] should be >1 Signed-off-by: Jinzhe Zeng --- deepmd/jax/train/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/jax/train/trainer.py b/deepmd/jax/train/trainer.py index fd6d700011..c9af915218 100644 --- a/deepmd/jax/train/trainer.py +++ b/deepmd/jax/train/trainer.py @@ -282,7 +282,7 @@ def train_step( sel=model.get_sel(), coord=jax_data["coord"], atype=jax_data["type"], - box=jax_data["box"] if jax_data["default_mesh"].size > 0 else None, + box=jax_data["box"] if jax_data["default_mesh"].size > 1 else None, fparam=jax_data.get("fparam", None), aparam=jax_data.get("aparam", None), ) From 4123ac2e93afd7d4f6d6c3ba8651c58b91e93d0b Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 25 Jun 2025 14:11:54 +0800 Subject: [PATCH 25/32] freeze with hessian Signed-off-by: Jinzhe Zeng --- deepmd/jax/entrypoints/freeze.py | 5 ++++- deepmd/jax/infer/deep_eval.py | 7 ++++++- deepmd/jax/jax2tf/serialization.py | 7 ++++++- deepmd/jax/model/hlo.py | 10 +++++++++- deepmd/jax/utils/serialization.py | 10 +++++++++- deepmd/main.py | 6 ++++++ 6 files changed, 40 insertions(+), 5 deletions(-) diff --git a/deepmd/jax/entrypoints/freeze.py b/deepmd/jax/entrypoints/freeze.py index bd283e8681..f9f6ac6a1f 100644 --- a/deepmd/jax/entrypoints/freeze.py +++ b/deepmd/jax/entrypoints/freeze.py @@ -13,6 +13,7 @@ def freeze( *, checkpoint_folder: str, output: str, + hessian: bool = False, **kwargs, ) -> None: """Freeze the graph in supplied folder. @@ -23,6 +24,8 @@ def freeze( location of either the folder with checkpoint or the checkpoint prefix output : str output file name + hessian : bool, optional + whether to freeze the hessian, by default False **kwargs other arguments """ @@ -31,6 +34,6 @@ def freeze( checkpoint_folder = checkpoint_meta.read_text().strip() if Path(checkpoint_folder).is_dir(): data = serialize_from_file(checkpoint_folder) - deserialize_to_file(output, data) + deserialize_to_file(output, data, hessian=hessian) else: raise FileNotFoundError(f"Checkpoint {checkpoint_folder} does not exist.") diff --git a/deepmd/jax/infer/deep_eval.py b/deepmd/jax/infer/deep_eval.py index acfd42b66a..125ae7667a 100644 --- a/deepmd/jax/infer/deep_eval.py +++ b/deepmd/jax/infer/deep_eval.py @@ -279,6 +279,7 @@ def _get_request_defs(self, atomic: bool) -> list[OutputVariableDef]: OutputVariableCategory.REDU, OutputVariableCategory.DERV_R, OutputVariableCategory.DERV_C_REDU, + OutputVariableCategory.DERV_R_DERV_R, ) ] @@ -419,4 +420,8 @@ def _get_output_shape(self, odef, nframes, natoms): def get_model_def_script(self) -> dict: """Get model definition script.""" - return json.loads(self.dp.get_model_def_script()) + return self.dp.get_model_def_script() + + def get_has_hessian(self) -> bool: + model_def_script = self.get_model_def_script() + return model_def_script.get("hessian_mode", False) diff --git a/deepmd/jax/jax2tf/serialization.py b/deepmd/jax/jax2tf/serialization.py index aac022ace9..75f46fb020 100644 --- a/deepmd/jax/jax2tf/serialization.py +++ b/deepmd/jax/jax2tf/serialization.py @@ -21,7 +21,7 @@ ) -def deserialize_to_file(model_file: str, data: dict) -> None: +def deserialize_to_file(model_file: str, data: dict, hessian: bool = False) -> None: """Deserialize the dictionary to a model file. Parameters @@ -30,10 +30,15 @@ def deserialize_to_file(model_file: str, data: dict) -> None: The model file to be saved. data : dict The dictionary to be deserialized. + hessian : bool + Add the Hessian to the model output. """ if model_file.endswith(".savedmodel"): model = BaseModel.deserialize(data["model"]) model_def_script = data["model_def_script"] + if hessian: + model.enable_hessian() + model_def_script["hessian_mode"] = True call_lower = model.call_lower tf_model = tf.Module() diff --git a/deepmd/jax/model/hlo.py b/deepmd/jax/model/hlo.py index 4d59957456..28993a8f5e 100644 --- a/deepmd/jax/model/hlo.py +++ b/deepmd/jax/model/hlo.py @@ -31,6 +31,14 @@ r_differentiable=True, c_differentiable=True, ), + "energy_hessian": OutputVariableDef( + "energy", + shape=[1], + reducible=True, + r_differentiable=True, + c_differentiable=True, + r_hessian=True, + ), "mask": OutputVariableDef( "mask", shape=[1], @@ -167,7 +175,7 @@ def call( def model_output_def(self): return ModelOutputDef( - FittingOutputDef([OUTPUT_DEFS[tt] for tt in self.model_output_type()]) + FittingOutputDef([OUTPUT_DEFS[tt if not (self.model_def_script.get("hessian_mode", False) and tt == "energy") else f"{tt}_hessian"] for tt in self.model_output_type()]) ) def call_lower( diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index 454affba31..c9ed490d07 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -22,7 +22,7 @@ ) -def deserialize_to_file(model_file: str, data: dict) -> None: +def deserialize_to_file(model_file: str, data: dict, hessian: bool = False) -> None: """Deserialize the dictionary to a model file. Parameters @@ -31,10 +31,15 @@ def deserialize_to_file(model_file: str, data: dict) -> None: The model file to be saved. data : dict The dictionary to be deserialized. + hessian : bool + Add the Hessian to the model output. """ if model_file.endswith(".jax"): model = BaseModel.deserialize(data["model"]) model_def_script = data["model_def_script"] + if hessian: + model.enable_hessian() + model_def_script["hessian_mode"] = True _, state = nnx.split(model) with ocp.Checkpointer( ocp.CompositeCheckpointHandler("state", "model_def_script") @@ -49,6 +54,9 @@ def deserialize_to_file(model_file: str, data: dict) -> None: elif model_file.endswith(".hlo"): model = BaseModel.deserialize(data["model"]) model_def_script = data["model_def_script"] + if hessian: + model.enable_hessian() + model_def_script["hessian_mode"] = True call_lower = model.call_lower nf, nloc, nghost = jax_export.symbolic_shape("nf, nloc, nghost") diff --git a/deepmd/main.py b/deepmd/main.py index 14c0390bdc..d8f4aad119 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -335,6 +335,12 @@ def main_parser() -> argparse.ArgumentParser: type=str, help="(Supported backend: PyTorch) Task head (alias: model branch) to freeze if in multi-task mode.", ) + parser_frz.add_argument( + "--hessian", + action="store_true", + default=False, + help="Add the Hessian to the model output.", + ) # * test script ******************************************************************** parser_tst = subparsers.add_parser( From 1cd66ee0c83b855ef5ca1ede6ae743e27d759455 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 10 Jul 2025 22:52:29 +0800 Subject: [PATCH 26/32] pass mixed_type Signed-off-by: Jinzhe Zeng --- deepmd/jax/train/trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deepmd/jax/train/trainer.py b/deepmd/jax/train/trainer.py index c9af915218..5372e8f7d6 100644 --- a/deepmd/jax/train/trainer.py +++ b/deepmd/jax/train/trainer.py @@ -173,7 +173,9 @@ def train(self, train_data, valid_data=None) -> None: if not train_data.data_systems[ii].pbc: single_data["box"] = None model.atomic_model.descriptor.compute_input_stats(all_stat_sys) - model.atomic_model.fitting.compute_output_stats(all_stat) + model.atomic_model.fitting.compute_output_stats( + all_stat, mixed_type=train_data.mixed_type + ) def loss_fn( model, From 208b648c3dc016967e24fe3c79120469814db7ba Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 28 Jul 2025 08:44:45 +0000 Subject: [PATCH 27/32] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/jax/infer/deep_eval.py | 1 - deepmd/jax/model/hlo.py | 14 +++++++++++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/deepmd/jax/infer/deep_eval.py b/deepmd/jax/infer/deep_eval.py index 125ae7667a..69690af8ef 100644 --- a/deepmd/jax/infer/deep_eval.py +++ b/deepmd/jax/infer/deep_eval.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import json from typing import ( TYPE_CHECKING, Any, diff --git a/deepmd/jax/model/hlo.py b/deepmd/jax/model/hlo.py index 28993a8f5e..f3e0f1a8f2 100644 --- a/deepmd/jax/model/hlo.py +++ b/deepmd/jax/model/hlo.py @@ -175,7 +175,19 @@ def call( def model_output_def(self): return ModelOutputDef( - FittingOutputDef([OUTPUT_DEFS[tt if not (self.model_def_script.get("hessian_mode", False) and tt == "energy") else f"{tt}_hessian"] for tt in self.model_output_type()]) + FittingOutputDef( + [ + OUTPUT_DEFS[ + tt + if not ( + self.model_def_script.get("hessian_mode", False) + and tt == "energy" + ) + else f"{tt}_hessian" + ] + for tt in self.model_output_type() + ] + ) ) def call_lower( From 516d9e5582b1f6d849a186beae1bf18a1e4b19fa Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 27 Nov 2025 17:23:06 +0800 Subject: [PATCH 28/32] fix type hints --- deepmd/dpmodel/loss/ener.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deepmd/dpmodel/loss/ener.py b/deepmd/dpmodel/loss/ener.py index c475a04024..d624d0c3e6 100644 --- a/deepmd/dpmodel/loss/ener.py +++ b/deepmd/dpmodel/loss/ener.py @@ -418,9 +418,9 @@ def call( self, learning_rate: float, natoms: int, - model_dict: dict[str, np.ndarray], - label_dict: dict[str, np.ndarray], - ) -> dict[str, np.ndarray]: + model_dict: dict[str, Array], + label_dict: dict[str, Array], + ) -> dict[str, Array]: """Calculate loss from model results and labeled results.""" loss, more_loss = EnergyLoss.call( self, learning_rate, natoms, model_dict, label_dict From 72b4455d6a3e554a90d5a42df7b83b3ed7be8672 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 2 Dec 2025 15:55:53 +0800 Subject: [PATCH 29/32] fix compatibility with flax 0.12 --- deepmd/jax/train/trainer.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/deepmd/jax/train/trainer.py b/deepmd/jax/train/trainer.py index 5372e8f7d6..9fb3ee1b54 100644 --- a/deepmd/jax/train/trainer.py +++ b/deepmd/jax/train/trainer.py @@ -13,6 +13,7 @@ import numpy as np import optax import orbax.checkpoint as ocp +from packaging.version import Version from deepmd.common import ( symlink_prefix_files, @@ -37,6 +38,7 @@ from deepmd.jax.env import ( jnp, nnx, + flax_version, ) from deepmd.jax.model.base_model import ( BaseModel, @@ -152,7 +154,7 @@ def train(self, train_data, valid_data=None) -> None: tx = optax.adam( learning_rate=lambda step: self.lr.value(self.start_step + step, xp=jnp), ) - optimizer = nnx.Optimizer(model, tx) + optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) # data stat if self.init_model is None and self.restart is None: @@ -268,7 +270,10 @@ def train_step( fp, ap, ) - optimizer.update(grads) + if Version(flax_version) >= Version("0.11.0"): + optimizer.update(model, grads) + else: + optimizer.update(grads) start_time = time.time() disp_file_fp = open(self.disp_file, "w") @@ -311,7 +316,7 @@ def train_step( ) ) more_loss = loss_fn_more_loss( - optimizer.model, + model, self.lr.value(step), jax_data, extended_coord, @@ -340,7 +345,7 @@ def train_step( ) ) valid_more_loss = loss_fn_more_loss( - optimizer.model, + model, self.lr.value(step), jax_valid_data, extended_coord, From 387d989e1dce70311c36f3cc3ba2759b2efa67b8 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 2 Dec 2025 16:10:32 +0800 Subject: [PATCH 30/32] add type annotations --- deepmd/dpmodel/fitting/ener_fitting.py | 4 +- deepmd/dpmodel/loss/ener.py | 8 +-- deepmd/dpmodel/utils/learning_rate.py | 2 +- deepmd/jax/entrypoints/freeze.py | 3 +- deepmd/jax/entrypoints/train.py | 5 +- deepmd/jax/train/trainer.py | 91 +++++++++++++++----------- 6 files changed, 64 insertions(+), 49 deletions(-) diff --git a/deepmd/dpmodel/fitting/ener_fitting.py b/deepmd/dpmodel/fitting/ener_fitting.py index cf0c95663f..96c13cc4c7 100644 --- a/deepmd/dpmodel/fitting/ener_fitting.py +++ b/deepmd/dpmodel/fitting/ener_fitting.py @@ -113,7 +113,9 @@ def compute_output_stats(self, all_stat: dict, mixed_type: bool = False) -> None all_stat, rcond=self.rcond, mixed_type=mixed_type ) - def _compute_output_stats(self, all_stat, rcond=1e-3, mixed_type=False): + def _compute_output_stats( + self, all_stat: dict, rcond: float = 1e-3, mixed_type: bool = False + ) -> np.ndarray: data = all_stat["energy"] # data[sys_idx][batch_idx][frame_idx] sys_ener = [] diff --git a/deepmd/dpmodel/loss/ener.py b/deepmd/dpmodel/loss/ener.py index d624d0c3e6..3d33f575a6 100644 --- a/deepmd/dpmodel/loss/ener.py +++ b/deepmd/dpmodel/loss/ener.py @@ -393,10 +393,10 @@ def deserialize(cls, data: dict) -> "Loss": class EnergyHessianLoss(EnergyLoss): def __init__( self, - start_pref_h=0.0, - limit_pref_h=0.0, - **kwargs, - ): + start_pref_h: float = 0.0, + limit_pref_h: float = 0.0, + **kwargs: Any, + ) -> None: r"""Enable the layer to compute loss on hessian. Parameters diff --git a/deepmd/dpmodel/utils/learning_rate.py b/deepmd/dpmodel/utils/learning_rate.py index 7dd01adbe8..5b88bf7aa3 100644 --- a/deepmd/dpmodel/utils/learning_rate.py +++ b/deepmd/dpmodel/utils/learning_rate.py @@ -50,7 +50,7 @@ def __init__( self.decay_rate = decay_rate self.min_lr = stop_lr - def value(self, step: int, xp=np) -> np.float64: + def value(self, step: int, xp: Any = np) -> np.float64: """Get the learning rate at the given step.""" step_lr = self.start_lr * xp.power(self.decay_rate, step // self.decay_steps) step_lr = xp.clip(step_lr, self.min_lr, None) diff --git a/deepmd/jax/entrypoints/freeze.py b/deepmd/jax/entrypoints/freeze.py index f9f6ac6a1f..28ac063c6d 100644 --- a/deepmd/jax/entrypoints/freeze.py +++ b/deepmd/jax/entrypoints/freeze.py @@ -2,6 +2,7 @@ from pathlib import ( Path, ) +from typing import Any from deepmd.jax.utils.serialization import ( deserialize_to_file, @@ -14,7 +15,7 @@ def freeze( checkpoint_folder: str, output: str, hessian: bool = False, - **kwargs, + **kwargs: Any, ) -> None: """Freeze the graph in supplied folder. diff --git a/deepmd/jax/entrypoints/train.py b/deepmd/jax/entrypoints/train.py index 27b3e54e55..5146e0965c 100644 --- a/deepmd/jax/entrypoints/train.py +++ b/deepmd/jax/entrypoints/train.py @@ -8,6 +8,7 @@ import logging import time from typing import ( + Any, Optional, ) @@ -78,7 +79,7 @@ def train( skip_neighbor_stat: bool = False, finetune: Optional[str] = None, use_pretrain_script: bool = False, - **kwargs, + **kwargs: Any, ) -> None: """Run DeePMD model training. @@ -186,7 +187,7 @@ def train( log.info(f"wall time: {(end_time - start_time):.3f} s") -def update_sel(jdata): +def update_sel(jdata: dict) -> dict: log.info( "Calculate neighbor statistics... (add --skip-neighbor-stat to skip this step)" ) diff --git a/deepmd/jax/train/trainer.py b/deepmd/jax/train/trainer.py index 9fb3ee1b54..1930ad2d88 100644 --- a/deepmd/jax/train/trainer.py +++ b/deepmd/jax/train/trainer.py @@ -8,6 +8,7 @@ ) from typing import ( Optional, + TextIO, ) import numpy as np @@ -56,6 +57,7 @@ from deepmd.utils.data import ( DataRequirementItem, ) +from deepmd.utils.data_system import DeepmdDataSystem from deepmd.utils.model_stat import ( make_stat_input, ) @@ -66,7 +68,7 @@ class DPTrainer: def __init__( self, - jdata, + jdata: dict, init_model: Optional[str] = None, restart: Optional[str] = None, ) -> None: @@ -87,7 +89,7 @@ def __init__( self.training_param = jdata["training"] self.num_steps = self.training_param["numb_steps"] - def get_lr_and_coef(lr_param): + def get_lr_and_coef(lr_param: dict) -> LearningRateExp: lr_type = lr_param.get("type", "exp") if lr_type == "exp": lr = LearningRateExp( @@ -149,7 +151,9 @@ def get_lr_and_coef(lr_param): def data_requirements(self) -> list[DataRequirementItem]: return self.loss.label_requirement - def train(self, train_data, valid_data=None) -> None: + def train( + self, train_data: DeepmdDataSystem, valid_data: DeepmdDataSystem | None = None + ) -> None: model = self.model tx = optax.adam( learning_rate=lambda step: self.lr.value(self.start_step + step, xp=jnp), @@ -180,16 +184,16 @@ def train(self, train_data, valid_data=None) -> None: ) def loss_fn( - model, - lr, - label_dict, - extended_coord, - extended_atype, - nlist, - mapping, - fp, - ap, - ): + model: BaseModel, + lr: float, + label_dict: dict[str, jnp.ndarray], + extended_coord: jnp.ndarray, + extended_atype: jnp.ndarray, + nlist: jnp.ndarray, + mapping: jnp.ndarray | None, + fp: jnp.ndarray | None, + ap: jnp.ndarray | None, + ) -> jnp.ndarray: model_dict_lower = model.call_lower( extended_coord, extended_atype, @@ -214,16 +218,16 @@ def loss_fn( @nnx.jit def loss_fn_more_loss( - model, - lr, - label_dict, - extended_coord, - extended_atype, - nlist, - mapping, - fp, - ap, - ): + model: BaseModel, + lr: float, + label_dict: dict[str, jnp.ndarray], + extended_coord: jnp.ndarray, + extended_atype: jnp.ndarray, + nlist: jnp.ndarray, + mapping: jnp.ndarray | None, + fp: jnp.ndarray | None, + ap: jnp.ndarray | None, + ) -> dict[str, jnp.ndarray]: model_dict_lower = model.call_lower( extended_coord, extended_atype, @@ -248,17 +252,17 @@ def loss_fn_more_loss( @nnx.jit def train_step( - model, - optimizer, - lr, - label_dict, - extended_coord, - extended_atype, - nlist, - mapping, - fp, - ap, - ): + model: BaseModel, + optimizer: nnx.Optimizer, + lr: float, + label_dict: dict[str, jnp.ndarray], + extended_coord: jnp.ndarray, + extended_atype: jnp.ndarray, + nlist: jnp.ndarray, + mapping: jnp.ndarray | None, + fp: jnp.ndarray | None, + ap: jnp.ndarray | None, + ) -> None: grads = nnx.grad(loss_fn)( model, lr, @@ -393,11 +397,11 @@ def train_step( @staticmethod def print_on_training( - fp, - train_results, - valid_results, - cur_batch, - cur_lr, + fp: TextIO, + train_results: dict[str, float], + valid_results: dict[str, float] | None, + cur_batch: int, + cur_lr: float, ) -> None: print_str = "" print_str += f"{cur_batch:7d}" @@ -441,7 +445,14 @@ def prepare_input( box: Optional[np.ndarray] = None, fparam: Optional[np.ndarray] = None, aparam: Optional[np.ndarray] = None, -): +) -> tuple[ + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + Optional[np.ndarray], + Optional[np.ndarray], +]: nframes, nloc = atype.shape[:2] cc, bb, fp, ap = coord, box, fparam, aparam del coord, box, fparam, aparam From 7af98ab27ef2b1f72874247d5d078d82907fbf87 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 2 Dec 2025 16:27:47 +0800 Subject: [PATCH 31/32] print header --- deepmd/jax/train/trainer.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/deepmd/jax/train/trainer.py b/deepmd/jax/train/trainer.py index 1930ad2d88..0ea15c17c9 100644 --- a/deepmd/jax/train/trainer.py +++ b/deepmd/jax/train/trainer.py @@ -361,6 +361,12 @@ def train_step( ) else: valid_more_loss = None + if step == 0: + self.print_header( + disp_file_fp, + train_results=more_loss, + valid_results=valid_more_loss, + ) self.print_on_training( disp_file_fp, train_results=more_loss, @@ -435,6 +441,27 @@ def print_on_training( fp.write(print_str) fp.flush() + @staticmethod + def print_header( + fp: TextIO, + train_results: dict[str, float], + valid_results: dict[str, float] | None, + ) -> None: + print_str = "" + print_str += "# {:5s}".format("step") + if valid_results is not None: + prop_fmt = " %11s %11s" + for k in train_results.keys(): + print_str += prop_fmt % (k + "_val", k + "_trn") + else: + prop_fmt = " %11s" + for k in train_results.keys(): + print_str += prop_fmt % (k + "_trn") + print_str += " {:8s}\n".format("lr") + print_str += "# If there is no available reference data, rmse_*_{val,trn} will print nan\n" + fp.write(print_str) + fp.flush() + def prepare_input( *, # enforce keyword-only arguments From 8eff47052c5db584c3d862ba40fa48275ee0c7c4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 2 Dec 2025 08:31:15 +0000 Subject: [PATCH 32/32] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/jax/entrypoints/freeze.py | 4 +++- deepmd/jax/train/trainer.py | 10 +++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/deepmd/jax/entrypoints/freeze.py b/deepmd/jax/entrypoints/freeze.py index 28ac063c6d..345b2690ae 100644 --- a/deepmd/jax/entrypoints/freeze.py +++ b/deepmd/jax/entrypoints/freeze.py @@ -2,7 +2,9 @@ from pathlib import ( Path, ) -from typing import Any +from typing import ( + Any, +) from deepmd.jax.utils.serialization import ( deserialize_to_file, diff --git a/deepmd/jax/train/trainer.py b/deepmd/jax/train/trainer.py index 0ea15c17c9..7d29046d4f 100644 --- a/deepmd/jax/train/trainer.py +++ b/deepmd/jax/train/trainer.py @@ -14,7 +14,9 @@ import numpy as np import optax import orbax.checkpoint as ocp -from packaging.version import Version +from packaging.version import ( + Version, +) from deepmd.common import ( symlink_prefix_files, @@ -37,9 +39,9 @@ normalize_coord, ) from deepmd.jax.env import ( + flax_version, jnp, nnx, - flax_version, ) from deepmd.jax.model.base_model import ( BaseModel, @@ -57,7 +59,9 @@ from deepmd.utils.data import ( DataRequirementItem, ) -from deepmd.utils.data_system import DeepmdDataSystem +from deepmd.utils.data_system import ( + DeepmdDataSystem, +) from deepmd.utils.model_stat import ( make_stat_input, )