[mlir][llvm] Port overflowFlags to a native operation property (RELAND) (#89410)

This PR changes the LLVM dialect's IntegerOverflowFlags to be stored on
operations as native properties.

Reland to fix flang
This commit is contained in:
Jeff Niu 2024-04-19 09:23:00 -07:00 committed by GitHub
parent d86079f93c
commit e553ac4d81
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 183 additions and 116 deletions

View File

@ -2110,9 +2110,8 @@ struct XArrayCoorOpConversion
const bool baseIsBoxed = coor.getMemref().getType().isa<fir::BaseBoxType>(); const bool baseIsBoxed = coor.getMemref().getType().isa<fir::BaseBoxType>();
TypePair baseBoxTyPair = TypePair baseBoxTyPair =
baseIsBoxed ? getBoxTypePair(coor.getMemref().getType()) : TypePair{}; baseIsBoxed ? getBoxTypePair(coor.getMemref().getType()) : TypePair{};
mlir::LLVM::IntegerOverflowFlagsAttr nsw = mlir::LLVM::IntegerOverflowFlags nsw =
mlir::LLVM::IntegerOverflowFlagsAttr::get( mlir::LLVM::IntegerOverflowFlags::nsw;
rewriter.getContext(), mlir::LLVM::IntegerOverflowFlags::nsw);
// For each dimension of the array, generate the offset calculation. // For each dimension of the array, generate the offset calculation.
for (unsigned i = 0; i < rank; ++i, ++indexOffset, ++shapeOffset, for (unsigned i = 0; i < rank; ++i, ++indexOffset, ++shapeOffset,
@ -2396,9 +2395,8 @@ private:
auto cpnTy = fir::dyn_cast_ptrOrBoxEleTy(boxObjTy); auto cpnTy = fir::dyn_cast_ptrOrBoxEleTy(boxObjTy);
mlir::Type llvmPtrTy = ::getLlvmPtrType(coor.getContext()); mlir::Type llvmPtrTy = ::getLlvmPtrType(coor.getContext());
mlir::Type byteTy = ::getI8Type(coor.getContext()); mlir::Type byteTy = ::getI8Type(coor.getContext());
mlir::LLVM::IntegerOverflowFlagsAttr nsw = mlir::LLVM::IntegerOverflowFlags nsw =
mlir::LLVM::IntegerOverflowFlagsAttr::get( mlir::LLVM::IntegerOverflowFlags::nsw;
rewriter.getContext(), mlir::LLVM::IntegerOverflowFlags::nsw);
for (unsigned i = 1, last = operands.size(); i < last; ++i) { for (unsigned i = 1, last = operands.size(); i < last; ++i) {
if (auto arrTy = cpnTy.dyn_cast<fir::SequenceType>()) { if (auto arrTy = cpnTy.dyn_cast<fir::SequenceType>()) {

View File

@ -31,11 +31,6 @@ convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr);
LLVM::IntegerOverflowFlags LLVM::IntegerOverflowFlags
convertArithOverflowFlagsToLLVM(arith::IntegerOverflowFlags arithFlags); convertArithOverflowFlagsToLLVM(arith::IntegerOverflowFlags arithFlags);
/// Creates an LLVM overflow attribute from a given arithmetic overflow
/// attribute.
LLVM::IntegerOverflowFlagsAttr
convertArithOverflowAttrToLLVM(arith::IntegerOverflowFlagsAttr flagsAttr);
/// Creates an LLVM rounding mode enum value from a given arithmetic rounding /// Creates an LLVM rounding mode enum value from a given arithmetic rounding
/// mode enum value. /// mode enum value.
LLVM::RoundingMode LLVM::RoundingMode
@ -72,6 +67,9 @@ public:
} }
ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); } ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
LLVM::IntegerOverflowFlags getOverflowFlags() const {
return LLVM::IntegerOverflowFlags::none;
}
private: private:
NamedAttrList convertedAttr; NamedAttrList convertedAttr;
@ -89,19 +87,18 @@ public:
// Get the name of the arith overflow attribute. // Get the name of the arith overflow attribute.
StringRef arithAttrName = SourceOp::getIntegerOverflowAttrName(); StringRef arithAttrName = SourceOp::getIntegerOverflowAttrName();
// Remove the source overflow attribute. // Remove the source overflow attribute.
auto arithAttr = dyn_cast_if_present<arith::IntegerOverflowFlagsAttr>( if (auto arithAttr = dyn_cast_if_present<arith::IntegerOverflowFlagsAttr>(
convertedAttr.erase(arithAttrName)); convertedAttr.erase(arithAttrName))) {
if (arithAttr) { overflowFlags = convertArithOverflowFlagsToLLVM(arithAttr.getValue());
StringRef targetAttrName = TargetOp::getIntegerOverflowAttrName();
convertedAttr.set(targetAttrName,
convertArithOverflowAttrToLLVM(arithAttr));
} }
} }
ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); } ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
LLVM::IntegerOverflowFlags getOverflowFlags() const { return overflowFlags; }
private: private:
NamedAttrList convertedAttr; NamedAttrList convertedAttr;
LLVM::IntegerOverflowFlags overflowFlags = LLVM::IntegerOverflowFlags::none;
}; };
template <typename SourceOp, typename TargetOp> template <typename SourceOp, typename TargetOp>
@ -132,6 +129,9 @@ public:
} }
ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); } ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
LLVM::IntegerOverflowFlags getOverflowFlags() const {
return LLVM::IntegerOverflowFlags::none;
}
private: private:
NamedAttrList convertedAttr; NamedAttrList convertedAttr;

