[MLIR][SCF] Fix scf.index_switch lowering to preserve large case values (#189230)
`IndexSwitchLowering` stored case values as `SmallVector<int32_t>`, which silently truncated any `int64_t` case value larger than INT32_MAX (e.g. `4294967296` became `0`). The `cf.switch` flag was also created via `arith.index_cast index -> i32`, losing the upper 32 bits on 64-bit platforms. Fix: store case values as `SmallVector<APInt>` with 64-bit width, cast the index argument to `i64`, and use the `ArrayRef<APInt>` overload of `cf::SwitchOp::create` so the resulting switch correctly uses `i64` case values and flag type. Fixes #111589 Assisted-by: Claude Code
This commit is contained in:
parent
5da2546594
commit
acbf3f3186
@ -684,7 +684,7 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
|
||||
|
||||
// Convert the case regions.
|
||||
SmallVector<Block *> caseSuccessors;
|
||||
SmallVector<int32_t> caseValues;
|
||||
SmallVector<APInt> caseValues;
|
||||
caseSuccessors.reserve(op.getCases().size());
|
||||
caseValues.reserve(op.getCases().size());
|
||||
for (auto [region, value] : llvm::zip(op.getCaseRegions(), op.getCases())) {
|
||||
@ -692,7 +692,7 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
|
||||
if (failed(block))
|
||||
return failure();
|
||||
caseSuccessors.push_back(*block);
|
||||
caseValues.push_back(value);
|
||||
caseValues.push_back(APInt(64, value));
|
||||
}
|
||||
|
||||
// Convert the default region.
|
||||
@ -704,13 +704,12 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
|
||||
rewriter.setInsertionPointToEnd(condBlock);
|
||||
SmallVector<ValueRange> caseOperands(caseSuccessors.size(), {});
|
||||
|
||||
// Cast switch index to integer case value.
|
||||
// Cast switch index to i64 to avoid truncation for large case values.
|
||||
Value caseValue = arith::IndexCastOp::create(
|
||||
rewriter, op.getLoc(), rewriter.getI32Type(), op.getArg());
|
||||
rewriter, op.getLoc(), rewriter.getI64Type(), op.getArg());
|
||||
|
||||
cf::SwitchOp::create(rewriter, op.getLoc(), caseValue, *defaultBlock,
|
||||
ValueRange(), rewriter.getDenseI32ArrayAttr(caseValues),
|
||||
caseSuccessors, caseOperands);
|
||||
ValueRange(), caseValues, caseSuccessors, caseOperands);
|
||||
rewriter.replaceOp(op, continueBlock->getArguments());
|
||||
return success();
|
||||
}
|
||||
|
||||
@ -641,8 +641,8 @@ func.func @func_execute_region_elim_multi_yield() {
|
||||
|
||||
// CHECK-LABEL: @index_switch
|
||||
func.func @index_switch(%i: index, %a: i32, %b: i32, %c: i32) -> i32 {
|
||||
// CHECK: %[[CASE:.*]] = arith.index_cast %arg0 : index to i32
|
||||
// CHECK: cf.switch %[[CASE]] : i32
|
||||
// CHECK: %[[CASE:.*]] = arith.index_cast %arg0 : index to i64
|
||||
// CHECK: cf.switch %[[CASE]] : i64
|
||||
// CHECK-NEXT: default: ^[[DEFAULT:.+]],
|
||||
// CHECK-NEXT: 0: ^[[bb1:.+]],
|
||||
// CHECK-NEXT: 1: ^[[bb2:.+]]
|
||||
@ -667,6 +667,32 @@ func.func @index_switch(%i: index, %a: i32, %b: i32, %c: i32) -> i32 {
|
||||
return %0 : i32
|
||||
}
|
||||
|
||||
// Verify that case values larger than INT32_MAX are not truncated (issue #111589).
|
||||
// In particular, case 4294967296 (2^32) must not alias with case 0 after lowering.
|
||||
// CHECK-LABEL: @index_switch_large_cases
|
||||
func.func @index_switch_large_cases(%i: index) {
|
||||
// CHECK: %[[CASE:.*]] = arith.index_cast %arg0 : index to i64
|
||||
// CHECK: cf.switch %[[CASE]] : i64, [
|
||||
// CHECK-NEXT: default: ^[[DEFAULT:.+]],
|
||||
// CHECK-NEXT: 0: ^[[bb0:.+]],
|
||||
// CHECK-NEXT: 4294967296: ^[[bb1:.+]],
|
||||
// CHECK-NEXT: 8589934592: ^[[bb2:.+]]
|
||||
scf.index_switch %i
|
||||
case 0 {
|
||||
scf.yield
|
||||
}
|
||||
case 4294967296 { // 2^32, previously truncated to 0
|
||||
scf.yield
|
||||
}
|
||||
case 8589934592 { // 2^33
|
||||
scf.yield
|
||||
}
|
||||
default {
|
||||
scf.yield
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Note: scf.forall is lowered to scf.parallel, which is currently lowered to
|
||||
// scf.for and then to unstructured control flow. scf.parallel could lower more
|
||||
// efficiently to multi-threaded IR, at which point scf.forall would
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user