[CIR] Use zero-initializer for partial array fills (#154161)

If an array initializer list leaves eight or more elements that require
zero fill, we had been generating an individual zero element for every
one of them. This change instead follows the behavior of classic
codegen, which creates a constant structure with the specified elements
followed by a zero-initializer for the trailing zeros.
This commit is contained in:
Andy Kaylor 2025-08-19 12:14:05 -07:00 committed by GitHub
parent 0542355147
commit 6747139bc2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 241 additions and 19 deletions

View File

@ -341,6 +341,44 @@ def CIR_ConstVectorAttr : CIR_Attr<"ConstVector", "const_vector", [
let genVerifyDecl = 1;
}
//===----------------------------------------------------------------------===//
// ConstRecordAttr
//===----------------------------------------------------------------------===//
def CIR_ConstRecordAttr : CIR_Attr<"ConstRecord", "const_record", [
TypedAttrInterface
]> {
let summary = "Represents a constant record";
let description = [{
Effectively supports "struct-like" constants. It's must be built from
an `mlir::ArrayAttr` instance where each element is a typed attribute
(`mlir::TypedAttribute`).
Example:
```
cir.global external @rgb2 = #cir.const_record<{0 : i8,
5 : i64, #cir.null : !cir.ptr<i8>
}> : !cir.record<"", i8, i64, !cir.ptr<i8>>
```
}];
let parameters = (ins AttributeSelfTypeParameter<"">:$type,
"mlir::ArrayAttr":$members);
let builders = [
AttrBuilderWithInferredContext<(ins "cir::RecordType":$type,
"mlir::ArrayAttr":$members), [{
return $_get(type.getContext(), type, members);
}]>
];
let assemblyFormat = [{
`<` custom<RecordMembers>($members) `>`
}];
let genVerifyDecl = 1;
}
//===----------------------------------------------------------------------===//
// ConstPtrAttr
//===----------------------------------------------------------------------===//

View File

@ -60,6 +60,23 @@ public:
trailingZerosNum);
}
cir::ConstRecordAttr getAnonConstRecord(mlir::ArrayAttr arrayAttr,
bool packed = false,
bool padded = false,
mlir::Type ty = {}) {
llvm::SmallVector<mlir::Type, 4> members;
for (auto &f : arrayAttr) {
auto ta = mlir::cast<mlir::TypedAttr>(f);
members.push_back(ta.getType());
}
if (!ty)
ty = getAnonRecordTy(members, packed, padded);
auto sTy = mlir::cast<cir::RecordType>(ty);
return cir::ConstRecordAttr::get(sTy, arrayAttr);
}
std::string getUniqueAnonRecordName() { return getUniqueRecordName("anon"); }
std::string getUniqueRecordName(const std::string &baseName) {

View File

@ -285,7 +285,7 @@ emitArrayConstant(CIRGenModule &cgm, mlir::Type desiredType,
mlir::Type commonElementType, unsigned arrayBound,
SmallVectorImpl<mlir::TypedAttr> &elements,
mlir::TypedAttr filler) {
const CIRGenBuilderTy &builder = cgm.getBuilder();
CIRGenBuilderTy &builder = cgm.getBuilder();
unsigned nonzeroLength = arrayBound;
if (elements.size() < nonzeroLength && builder.isNullValue(filler))
@ -306,6 +306,33 @@ emitArrayConstant(CIRGenModule &cgm, mlir::Type desiredType,
if (trailingZeroes >= 8) {
assert(elements.size() >= nonzeroLength &&
"missing initializer for non-zero element");
if (commonElementType && nonzeroLength >= 8) {
// If all the elements had the same type up to the trailing zeroes and
// there are eight or more nonzero elements, emit a struct of two arrays
// (the nonzero data and the zeroinitializer).
SmallVector<mlir::Attribute, 4> eles;
eles.reserve(nonzeroLength);
for (const auto &element : elements)
eles.push_back(element);
auto initial = cir::ConstArrayAttr::get(
cir::ArrayType::get(commonElementType, nonzeroLength),
mlir::ArrayAttr::get(builder.getContext(), eles));
elements.resize(2);
elements[0] = initial;
} else {
// Otherwise, emit a struct with individual elements for each nonzero
// initializer, followed by a zeroinitializer array filler.
elements.resize(nonzeroLength + 1);
}
mlir::Type fillerType =
commonElementType
? commonElementType
: mlir::cast<cir::ArrayType>(desiredType).getElementType();
fillerType = cir::ArrayType::get(fillerType, trailingZeroes);
elements.back() = cir::ZeroAttr::get(fillerType);
commonElementType = nullptr;
} else if (elements.size() != arrayBound) {
elements.resize(arrayBound, filler);
@ -325,8 +352,13 @@ emitArrayConstant(CIRGenModule &cgm, mlir::Type desiredType,
mlir::ArrayAttr::get(builder.getContext(), eles));
}
cgm.errorNYI("array with different type elements");
return {};
SmallVector<mlir::Attribute, 4> eles;
eles.reserve(elements.size());
for (auto const &element : elements)
eles.push_back(element);
auto arrAttr = mlir::ArrayAttr::get(builder.getContext(), eles);
return builder.getAnonConstRecord(arrAttr, /*isPacked=*/true);
}
//===----------------------------------------------------------------------===//

View File

@ -15,6 +15,14 @@
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/TypeSwitch.h"
//===-----------------------------------------------------------------===//
// RecordMembers
//===-----------------------------------------------------------------===//
static void printRecordMembers(mlir::AsmPrinter &p, mlir::ArrayAttr members);
static mlir::ParseResult parseRecordMembers(mlir::AsmParser &parser,
mlir::ArrayAttr &members);
//===-----------------------------------------------------------------===//
// IntLiteral
//===-----------------------------------------------------------------===//
@ -68,6 +76,61 @@ void CIRDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const {
llvm_unreachable("unexpected CIR type kind");
}
static void printRecordMembers(mlir::AsmPrinter &printer,
mlir::ArrayAttr members) {
printer << '{';
llvm::interleaveComma(members, printer);
printer << '}';
}
static ParseResult parseRecordMembers(mlir::AsmParser &parser,
mlir::ArrayAttr &members) {
llvm::SmallVector<mlir::Attribute, 4> elts;
auto delimiter = AsmParser::Delimiter::Braces;
auto result = parser.parseCommaSeparatedList(delimiter, [&]() {
mlir::TypedAttr attr;
if (parser.parseAttribute(attr).failed())
return mlir::failure();
elts.push_back(attr);
return mlir::success();
});
if (result.failed())
return mlir::failure();
members = mlir::ArrayAttr::get(parser.getContext(), elts);
return mlir::success();
}
//===----------------------------------------------------------------------===//
// ConstRecordAttr definitions
//===----------------------------------------------------------------------===//
LogicalResult
ConstRecordAttr::verify(function_ref<InFlightDiagnostic()> emitError,
mlir::Type type, ArrayAttr members) {
auto sTy = mlir::dyn_cast_if_present<cir::RecordType>(type);
if (!sTy)
return emitError() << "expected !cir.record type";
if (sTy.getMembers().size() != members.size())
return emitError() << "number of elements must match";
unsigned attrIdx = 0;
for (auto &member : sTy.getMembers()) {
auto m = mlir::cast<mlir::TypedAttr>(members[attrIdx]);
if (member != m.getType())
return emitError() << "element at index " << attrIdx << " has type "
<< m.getType()
<< " but the expected type for this element is "
<< member;
attrIdx++;
}
return success();
}
//===----------------------------------------------------------------------===//
// OptInfoAttr definitions
//===----------------------------------------------------------------------===//

View File

@ -341,8 +341,8 @@ static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
}
if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr,
cir::ConstComplexAttr, cir::GlobalViewAttr, cir::PoisonAttr>(
attrType))
cir::ConstComplexAttr, cir::ConstRecordAttr,
cir::GlobalViewAttr, cir::PoisonAttr>(attrType))
return success();
assert(isa<TypedAttr>(attrType) && "What else could we be looking at here?");