View File

@ -11,6 +11,7 @@
#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
namespace mlir { namespace mlir {
@ -18,13 +19,16 @@ class CallOpInterface;
namespace LLVM { namespace LLVM {
namespace detail { namespace detail {
/// Handle generically setting flags as native properties on LLVM operations.
void setNativeProperties(Operation *op, IntegerOverflowFlags overflowFlags);
/// Replaces the given operation "op" with a new operation of type "targetOp" /// Replaces the given operation "op" with a new operation of type "targetOp"
/// and given operands. /// and given operands.
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, LogicalResult oneToOneRewrite(
ValueRange operands, Operation *op, StringRef targetOp, ValueRange operands,
ArrayRef<NamedAttribute> targetAttrs, ArrayRef<NamedAttribute> targetAttrs,
const LLVMTypeConverter &typeConverter, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
ConversionPatternRewriter &rewriter); IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
} // namespace detail } // namespace detail
} // namespace LLVM } // namespace LLVM

View File

@ -54,11 +54,11 @@ LogicalResult handleMultidimensionalVectors(
std::function<Value(Type, ValueRange)> createOperand, std::function<Value(Type, ValueRange)> createOperand,
ConversionPatternRewriter &rewriter); ConversionPatternRewriter &rewriter);
LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp, LogicalResult vectorOneToOneRewrite(
ValueRange operands, Operation *op, StringRef targetOp, ValueRange operands,
ArrayRef<NamedAttribute> targetAttrs, ArrayRef<NamedAttribute> targetAttrs,
const LLVMTypeConverter &typeConverter, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
ConversionPatternRewriter &rewriter); IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
} // namespace detail } // namespace detail
} // namespace LLVM } // namespace LLVM
@ -70,6 +70,9 @@ public:
AttrConvertPassThrough(SourceOp srcOp) : srcAttrs(srcOp->getAttrs()) {} AttrConvertPassThrough(SourceOp srcOp) : srcAttrs(srcOp->getAttrs()) {}
ArrayRef<NamedAttribute> getAttrs() const { return srcAttrs; } ArrayRef<NamedAttribute> getAttrs() const { return srcAttrs; }
LLVM::IntegerOverflowFlags getOverflowFlags() const {
return LLVM::IntegerOverflowFlags::none;
}
private: private:
ArrayRef<NamedAttribute> srcAttrs; ArrayRef<NamedAttribute> srcAttrs;
@ -100,7 +103,8 @@ public:
return LLVM::detail::vectorOneToOneRewrite( return LLVM::detail::vectorOneToOneRewrite(
op, TargetOp::getOperationName(), adaptor.getOperands(), op, TargetOp::getOperationName(), adaptor.getOperands(),
attrConvert.getAttrs(), *this->getTypeConverter(), rewriter); attrConvert.getAttrs(), *this->getTypeConverter(), rewriter,
attrConvert.getOverflowFlags());
} }
}; };
} // namespace mlir } // namespace mlir

