[mlir][LLVMIR] (NFC) Add convenience builders for ConstantOp

And clean up some of the user code
This commit is contained in:
Jeff Niu 2022-08-09 14:40:07 -04:00
parent c951edb7b2
commit 0af643f3ce
16 changed files with 77 additions and 92 deletions

View File

@ -53,8 +53,7 @@ static mlir::LLVM::ConstantOp
genConstantIndex(mlir::Location loc, mlir::Type ity,
mlir::ConversionPatternRewriter &rewriter,
std::int64_t offset) {
auto cattr = rewriter.getI64IntegerAttr(offset);
return rewriter.create<mlir::LLVM::ConstantOp>(loc, ity, cattr);
return rewriter.create<mlir::LLVM::ConstantOp>(loc, ity, offset);
}
static mlir::Block *createBlock(mlir::ConversionPatternRewriter &rewriter,
@ -102,8 +101,7 @@ protected:
genI32Constant(mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
int value) const {
mlir::Type i32Ty = rewriter.getI32Type();
mlir::IntegerAttr attr = rewriter.getI32IntegerAttr(value);
return rewriter.create<mlir::LLVM::ConstantOp>(loc, i32Ty, attr);
return rewriter.create<mlir::LLVM::ConstantOp>(loc, i32Ty, value);
}
mlir::LLVM::ConstantOp
@ -111,8 +109,7 @@ protected:
mlir::ConversionPatternRewriter &rewriter,
int offset) const {
mlir::Type ity = lowerTy().offsetType();
mlir::IntegerAttr cattr = rewriter.getI32IntegerAttr(offset);
return rewriter.create<mlir::LLVM::ConstantOp>(loc, ity, cattr);
return rewriter.create<mlir::LLVM::ConstantOp>(loc, ity, offset);
}
/// Perform an extension or truncation as needed on an integer value. Lowering
@ -630,13 +627,11 @@ struct StringLitOpConversion : public FIROpConversion<fir::StringLitOp> {
if (auto arr = attr.dyn_cast<mlir::DenseElementsAttr>()) {
cst = rewriter.create<mlir::LLVM::ConstantOp>(loc, ty, arr);
} else if (auto arr = attr.dyn_cast<mlir::ArrayAttr>()) {
for (auto a : llvm::enumerate(arr.getValue())) {
for (auto &a :
llvm::enumerate(arr.getAsValueRange<mlir::IntegerAttr>())) {
// convert each character to a precise bitsize
auto elemAttr = mlir::IntegerAttr::get(
intTy,
a.value().cast<mlir::IntegerAttr>().getValue().zextOrTrunc(bits));
auto elemCst =
rewriter.create<mlir::LLVM::ConstantOp>(loc, intTy, elemAttr);
auto elemCst = rewriter.create<mlir::LLVM::ConstantOp>(
loc, intTy, a.value().zextOrTrunc(bits));
auto index = mlir::ArrayAttr::get(
constop.getContext(), rewriter.getI32IntegerAttr(a.index()));
cst = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ty, cst, elemCst,
@ -733,12 +728,10 @@ struct ConstcOpConversion : public FIROpConversion<fir::ConstcOp> {
mlir::MLIRContext *ctx = conc.getContext();
mlir::Type ty = convertType(conc.getType());
mlir::Type ety = convertType(getComplexEleTy(conc.getType()));
auto realFloatAttr = mlir::FloatAttr::get(ety, getValue(conc.getReal()));
auto realPart =
rewriter.create<mlir::LLVM::ConstantOp>(loc, ety, realFloatAttr);
auto imFloatAttr = mlir::FloatAttr::get(ety, getValue(conc.getImaginary()));
auto imPart =
rewriter.create<mlir::LLVM::ConstantOp>(loc, ety, imFloatAttr);
auto realPart = rewriter.create<mlir::LLVM::ConstantOp>(
loc, ety, getValue(conc.getReal()));
auto imPart = rewriter.create<mlir::LLVM::ConstantOp>(
loc, ety, getValue(conc.getImaginary()));
auto realIndex = mlir::ArrayAttr::get(ctx, rewriter.getI32IntegerAttr(0));
auto imIndex = mlir::ArrayAttr::get(ctx, rewriter.getI32IntegerAttr(1));
auto undef = rewriter.create<mlir::LLVM::UndefOp>(loc, ty);

View File

@ -154,9 +154,8 @@ private:
// Get the pointer to the first character in the global string.
Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
Value cst0 = builder.create<LLVM::ConstantOp>(
loc, IntegerType::get(builder.getContext(), 64),
builder.getIntegerAttr(builder.getIndexType(), 0));
Value cst0 = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
builder.getIndexAttr(0));
return builder.create<LLVM::GEPOp>(
loc,
LLVM::LLVMPointerType::get(IntegerType::get(builder.getContext(), 8)),

View File

@ -154,10 +154,8 @@ private:
// Get the pointer to the first character in the global string.
Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
Value cst0 = builder.create<LLVM::ConstantOp>(
loc, IntegerType::get(builder.getContext(), 64),
builder.getIntegerAttr(builder.getIndexType(), 0));
return builder.create<LLVM::GEPOp>(
Value cst0 = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
builder.getIndexAttr(0));
loc,
LLVM::LLVMPointerType::get(IntegerType::get(builder.getContext(), 8)),
globalPtr, ArrayRef<Value>({cst0, cst0}));

View File

@ -1521,8 +1521,25 @@ def LLVM_ConstantOp
let arguments = (ins AnyAttr:$value);
let results = (outs LLVM_Type:$res);
let builders = [LLVM_OneResultOpBuilder];
let assemblyFormat = "`(` $value `)` attr-dict `:` type($res)";
let builders = [
LLVM_OneResultOpBuilder,
OpBuilder<(ins "Type":$type, "int64_t":$value), [{
build($_builder, $_state, type, $_builder.getIntegerAttr(type, value));
}]>,
OpBuilder<(ins "Type":$type, "const APInt &":$value), [{
build($_builder, $_state, type, $_builder.getIntegerAttr(type, value));
}]>,
OpBuilder<(ins "Type":$type, "const APFloat &":$value), [{
build($_builder, $_state, type, $_builder.getFloatAttr(type, value));
}]>,
OpBuilder<(ins "TypedAttr":$value), [{
build($_builder, $_state, value.getType(), value);
}]>
];
let hasFolder = 1;
let hasVerifier = 1;
}

View File

@ -18,9 +18,8 @@ using namespace mlir::amdgpu;
static Value createI32Constant(ConversionPatternRewriter &rewriter,
Location loc, int32_t value) {
IntegerAttr valAttr = rewriter.getI32IntegerAttr(value);
Type llvmI32 = rewriter.getI32Type();
return rewriter.create<LLVM::ConstantOp>(loc, llvmI32, valAttr);
return rewriter.create<LLVM::ConstantOp>(loc, llvmI32, value);
}
namespace {
@ -118,8 +117,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
MemRefDescriptor memrefDescriptor(memref);
Type llvmI64 = this->typeConverter->convertType(rewriter.getI64Type());
Type llvm2xI32 = this->typeConverter->convertType(VectorType::get(2, i32));
Value c32I64 = rewriter.create<LLVM::ConstantOp>(
loc, llvmI64, rewriter.getI64IntegerAttr(32));
Value c32I64 = rewriter.create<LLVM::ConstantOp>(loc, llvmI64, 32);
Value resource = rewriter.create<LLVM::UndefOp>(loc, llvm4xI32);

View File

@ -311,8 +311,8 @@ public:
auto loc = op->getLoc();
// Constants for initializing coroutine frame.
auto constZero = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
auto constZero =
rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 0);
auto nullPtr = rewriter.create<LLVM::NullOp>(loc, i8Ptr);
// Get coroutine id: @llvm.coro.id.
@ -351,7 +351,7 @@ public:
// parameter.
auto makeConstant = [&](uint64_t c) {
return rewriter.create<LLVM::ConstantOp>(
op->getLoc(), rewriter.getI64Type(), rewriter.getI64IntegerAttr(c));
op->getLoc(), rewriter.getI64Type(), c);
};
coroSize = rewriter.create<LLVM::AddOp>(op->getLoc(), coroSize, coroAlign);
coroSize =

View File

@ -120,8 +120,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
.template cast<Type>(),
allocaAddrSpace);
Value numElements = rewriter.create<LLVM::ConstantOp>(
gpuFuncOp.getLoc(), int64Ty,
rewriter.getI64IntegerAttr(type.getNumElements()));
gpuFuncOp.getLoc(), int64Ty, type.getNumElements());
Value allocated = rewriter.create<LLVM::AllocaOp>(
gpuFuncOp.getLoc(), ptrType, numElements, /*alignment=*/0);
auto descr = MemRefDescriptor::fromStaticShape(
@ -219,8 +218,7 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
{llvmI64, i8Ptr, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32}));
/// Start the printf hostcall
Value zeroI64 = rewriter.create<LLVM::ConstantOp>(
loc, llvmI64, rewriter.getI64IntegerAttr(0));
Value zeroI64 = rewriter.create<LLVM::ConstantOp>(loc, llvmI64, 0);
auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64);
Value printfDesc = printfBeginCall.getResult(0);
@ -251,13 +249,11 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
Value stringStart = rewriter.create<LLVM::GEPOp>(
loc, i8Ptr, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
Value stringLen = rewriter.create<LLVM::ConstantOp>(
loc, llvmI64, rewriter.getI64IntegerAttr(formatStringSize));
Value stringLen =
rewriter.create<LLVM::ConstantOp>(loc, llvmI64, formatStringSize);
Value oneI32 = rewriter.create<LLVM::ConstantOp>(
loc, llvmI32, rewriter.getI32IntegerAttr(1));
Value zeroI32 = rewriter.create<LLVM::ConstantOp>(
loc, llvmI32, rewriter.getI32IntegerAttr(0));
Value oneI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 1);
Value zeroI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 0);
auto appendFormatCall = rewriter.create<LLVM::CallOp>(
loc, ocklAppendStringN,
@ -274,8 +270,8 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
SmallVector<mlir::Value, 2 + argsPerAppend + 1> arguments;
arguments.push_back(printfDesc);
arguments.push_back(rewriter.create<LLVM::ConstantOp>(
loc, llvmI32, rewriter.getI32IntegerAttr(numArgsThisCall)));
arguments.push_back(
rewriter.create<LLVM::ConstantOp>(loc, llvmI32, numArgsThisCall));
for (size_t i = group; i < bound; ++i) {
Value arg = adaptor.args()[i];
if (auto floatType = arg.getType().dyn_cast<FloatType>()) {

View File

@ -675,12 +675,11 @@ Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray(
argumentTypes.push_back(argument.getType());
auto structType = LLVM::LLVMStructType::getNewIdentified(context, StringRef(),
argumentTypes);
auto one = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
builder.getI32IntegerAttr(1));
auto one = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type, 1);
auto structPtr = builder.create<LLVM::AllocaOp>(
loc, LLVM::LLVMPointerType::get(structType), one, /*alignment=*/0);
auto arraySize = builder.create<LLVM::ConstantOp>(
loc, llvmInt32Type, builder.getI32IntegerAttr(numArguments));
auto arraySize =
builder.create<LLVM::ConstantOp>(loc, llvmInt32Type, numArguments);
auto arrayPtr = builder.create<LLVM::AllocaOp>(loc, llvmPointerPointerType,
arraySize, /*alignment=*/0);
for (const auto &en : llvm::enumerate(arguments)) {
@ -786,8 +785,7 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
launchOp.getKernelName().getValue(), loc, rewriter);
auto function = moduleGetFunctionCallBuilder.create(
loc, rewriter, {module.getResult(0), kernelName});
auto zero = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
rewriter.getI32IntegerAttr(0));
auto zero = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type, 0);
Value stream =
adaptor.asyncDependencies().empty()
? streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0)

