Skip to content

Commit 881d95e

Browse files
pd: suppport CINN for se_e2_a inference (#4770)
1. support CINN for se_e2_a inference 2. adjust order between to_static(CINN) wrapper and distributed wrapper for compatibility <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Summary by CodeRabbit - **New Features** - Added informational logging to notify users when CINN (Compiler for INference and Neural Networks) compilation is enabled during model training and inference. - Improved training performance by initializing CINN compilation earlier to reduce runtime overhead. - **Style** - Improved clarity of CINN-related messages to inform users about potential compilation time and backend usage. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 2dd47b5 commit 881d95e

File tree

2 files changed

+61
-14
lines changed

2 files changed

+61
-14
lines changed

deepmd/pd/train/training.py

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,49 @@ def warm_up_linear(step, warmup_steps):
599599
else:
600600
raise ValueError(f"Not supported optimizer type '{self.opt_type}'")
601601

602+
# NOTE: to_static + compiler should be before distributed wrapper
603+
if CINN:
604+
from paddle import (
605+
jit,
606+
static,
607+
)
608+
609+
backend = "CINN" if CINN else None
610+
self.wrapper.forward = jit.to_static(
611+
backend=backend,
612+
input_spec=[
613+
static.InputSpec([1, -1, 3], "float64", name="coord"), # coord
614+
static.InputSpec([1, -1], "int32", name="atype"), # atype
615+
None, # spin
616+
static.InputSpec([1, 9], "float64", name="box"), # box
617+
static.InputSpec([], "float64", name="cur_lr"), # cur_lr
618+
{
619+
"find_box": np.float32(1.0),
620+
"find_coord": np.float32(1.0),
621+
"find_numb_copy": np.float32(0.0),
622+
"numb_copy": static.InputSpec(
623+
[1, 1], "int64", name="numb_copy"
624+
),
625+
"find_energy": np.float32(1.0),
626+
"energy": static.InputSpec([1, 1], "float64", name="energy"),
627+
"find_force": np.float32(1.0),
628+
"force": static.InputSpec([1, -1, 3], "float64", name="force"),
629+
"natoms": static.InputSpec([1, -1], "int32", name="natoms"),
630+
}, # label,
631+
# None, # task_key
632+
# False, # inference_only
633+
# False, # do_atomic_virial
634+
# None, # fparam
635+
# None, # aparam
636+
],
637+
full_graph=True,
638+
)(self.wrapper.forward)
639+
640+
log.info(
641+
"Enable CINN during training, there may be some additional "
642+
"compilation time in the first traning step."
643+
)
644+
602645
if dist.is_available() and dist.is_initialized():
603646
# DDP will guarantee the model parameters are identical across all processes
604647
self.wrapper = fleet.distributed_model(
@@ -631,20 +674,6 @@ def warm_up_linear(step, warmup_steps):
631674
self.profiling_file = training_params.get("profiling_file", "timeline.json")
632675

633676
def run(self) -> None:
634-
if CINN:
635-
from paddle import (
636-
jit,
637-
)
638-
639-
backend = "CINN" if CINN else None
640-
self.wrapper.forward = jit.to_static(full_graph=True, backend=backend)(
641-
self.wrapper.forward
642-
)
643-
log.info(
644-
"Enable CINN during training, there may be some additional "
645-
"compilation time in the first traning step."
646-
)
647-
648677
fout = (
649678
open(
650679
self.disp_file,

source/api_cc/src/DeepPotPD.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,24 @@ void DeepPotPD::init(const std::string& model,
120120
std::cout << "load model from: " << model << " to gpu:" << gpu_id
121121
<< std::endl;
122122
}
123+
if (config->cinn_enabled()) {
124+
std::cout << "model.forward will be compiled with cinn." << std::endl;
125+
} else {
126+
std::cout << "NOTE: You can try: \n'export FLAGS_prim_all=true"
127+
" FLAGS_enable_pir_in_executor=1"
128+
" FLAGS_prim_enable_dynamic=true FLAGS_use_cinn=true'\n"
129+
"to speed up C++ inference with paddle backend"
130+
<< std::endl;
131+
}
132+
if (config_fl->cinn_enabled()) {
133+
std::cout << "model.forward_lower will be compiled with cinn." << std::endl;
134+
} else {
135+
std::cout << "NOTE: You can try: \n'export FLAGS_prim_all=true"
136+
" FLAGS_enable_pir_in_executor=1"
137+
" FLAGS_prim_enable_dynamic=true FLAGS_use_cinn=true'\n"
138+
"to speed up C++ inference with paddle backend"
139+
<< std::endl;
140+
}
123141

124142
// NOTE: Both set to 1 now.
125143
// get_env_nthreads(num_intra_nthreads,

0 commit comments

Comments
 (0)