From acbf3f318694a4f3c382caf040cf586e1ff02c5f Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Tue, 31 Mar 2026 00:46:28 +0200 Subject: [PATCH] [MLIR][SCF] Fix scf.index_switch lowering to preserve large case values (#189230) `IndexSwitchLowering` stored case values as `SmallVector`, which silently truncated any `int64_t` case value larger than INT32_MAX (e.g. `4294967296` became `0`). The `cf.switch` flag was also created via `arith.index_cast index -> i32`, losing the upper 32 bits on 64-bit platforms. Fix: store case values as `SmallVector` with 64-bit width, cast the index argument to `i64`, and use the `ArrayRef` overload of `cf::SwitchOp::create` so the resulting switch correctly uses `i64` case values and flag type. Fixes #111589 Assisted-by: Claude Code --- .../SCFToControlFlow/SCFToControlFlow.cpp | 11 ++++--- .../SCFToControlFlow/convert-to-cfg.mlir | 30 +++++++++++++++++-- 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp index 03842cc9bd3a..2972d79c4302 100644 --- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp +++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp @@ -684,7 +684,7 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op, // Convert the case regions. SmallVector caseSuccessors; - SmallVector caseValues; + SmallVector caseValues; caseSuccessors.reserve(op.getCases().size()); caseValues.reserve(op.getCases().size()); for (auto [region, value] : llvm::zip(op.getCaseRegions(), op.getCases())) { @@ -692,7 +692,7 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op, if (failed(block)) return failure(); caseSuccessors.push_back(*block); - caseValues.push_back(value); + caseValues.push_back(APInt(64, value)); } // Convert the default region. @@ -704,13 +704,12 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op, rewriter.setInsertionPointToEnd(condBlock); SmallVector caseOperands(caseSuccessors.size(), {}); - // Cast switch index to integer case value. + // Cast switch index to i64 to avoid truncation for large case values. Value caseValue = arith::IndexCastOp::create( - rewriter, op.getLoc(), rewriter.getI32Type(), op.getArg()); + rewriter, op.getLoc(), rewriter.getI64Type(), op.getArg()); cf::SwitchOp::create(rewriter, op.getLoc(), caseValue, *defaultBlock, - ValueRange(), rewriter.getDenseI32ArrayAttr(caseValues), - caseSuccessors, caseOperands); + ValueRange(), caseValues, caseSuccessors, caseOperands); rewriter.replaceOp(op, continueBlock->getArguments()); return success(); } diff --git a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir index 0c4f20e8d1a0..7ad8d594a23a 100644 --- a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir +++ b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir @@ -641,8 +641,8 @@ func.func @func_execute_region_elim_multi_yield() { // CHECK-LABEL: @index_switch func.func @index_switch(%i: index, %a: i32, %b: i32, %c: i32) -> i32 { - // CHECK: %[[CASE:.*]] = arith.index_cast %arg0 : index to i32 - // CHECK: cf.switch %[[CASE]] : i32 + // CHECK: %[[CASE:.*]] = arith.index_cast %arg0 : index to i64 + // CHECK: cf.switch %[[CASE]] : i64 // CHECK-NEXT: default: ^[[DEFAULT:.+]], // CHECK-NEXT: 0: ^[[bb1:.+]], // CHECK-NEXT: 1: ^[[bb2:.+]] @@ -667,6 +667,32 @@ func.func @index_switch(%i: index, %a: i32, %b: i32, %c: i32) -> i32 { return %0 : i32 } +// Verify that case values larger than INT32_MAX are not truncated (issue #111589). +// In particular, case 4294967296 (2^32) must not alias with case 0 after lowering. +// CHECK-LABEL: @index_switch_large_cases +func.func @index_switch_large_cases(%i: index) { + // CHECK: %[[CASE:.*]] = arith.index_cast %arg0 : index to i64 + // CHECK: cf.switch %[[CASE]] : i64, [ + // CHECK-NEXT: default: ^[[DEFAULT:.+]], + // CHECK-NEXT: 0: ^[[bb0:.+]], + // CHECK-NEXT: 4294967296: ^[[bb1:.+]], + // CHECK-NEXT: 8589934592: ^[[bb2:.+]] + scf.index_switch %i + case 0 { + scf.yield + } + case 4294967296 { // 2^32, previously truncated to 0 + scf.yield + } + case 8589934592 { // 2^33 + scf.yield + } + default { + scf.yield + } + return +} + // Note: scf.forall is lowered to scf.parallel, which is currently lowered to // scf.for and then to unstructured control flow. scf.parallel could lower more // efficiently to multi-threaded IR, at which point scf.forall would