[mlir][llvm] Port overflowFlags
to a native operation property (RELAND) (#89410)
This PR changes the LLVM dialect's IntegerOverflowFlags to be stored on operations as native properties. Reland to fix flang
This commit is contained in:
parent
d86079f93c
commit
e553ac4d81
@ -2110,9 +2110,8 @@ struct XArrayCoorOpConversion
|
|||||||
const bool baseIsBoxed = coor.getMemref().getType().isa<fir::BaseBoxType>();
|
const bool baseIsBoxed = coor.getMemref().getType().isa<fir::BaseBoxType>();
|
||||||
TypePair baseBoxTyPair =
|
TypePair baseBoxTyPair =
|
||||||
baseIsBoxed ? getBoxTypePair(coor.getMemref().getType()) : TypePair{};
|
baseIsBoxed ? getBoxTypePair(coor.getMemref().getType()) : TypePair{};
|
||||||
mlir::LLVM::IntegerOverflowFlagsAttr nsw =
|
mlir::LLVM::IntegerOverflowFlags nsw =
|
||||||
mlir::LLVM::IntegerOverflowFlagsAttr::get(
|
mlir::LLVM::IntegerOverflowFlags::nsw;
|
||||||
rewriter.getContext(), mlir::LLVM::IntegerOverflowFlags::nsw);
|
|
||||||
|
|
||||||
// For each dimension of the array, generate the offset calculation.
|
// For each dimension of the array, generate the offset calculation.
|
||||||
for (unsigned i = 0; i < rank; ++i, ++indexOffset, ++shapeOffset,
|
for (unsigned i = 0; i < rank; ++i, ++indexOffset, ++shapeOffset,
|
||||||
@ -2396,9 +2395,8 @@ private:
|
|||||||
auto cpnTy = fir::dyn_cast_ptrOrBoxEleTy(boxObjTy);
|
auto cpnTy = fir::dyn_cast_ptrOrBoxEleTy(boxObjTy);
|
||||||
mlir::Type llvmPtrTy = ::getLlvmPtrType(coor.getContext());
|
mlir::Type llvmPtrTy = ::getLlvmPtrType(coor.getContext());
|
||||||
mlir::Type byteTy = ::getI8Type(coor.getContext());
|
mlir::Type byteTy = ::getI8Type(coor.getContext());
|
||||||
mlir::LLVM::IntegerOverflowFlagsAttr nsw =
|
mlir::LLVM::IntegerOverflowFlags nsw =
|
||||||
mlir::LLVM::IntegerOverflowFlagsAttr::get(
|
mlir::LLVM::IntegerOverflowFlags::nsw;
|
||||||
rewriter.getContext(), mlir::LLVM::IntegerOverflowFlags::nsw);
|
|
||||||
|
|
||||||
for (unsigned i = 1, last = operands.size(); i < last; ++i) {
|
for (unsigned i = 1, last = operands.size(); i < last; ++i) {
|
||||||
if (auto arrTy = cpnTy.dyn_cast<fir::SequenceType>()) {
|
if (auto arrTy = cpnTy.dyn_cast<fir::SequenceType>()) {
|
||||||
|
@ -31,11 +31,6 @@ convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr);
|
|||||||
LLVM::IntegerOverflowFlags
|
LLVM::IntegerOverflowFlags
|
||||||
convertArithOverflowFlagsToLLVM(arith::IntegerOverflowFlags arithFlags);
|
convertArithOverflowFlagsToLLVM(arith::IntegerOverflowFlags arithFlags);
|
||||||
|
|
||||||
/// Creates an LLVM overflow attribute from a given arithmetic overflow
|
|
||||||
/// attribute.
|
|
||||||
LLVM::IntegerOverflowFlagsAttr
|
|
||||||
convertArithOverflowAttrToLLVM(arith::IntegerOverflowFlagsAttr flagsAttr);
|
|
||||||
|
|
||||||
/// Creates an LLVM rounding mode enum value from a given arithmetic rounding
|
/// Creates an LLVM rounding mode enum value from a given arithmetic rounding
|
||||||
/// mode enum value.
|
/// mode enum value.
|
||||||
LLVM::RoundingMode
|
LLVM::RoundingMode
|
||||||
@ -72,6 +67,9 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
|
ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
|
||||||
|
LLVM::IntegerOverflowFlags getOverflowFlags() const {
|
||||||
|
return LLVM::IntegerOverflowFlags::none;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
NamedAttrList convertedAttr;
|
NamedAttrList convertedAttr;
|
||||||
@ -89,19 +87,18 @@ public:
|
|||||||
// Get the name of the arith overflow attribute.
|
// Get the name of the arith overflow attribute.
|
||||||
StringRef arithAttrName = SourceOp::getIntegerOverflowAttrName();
|
StringRef arithAttrName = SourceOp::getIntegerOverflowAttrName();
|
||||||
// Remove the source overflow attribute.
|
// Remove the source overflow attribute.
|
||||||
auto arithAttr = dyn_cast_if_present<arith::IntegerOverflowFlagsAttr>(
|
if (auto arithAttr = dyn_cast_if_present<arith::IntegerOverflowFlagsAttr>(
|
||||||
convertedAttr.erase(arithAttrName));
|
convertedAttr.erase(arithAttrName))) {
|
||||||
if (arithAttr) {
|
overflowFlags = convertArithOverflowFlagsToLLVM(arithAttr.getValue());
|
||||||
StringRef targetAttrName = TargetOp::getIntegerOverflowAttrName();
|
|
||||||
convertedAttr.set(targetAttrName,
|
|
||||||
convertArithOverflowAttrToLLVM(arithAttr));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
|
ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
|
||||||
|
LLVM::IntegerOverflowFlags getOverflowFlags() const { return overflowFlags; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
NamedAttrList convertedAttr;
|
NamedAttrList convertedAttr;
|
||||||
|
LLVM::IntegerOverflowFlags overflowFlags = LLVM::IntegerOverflowFlags::none;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename SourceOp, typename TargetOp>
|
template <typename SourceOp, typename TargetOp>
|
||||||
@ -132,6 +129,9 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
|
ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
|
||||||
|
LLVM::IntegerOverflowFlags getOverflowFlags() const {
|
||||||
|
return LLVM::IntegerOverflowFlags::none;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
NamedAttrList convertedAttr;
|
NamedAttrList convertedAttr;
|
||||||
|
@ -11,6 +11,7 @@
|
|||||||
|
|
||||||
#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
|
#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
|
||||||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
||||||
|
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
@ -18,13 +19,16 @@ class CallOpInterface;
|
|||||||
|
|
||||||
namespace LLVM {
|
namespace LLVM {
|
||||||
namespace detail {
|
namespace detail {
|
||||||
|
/// Handle generically setting flags as native properties on LLVM operations.
|
||||||
|
void setNativeProperties(Operation *op, IntegerOverflowFlags overflowFlags);
|
||||||
|
|
||||||
/// Replaces the given operation "op" with a new operation of type "targetOp"
|
/// Replaces the given operation "op" with a new operation of type "targetOp"
|
||||||
/// and given operands.
|
/// and given operands.
|
||||||
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp,
|
LogicalResult oneToOneRewrite(
|
||||||
ValueRange operands,
|
Operation *op, StringRef targetOp, ValueRange operands,
|
||||||
ArrayRef<NamedAttribute> targetAttrs,
|
ArrayRef<NamedAttribute> targetAttrs,
|
||||||
const LLVMTypeConverter &typeConverter,
|
const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
|
||||||
ConversionPatternRewriter &rewriter);
|
IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
|
||||||
|
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
} // namespace LLVM
|
} // namespace LLVM
|
||||||
|
@ -54,11 +54,11 @@ LogicalResult handleMultidimensionalVectors(
|
|||||||
std::function<Value(Type, ValueRange)> createOperand,
|
std::function<Value(Type, ValueRange)> createOperand,
|
||||||
ConversionPatternRewriter &rewriter);
|
ConversionPatternRewriter &rewriter);
|
||||||
|
|
||||||
LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp,
|
LogicalResult vectorOneToOneRewrite(
|
||||||
ValueRange operands,
|
Operation *op, StringRef targetOp, ValueRange operands,
|
||||||
ArrayRef<NamedAttribute> targetAttrs,
|
ArrayRef<NamedAttribute> targetAttrs,
|
||||||
const LLVMTypeConverter &typeConverter,
|
const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
|
||||||
ConversionPatternRewriter &rewriter);
|
IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
} // namespace LLVM
|
} // namespace LLVM
|
||||||
|
|
||||||
@ -70,6 +70,9 @@ public:
|
|||||||
AttrConvertPassThrough(SourceOp srcOp) : srcAttrs(srcOp->getAttrs()) {}
|
AttrConvertPassThrough(SourceOp srcOp) : srcAttrs(srcOp->getAttrs()) {}
|
||||||
|
|
||||||
ArrayRef<NamedAttribute> getAttrs() const { return srcAttrs; }
|
ArrayRef<NamedAttribute> getAttrs() const { return srcAttrs; }
|
||||||
|
LLVM::IntegerOverflowFlags getOverflowFlags() const {
|
||||||
|
return LLVM::IntegerOverflowFlags::none;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ArrayRef<NamedAttribute> srcAttrs;
|
ArrayRef<NamedAttribute> srcAttrs;
|
||||||
@ -100,7 +103,8 @@ public:
|
|||||||
|
|
||||||
return LLVM::detail::vectorOneToOneRewrite(
|
return LLVM::detail::vectorOneToOneRewrite(
|
||||||
op, TargetOp::getOperationName(), adaptor.getOperands(),
|
op, TargetOp::getOperationName(), adaptor.getOperands(),
|
||||||
attrConvert.getAttrs(), *this->getTypeConverter(), rewriter);
|
attrConvert.getAttrs(), *this->getTypeConverter(), rewriter,
|
||||||
|
attrConvert.getOverflowFlags());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
@ -50,58 +50,40 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
|
|||||||
|
|
||||||
def IntegerOverflowFlagsInterface : OpInterface<"IntegerOverflowFlagsInterface"> {
|
def IntegerOverflowFlagsInterface : OpInterface<"IntegerOverflowFlagsInterface"> {
|
||||||
let description = [{
|
let description = [{
|
||||||
Access to op integer overflow flags.
|
This interface defines an LLVM operation with integer overflow flags and
|
||||||
|
provides a uniform API for accessing them.
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let cppNamespace = "::mlir::LLVM";
|
let cppNamespace = "::mlir::LLVM";
|
||||||
|
|
||||||
let methods = [
|
let methods = [
|
||||||
InterfaceMethod<
|
InterfaceMethod<[{
|
||||||
/*desc=*/ "Returns an IntegerOverflowFlagsAttr attribute for the operation",
|
Get the integer overflow flags for the operation.
|
||||||
/*returnType=*/ "IntegerOverflowFlagsAttr",
|
}], "IntegerOverflowFlags", "getOverflowFlags", (ins), [{}], [{
|
||||||
/*methodName=*/ "getOverflowAttr",
|
return $_op.getProperties().overflowFlags;
|
||||||
/*args=*/ (ins),
|
}]>,
|
||||||
/*methodBody=*/ [{}],
|
InterfaceMethod<[{
|
||||||
/*defaultImpl=*/ [{
|
Set the integer overflow flags for the operation.
|
||||||
auto op = cast<ConcreteOp>(this->getOperation());
|
}], "void", "setOverflowFlags", (ins "IntegerOverflowFlags":$flags), [{}], [{
|
||||||
return op.getOverflowFlagsAttr();
|
$_op.getProperties().overflowFlags = flags;
|
||||||
}]
|
}]>,
|
||||||
>,
|
InterfaceMethod<[{
|
||||||
InterfaceMethod<
|
Returns whether the operation has the No Unsigned Wrap keyword.
|
||||||
/*desc=*/ "Returns whether the operation has the No Unsigned Wrap keyword",
|
}], "bool", "hasNoUnsignedWrap", (ins), [{}], [{
|
||||||
/*returnType=*/ "bool",
|
return bitEnumContainsAll($_op.getOverflowFlags(),
|
||||||
/*methodName=*/ "hasNoUnsignedWrap",
|
IntegerOverflowFlags::nuw);
|
||||||
/*args=*/ (ins),
|
}]>,
|
||||||
/*methodBody=*/ [{}],
|
InterfaceMethod<[{
|
||||||
/*defaultImpl=*/ [{
|
Returns whether the operation has the No Signed Wrap keyword.
|
||||||
auto op = cast<ConcreteOp>(this->getOperation());
|
}], "bool", "hasNoSignedWrap", (ins), [{}], [{
|
||||||
IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
|
return bitEnumContainsAll($_op.getOverflowFlags(),
|
||||||
return bitEnumContainsAll(flags, IntegerOverflowFlags::nuw);
|
IntegerOverflowFlags::nsw);
|
||||||
}]
|
}]>,
|
||||||
>,
|
StaticInterfaceMethod<[{
|
||||||
InterfaceMethod<
|
Get the attribute name of the overflow flags property.
|
||||||
/*desc=*/ "Returns whether the operation has the No Signed Wrap keyword",
|
}], "StringRef", "getOverflowFlagsAttrName", (ins), [{}], [{
|
||||||
/*returnType=*/ "bool",
|
return "overflowFlags";
|
||||||
/*methodName=*/ "hasNoSignedWrap",
|
}]>,
|
||||||
/*args=*/ (ins),
|
|
||||||
/*methodBody=*/ [{}],
|
|
||||||
/*defaultImpl=*/ [{
|
|
||||||
auto op = cast<ConcreteOp>(this->getOperation());
|
|
||||||
IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
|
|
||||||
return bitEnumContainsAll(flags, IntegerOverflowFlags::nsw);
|
|
||||||
}]
|
|
||||||
>,
|
|
||||||
StaticInterfaceMethod<
|
|
||||||
/*desc=*/ [{Returns the name of the IntegerOverflowFlagsAttr attribute
|
|
||||||
for the operation}],
|
|
||||||
/*returnType=*/ "StringRef",
|
|
||||||
/*methodName=*/ "getIntegerOverflowAttrName",
|
|
||||||
/*args=*/ (ins),
|
|
||||||
/*methodBody=*/ [{}],
|
|
||||||
/*defaultImpl=*/ [{
|
|
||||||
return "overflowFlags";
|
|
||||||
}]
|
|
||||||
>
|
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -59,17 +59,30 @@ class LLVM_IntArithmeticOpWithOverflowFlag<string mnemonic, string instName,
|
|||||||
list<Trait> traits = []> :
|
list<Trait> traits = []> :
|
||||||
LLVM_ArithmeticOpBase<AnySignlessInteger, mnemonic, instName,
|
LLVM_ArithmeticOpBase<AnySignlessInteger, mnemonic, instName,
|
||||||
!listconcat([DeclareOpInterfaceMethods<IntegerOverflowFlagsInterface>], traits)> {
|
!listconcat([DeclareOpInterfaceMethods<IntegerOverflowFlagsInterface>], traits)> {
|
||||||
dag iofArg = (
|
dag iofArg = (ins EnumProperty<"IntegerOverflowFlags">:$overflowFlags);
|
||||||
ins DefaultValuedAttr<LLVM_IntegerOverflowFlagsAttr, "{}">:$overflowFlags);
|
|
||||||
let arguments = !con(commonArgs, iofArg);
|
let arguments = !con(commonArgs, iofArg);
|
||||||
|
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<(ins "Type":$type, "Value":$lhs, "Value":$rhs,
|
||||||
|
"IntegerOverflowFlags":$overflowFlags), [{
|
||||||
|
build($_builder, $_state, type, lhs, rhs);
|
||||||
|
$_state.getOrAddProperties<Properties>().overflowFlags = overflowFlags;
|
||||||
|
}]>,
|
||||||
|
OpBuilder<(ins "Value":$lhs, "Value":$rhs,
|
||||||
|
"IntegerOverflowFlags":$overflowFlags), [{
|
||||||
|
build($_builder, $_state, lhs, rhs);
|
||||||
|
$_state.getOrAddProperties<Properties>().overflowFlags = overflowFlags;
|
||||||
|
}]>
|
||||||
|
];
|
||||||
|
|
||||||
string mlirBuilder = [{
|
string mlirBuilder = [{
|
||||||
auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
|
auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
|
||||||
moduleImport.setIntegerOverflowFlagsAttr(inst, op);
|
moduleImport.setIntegerOverflowFlags(inst, op);
|
||||||
$res = op;
|
$res = op;
|
||||||
}];
|
}];
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
$lhs `,` $rhs (`overflow` `` $overflowFlags^)?
|
$lhs `,` $rhs `` custom<OverflowFlags>($overflowFlags)
|
||||||
custom<LLVMOpAttrs>(attr-dict) `:` type($res)
|
`` custom<LLVMOpAttrs>(attr-dict) `:` type($res)
|
||||||
}];
|
}];
|
||||||
string llvmBuilder =
|
string llvmBuilder =
|
||||||
"$res = builder.Create" # instName #
|
"$res = builder.Create" # instName #
|
||||||
|
@ -183,8 +183,7 @@ public:
|
|||||||
/// Sets the integer overflow flags (nsw/nuw) attribute for the imported
|
/// Sets the integer overflow flags (nsw/nuw) attribute for the imported
|
||||||
/// operation `op` given the original instruction `inst`. Asserts if the
|
/// operation `op` given the original instruction `inst`. Asserts if the
|
||||||
/// operation does not implement the integer overflow flag interface.
|
/// operation does not implement the integer overflow flag interface.
|
||||||
void setIntegerOverflowFlagsAttr(llvm::Instruction *inst,
|
void setIntegerOverflowFlags(llvm::Instruction *inst, Operation *op) const;
|
||||||
Operation *op) const;
|
|
||||||
|
|
||||||
/// Sets the fastmath flags attribute for the imported operation `op` given
|
/// Sets the fastmath flags attribute for the imported operation `op` given
|
||||||
/// the original instruction `inst`. Asserts if the operation does not
|
/// the original instruction `inst`. Asserts if the operation does not
|
||||||
|
@ -49,13 +49,6 @@ LLVM::IntegerOverflowFlags mlir::arith::convertArithOverflowFlagsToLLVM(
|
|||||||
return llvmFlags;
|
return llvmFlags;
|
||||||
}
|
}
|
||||||
|
|
||||||
LLVM::IntegerOverflowFlagsAttr mlir::arith::convertArithOverflowAttrToLLVM(
|
|
||||||
arith::IntegerOverflowFlagsAttr flagsAttr) {
|
|
||||||
arith::IntegerOverflowFlags arithFlags = flagsAttr.getValue();
|
|
||||||
return LLVM::IntegerOverflowFlagsAttr::get(
|
|
||||||
flagsAttr.getContext(), convertArithOverflowFlagsToLLVM(arithFlags));
|
|
||||||
}
|
|
||||||
|
|
||||||
LLVM::RoundingMode
|
LLVM::RoundingMode
|
||||||
mlir::arith::convertArithRoundingModeToLLVM(arith::RoundingMode roundingMode) {
|
mlir::arith::convertArithRoundingModeToLLVM(arith::RoundingMode roundingMode) {
|
||||||
switch (roundingMode) {
|
switch (roundingMode) {
|
||||||
|
@ -329,14 +329,19 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
|
|||||||
// Detail methods
|
// Detail methods
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void LLVM::detail::setNativeProperties(Operation *op,
|
||||||
|
IntegerOverflowFlags overflowFlags) {
|
||||||
|
if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op))
|
||||||
|
iface.setOverflowFlags(overflowFlags);
|
||||||
|
}
|
||||||
|
|
||||||
/// Replaces the given operation "op" with a new operation of type "targetOp"
|
/// Replaces the given operation "op" with a new operation of type "targetOp"
|
||||||
/// and given operands.
|
/// and given operands.
|
||||||
LogicalResult
|
LogicalResult LLVM::detail::oneToOneRewrite(
|
||||||
LLVM::detail::oneToOneRewrite(Operation *op, StringRef targetOp,
|
Operation *op, StringRef targetOp, ValueRange operands,
|
||||||
ValueRange operands,
|
ArrayRef<NamedAttribute> targetAttrs,
|
||||||
ArrayRef<NamedAttribute> targetAttrs,
|
const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
|
||||||
const LLVMTypeConverter &typeConverter,
|
IntegerOverflowFlags overflowFlags) {
|
||||||
ConversionPatternRewriter &rewriter) {
|
|
||||||
unsigned numResults = op->getNumResults();
|
unsigned numResults = op->getNumResults();
|
||||||
|
|
||||||
SmallVector<Type> resultTypes;
|
SmallVector<Type> resultTypes;
|
||||||
@ -352,6 +357,8 @@ LLVM::detail::oneToOneRewrite(Operation *op, StringRef targetOp,
|
|||||||
rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
|
rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
|
||||||
resultTypes, targetAttrs);
|
resultTypes, targetAttrs);
|
||||||
|
|
||||||
|
setNativeProperties(newOp, overflowFlags);
|
||||||
|
|
||||||
// If the operation produced 0 or 1 result, return them immediately.
|
// If the operation produced 0 or 1 result, return them immediately.
|
||||||
if (numResults == 0)
|
if (numResults == 0)
|
||||||
return rewriter.eraseOp(op), success();
|
return rewriter.eraseOp(op), success();
|
||||||
|
@ -103,12 +103,11 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors(
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult LLVM::detail::vectorOneToOneRewrite(
|
||||||
LLVM::detail::vectorOneToOneRewrite(Operation *op, StringRef targetOp,
|
Operation *op, StringRef targetOp, ValueRange operands,
|
||||||
ValueRange operands,
|
ArrayRef<NamedAttribute> targetAttrs,
|
||||||
ArrayRef<NamedAttribute> targetAttrs,
|
const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
|
||||||
const LLVMTypeConverter &typeConverter,
|
IntegerOverflowFlags overflowFlags) {
|
||||||
ConversionPatternRewriter &rewriter) {
|
|
||||||
assert(!operands.empty());
|
assert(!operands.empty());
|
||||||
|
|
||||||
// Cannot convert ops if their operands are not of LLVM type.
|
// Cannot convert ops if their operands are not of LLVM type.
|
||||||
@ -118,14 +117,15 @@ LLVM::detail::vectorOneToOneRewrite(Operation *op, StringRef targetOp,
|
|||||||
auto llvmNDVectorTy = operands[0].getType();
|
auto llvmNDVectorTy = operands[0].getType();
|
||||||
if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy))
|
if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy))
|
||||||
return oneToOneRewrite(op, targetOp, operands, targetAttrs, typeConverter,
|
return oneToOneRewrite(op, targetOp, operands, targetAttrs, typeConverter,
|
||||||
rewriter);
|
rewriter, overflowFlags);
|
||||||
|
|
||||||
auto callback = [op, targetOp, targetAttrs, &rewriter](Type llvm1DVectorTy,
|
auto callback = [op, targetOp, targetAttrs, overflowFlags,
|
||||||
ValueRange operands) {
|
&rewriter](Type llvm1DVectorTy, ValueRange operands) {
|
||||||
return rewriter
|
Operation *newOp =
|
||||||
.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
|
rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp),
|
||||||
llvm1DVectorTy, targetAttrs)
|
operands, llvm1DVectorTy, targetAttrs);
|
||||||
->getResult(0);
|
LLVM::detail::setNativeProperties(newOp, overflowFlags);
|
||||||
|
return newOp->getResult(0);
|
||||||
};
|
};
|
||||||
|
|
||||||
return handleMultidimensionalVectors(op, operands, typeConverter, callback,
|
return handleMultidimensionalVectors(op, operands, typeConverter, callback,
|
||||||
|
@ -47,6 +47,74 @@ using mlir::LLVM::linkage::getMaxEnumValForLinkage;
|
|||||||
|
|
||||||
#include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc"
|
#include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc"
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Property Helpers
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// IntegerOverflowFlags
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
static Attribute convertToAttribute(MLIRContext *ctx,
|
||||||
|
IntegerOverflowFlags flags) {
|
||||||
|
return IntegerOverflowFlagsAttr::get(ctx, flags);
|
||||||
|
}
|
||||||
|
|
||||||
|
static LogicalResult
|
||||||
|
convertFromAttribute(IntegerOverflowFlags &flags, Attribute attr,
|
||||||
|
function_ref<InFlightDiagnostic()> emitError) {
|
||||||
|
auto flagsAttr = dyn_cast<IntegerOverflowFlagsAttr>(attr);
|
||||||
|
if (!flagsAttr) {
|
||||||
|
return emitError() << "expected 'overflowFlags' attribute to be an "
|
||||||
|
"IntegerOverflowFlagsAttr, but got "
|
||||||
|
<< attr;
|
||||||
|
}
|
||||||
|
flags = flagsAttr.getValue();
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
static ParseResult parseOverflowFlags(AsmParser &p,
|
||||||
|
IntegerOverflowFlags &flags) {
|
||||||
|
if (failed(p.parseOptionalKeyword("overflow"))) {
|
||||||
|
flags = IntegerOverflowFlags::none;
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
if (p.parseLess())
|
||||||
|
return failure();
|
||||||
|
do {
|
||||||
|
StringRef kw;
|
||||||
|
SMLoc loc = p.getCurrentLocation();
|
||||||
|
if (p.parseKeyword(&kw))
|
||||||
|
return failure();
|
||||||
|
std::optional<IntegerOverflowFlags> flag =
|
||||||
|
symbolizeIntegerOverflowFlags(kw);
|
||||||
|
if (!flag)
|
||||||
|
return p.emitError(loc,
|
||||||
|
"invalid overflow flag: expected nsw, nuw, or none");
|
||||||
|
flags = flags | *flag;
|
||||||
|
} while (succeeded(p.parseOptionalComma()));
|
||||||
|
return p.parseGreater();
|
||||||
|
}
|
||||||
|
|
||||||
|
static void printOverflowFlags(AsmPrinter &p, Operation *op,
|
||||||
|
IntegerOverflowFlags flags) {
|
||||||
|
if (flags == IntegerOverflowFlags::none)
|
||||||
|
return;
|
||||||
|
p << " overflow<";
|
||||||
|
SmallVector<StringRef, 2> strs;
|
||||||
|
if (bitEnumContainsAny(flags, IntegerOverflowFlags::nsw))
|
||||||
|
strs.push_back("nsw");
|
||||||
|
if (bitEnumContainsAny(flags, IntegerOverflowFlags::nuw))
|
||||||
|
strs.push_back("nuw");
|
||||||
|
llvm::interleaveComma(strs, p);
|
||||||
|
p << ">";
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Attribute Helpers
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
static constexpr const char kElemTypeAttrName[] = "elem_type";
|
static constexpr const char kElemTypeAttrName[] = "elem_type";
|
||||||
|
|
||||||
static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) {
|
static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) {
|
||||||
@ -70,12 +138,12 @@ static ParseResult parseLLVMOpAttrs(OpAsmParser &parser,
|
|||||||
static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op,
|
static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op,
|
||||||
DictionaryAttr attrs) {
|
DictionaryAttr attrs) {
|
||||||
auto filteredAttrs = processFMFAttr(attrs.getValue());
|
auto filteredAttrs = processFMFAttr(attrs.getValue());
|
||||||
if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op))
|
if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op)) {
|
||||||
printer.printOptionalAttrDict(
|
printer.printOptionalAttrDict(
|
||||||
filteredAttrs,
|
filteredAttrs, /*elidedAttrs=*/{iface.getOverflowFlagsAttrName()});
|
||||||
/*elidedAttrs=*/{iface.getIntegerOverflowAttrName()});
|
} else {
|
||||||
else
|
|
||||||
printer.printOptionalAttrDict(filteredAttrs);
|
printer.printOptionalAttrDict(filteredAttrs);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and
|
/// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and
|
||||||
|
@ -627,8 +627,8 @@ void ModuleImport::setNonDebugMetadataAttrs(llvm::Instruction *inst,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ModuleImport::setIntegerOverflowFlagsAttr(llvm::Instruction *inst,
|
void ModuleImport::setIntegerOverflowFlags(llvm::Instruction *inst,
|
||||||
Operation *op) const {
|
Operation *op) const {
|
||||||
auto iface = cast<IntegerOverflowFlagsInterface>(op);
|
auto iface = cast<IntegerOverflowFlagsInterface>(op);
|
||||||
|
|
||||||
IntegerOverflowFlags value = {};
|
IntegerOverflowFlags value = {};
|
||||||
@ -636,8 +636,7 @@ void ModuleImport::setIntegerOverflowFlagsAttr(llvm::Instruction *inst,
|
|||||||
value =
|
value =
|
||||||
bitEnumSet(value, IntegerOverflowFlags::nuw, inst->hasNoUnsignedWrap());
|
bitEnumSet(value, IntegerOverflowFlags::nuw, inst->hasNoUnsignedWrap());
|
||||||
|
|
||||||
auto attr = IntegerOverflowFlagsAttr::get(op->getContext(), value);
|
iface.setOverflowFlags(value);
|
||||||
iface->setAttr(iface.getIntegerOverflowAttrName(), attr);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst,
|
void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user