[MLIR][ArithToLLVM] Fix index_cast on memref types generating invalid LLVM IR (#189227)

`arith.index_cast` and `arith.index_castui` accept memref operands (via
`IndexCastTypeConstraint`), but `IndexCastOpLowering::matchAndRewrite`
did not handle this case. When the operand was a memref, the conversion
framework substituted the converted LLVM struct type, and the lowering
incorrectly attempted to emit `llvm.sext`/`llvm.zext`/`llvm.trunc` on a
struct value, producing invalid LLVM IR.

Since LLVM uses opaque pointers, all memrefs with integer or index
element types lower to the same `\!llvm.struct<(ptr, ptr, i64, ...)>`
type, making `arith.index_cast` on memrefs a no-op at the LLVM level.
Add a check that treats the memref case as an identity conversion (same
as the same-bit-width path).

Fixes #92377

Assisted-by: Claude Code
This commit is contained in:
Mehdi Amini 2026-04-01 11:03:14 +02:00 committed by GitHub
parent b1f8c28559
commit 249e871fa4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 32 additions and 0 deletions

View File

@ -368,6 +368,14 @@ LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
return success();
}
// Memref index_cast is a no-op at the LLVM level since LLVM uses opaque
// pointers and memrefs of different integer/index element types all convert
// to the same LLVM struct type.
if (isa<MemRefType>(op.getIn().getType())) {
rewriter.replaceOp(op, adaptor.getIn());
return success();
}
bool isNonNeg = false;
if constexpr (std::is_same_v<ExtCastTy, LLVM::ZExtOp>)
isNonNeg = op.getNonNeg();

View File

@ -160,6 +160,30 @@ func.func @index_castui_nneg_not_set(%arg0: i1) {
// -----
// Memref index_cast is a no-op at the LLVM level since LLVM uses opaque
// pointers and all memrefs with integer or index element types convert to the
// same struct type. Verify that no sext/zext/trunc is generated.
// CHECK-LABEL: @memref_index_cast
// CHECK-NOT: llvm.sext
// CHECK-NOT: llvm.trunc
func.func @memref_index_cast(%arg0: memref<3xi32>) -> memref<3xindex> {
%0 = arith.index_cast %arg0 : memref<3xi32> to memref<3xindex>
return %0 : memref<3xindex>
}
// -----
// CHECK-LABEL: @memref_index_castui
// CHECK-NOT: llvm.zext
// CHECK-NOT: llvm.trunc
func.func @memref_index_castui(%arg0: memref<3xi32>) -> memref<3xindex> {
%0 = arith.index_castui %arg0 : memref<3xi32> to memref<3xindex>
return %0 : memref<3xindex>
}
// -----
// Checking conversion of signed integer types to floating point.
// CHECK-LABEL: @sitofp
func.func @sitofp(%arg0 : i32, %arg1 : i64) {