[MLIR][ArithToLLVM] Fix index_cast on memref types generating invalid LLVM IR (#189227)
`arith.index_cast` and `arith.index_castui` accept memref operands (via `IndexCastTypeConstraint`), but `IndexCastOpLowering::matchAndRewrite` did not handle this case. When the operand was a memref, the conversion framework substituted the converted LLVM struct type, and the lowering incorrectly attempted to emit `llvm.sext`/`llvm.zext`/`llvm.trunc` on a struct value, producing invalid LLVM IR. Since LLVM uses opaque pointers, all memrefs with integer or index element types lower to the same `\!llvm.struct<(ptr, ptr, i64, ...)>` type, making `arith.index_cast` on memrefs a no-op at the LLVM level. Add a check that treats the memref case as an identity conversion (same as the same-bit-width path). Fixes #92377 Assisted-by: Claude Code
This commit is contained in:
parent
b1f8c28559
commit
249e871fa4
@ -368,6 +368,14 @@ LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
|
||||
return success();
|
||||
}
|
||||
|
||||
// Memref index_cast is a no-op at the LLVM level since LLVM uses opaque
|
||||
// pointers and memrefs of different integer/index element types all convert
|
||||
// to the same LLVM struct type.
|
||||
if (isa<MemRefType>(op.getIn().getType())) {
|
||||
rewriter.replaceOp(op, adaptor.getIn());
|
||||
return success();
|
||||
}
|
||||
|
||||
bool isNonNeg = false;
|
||||
if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>)
|
||||
isNonNeg = op.getNonNeg();
|
||||
|
||||
@ -160,6 +160,30 @@ func.func @index_castui_nneg_not_set(%arg0: i1) {
|
||||
|
||||
// -----
|
||||
|
||||
// Memref index_cast is a no-op at the LLVM level since LLVM uses opaque
|
||||
// pointers and all memrefs with integer or index element types convert to the
|
||||
// same struct type. Verify that no sext/zext/trunc is generated.
|
||||
|
||||
// CHECK-LABEL: @memref_index_cast
|
||||
// CHECK-NOT: llvm.sext
|
||||
// CHECK-NOT: llvm.trunc
|
||||
func.func @memref_index_cast(%arg0: memref<3xi32>) -> memref<3xindex> {
|
||||
%0 = arith.index_cast %arg0 : memref<3xi32> to memref<3xindex>
|
||||
return %0 : memref<3xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @memref_index_castui
|
||||
// CHECK-NOT: llvm.zext
|
||||
// CHECK-NOT: llvm.trunc
|
||||
func.func @memref_index_castui(%arg0: memref<3xi32>) -> memref<3xindex> {
|
||||
%0 = arith.index_castui %arg0 : memref<3xi32> to memref<3xindex>
|
||||
return %0 : memref<3xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Checking conversion of signed integer types to floating point.
|
||||
// CHECK-LABEL: @sitofp
|
||||
func.func @sitofp(%arg0 : i32, %arg1 : i64) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user