Skip to content

Commit e698253

Browse files
fix
1 parent cb6d855 commit e698253

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

deepmd/pd/model/task/fitting.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -111,15 +111,17 @@ def compute_input_stats(
111111
fparam_std,
112112
)
113113
fparam_inv_std = 1.0 / fparam_std
114-
self.fparam_avg.copy_(
114+
paddle.assign(
115115
paddle.to_tensor(
116116
fparam_avg, place=env.DEVICE, dtype=self.fparam_avg.dtype
117-
)
117+
),
118+
self.fparam_avg,
118119
)
119-
self.fparam_inv_std.copy_(
120+
paddle.assign(
120121
paddle.to_tensor(
121122
fparam_inv_std, place=env.DEVICE, dtype=self.fparam_inv_std.dtype
122-
)
123+
),
124+
self.fparam_inv_std,
123125
)
124126
# stat aparam
125127
if self.numb_aparam > 0:
@@ -144,15 +146,17 @@ def compute_input_stats(
144146
aparam_std,
145147
)
146148
aparam_inv_std = 1.0 / aparam_std
147-
self.aparam_avg.copy_(
149+
paddle.assign(
148150
paddle.to_tensor(
149151
aparam_avg, place=env.DEVICE, dtype=self.aparam_avg.dtype
150-
)
152+
),
153+
self.aparam_avg,
151154
)
152-
self.aparam_inv_std.copy_(
155+
paddle.assign(
153156
paddle.to_tensor(
154157
aparam_inv_std, place=env.DEVICE, dtype=self.aparam_inv_std.dtype
155-
)
158+
),
159+
self.aparam_inv_std,
156160
)
157161

158162

0 commit comments

Comments
 (0)