@@ -141,11 +141,13 @@ class RmsNormFusePattern : public paddle::drr::DrrPatternBase {
141141class AddRmsNormFusePattern : public paddle ::drr::DrrPatternBase {
142142 private:
143143 const bool extra_add_;
144+ const bool trans_extra_add_;
144145
145146 public:
146- explicit AddRmsNormFusePattern (bool extra_add) : extra_add_(extra_add) {}
147+ AddRmsNormFusePattern (bool extra_add, bool trans_extra_add)
148+ : extra_add_(extra_add), trans_extra_add_{trans_extra_add} {}
147149
148- uint32_t benefit () const override { return extra_add_ ? 2 : 1 ; }
150+ uint32_t benefit () const override { return extra_add_ ? 4 : 3 ; }
149151
150152 std::string name () const override { return " AddRmsNormFusePattern" ; }
151153
@@ -176,7 +178,9 @@ class AddRmsNormFusePattern : public paddle::drr::DrrPatternBase {
176178 if (extra_add_) {
177179 const auto &add1 = pat.Op (paddle::dialect::AddOp::name ());
178180 pat.Tensor (" add_out1" ) =
179- add1 (pat.Tensor (" add_out" ), pat.Tensor (" any_tensor" ));
181+ trans_extra_add_
182+ ? add1 (pat.Tensor (" any_tensor" ), pat.Tensor (" add_out" ))
183+ : add1 (pat.Tensor (" add_out" ), pat.Tensor (" any_tensor" ));
180184 }
181185 paddle::drr::ResultPattern res = pat.ResultPattern ();
182186 const auto &res_rms_norm =
@@ -207,11 +211,13 @@ class AddRmsNormFusePattern : public paddle::drr::DrrPatternBase {
207211class AddLayerNormFusePattern : public paddle ::drr::DrrPatternBase {
208212 private:
209213 const bool extra_add_;
214+ const bool trans_extra_add_;
210215
211216 public:
212- explicit AddLayerNormFusePattern (bool extra_add) : extra_add_(extra_add) {}
217+ AddLayerNormFusePattern (bool extra_add, bool trans_extra_add)
218+ : extra_add_(extra_add), trans_extra_add_{trans_extra_add} {}
213219
214- uint32_t benefit () const override { return extra_add_ ? 2 : 1 ; }
220+ uint32_t benefit () const override { return extra_add_ ? 4 : 3 ; }
215221 std::string name () const override { return " AddLayerNormFusePattern" ; }
216222
217223 void operator ()(paddle::drr::DrrPatternContext *ctx) const override {
@@ -231,22 +237,20 @@ class AddLayerNormFusePattern : public paddle::drr::DrrPatternBase {
231237 if (extra_add_) {
232238 const auto &add1 = pat.Op (paddle::dialect::AddOp::name ());
233239 pat.Tensor (" add_out1" ) =
234- add1 (pat.Tensor (" add_out" ), pat.Tensor (" any_tensor" ));
240+ trans_extra_add_
241+ ? add1 (pat.Tensor (" any_tensor" ), pat.Tensor (" add_out" ))
242+ : add1 (pat.Tensor (" add_out" ), pat.Tensor (" any_tensor" ));
235243 }
236244
237245 paddle::drr::ResultPattern res = pat.ResultPattern ();
238246 const auto &cast_op_dtype = res.ComputeAttr (
239247 [](const paddle::drr::MatchContext &match_ctx) -> phi::DataType {
240- auto x_dtype = pir::GetDataTypeFromValue (match_ctx.Tensor (" x" ));
241- return paddle::dialect::TransToPhiDataType (x_dtype);
248+ return phi::DataType::FLOAT32;
242249 });
243- const auto &cast_op_1 =
250+ const auto cast_1_op =
244251 res.Op (paddle::dialect::CastOp::name (), {{" dtype" , cast_op_dtype}});
245- res.Tensor (" casted_bias" ) = cast_op_1 (res.Tensor (" bias" ));
246- const auto &cast_op_2 =
252+ const auto cast_2_op =
247253 res.Op (paddle::dialect::CastOp::name (), {{" dtype" , cast_op_dtype}});
248- res.Tensor (" casted_w" ) = cast_op_2 (res.Tensor (" w" ));
249-
250254 const auto &fuse_layer_norm =
251255 res.Op (paddle::dialect::FusedBiasResidualLayernormOp::name (),
252256 {{" epsilon" , pat.Attr (" epsilon" )},
@@ -256,14 +260,15 @@ class AddLayerNormFusePattern : public paddle::drr::DrrPatternBase {
256260 {" quant_round_type" , res.Int32Attr (0 )},
257261 {" quant_max_bound" , res.Float32Attr (0.0 )},
258262 {" quant_min_bound" , res.Float32Attr (0.0 )}});
259-
263+ res.Tensor (" w_cast" ) = cast_1_op (res.Tensor (" w" ));
264+ res.Tensor (" bias_cast" ) = cast_1_op (res.Tensor (" bias" ));
260265 fuse_layer_norm (
261266 {
262267 &res.Tensor (" x" ),
263- &res.Tensor (" casted_bias" ),
264- &res.Tensor (" residual" ),
265- &res.Tensor (" casted_w" ),
266268 &res.InputNoneTensor (),
269+ &res.Tensor (" residual" ),
270+ &res.Tensor (" w_cast" ),
271+ &res.Tensor (" bias_cast" ),
267272 },
268273 {&res.Tensor (" layer_norm_out" ),
269274 &res.Tensor (" add_out" ),
@@ -272,6 +277,163 @@ class AddLayerNormFusePattern : public paddle::drr::DrrPatternBase {
272277 }
273278};
274279
280+ class AddGroupNormFusePattern : public paddle ::drr::DrrPatternBase {
281+ private:
282+ const bool extra_add_;
283+ const bool trans_extra_add_;
284+
285+ public:
286+ AddGroupNormFusePattern (bool extra_add, bool trans_extra_add)
287+ : extra_add_(extra_add), trans_extra_add_{trans_extra_add} {}
288+
289+ uint32_t benefit () const override { return extra_add_ ? 4 : 3 ; }
290+ std::string name () const override { return " AddGroupNormFusePattern" ; }
291+
292+ void operator ()(paddle::drr::DrrPatternContext *ctx) const override {
293+ paddle::drr::SourcePattern pat = ctx->SourcePattern ();
294+ const auto &add = pat.Op (paddle::dialect::AddOp::name ());
295+ const auto &group_norm = pat.Op (paddle::dialect::GroupNormOp::name (),
296+ {{" epsilon" , pat.Attr (" epsilon" )},
297+ {" groups" , pat.Attr (" groups" )},
298+ {" data_format" , pat.Attr (" data_format" )}});
299+ pat.Tensor (" add_out" ) = add (pat.Tensor (" x" ), pat.Tensor (" residual" ));
300+ group_norm (
301+ {&pat.Tensor (" add_out" ), &pat.Tensor (" scale" ), &pat.Tensor (" bias" )},
302+ {&pat.Tensor (" group_out" ),
303+ &pat.Tensor (" mean_out_0" ),
304+ &pat.Tensor (" variance_out_0" )});
305+ // TODO(bukejiyu) :DRR support matching placeholder op,
306+ // the following needs to be deleted
307+ if (extra_add_) {
308+ const auto &add1 = pat.Op (paddle::dialect::AddOp::name ());
309+ pat.Tensor (" add_out1" ) =
310+ trans_extra_add_
311+ ? add1 (pat.Tensor (" any_tensor" ), pat.Tensor (" add_out" ))
312+ : add1 (pat.Tensor (" add_out" ), pat.Tensor (" any_tensor" ));
313+ }
314+ pat.AddConstraint ([this ](const paddle::drr::MatchContext &match_ctx) {
315+ auto x_dtype = pir::GetDataTypeFromValue (match_ctx.Tensor (" x" ));
316+ if (!x_dtype.isa <pir::Float16Type>() &&
317+ !x_dtype.isa <pir::BFloat16Type>()) {
318+ return false ;
319+ }
320+ return true ;
321+ });
322+ paddle::drr::ResultPattern res = pat.ResultPattern ();
323+ const auto &add_group_norm_silu_op =
324+ res.Op (paddle::dialect::AddGroupNormSiluOp::name (),
325+ {{" epsilon" , pat.Attr (" epsilon" )},
326+ {" groups" , pat.Attr (" groups" )},
327+ {" data_format" , pat.Attr (" data_format" )},
328+ {" activation" , res.StrAttr (" " )}});
329+
330+ add_group_norm_silu_op ({&res.Tensor (" x" ),
331+ &res.Tensor (" residual" ),
332+ &res.Tensor (" scale" ),
333+ &res.Tensor (" bias" )},
334+ {&res.Tensor (" group_out" ),
335+ &res.Tensor (" add_out" ),
336+ &res.Tensor (" mean_out" ),
337+ &res.Tensor (" variance_out" )});
338+ }
339+ };
340+
341+ class AddGroupNormWithActPattern : public paddle ::drr::DrrPatternBase {
342+ public:
343+ uint32_t benefit () const override { return 2 ; }
344+ std::string name () const override { return " AddGroupNormWithActPattern" ; }
345+
346+ void operator ()(paddle::drr::DrrPatternContext *ctx) const override {
347+ paddle::drr::SourcePattern pat = ctx->SourcePattern ();
348+ const auto &add_group_norm_silu_op =
349+ pat.Op (paddle::dialect::AddGroupNormSiluOp::name (),
350+ {{" epsilon" , pat.Attr (" epsilon" )},
351+ {" groups" , pat.Attr (" groups" )},
352+ {" data_format" , pat.Attr (" data_format" )},
353+ {" activation" , pat.Attr (" activation" )}});
354+ const auto &silu = pat.Op (paddle::dialect::SiluOp::name ());
355+ add_group_norm_silu_op ({&pat.Tensor (" x" ),
356+ &pat.Tensor (" residual" ),
357+ &pat.Tensor (" scale" ),
358+ &pat.Tensor (" bias" )},
359+ {&pat.Tensor (" group_out" ),
360+ &pat.Tensor (" add_out" ),
361+ &pat.Tensor (" mean_out_0" ),
362+ &pat.Tensor (" variance_out_0" )});
363+ pat.Tensor (" silu_out" ) = silu (pat.Tensor (" group_out" ));
364+ pat.AddConstraint ([this ](const paddle::drr::MatchContext &match_ctx) {
365+ auto x_dtype = pir::GetDataTypeFromValue (match_ctx.Tensor (" x" ));
366+ if (!x_dtype.isa <pir::Float16Type>() &&
367+ !x_dtype.isa <pir::BFloat16Type>()) {
368+ return false ;
369+ }
370+ auto activation = match_ctx.Attr <std::string>(" activation" );
371+ if (activation != " " ) {
372+ return false ;
373+ }
374+ return true ;
375+ });
376+ paddle::drr::ResultPattern res = pat.ResultPattern ();
377+ const auto &res_add_group_norm_silu_op =
378+ res.Op (paddle::dialect::AddGroupNormSiluOp::name (),
379+ {{" epsilon" , pat.Attr (" epsilon" )},
380+ {" groups" , pat.Attr (" groups" )},
381+ {" data_format" , pat.Attr (" data_format" )},
382+ {" activation" , res.StrAttr (" silu" )}});
383+ res_add_group_norm_silu_op ({&res.Tensor (" x" ),
384+ &res.Tensor (" residual" ),
385+ &res.Tensor (" scale" ),
386+ &res.Tensor (" bias" )},
387+ {&res.Tensor (" silu_out" ),
388+ &res.Tensor (" add_out" ),
389+ &res.Tensor (" mean_out" ),
390+ &res.Tensor (" variance_out" )});
391+ }
392+ };
393+
394+ class GroupNormWithActPattern : public paddle ::drr::DrrPatternBase {
395+ public:
396+ uint32_t benefit () const override { return 1 ; }
397+ std::string name () const override { return " GroupNormWithActPattern" ; }
398+
399+ void operator ()(paddle::drr::DrrPatternContext *ctx) const override {
400+ paddle::drr::SourcePattern pat = ctx->SourcePattern ();
401+ const auto &group_norm = pat.Op (paddle::dialect::GroupNormOp::name (),
402+ {{" epsilon" , pat.Attr (" epsilon" )},
403+ {" groups" , pat.Attr (" groups" )},
404+ {" data_format" , pat.Attr (" data_format" )}});
405+ const auto &silu = pat.Op (paddle::dialect::SiluOp::name ());
406+ group_norm ({&pat.Tensor (" x" ), &pat.Tensor (" scale" ), &pat.Tensor (" bias" )},
407+ {&pat.Tensor (" group_out" ),
408+ &pat.Tensor (" mean_out_0" ),
409+ &pat.Tensor (" variance_out_0" )});
410+ pat.Tensor (" silu_out" ) = silu (pat.Tensor (" group_out" ));
411+ pat.AddConstraint ([this ](const paddle::drr::MatchContext &match_ctx) {
412+ auto x_dtype = pir::GetDataTypeFromValue (match_ctx.Tensor (" x" ));
413+ if (!x_dtype.isa <pir::Float16Type>() &&
414+ !x_dtype.isa <pir::BFloat16Type>()) {
415+ return false ;
416+ }
417+ return true ;
418+ });
419+ paddle::drr::ResultPattern res = pat.ResultPattern ();
420+ const auto &add_group_norm_silu_op =
421+ res.Op (paddle::dialect::AddGroupNormSiluOp::name (),
422+ {{" epsilon" , pat.Attr (" epsilon" )},
423+ {" groups" , pat.Attr (" groups" )},
424+ {" data_format" , pat.Attr (" data_format" )},
425+ {" activation" , res.StrAttr (" silu" )}});
426+ add_group_norm_silu_op ({&res.Tensor (" x" ),
427+ &res.InputNoneTensor (),
428+ &res.Tensor (" scale" ),
429+ &res.Tensor (" bias" )},
430+ {&res.Tensor (" silu_out" ),
431+ &res.OutputNoneTensor (),
432+ &res.Tensor (" mean_out" ),
433+ &res.Tensor (" variance_out" )});
434+ }
435+ };
436+
275437class AddNormFusePass : public pir ::PatternRewritePass {
276438 public:
277439 AddNormFusePass () : pir::PatternRewritePass(" add_norm_fuse_pass" , 2 ) {}
@@ -290,13 +452,37 @@ class AddNormFusePass : public pir::PatternRewritePass {
290452 // x--------
291453 // add-rms_norm ---> rms_norm
292454 // residual-
293- ps.Add (paddle::drr::Create<AddRmsNormFusePattern>(context, !extra_add));
294- ps.Add (paddle::drr::Create<AddRmsNormFusePattern>(context, extra_add));
455+ ps.Add (
456+ paddle::drr::Create<AddRmsNormFusePattern>(context, !extra_add, false ));
457+ ps.Add (
458+ paddle::drr::Create<AddRmsNormFusePattern>(context, extra_add, true ));
459+ ps.Add (
460+ paddle::drr::Create<AddRmsNormFusePattern>(context, extra_add, false ));
461+
295462 // x--------
296463 // add-layer_norm ----> fused_bias_residual_layernorm
297464 // residual-
298- ps.Add (paddle::drr::Create<AddLayerNormFusePattern>(context, !extra_add));
299- ps.Add (paddle::drr::Create<AddLayerNormFusePattern>(context, extra_add));
465+ ps.Add (paddle::drr::Create<AddLayerNormFusePattern>(
466+ context, !extra_add, false ));
467+ ps.Add (
468+ paddle::drr::Create<AddLayerNormFusePattern>(context, extra_add, true ));
469+ ps.Add (paddle::drr::Create<AddLayerNormFusePattern>(
470+ context, extra_add, false ));
471+
472+ // x--------
473+ // add-group_norm ----> add_group_norm_silu
474+ // residual-
475+ ps.Add (paddle::drr::Create<AddGroupNormFusePattern>(
476+ context, !extra_add, true ));
477+ ps.Add (
478+ paddle::drr::Create<AddGroupNormFusePattern>(context, extra_add, true ));
479+ ps.Add (paddle::drr::Create<AddGroupNormFusePattern>(
480+ context, extra_add, false ));
481+
482+ // add_group_norm_silu-silu --->add_group_norm_silu
483+ ps.Add (paddle::drr::Create<AddGroupNormWithActPattern>(context));
484+ // group-silu->add_group_norm_silu
485+ ps.Add (paddle::drr::Create<GroupNormWithActPattern>(context));
300486 return ps;
301487 }
302488};
0 commit comments