View File

@ -201,8 +201,8 @@ public:
mlir::Value visit(mlir::Attribute attr) {
return llvm::TypeSwitch<mlir::Attribute, mlir::Value>(attr)
.Case<cir::IntAttr, cir::FPAttr, cir::ConstComplexAttr,
cir::ConstArrayAttr, cir::ConstVectorAttr, cir::ConstPtrAttr,
cir::GlobalViewAttr, cir::ZeroAttr>(
cir::ConstArrayAttr, cir::ConstRecordAttr, cir::ConstVectorAttr,
cir::ConstPtrAttr, cir::GlobalViewAttr, cir::ZeroAttr>(
[&](auto attrT) { return visitCirAttr(attrT); })
.Default([&](auto attrT) { return mlir::Value(); });
}
@ -212,6 +212,7 @@ public:
mlir::Value visitCirAttr(cir::ConstComplexAttr complexAttr);
mlir::Value visitCirAttr(cir::ConstPtrAttr ptrAttr);
mlir::Value visitCirAttr(cir::ConstArrayAttr attr);
mlir::Value visitCirAttr(cir::ConstRecordAttr attr);
mlir::Value visitCirAttr(cir::ConstVectorAttr attr);
mlir::Value visitCirAttr(cir::GlobalViewAttr attr);
mlir::Value visitCirAttr(cir::ZeroAttr attr);
@ -386,6 +387,21 @@ mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstArrayAttr attr) {
return result;
}
/// ConstRecord visitor.
mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstRecordAttr constRecord) {
const mlir::Type llvmTy = converter->convertType(constRecord.getType());
const mlir::Location loc = parentOp->getLoc();
mlir::Value result = rewriter.create<mlir::LLVM::UndefOp>(loc, llvmTy);
// Iteratively lower each constant element of the record.
for (auto [idx, elt] : llvm::enumerate(constRecord.getMembers())) {
mlir::Value init = visit(elt);
result = rewriter.create<mlir::LLVM::InsertValueOp>(loc, result, init, idx);
}
return result;
}
/// ConstVectorAttr visitor.
mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstVectorAttr attr) {
const mlir::Type llvmTy = converter->convertType(attr.getType());
@ -1286,6 +1302,11 @@ mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite(
rewriter.eraseOp(op);
return mlir::success();
}
} else if (const auto recordAttr =
mlir::dyn_cast<cir::ConstRecordAttr>(op.getValue())) {
auto initVal = lowerCirAttrAsValue(op, recordAttr, rewriter, typeConverter);
rewriter.replaceOp(op, initVal);
return mlir::success();
} else if (const auto vecTy = mlir::dyn_cast<cir::VectorType>(op.getType())) {
rewriter.replaceOp(op, lowerCirAttrAsValue(op, op.getValue(), rewriter,
getTypeConverter()));
@ -1527,9 +1548,9 @@ CIRToLLVMGlobalOpLowering::matchAndRewriteRegionInitializedGlobal(
cir::GlobalOp op, mlir::Attribute init,
mlir::ConversionPatternRewriter &rewriter) const {
// TODO: Generalize this handling when more types are needed here.
assert(
(isa<cir::ConstArrayAttr, cir::ConstVectorAttr, cir::ConstPtrAttr,
cir::ConstComplexAttr, cir::GlobalViewAttr, cir::ZeroAttr>(init)));
assert((isa<cir::ConstArrayAttr, cir::ConstRecordAttr, cir::ConstVectorAttr,
cir::ConstPtrAttr, cir::ConstComplexAttr, cir::GlobalViewAttr,
cir::ZeroAttr>(init)));
// TODO(cir): once LLVM's dialect has proper equivalent attributes this
// should be updated. For now, we use a custom op to initialize globals
@ -1582,8 +1603,9 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
return mlir::failure();
}
} else if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr,
cir::ConstPtrAttr, cir::ConstComplexAttr,
cir::GlobalViewAttr, cir::ZeroAttr>(init.value())) {
cir::ConstRecordAttr, cir::ConstPtrAttr,
cir::ConstComplexAttr, cir::GlobalViewAttr,
cir::ZeroAttr>(init.value())) {
// TODO(cir): once LLVM's dialect has proper equivalent attributes this
// should be updated. For now, we use a custom op to initialize globals
// to the appropriate value.

View File

@ -45,9 +45,9 @@ int dd[3][2] = {{1, 2}, {3, 4}, {5, 6}};
// OGCG: [i32 3, i32 4], [2 x i32] [i32 5, i32 6]]
int e[10] = {1, 2};
// CIR: cir.global external @e = #cir.const_array<[#cir.int<1> : !s32i, #cir.int<2> : !s32i], trailing_zeros> : !cir.array<!s32i x 10>
// CIR: cir.global external @e = #cir.const_record<{#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.zero : !cir.array<!s32i x 8>}> : !rec_anon_struct
// LLVM: @e = global [10 x i32] [i32 1, i32 2, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0]
// LLVM: @e = global <{ i32, i32, [8 x i32] }> <{ i32 1, i32 2, [8 x i32] zeroinitializer }>
// OGCG: @e = global <{ i32, i32, [8 x i32] }> <{ i32 1, i32 2, [8 x i32] zeroinitializer }>
@ -58,6 +58,28 @@ int f[5] = {1, 2};
// OGCG: @f = global [5 x i32] [i32 1, i32 2, i32 0, i32 0, i32 0]
int g[16] = {1, 2, 3, 4, 5, 6, 7, 8};
// CIR: cir.global external @g = #cir.const_record<{
// CIR-SAME: #cir.const_array<[#cir.int<1> : !s32i, #cir.int<2> : !s32i,
// CIR-SAME: #cir.int<3> : !s32i, #cir.int<4> : !s32i,
// CIR-SAME: #cir.int<5> : !s32i, #cir.int<6> : !s32i,
// CIR-SAME: #cir.int<7> : !s32i, #cir.int<8> : !s32i]>
// CIR-SAME: : !cir.array<!s32i x 8>,
// CIR-SAME: #cir.zero : !cir.array<!s32i x 8>}> : !rec_anon_struct1
// LLVM: @g = global <{ [8 x i32], [8 x i32] }>
// LLVM-SAME: <{ [8 x i32]
// LLVM-SAME: [i32 1, i32 2, i32 3, i32 4,
// LLVM-SAME: i32 5, i32 6, i32 7, i32 8],
// LLVM-SAME: [8 x i32] zeroinitializer }>
// OGCG: @g = global <{ [8 x i32], [8 x i32] }>
// OGCG-SAME: <{ [8 x i32]
// OGCG-SAME: [i32 1, i32 2, i32 3, i32 4,
// OGCG-SAME: i32 5, i32 6, i32 7, i32 8],
// OGCG-SAME: [8 x i32] zeroinitializer }>
extern int b[10];
// CIR: cir.global "private" external @b : !cir.array<!s32i x 10>
// LLVM: @b = external global [10 x i32]

View File

@ -0,0 +1,23 @@
// RUN: cir-opt %s -verify-diagnostics -split-input-file
!s32i = !cir.int<s, 32>
!rec_anon_struct = !cir.record<struct packed {!s32i, !s32i, !cir.array<!s32i x 8>}>
// expected-error @below {{expected !cir.record type}}
cir.global external @e = #cir.const_record<{#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.zero : !cir.array<!s32i x 8>}> : !cir.ptr<!rec_anon_struct>
// -----
!s32i = !cir.int<s, 32>
!rec_anon_struct = !cir.record<struct packed {!s32i, !s32i, !cir.array<!s32i x 8>}>
// expected-error @below {{number of elements must match}}
cir.global external @e = #cir.const_record<{#cir.int<1> : !s32i, #cir.zero : !cir.array<!s32i x 8>}> : !rec_anon_struct
// -----
!s32i = !cir.int<s, 32>
!rec_anon_struct = !cir.record<struct packed {!s32i, !s32i, !cir.array<!s32i x 8>}>
// expected-error @below {{element at index 1 has type '!cir.float' but the expected type for this element is '!cir.int<s, 32>'}}
cir.global external @e = #cir.const_record<{#cir.int<1> : !s32i, #cir.fp<2.000000e+00> : !cir.float, #cir.zero : !cir.array<!s32i x 8>}> : !rec_anon_struct

View File

@ -13,8 +13,9 @@
// CHECK-DAG: !rec_S = !cir.record<struct "S" incomplete>
// CHECK-DAG: !rec_U = !cir.record<union "U" incomplete>
!rec_anon_struct = !cir.record<struct {!cir.array<!cir.ptr<!u8i> x 5>}>
!rec_anon_struct1 = !cir.record<struct {!cir.ptr<!u8i>, !cir.ptr<!u8i>, !cir.ptr<!u8i>}>
!rec_anon_struct = !cir.record<struct packed {!s32i, !s32i, !cir.array<!s32i x 8>}>
!rec_anon_struct1 = !cir.record<struct {!cir.array<!cir.ptr<!u8i> x 5>}>
!rec_anon_struct2 = !cir.record<struct {!cir.ptr<!u8i>, !cir.ptr<!u8i>, !cir.ptr<!u8i>}>
!rec_S1 = !cir.record<struct "S1" {!s32i, !s32i}>
!rec_Sc = !cir.record<struct "Sc" {!u8i, !u16i, !u32i}>
@ -42,18 +43,22 @@
!rec_Node = !cir.record<struct "Node" {!cir.ptr<!cir.record<struct "Node">>}>
// CHECK-DAG: !cir.record<struct "Node" {!cir.ptr<!cir.record<struct "Node">>}>
module {
cir.global external @p1 = #cir.ptr<null> : !cir.ptr<!rec_S>
cir.global external @p2 = #cir.ptr<null> : !cir.ptr<!rec_U>
cir.global external @p3 = #cir.ptr<null> : !cir.ptr<!rec_C>
cir.global external @arr = #cir.const_record<{#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.zero : !cir.array<!s32i x 8>}> : !rec_anon_struct
// CHECK: cir.global external @p1 = #cir.ptr<null> : !cir.ptr<!rec_S>
// CHECK: cir.global external @p2 = #cir.ptr<null> : !cir.ptr<!rec_U>
// CHECK: cir.global external @p3 = #cir.ptr<null> : !cir.ptr<!rec_C>
// CHECK: cir.global external @arr = #cir.const_record<{#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.zero : !cir.array<!s32i x 8>}> : !rec_anon_struct
// Dummy function to use types and force them to be printed.
cir.func @useTypes(%arg0: !rec_Node,
%arg1: !rec_anon_struct1,
%arg2: !rec_anon_struct,
%arg1: !rec_anon_struct2,
%arg2: !rec_anon_struct1,
%arg3: !rec_S1,
%arg4: !rec_Ac,
%arg5: !rec_P1,

View File

@ -19,7 +19,7 @@ int dd[3][2] = {{1, 2}, {3, 4}, {5, 6}};
// CHECK: [i32 3, i32 4], [2 x i32] [i32 5, i32 6]]
int e[10] = {1, 2};
// CHECK: @e = global [10 x i32] [i32 1, i32 2, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0]
// CHECK: @e = global <{ i32, i32, [8 x i32] }> <{ i32 1, i32 2, [8 x i32] zeroinitializer }>
int f[5] = {1, 2};
// CHECK: @f = global [5 x i32] [i32 1, i32 2, i32 0, i32 0, i32 0]