diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp index 03842cc9bd3a..2972d79c4302 100644 --- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp +++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp @@ -684,7 +684,7 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op, // Convert the case regions. SmallVector caseSuccessors; - SmallVector caseValues; + SmallVector caseValues; caseSuccessors.reserve(op.getCases().size()); caseValues.reserve(op.getCases().size()); for (auto [region, value] : llvm::zip(op.getCaseRegions(), op.getCases())) { @@ -692,7 +692,7 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op, if (failed(block)) return failure(); caseSuccessors.push_back(*block); - caseValues.push_back(value); + caseValues.push_back(APInt(64, value)); } // Convert the default region. @@ -704,13 +704,12 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op, rewriter.setInsertionPointToEnd(condBlock); SmallVector caseOperands(caseSuccessors.size(), {}); - // Cast switch index to integer case value. + // Cast switch index to i64 to avoid truncation for large case values. Value caseValue = arith::IndexCastOp::create( - rewriter, op.getLoc(), rewriter.getI32Type(), op.getArg()); + rewriter, op.getLoc(), rewriter.getI64Type(), op.getArg()); cf::SwitchOp::create(rewriter, op.getLoc(), caseValue, *defaultBlock, - ValueRange(), rewriter.getDenseI32ArrayAttr(caseValues), - caseSuccessors, caseOperands); + ValueRange(), caseValues, caseSuccessors, caseOperands); rewriter.replaceOp(op, continueBlock->getArguments()); return success(); } diff --git a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir index 0c4f20e8d1a0..7ad8d594a23a 100644 --- a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir +++ b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir @@ -641,8 +641,8 @@ func.func @func_execute_region_elim_multi_yield() { // CHECK-LABEL: @index_switch func.func @index_switch(%i: index, %a: i32, %b: i32, %c: i32) -> i32 { - // CHECK: %[[CASE:.*]] = arith.index_cast %arg0 : index to i32 - // CHECK: cf.switch %[[CASE]] : i32 + // CHECK: %[[CASE:.*]] = arith.index_cast %arg0 : index to i64 + // CHECK: cf.switch %[[CASE]] : i64 // CHECK-NEXT: default: ^[[DEFAULT:.+]], // CHECK-NEXT: 0: ^[[bb1:.+]], // CHECK-NEXT: 1: ^[[bb2:.+]] @@ -667,6 +667,32 @@ func.func @index_switch(%i: index, %a: i32, %b: i32, %c: i32) -> i32 { return %0 : i32 } +// Verify that case values larger than INT32_MAX are not truncated (issue #111589). +// In particular, case 4294967296 (2^32) must not alias with case 0 after lowering. +// CHECK-LABEL: @index_switch_large_cases +func.func @index_switch_large_cases(%i: index) { + // CHECK: %[[CASE:.*]] = arith.index_cast %arg0 : index to i64 + // CHECK: cf.switch %[[CASE]] : i64, [ + // CHECK-NEXT: default: ^[[DEFAULT:.+]], + // CHECK-NEXT: 0: ^[[bb0:.+]], + // CHECK-NEXT: 4294967296: ^[[bb1:.+]], + // CHECK-NEXT: 8589934592: ^[[bb2:.+]] + scf.index_switch %i + case 0 { + scf.yield + } + case 4294967296 { // 2^32, previously truncated to 0 + scf.yield + } + case 8589934592 { // 2^33 + scf.yield + } + default { + scf.yield + } + return +} + // Note: scf.forall is lowered to scf.parallel, which is currently lowered to // scf.for and then to unstructured control flow. scf.parallel could lower more // efficiently to multi-threaded IR, at which point scf.forall would