View File

@ -50,58 +50,40 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
def IntegerOverflowFlagsInterface : OpInterface<"IntegerOverflowFlagsInterface"> { def IntegerOverflowFlagsInterface : OpInterface<"IntegerOverflowFlagsInterface"> {
let description = [{ let description = [{
Access to op integer overflow flags. This interface defines an LLVM operation with integer overflow flags and
provides a uniform API for accessing them.
}]; }];
let cppNamespace = "::mlir::LLVM"; let cppNamespace = "::mlir::LLVM";
let methods = [ let methods = [
InterfaceMethod< InterfaceMethod<[{
/*desc=*/ "Returns an IntegerOverflowFlagsAttr attribute for the operation", Get the integer overflow flags for the operation.
/*returnType=*/ "IntegerOverflowFlagsAttr", }], "IntegerOverflowFlags", "getOverflowFlags", (ins), [{}], [{
/*methodName=*/ "getOverflowAttr", return $_op.getProperties().overflowFlags;
/*args=*/ (ins), }]>,
/*methodBody=*/ [{}], InterfaceMethod<[{
/*defaultImpl=*/ [{ Set the integer overflow flags for the operation.
auto op = cast<ConcreteOp>(this->getOperation()); }], "void", "setOverflowFlags", (ins "IntegerOverflowFlags":$flags), [{}], [{
return op.getOverflowFlagsAttr(); $_op.getProperties().overflowFlags = flags;
}] }]>,
>, InterfaceMethod<[{
InterfaceMethod< Returns whether the operation has the No Unsigned Wrap keyword.
/*desc=*/ "Returns whether the operation has the No Unsigned Wrap keyword", }], "bool", "hasNoUnsignedWrap", (ins), [{}], [{
/*returnType=*/ "bool", return bitEnumContainsAll($_op.getOverflowFlags(),
/*methodName=*/ "hasNoUnsignedWrap", IntegerOverflowFlags::nuw);
/*args=*/ (ins), }]>,
/*methodBody=*/ [{}], InterfaceMethod<[{
/*defaultImpl=*/ [{ Returns whether the operation has the No Signed Wrap keyword.
auto op = cast<ConcreteOp>(this->getOperation()); }], "bool", "hasNoSignedWrap", (ins), [{}], [{
IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue(); return bitEnumContainsAll($_op.getOverflowFlags(),
return bitEnumContainsAll(flags, IntegerOverflowFlags::nuw); IntegerOverflowFlags::nsw);
}] }]>,
>, StaticInterfaceMethod<[{
InterfaceMethod< Get the attribute name of the overflow flags property.
/*desc=*/ "Returns whether the operation has the No Signed Wrap keyword", }], "StringRef", "getOverflowFlagsAttrName", (ins), [{}], [{
/*returnType=*/ "bool", return "overflowFlags";
/*methodName=*/ "hasNoSignedWrap", }]>,
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = cast<ConcreteOp>(this->getOperation());
IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
return bitEnumContainsAll(flags, IntegerOverflowFlags::nsw);
}]
>,
StaticInterfaceMethod<
/*desc=*/ [{Returns the name of the IntegerOverflowFlagsAttr attribute
for the operation}],
/*returnType=*/ "StringRef",
/*methodName=*/ "getIntegerOverflowAttrName",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
return "overflowFlags";
}]
>
]; ];
} }

View File

