Skip to content

Commit 9ec7969

Browse files
authored
[MLU]fix bn and bn_grad (#828)
1 parent 7d280aa commit 9ec7969

File tree

1 file changed

+72
-32
lines changed

1 file changed

+72
-32
lines changed

backends/mlu/kernels/batch_norm_kernel.cc

Lines changed: 72 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ void BatchNormKernel(const Context& dev_ctx,
2222
const phi::DenseTensor& x,
2323
const phi::DenseTensor& running_mean,
2424
const phi::DenseTensor& running_var,
25-
const phi::DenseTensor& scale,
26-
const phi::DenseTensor& bias,
25+
const paddle::optional<phi::DenseTensor>& scale,
26+
const paddle::optional<phi::DenseTensor>& bias,
2727
bool is_test,
2828
float momentum,
2929
float epsilon,
@@ -61,6 +61,26 @@ void BatchNormKernel(const Context& dev_ctx,
6161
: x_dims[x_dims.size() - 1]);
6262
const int sample_size = x.numel() / N / C;
6363

64+
auto* Scale = scale.get_ptr();
65+
auto* Bias = bias.get_ptr();
66+
67+
phi::DenseTensor new_scale, new_bias;
68+
if (Scale) {
69+
new_scale = scale.get();
70+
} else {
71+
new_scale.Resize({C});
72+
dev_ctx.template Alloc<T>(&new_scale);
73+
FillMLUTensorWithHostValue<T>(dev_ctx, static_cast<T>(1), &new_scale);
74+
}
75+
76+
if (Bias) {
77+
new_bias = bias.get();
78+
} else {
79+
new_bias.Resize({C});
80+
dev_ctx.template Alloc<T>(&new_bias);
81+
FillMLUTensorWithHostValue<T>(dev_ctx, static_cast<T>(0), &new_bias);
82+
}
83+
6484
// alloc memory
6585
dev_ctx.template Alloc<T>(y);
6686

@@ -78,7 +98,7 @@ void BatchNormKernel(const Context& dev_ctx,
7898
transformed_shape,
7999
ToCnnlDataType<T>(),
80100
CNNL_LAYOUT_NHWC);
81-
MLUCnnlTensorDesc others_input_desc(scale);
101+
MLUCnnlTensorDesc others_input_desc(new_scale);
82102
// input dimension is 2 and the format is NCHW. The input can be regarded as
83103
// NHWC format. Don't need to transpose.
84104
bool need_transpose =
@@ -123,8 +143,8 @@ void BatchNormKernel(const Context& dev_ctx,
123143
transformed_desc.get(),
124144
GetBasePtr(&transformed_x),
125145
others_input_desc.get(),
126-
GetBasePtr(&scale),
127-
GetBasePtr(&bias),
146+
GetBasePtr(&new_scale),
147+
GetBasePtr(&new_bias),
128148
GetBasePtr(&running_mean),
129149
GetBasePtr(&running_var),
130150
epsilon,
@@ -155,8 +175,8 @@ template <typename T, typename Context>
155175
void BatchNormGradKernel(
156176
const Context& dev_ctx,
157177
const phi::DenseTensor& x,
158-
const phi::DenseTensor& scale,
159-
const phi::DenseTensor& bias,
178+
const paddle::optional<phi::DenseTensor>& scale,
179+
const paddle::optional<phi::DenseTensor>& bias,
160180
const paddle::optional<phi::DenseTensor>& mean,
161181
const paddle::optional<phi::DenseTensor>& variance,
162182
const phi::DenseTensor& saved_mean,
@@ -172,7 +192,47 @@ void BatchNormGradKernel(
172192
phi::DenseTensor* d_x,
173193
phi::DenseTensor* d_scale,
174194
phi::DenseTensor* d_bias) {
195+
const auto& x_dims = x.dims();
196+
PADDLE_ENFORCE_GE(
197+
x_dims.size(),
198+
2,
199+
phi::errors::InvalidArgument(
200+
"The size of input X's dimensions should be larger than 1."
201+
"But received: the size of input X's dimensions is [%d]",
202+
x_dims.size()));
203+
PADDLE_ENFORCE_LE(
204+
x_dims.size(),
205+
5,
206+
phi::errors::InvalidArgument(
207+
"The size of input X's dimensions should be less than 6."
208+
"But received: the size of input X's dimensions is [%d]",
209+
x_dims.size()));
210+
175211
DataLayout data_layout = StringToDataLayout(data_layout_str);
212+
const int N = x_dims[0];
213+
const int C = (data_layout == DataLayout::kNCHW ? x_dims[1]
214+
: x_dims[x_dims.size() - 1]);
215+
const int sample_size = x.numel() / N / C;
216+
217+
auto* Scale = scale.get_ptr();
218+
auto* Bias = bias.get_ptr();
219+
220+
phi::DenseTensor new_scale, new_bias;
221+
if (Scale) {
222+
new_scale = scale.get();
223+
} else {
224+
new_scale.Resize({C});
225+
dev_ctx.template Alloc<T>(&new_scale);
226+
FillMLUTensorWithHostValue<T>(dev_ctx, static_cast<T>(1), &new_scale);
227+
}
228+
229+
if (Bias) {
230+
new_bias = bias.get();
231+
} else {
232+
new_bias.Resize({C});
233+
dev_ctx.template Alloc<T>(&new_bias);
234+
FillMLUTensorWithHostValue<T>(dev_ctx, static_cast<T>(0), &new_bias);
235+
}
176236

177237
Tensor d_x_tmp;
178238
if (d_x == nullptr) {
@@ -182,12 +242,12 @@ void BatchNormGradKernel(
182242
Tensor scale_grad_tmp;
183243
if (d_scale == nullptr) {
184244
d_scale = &scale_grad_tmp;
185-
d_scale->Resize(scale.dims());
245+
d_scale->Resize(new_scale.dims());
186246
}
187247
Tensor bias_grad_tmp;
188248
if (d_bias == nullptr) {
189249
d_bias = &bias_grad_tmp;
190-
d_bias->Resize(bias.dims());
250+
d_bias->Resize(new_bias.dims());
191251
}
192252

193253
dev_ctx.template Alloc<T>(d_x);
@@ -197,26 +257,6 @@ void BatchNormGradKernel(
197257

198258
use_global_stats = is_test || use_global_stats;
199259

200-
const auto& x_dims = x.dims();
201-
PADDLE_ENFORCE_GE(
202-
x_dims.size(),
203-
2,
204-
phi::errors::InvalidArgument(
205-
"The size of input X's dimensions should be larger than 1."
206-
"But received: the size of input X's dimensions is [%d]",
207-
x_dims.size()));
208-
PADDLE_ENFORCE_LE(
209-
x_dims.size(),
210-
5,
211-
phi::errors::InvalidArgument(
212-
"The size of input X's dimensions should be less than 6."
213-
"But received: the size of input X's dimensions is [%d]",
214-
x_dims.size()));
215-
const int N = x_dims[0];
216-
const int C = (data_layout == DataLayout::kNCHW ? x_dims[1]
217-
: x_dims[x_dims.size() - 1]);
218-
const int sample_size = x.numel() / N / C;
219-
220260
Tensor transformed_d_y;
221261
Tensor transformed_x;
222262
Tensor transformed_d_x;
@@ -227,7 +267,7 @@ void BatchNormGradKernel(
227267
transformed_shape,
228268
ToCnnlDataType<T>(),
229269
CNNL_LAYOUT_NHWC);
230-
MLUCnnlTensorDesc others_input_desc(scale);
270+
MLUCnnlTensorDesc others_input_desc(new_scale);
231271

232272
bool need_transpose =
233273
(data_layout == DataLayout::kNCHW && x_dims.size() != 2);
@@ -286,7 +326,7 @@ void BatchNormGradKernel(
286326
transformed_desc.get(),
287327
GetBasePtr(&transformed_x),
288328
others_input_desc.get(),
289-
GetBasePtr(&scale),
329+
GetBasePtr(&new_scale),
290330
GetBasePtr(running_mean),
291331
GetBasePtr(running_variance),
292332
epsilon,
@@ -305,7 +345,7 @@ void BatchNormGradKernel(
305345
transformed_desc.get(),
306346
GetBasePtr(&transformed_x),
307347
others_input_desc.get(),
308-
GetBasePtr(&scale),
348+
GetBasePtr(&new_scale),
309349
GetBasePtr(&saved_mean),
310350
GetBasePtr(&saved_inv_variance),
311351
epsilon,

0 commit comments

Comments
 (0)