View File

@ -89,12 +89,9 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
auto resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
{valueTy, predTy});
Value one = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(1));
Value minusOne = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(-1));
Value thirtyTwo = rewriter.create<LLVM::ConstantOp>(
loc, int32Type, rewriter.getI32IntegerAttr(32));
Value one = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 1);
Value minusOne = rewriter.create<LLVM::ConstantOp>(loc, int32Type, -1);
Value thirtyTwo = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 32);
Value numLeadInactiveLane = rewriter.create<LLVM::SubOp>(
loc, int32Type, thirtyTwo, adaptor.width());
// Bit mask of active lanes: `(-1) >> (32 - activeWidth)`.

View File

@ -260,8 +260,7 @@ struct WmmaConstantOpToNVVMLowering
Value vecCst = rewriter.create<LLVM::UndefOp>(loc, vecType);
for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) {
Value idx = rewriter.create<LLVM::ConstantOp>(
loc, typeConverter->convertType(rewriter.getIntegerType(32)),
rewriter.getI32IntegerAttr(vecEl));
loc, rewriter.getI32Type(), vecEl);
vecCst = rewriter.create<LLVM::InsertElementOp>(loc, vecType, vecCst,
cst, idx);
}

View File

@ -218,15 +218,15 @@ void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
// Create LLVM constant for the descriptor set index.
// Bind all memrefs to the `0` descriptor set, the same way as `GPUToSPIRV`
// pass does.
Value descriptorSet = builder.create<LLVM::ConstantOp>(
loc, getInt32Type(), builder.getI32IntegerAttr(0));
Value descriptorSet =
builder.create<LLVM::ConstantOp>(loc, getInt32Type(), 0);
for (const auto &en :
llvm::enumerate(cInterfaceVulkanLaunchCallOp.getOperands().drop_front(
kVulkanLaunchNumConfigOperands))) {
// Create LLVM constant for the descriptor binding index.
Value descriptorBinding = builder.create<LLVM::ConstantOp>(
loc, getInt32Type(), builder.getI32IntegerAttr(en.index()));
Value descriptorBinding =
builder.create<LLVM::ConstantOp>(loc, getInt32Type(), en.index());
auto ptrToMemRefDescriptor = en.value();
uint32_t rank = 0;
@ -384,8 +384,7 @@ void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall(
// Create LLVM constant for the size of SPIR-V binary shader.
Value binarySize = builder.create<LLVM::ConstantOp>(
loc, getInt32Type(),
builder.getI32IntegerAttr(spirvAttributes.first.getValue().size()));
loc, getInt32Type(), spirvAttributes.first.getValue().size());
// Create call to `bindMemRef` for each memref operand.
createBindMemRefCalls(cInterfaceVulkanLaunchCallOp, vulkanRuntime);

View File

@ -98,8 +98,8 @@ void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
// integer attribute.
static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
Type resultType, int64_t value) {
return builder.create<LLVM::ConstantOp>(
loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value));
return builder.create<LLVM::ConstantOp>(loc, resultType,
builder.getIndexAttr(value));
}
/// Builds IR extracting the offset from the descriptor.

View File

@ -56,8 +56,8 @@ Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder,
Location loc,
Type resultType,
int64_t value) {
return builder.create<LLVM::ConstantOp>(
loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value));
return builder.create<LLVM::ConstantOp>(loc, resultType,
builder.getIndexAttr(value));
}
Value ConvertToLLVMPattern::createIndexConstant(

View File

@ -474,14 +474,11 @@ Type LLVMTypeConverter::packFunctionResults(TypeRange types) {
Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand,
OpBuilder &builder) {
auto *context = builder.getContext();
auto int64Ty = IntegerType::get(builder.getContext(), 64);
auto indexType = IndexType::get(context);
// Alloca with proper alignment. We do not expect optimizations of this
// alloca op and so we omit allocating at the entry block.
auto ptrType = LLVM::LLVMPointerType::get(operand.getType());
Value one = builder.create<LLVM::ConstantOp>(loc, int64Ty,
IntegerAttr::get(indexType, 1));
Value one = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
builder.getIndexAttr(1));
Value allocated =
builder.create<LLVM::AllocaOp>(loc, ptrType, one, /*alignment=*/0);
// Store into the alloca'ed descriptor.

View File

@ -57,12 +57,10 @@ struct CountOpLowering : public ConvertOpToLLVMPattern<MathOp> {
auto loc = op.getLoc();
auto resultType = op.getResult().getType();
auto boolType = rewriter.getIntegerType(1);
auto boolZero = rewriter.getIntegerAttr(boolType, 0);
auto boolZero = rewriter.getBoolAttr(false);
if (!operandType.template isa<LLVM::LLVMArrayType>()) {
LLVM::ConstantOp zero =
rewriter.create<LLVM::ConstantOp>(loc, boolType, boolZero);
LLVM::ConstantOp zero = rewriter.create<LLVM::ConstantOp>(loc, boolZero);
rewriter.replaceOpWithNewOp<LLVMOp>(op, resultType, adaptor.getOperand(),
zero);
return success();
@ -76,7 +74,7 @@ struct CountOpLowering : public ConvertOpToLLVMPattern<MathOp> {
op.getOperation(), adaptor.getOperands(), *this->getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) {
LLVM::ConstantOp zero =
rewriter.create<LLVM::ConstantOp>(loc, boolType, boolZero);
rewriter.create<LLVM::ConstantOp>(loc, boolZero);
return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0],
zero);
},

View File

@ -727,14 +727,12 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
// Replace with llvm.prefetch.
auto llvmI32Type = typeConverter->convertType(rewriter.getIntegerType(32));
auto isWrite = rewriter.create<LLVM::ConstantOp>(
loc, llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.getIsWrite()));
auto isWrite = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type,
prefetchOp.getIsWrite());
auto localityHint = rewriter.create<LLVM::ConstantOp>(
loc, llvmI32Type,
rewriter.getI32IntegerAttr(prefetchOp.getLocalityHint()));
loc, llvmI32Type, prefetchOp.getLocalityHint());
auto isData = rewriter.create<LLVM::ConstantOp>(
loc, llvmI32Type,
rewriter.getI32IntegerAttr(prefetchOp.getIsDataCache()));
loc, llvmI32Type, prefetchOp.getIsDataCache());
rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
localityHint, isData);
@ -889,9 +887,8 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
Value targetOffset = targetDesc.offset(rewriter, loc);
Value targetPtr = rewriter.create<LLVM::GEPOp>(loc, targetBasePtr.getType(),
targetBasePtr, targetOffset);
Value isVolatile = rewriter.create<LLVM::ConstantOp>(
loc, typeConverter->convertType(rewriter.getI1Type()),
rewriter.getBoolAttr(false));
Value isVolatile =
rewriter.create<LLVM::ConstantOp>(loc, rewriter.getBoolAttr(false));
rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize,
isVolatile);
rewriter.eraseOp(op);
@ -908,8 +905,8 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
// First make sure we have an unranked memref descriptor representation.
auto makeUnranked = [&, this](Value ranked, BaseMemRefType type) {
auto rank = rewriter.create<LLVM::ConstantOp>(
loc, getIndexType(), rewriter.getIndexAttr(type.getRank()));
auto rank = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
type.getRank());
auto *typeConverter = getTypeConverter();
auto ptr =
typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
@ -1524,8 +1521,7 @@ static void fillInStridesForCollapsedMemDescriptor(
break;
}
Value one = rewriter.create<LLVM::ConstantOp>(
loc, typeConverter->convertType(rewriter.getI64Type()),
rewriter.getI32IntegerAttr(1));
loc, rewriter.getI64Type(), rewriter.getI32IntegerAttr(1));
Value predNeOne = rewriter.create<LLVM::ICmpOp>(
loc, LLVM::ICmpPredicate::ne, srcDesc.size(rewriter, loc, srcIndex),
one);