@ -59,17 +59,30 @@ class LLVM_IntArithmeticOpWithOverflowFlag<string mnemonic, string instName,
list<Trait> traits = []> : list<Trait> traits = []> :
LLVM_ArithmeticOpBase<AnySignlessInteger, mnemonic, instName, LLVM_ArithmeticOpBase<AnySignlessInteger, mnemonic, instName,
!listconcat([DeclareOpInterfaceMethods<IntegerOverflowFlagsInterface>], traits)> { !listconcat([DeclareOpInterfaceMethods<IntegerOverflowFlagsInterface>], traits)> {
dag iofArg = ( dag iofArg = (ins EnumProperty<"IntegerOverflowFlags">:$overflowFlags);
ins DefaultValuedAttr<LLVM_IntegerOverflowFlagsAttr, "{}">:$overflowFlags);
let arguments = !con(commonArgs, iofArg); let arguments = !con(commonArgs, iofArg);
let builders = [
OpBuilder<(ins "Type":$type, "Value":$lhs, "Value":$rhs,
"IntegerOverflowFlags":$overflowFlags), [{
build($_builder, $_state, type, lhs, rhs);
$_state.getOrAddProperties<Properties>().overflowFlags = overflowFlags;
}]>,
OpBuilder<(ins "Value":$lhs, "Value":$rhs,
"IntegerOverflowFlags":$overflowFlags), [{
build($_builder, $_state, lhs, rhs);
$_state.getOrAddProperties<Properties>().overflowFlags = overflowFlags;
}]>
];
string mlirBuilder = [{ string mlirBuilder = [{
auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs); auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
moduleImport.setIntegerOverflowFlagsAttr(inst, op); moduleImport.setIntegerOverflowFlags(inst, op);
$res = op; $res = op;
}]; }];
let assemblyFormat = [{ let assemblyFormat = [{
$lhs `,` $rhs (`overflow` `` $overflowFlags^)? $lhs `,` $rhs `` custom<OverflowFlags>($overflowFlags)
custom<LLVMOpAttrs>(attr-dict) `:` type($res) `` custom<LLVMOpAttrs>(attr-dict) `:` type($res)
}]; }];
string llvmBuilder = string llvmBuilder =
"$res = builder.Create" # instName # "$res = builder.Create" # instName #

View File

@ -183,8 +183,7 @@ public:
/// Sets the integer overflow flags (nsw/nuw) attribute for the imported /// Sets the integer overflow flags (nsw/nuw) attribute for the imported
/// operation `op` given the original instruction `inst`. Asserts if the /// operation `op` given the original instruction `inst`. Asserts if the
/// operation does not implement the integer overflow flag interface. /// operation does not implement the integer overflow flag interface.
void setIntegerOverflowFlagsAttr(llvm::Instruction *inst, void setIntegerOverflowFlags(llvm::Instruction *inst, Operation *op) const;
Operation *op) const;
/// Sets the fastmath flags attribute for the imported operation `op` given /// Sets the fastmath flags attribute for the imported operation `op` given
/// the original instruction `inst`. Asserts if the operation does not /// the original instruction `inst`. Asserts if the operation does not

View File

@ -49,13 +49,6 @@ LLVM::IntegerOverflowFlags mlir::arith::convertArithOverflowFlagsToLLVM(
return llvmFlags; return llvmFlags;
} }
LLVM::IntegerOverflowFlagsAttr mlir::arith::convertArithOverflowAttrToLLVM(
arith::IntegerOverflowFlagsAttr flagsAttr) {
arith::IntegerOverflowFlags arithFlags = flagsAttr.getValue();
return LLVM::IntegerOverflowFlagsAttr::get(
flagsAttr.getContext(), convertArithOverflowFlagsToLLVM(arithFlags));
}
LLVM::RoundingMode LLVM::RoundingMode
mlir::arith::convertArithRoundingModeToLLVM(arith::RoundingMode roundingMode) { mlir::arith::convertArithRoundingModeToLLVM(arith::RoundingMode roundingMode) {
switch (roundingMode) { switch (roundingMode) {

View File

@ -329,14 +329,19 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
// Detail methods // Detail methods
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void LLVM::detail::setNativeProperties(Operation *op,
IntegerOverflowFlags overflowFlags) {
if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op))
iface.setOverflowFlags(overflowFlags);
}
/// Replaces the given operation "op" with a new operation of type "targetOp" /// Replaces the given operation "op" with a new operation of type "targetOp"
/// and given operands. /// and given operands.
LogicalResult LogicalResult LLVM::detail::oneToOneRewrite(
LLVM::detail::oneToOneRewrite(Operation *op, StringRef targetOp, Operation *op, StringRef targetOp, ValueRange operands,
ValueRange operands, ArrayRef<NamedAttribute> targetAttrs,
ArrayRef<NamedAttribute> targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
const LLVMTypeConverter &typeConverter, IntegerOverflowFlags overflowFlags) {
ConversionPatternRewriter &rewriter) {
unsigned numResults = op->getNumResults(); unsigned numResults = op->getNumResults();
SmallVector<Type> resultTypes; SmallVector<Type> resultTypes;
@ -352,6 +357,8 @@ LLVM::detail::oneToOneRewrite(Operation *op, StringRef targetOp,
rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands, rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
resultTypes, targetAttrs); resultTypes, targetAttrs);
setNativeProperties(newOp, overflowFlags);
// If the operation produced 0 or 1 result, return them immediately. // If the operation produced 0 or 1 result, return them immediately.
if (numResults == 0) if (numResults == 0)
return rewriter.eraseOp(op), success(); return rewriter.eraseOp(op), success();

View File

@ -103,12 +103,11 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors(
return success(); return success();
} }
LogicalResult LogicalResult LLVM::detail::vectorOneToOneRewrite(
LLVM::detail::vectorOneToOneRewrite(Operation *op, StringRef targetOp, Operation *op, StringRef targetOp, ValueRange operands,
ValueRange operands, ArrayRef<NamedAttribute> targetAttrs,
ArrayRef<NamedAttribute> targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
const LLVMTypeConverter &typeConverter, IntegerOverflowFlags overflowFlags) {
ConversionPatternRewriter &rewriter) {
assert(!operands.empty()); assert(!operands.empty());
// Cannot convert ops if their operands are not of LLVM type. // Cannot convert ops if their operands are not of LLVM type.
@ -118,14 +117,15 @@ LLVM::detail::vectorOneToOneRewrite(Operation *op, StringRef targetOp,
auto llvmNDVectorTy = operands[0].getType(); auto llvmNDVectorTy = operands[0].getType();
if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy))
return oneToOneRewrite(op, targetOp, operands, targetAttrs, typeConverter, return oneToOneRewrite(op, targetOp, operands, targetAttrs, typeConverter,
rewriter); rewriter, overflowFlags);
auto callback = [op, targetOp, targetAttrs, &rewriter](Type llvm1DVectorTy, auto callback = [op, targetOp, targetAttrs, overflowFlags,
ValueRange operands) { &rewriter](Type llvm1DVectorTy, ValueRange operands) {
return rewriter Operation *newOp =
.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands, rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp),
llvm1DVectorTy, targetAttrs) operands, llvm1DVectorTy, targetAttrs);
->getResult(0); LLVM::detail::setNativeProperties(newOp, overflowFlags);
return newOp->getResult(0);
}; };
return handleMultidimensionalVectors(op, operands, typeConverter, callback, return handleMultidimensionalVectors(op, operands, typeConverter, callback,

View File

@ -47,6 +47,74 @@ using mlir::LLVM::linkage::getMaxEnumValForLinkage;
#include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc" #include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc"
//===----------------------------------------------------------------------===//
// Property Helpers
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// IntegerOverflowFlags
namespace mlir {
static Attribute convertToAttribute(MLIRContext *ctx,
IntegerOverflowFlags flags) {
return IntegerOverflowFlagsAttr::get(ctx, flags);
}
static LogicalResult
convertFromAttribute(IntegerOverflowFlags &flags, Attribute attr,
function_ref<InFlightDiagnostic()> emitError) {
auto flagsAttr = dyn_cast<IntegerOverflowFlagsAttr>(attr);
if (!flagsAttr) {
return emitError() << "expected 'overflowFlags' attribute to be an "
"IntegerOverflowFlagsAttr, but got "
<< attr;
}
flags = flagsAttr.getValue();
return success();
}
} // namespace mlir
static ParseResult parseOverflowFlags(AsmParser &p,
IntegerOverflowFlags &flags) {
if (failed(p.parseOptionalKeyword("overflow"))) {
flags = IntegerOverflowFlags::none;
return success();
}
if (p.parseLess())
return failure();
do {
StringRef kw;
SMLoc loc = p.getCurrentLocation();
if (p.parseKeyword(&kw))
return failure();
std::optional<IntegerOverflowFlags> flag =
symbolizeIntegerOverflowFlags(kw);
if (!flag)
return p.emitError(loc,
"invalid overflow flag: expected nsw, nuw, or none");
flags = flags | *flag;
} while (succeeded(p.parseOptionalComma()));
return p.parseGreater();
}
static void printOverflowFlags(AsmPrinter &p, Operation *op,
IntegerOverflowFlags flags) {
if (flags == IntegerOverflowFlags::none)
return;
p << " overflow<";
SmallVector<StringRef, 2> strs;
if (bitEnumContainsAny(flags, IntegerOverflowFlags::nsw))
strs.push_back("nsw");
if (bitEnumContainsAny(flags, IntegerOverflowFlags::nuw))
strs.push_back("nuw");
llvm::interleaveComma(strs, p);
p << ">";
}
//===----------------------------------------------------------------------===//
// Attribute Helpers
//===----------------------------------------------------------------------===//
static constexpr const char kElemTypeAttrName[] = "elem_type"; static constexpr const char kElemTypeAttrName[] = "elem_type";
static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) { static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) {
@ -70,12 +138,12 @@ static ParseResult parseLLVMOpAttrs(OpAsmParser &parser,
static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op, static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op,
DictionaryAttr attrs) { DictionaryAttr attrs) {
auto filteredAttrs = processFMFAttr(attrs.getValue()); auto filteredAttrs = processFMFAttr(attrs.getValue());
if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op)) if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op)) {
printer.printOptionalAttrDict( printer.printOptionalAttrDict(
filteredAttrs, filteredAttrs, /*elidedAttrs=*/{iface.getOverflowFlagsAttrName()});
/*elidedAttrs=*/{iface.getIntegerOverflowAttrName()}); } else {
else
printer.printOptionalAttrDict(filteredAttrs); printer.printOptionalAttrDict(filteredAttrs);
}
} }
/// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and /// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and

View File

@ -627,8 +627,8 @@ void ModuleImport::setNonDebugMetadataAttrs(llvm::Instruction *inst,
} }
} }
void ModuleImport::setIntegerOverflowFlagsAttr(llvm::Instruction *inst, void ModuleImport::setIntegerOverflowFlags(llvm::Instruction *inst,
Operation *op) const { Operation *op) const {
auto iface = cast<IntegerOverflowFlagsInterface>(op); auto iface = cast<IntegerOverflowFlagsInterface>(op);
IntegerOverflowFlags value = {}; IntegerOverflowFlags value = {};
@ -636,8 +636,7 @@ void ModuleImport::setIntegerOverflowFlagsAttr(llvm::Instruction *inst,
value = value =
bitEnumSet(value, IntegerOverflowFlags::nuw, inst->hasNoUnsignedWrap()); bitEnumSet(value, IntegerOverflowFlags::nuw, inst->hasNoUnsignedWrap());
auto attr = IntegerOverflowFlagsAttr::get(op->getContext(), value); iface.setOverflowFlags(value);
iface->setAttr(iface.getIntegerOverflowAttrName(), attr);
} }
void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst, void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst,