@@ -22,8 +22,8 @@ void BatchNormKernel(const Context& dev_ctx,
2222 const phi::DenseTensor& x,
2323 const phi::DenseTensor& running_mean,
2424 const phi::DenseTensor& running_var,
25- const phi::DenseTensor& scale,
26- const phi::DenseTensor& bias,
25+ const paddle::optional< phi::DenseTensor> & scale,
26+ const paddle::optional< phi::DenseTensor> & bias,
2727 bool is_test,
2828 float momentum,
2929 float epsilon,
@@ -61,6 +61,26 @@ void BatchNormKernel(const Context& dev_ctx,
6161 : x_dims[x_dims.size () - 1 ]);
6262 const int sample_size = x.numel () / N / C;
6363
64+ auto * Scale = scale.get_ptr ();
65+ auto * Bias = bias.get_ptr ();
66+
67+ phi::DenseTensor new_scale, new_bias;
68+ if (Scale) {
69+ new_scale = scale.get ();
70+ } else {
71+ new_scale.Resize ({C});
72+ dev_ctx.template Alloc <T>(&new_scale);
73+ FillMLUTensorWithHostValue<T>(dev_ctx, static_cast <T>(1 ), &new_scale);
74+ }
75+
76+ if (Bias) {
77+ new_bias = bias.get ();
78+ } else {
79+ new_bias.Resize ({C});
80+ dev_ctx.template Alloc <T>(&new_bias);
81+ FillMLUTensorWithHostValue<T>(dev_ctx, static_cast <T>(0 ), &new_bias);
82+ }
83+
6484 // alloc memory
6585 dev_ctx.template Alloc <T>(y);
6686
@@ -78,7 +98,7 @@ void BatchNormKernel(const Context& dev_ctx,
7898 transformed_shape,
7999 ToCnnlDataType<T>(),
80100 CNNL_LAYOUT_NHWC);
81- MLUCnnlTensorDesc others_input_desc (scale );
101+ MLUCnnlTensorDesc others_input_desc (new_scale );
82102 // input dimension is 2 and the format is NCHW. The input can be regarded as
83103 // NHWC format. Don't need to transpose.
84104 bool need_transpose =
@@ -123,8 +143,8 @@ void BatchNormKernel(const Context& dev_ctx,
123143 transformed_desc.get (),
124144 GetBasePtr (&transformed_x),
125145 others_input_desc.get (),
126- GetBasePtr (&scale ),
127- GetBasePtr (&bias ),
146+ GetBasePtr (&new_scale ),
147+ GetBasePtr (&new_bias ),
128148 GetBasePtr (&running_mean),
129149 GetBasePtr (&running_var),
130150 epsilon,
@@ -155,8 +175,8 @@ template <typename T, typename Context>
155175void BatchNormGradKernel (
156176 const Context& dev_ctx,
157177 const phi::DenseTensor& x,
158- const phi::DenseTensor& scale,
159- const phi::DenseTensor& bias,
178+ const paddle::optional< phi::DenseTensor> & scale,
179+ const paddle::optional< phi::DenseTensor> & bias,
160180 const paddle::optional<phi::DenseTensor>& mean,
161181 const paddle::optional<phi::DenseTensor>& variance,
162182 const phi::DenseTensor& saved_mean,
@@ -172,7 +192,47 @@ void BatchNormGradKernel(
172192 phi::DenseTensor* d_x,
173193 phi::DenseTensor* d_scale,
174194 phi::DenseTensor* d_bias) {
195+ const auto & x_dims = x.dims ();
196+ PADDLE_ENFORCE_GE (
197+ x_dims.size (),
198+ 2 ,
199+ phi::errors::InvalidArgument (
200+ " The size of input X's dimensions should be larger than 1."
201+ " But received: the size of input X's dimensions is [%d]" ,
202+ x_dims.size ()));
203+ PADDLE_ENFORCE_LE (
204+ x_dims.size (),
205+ 5 ,
206+ phi::errors::InvalidArgument (
207+ " The size of input X's dimensions should be less than 6."
208+ " But received: the size of input X's dimensions is [%d]" ,
209+ x_dims.size ()));
210+
175211 DataLayout data_layout = StringToDataLayout (data_layout_str);
212+ const int N = x_dims[0 ];
213+ const int C = (data_layout == DataLayout::kNCHW ? x_dims[1 ]
214+ : x_dims[x_dims.size () - 1 ]);
215+ const int sample_size = x.numel () / N / C;
216+
217+ auto * Scale = scale.get_ptr ();
218+ auto * Bias = bias.get_ptr ();
219+
220+ phi::DenseTensor new_scale, new_bias;
221+ if (Scale) {
222+ new_scale = scale.get ();
223+ } else {
224+ new_scale.Resize ({C});
225+ dev_ctx.template Alloc <T>(&new_scale);
226+ FillMLUTensorWithHostValue<T>(dev_ctx, static_cast <T>(1 ), &new_scale);
227+ }
228+
229+ if (Bias) {
230+ new_bias = bias.get ();
231+ } else {
232+ new_bias.Resize ({C});
233+ dev_ctx.template Alloc <T>(&new_bias);
234+ FillMLUTensorWithHostValue<T>(dev_ctx, static_cast <T>(0 ), &new_bias);
235+ }
176236
177237 Tensor d_x_tmp;
178238 if (d_x == nullptr ) {
@@ -182,12 +242,12 @@ void BatchNormGradKernel(
182242 Tensor scale_grad_tmp;
183243 if (d_scale == nullptr ) {
184244 d_scale = &scale_grad_tmp;
185- d_scale->Resize (scale .dims ());
245+ d_scale->Resize (new_scale .dims ());
186246 }
187247 Tensor bias_grad_tmp;
188248 if (d_bias == nullptr ) {
189249 d_bias = &bias_grad_tmp;
190- d_bias->Resize (bias .dims ());
250+ d_bias->Resize (new_bias .dims ());
191251 }
192252
193253 dev_ctx.template Alloc <T>(d_x);
@@ -197,26 +257,6 @@ void BatchNormGradKernel(
197257
198258 use_global_stats = is_test || use_global_stats;
199259
200- const auto & x_dims = x.dims ();
201- PADDLE_ENFORCE_GE (
202- x_dims.size (),
203- 2 ,
204- phi::errors::InvalidArgument (
205- " The size of input X's dimensions should be larger than 1."
206- " But received: the size of input X's dimensions is [%d]" ,
207- x_dims.size ()));
208- PADDLE_ENFORCE_LE (
209- x_dims.size (),
210- 5 ,
211- phi::errors::InvalidArgument (
212- " The size of input X's dimensions should be less than 6."
213- " But received: the size of input X's dimensions is [%d]" ,
214- x_dims.size ()));
215- const int N = x_dims[0 ];
216- const int C = (data_layout == DataLayout::kNCHW ? x_dims[1 ]
217- : x_dims[x_dims.size () - 1 ]);
218- const int sample_size = x.numel () / N / C;
219-
220260 Tensor transformed_d_y;
221261 Tensor transformed_x;
222262 Tensor transformed_d_x;
@@ -227,7 +267,7 @@ void BatchNormGradKernel(
227267 transformed_shape,
228268 ToCnnlDataType<T>(),
229269 CNNL_LAYOUT_NHWC);
230- MLUCnnlTensorDesc others_input_desc (scale );
270+ MLUCnnlTensorDesc others_input_desc (new_scale );
231271
232272 bool need_transpose =
233273 (data_layout == DataLayout::kNCHW && x_dims.size () != 2 );
@@ -286,7 +326,7 @@ void BatchNormGradKernel(
286326 transformed_desc.get (),
287327 GetBasePtr (&transformed_x),
288328 others_input_desc.get (),
289- GetBasePtr (&scale ),
329+ GetBasePtr (&new_scale ),
290330 GetBasePtr (running_mean),
291331 GetBasePtr (running_variance),
292332 epsilon,
@@ -305,7 +345,7 @@ void BatchNormGradKernel(
305345 transformed_desc.get (),
306346 GetBasePtr (&transformed_x),
307347 others_input_desc.get (),
308- GetBasePtr (&scale ),
348+ GetBasePtr (&new_scale ),
309349 GetBasePtr (&saved_mean),
310350 GetBasePtr (&saved_inv_variance),
311351 epsilon,
0 commit comments