Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
a51b924
support paddle inference in deepeval
HydrogenSulfate Sep 19, 2025
07118cf
Merge branch 'devel' of https://github.com/deepmodeling/deepmd-kit in…
HydrogenSulfate Sep 19, 2025
c71f369
fix bugs in deeppotpd.cc
HydrogenSulfate Sep 19, 2025
ec3420c
fix
HydrogenSulfate Sep 19, 2025
dfbae88
Merge remote-tracking branch 'upstream/devel' into pd_deepeval
HydrogenSulfate Sep 22, 2025
04e89a1
support deeppot for json:
HydrogenSulfate Sep 22, 2025
43dd34c
update ase document
HydrogenSulfate Sep 22, 2025
c734732
support get_dim_fparam and get_dim_aparam
HydrogenSulfate Sep 22, 2025
72352c8
refine code
HydrogenSulfate Sep 22, 2025
5e451f0
refine code
HydrogenSulfate Sep 22, 2025
ce06281
fix UT
HydrogenSulfate Sep 22, 2025
bb43b26
fix UT
HydrogenSulfate Sep 22, 2025
9ab08b8
fix
HydrogenSulfate Sep 22, 2025
e30b7cc
fix
HydrogenSulfate Sep 22, 2025
171675b
restore
HydrogenSulfate Sep 22, 2025
80a4c5f
fix
HydrogenSulfate Sep 22, 2025
b650406
fix conversion of type_map
HydrogenSulfate Sep 23, 2025
4554fcc
fix code QL
HydrogenSulfate Sep 23, 2025
59d5c6f
fix
HydrogenSulfate Sep 23, 2025
8d71b11
fix
HydrogenSulfate Sep 23, 2025
6c22d40
fix UT
HydrogenSulfate Sep 23, 2025
dd119df
bump paddle to 3.2.0
HydrogenSulfate Sep 23, 2025
187f7f2
fix masked_add__decomp
HydrogenSulfate Sep 23, 2025
db5d3b4
bump to 3.1.1
HydrogenSulfate Sep 24, 2025
b370799
restore to 3.0.0
HydrogenSulfate Sep 24, 2025
8c29799
add deep_eval_test
HydrogenSulfate Sep 24, 2025
f47c45a
fix type annotations and refine docstrings
HydrogenSulfate Sep 25, 2025
37d7271
Merge remote-tracking branch 'upstream/devel' into pd_deepeval
HydrogenSulfate Sep 25, 2025
e019e12
Update deepmd/pd/infer/deep_eval.py
HydrogenSulfate Sep 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 42 additions & 15 deletions deepmd/pd/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Path,
)
from typing import (
Any,
Optional,
Union,
)
Expand Down Expand Up @@ -80,15 +81,15 @@


def get_trainer(
config,
init_model=None,
restart_model=None,
finetune_model=None,
force_load=False,
init_frz_model=None,
shared_links=None,
finetune_links=None,
):
config: dict[str, Any],
init_model: Optional[str] = None,
restart_model: Optional[str] = None,
finetune_model: Optional[str] = None,
force_load: bool = False,
init_frz_model: Optional[str] = None,
shared_links: Optional[dict[str, Any]] = None,
finetune_links: Optional[dict[str, Any]] = None,
) -> training.Trainer:
multi_task = "model_dict" in config.get("model", {})

# Initialize DDP
Expand All @@ -98,17 +99,22 @@ def get_trainer(
fleet.init(is_collective=True)

def prepare_trainer_input_single(
model_params_single, data_dict_single, rank=0, seed=None
):
model_params_single: dict[str, Any],
data_dict_single: dict[str, Any],
rank: int = 0,
seed: Optional[int] = None,
) -> tuple[DpLoaderSet, Optional[DpLoaderSet], Optional[DPPath]]:
training_dataset_params = data_dict_single["training_data"]
validation_dataset_params = data_dict_single.get("validation_data", None)
validation_systems = (
validation_dataset_params["systems"] if validation_dataset_params else None
)
training_systems = training_dataset_params["systems"]
training_systems = process_systems(training_systems)
trn_patterns = training_dataset_params.get("rglob_patterns", None)
training_systems = process_systems(training_systems, patterns=trn_patterns)
if validation_systems is not None:
validation_systems = process_systems(validation_systems)
val_patterns = validation_dataset_params.get("rglob_patterns", None)
validation_systems = process_systems(validation_systems, val_patterns)

# stat files
stat_file_path_single = data_dict_single.get("stat_file", None)
Expand Down Expand Up @@ -342,6 +348,7 @@ def freeze(
model: str,
output: str = "frozen_model.json",
head: Optional[str] = None,
do_atomic_virial: bool = False,
) -> None:
paddle.set_flags(
{
Expand Down Expand Up @@ -374,7 +381,7 @@ def freeze(
None, # fparam
None, # aparam
# InputSpec([], dtype="bool", name="do_atomic_virial"), # do_atomic_virial
False, # do_atomic_virial
do_atomic_virial, # do_atomic_virial
],
full_graph=True,
)
Expand All @@ -396,7 +403,7 @@ def freeze(
None, # fparam
None, # aparam
# InputSpec([], dtype="bool", name="do_atomic_virial"), # do_atomic_virial
False, # do_atomic_virial
do_atomic_virial, # do_atomic_virial
(
InputSpec([-1], "int64", name="send_list"),
InputSpec([-1], "int32", name="send_proc"),
Expand All @@ -409,6 +416,26 @@ def freeze(
],
full_graph=True,
)
for method_name in [
"get_buffer_rcut",
"get_buffer_type_map",
"get_buffer_dim_fparam",
"get_buffer_dim_aparam",
"get_buffer_intensive",
"get_buffer_sel_type",
"get_buffer_numb_dos",
"get_buffer_task_dim",
]:
if hasattr(model, method_name):
setattr(
model,
method_name,
paddle.jit.to_static(
getattr(model, method_name),
input_spec=[],
full_graph=True,
),
)
if output.endswith(".json"):
output = output[:-5]
paddle.jit.save(
Expand Down
Loading