[mlir] added a check in the walk to prevent catching a cos in a nested region (#190064)

The walk in SincosFusion may detect a cos within a nested region of the
sin block. This triggers an assertion in `isBeforeInBlock` later on.
Added a check within the walk so it filters operations in nested
regions, which are not in the same block and should not be fused anyway.

---------

Co-authored-by: Yebin Chon <ychon@nvidia.com>
This commit is contained in:
yebinchon 2026-04-01 23:10:56 -04:00 committed by GitHub
parent d52daeac79
commit 495e1a4257
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 25 additions and 4 deletions

View File

@ -27,13 +27,11 @@ struct SincosFusionPattern : OpRewritePattern<math::SinOp> {
mlir::arith::FastMathFlags sinFastMathFlags = sinOp.getFastmath();
math::CosOp cosOp = nullptr;
sinOp->getBlock()->walk([&](math::CosOp op) {
for (auto op : sinOp->getBlock()->getOps<math::CosOp>())
if (op.getOperand() == operand && op.getFastmath() == sinFastMathFlags) {
cosOp = op;
return WalkResult::interrupt();
break;
}
return WalkResult::advance();
});
if (!cosOp)
return failure();

View File

@ -74,6 +74,29 @@ func.func @sincos_no_fusion_different_block(%arg0 : f32, %flag : i1) -> f32 {
func.return %0 : f32
}
// CHECK-LABEL: func.func @sincos_no_fusion_nested_region(
// CHECK-SAME: %[[ARG0:.*]]: f32,
// CHECK-SAME: %[[ARG1:.*]]: i1) -> (f32, f32) {
// CHECK: %[[SIN:.*]] = math.sin %[[ARG0]] : f32
// CHECK: %[[IF:.*]] = scf.if %[[ARG1]] -> (f32) {
// CHECK: %[[COS:.*]] = math.cos %[[ARG0]] : f32
// CHECK: scf.yield %[[COS]] : f32
// CHECK: } else {
// CHECK: scf.yield %[[SIN]] : f32
// CHECK: }
// CHECK: return %[[SIN]], %[[IF]] : f32, f32
// CHECK: }
func.func @sincos_no_fusion_nested_region(%arg0 : f32, %flag : i1) -> (f32, f32) {
%s = math.sin %arg0 : f32
%r = scf.if %flag -> f32 {
%c = math.cos %arg0 : f32
scf.yield %c : f32
} else {
scf.yield %s : f32
}
func.return %s, %r : f32, f32
}
// CHECK-LABEL: func.func @sincos_fusion_preserve_fastmath(
// CHECK-SAME: %[[ARG0:.*]]: f32) -> (f32, f32) {
// CHECK: %[[VAL_0:.*]], %[[VAL_1:.*]] = math.sincos %[[ARG0]] fastmath<contract> : f32