[mlir] added a check in the walk to prevent catching a cos in a nested region (#190064)
The walk in SincosFusion may detect a cos within a nested region of the sin block. This triggers an assertion in `isBeforeInBlock` later on. Added a check within the walk so it filters operations in nested regions, which are not in the same block and should not be fused anyway. --------- Co-authored-by: Yebin Chon <ychon@nvidia.com>
This commit is contained in:
parent
d52daeac79
commit
495e1a4257
@ -27,13 +27,11 @@ struct SincosFusionPattern : OpRewritePattern<math::SinOp> {
|
||||
mlir::arith::FastMathFlags sinFastMathFlags = sinOp.getFastmath();
|
||||
|
||||
math::CosOp cosOp = nullptr;
|
||||
sinOp->getBlock()->walk([&](math::CosOp op) {
|
||||
for (auto op : sinOp->getBlock()->getOps<math::CosOp>())
|
||||
if (op.getOperand() == operand && op.getFastmath() == sinFastMathFlags) {
|
||||
cosOp = op;
|
||||
return WalkResult::interrupt();
|
||||
break;
|
||||
}
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
if (!cosOp)
|
||||
return failure();
|
||||
|
||||
@ -74,6 +74,29 @@ func.func @sincos_no_fusion_different_block(%arg0 : f32, %flag : i1) -> f32 {
|
||||
func.return %0 : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @sincos_no_fusion_nested_region(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: f32,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: i1) -> (f32, f32) {
|
||||
// CHECK: %[[SIN:.*]] = math.sin %[[ARG0]] : f32
|
||||
// CHECK: %[[IF:.*]] = scf.if %[[ARG1]] -> (f32) {
|
||||
// CHECK: %[[COS:.*]] = math.cos %[[ARG0]] : f32
|
||||
// CHECK: scf.yield %[[COS]] : f32
|
||||
// CHECK: } else {
|
||||
// CHECK: scf.yield %[[SIN]] : f32
|
||||
// CHECK: }
|
||||
// CHECK: return %[[SIN]], %[[IF]] : f32, f32
|
||||
// CHECK: }
|
||||
func.func @sincos_no_fusion_nested_region(%arg0 : f32, %flag : i1) -> (f32, f32) {
|
||||
%s = math.sin %arg0 : f32
|
||||
%r = scf.if %flag -> f32 {
|
||||
%c = math.cos %arg0 : f32
|
||||
scf.yield %c : f32
|
||||
} else {
|
||||
scf.yield %s : f32
|
||||
}
|
||||
func.return %s, %r : f32, f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @sincos_fusion_preserve_fastmath(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: f32) -> (f32, f32) {
|
||||
// CHECK: %[[VAL_0:.*]], %[[VAL_1:.*]] = math.sincos %[[ARG0]] fastmath<contract> : f32
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user