From ae40d625410036d65cfe09f2122b81450f62ea99 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Tue, 9 Nov 2021 00:05:55 +0000 Subject: [PATCH] [mlir] Refactor ElementsAttr's value access API There are several aspects of the API that either aren't easy to use, or are deceptively easy to do the wrong thing. The main change of this commit is to remove all of the `getValue`/`getFlatValue` from ElementsAttr and instead provide operator[] methods on the ranges returned by `getValues`. This provides a much more convenient API for the value ranges. It also removes the easy-to-be-inefficient nature of getValue/getFlatValue, which under the hood would construct a new range for the type `T`. Constructing a range is not necessarily cheap in all cases, and could lead to very poor performance if used within a loop; i.e. if you were to naively write something like: ``` DenseElementsAttr attr = ...; for (int i = 0; i < size; ++i) { // We are internally rebuilding the APFloat value range on each iteration!! APFloat it = attr.getFlatValue(i); } ``` Differential Revision: https://reviews.llvm.org/D113229 --- .../mlir/IR/BuiltinAttributeInterfaces.h | 37 ++++++++ .../mlir/IR/BuiltinAttributeInterfaces.td | 76 +++++++-------- mlir/include/mlir/IR/BuiltinAttributes.h | 92 ++++++++----------- mlir/include/mlir/IR/BuiltinAttributes.td | 15 +-- mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 28 +++--- mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 2 +- .../Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 2 +- mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 21 ++--- mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp | 4 +- .../Linalg/Transforms/ElementwiseOpFusion.cpp | 42 +++------ .../Dialect/Linalg/Transforms/Transforms.cpp | 4 +- .../SPIRV/IR/SPIRVCanonicalization.cpp | 2 +- mlir/lib/Dialect/Shape/IR/Shape.cpp | 2 +- mlir/lib/Dialect/StandardOps/IR/Ops.cpp | 66 +++++++------ mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 6 +- mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 6 +- mlir/lib/IR/BuiltinAttributeInterfaces.cpp | 12 +-- mlir/lib/IR/BuiltinAttributes.cpp | 61 ++---------- mlir/lib/Interfaces/InferTypeOpInterface.cpp | 2 +- .../LLVMIR/LLVMToLLVMIRTranslation.cpp | 10 +- .../Target/SPIRV/Serialization/Serializer.cpp | 7 +- .../test-linalg-ods-yaml-gen.yaml | 4 +- .../mlir-linalg-ods-yaml-gen.cpp | 38 +++----- .../Dialect/Quant/QuantizationUtilsTest.cpp | 9 +- mlir/unittests/IR/AttributeTest.cpp | 8 +- 25 files changed, 241 insertions(+), 315 deletions(-) diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h index c48a359383ff..2ed1c84ee537 100644 --- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h +++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h @@ -227,6 +227,33 @@ private: ElementsAttrIndexer indexer; ptrdiff_t index; }; + +/// This class provides iterator utilities for an ElementsAttr range. +template +class ElementsAttrRange : public llvm::iterator_range { +public: + using reference = typename IteratorT::reference; + + ElementsAttrRange(Type shapeType, + const llvm::iterator_range &range) + : llvm::iterator_range(range), shapeType(shapeType) {} + ElementsAttrRange(Type shapeType, IteratorT beginIt, IteratorT endIt) + : ElementsAttrRange(shapeType, llvm::make_range(beginIt, endIt)) {} + + /// Return the value at the given index. + reference operator[](ArrayRef index) const; + reference operator[](uint64_t index) const { + return *std::next(this->begin(), index); + } + + /// Return the size of this range. + size_t size() const { return llvm::size(*this); } + +private: + /// The shaped type of the parent ElementsAttr. + Type shapeType; +}; + } // namespace detail //===----------------------------------------------------------------------===// @@ -256,6 +283,16 @@ verifyAffineMapAsLayout(AffineMap m, ArrayRef shape, //===----------------------------------------------------------------------===// namespace mlir { +namespace detail { +/// Return the value at the given index. +template +auto ElementsAttrRange::operator[](ArrayRef index) const + -> reference { + // Skip to the element corresponding to the flattened index. + return (*this)[ElementsAttr::getFlattenedIndex(shapeType, index)]; +} +} // namespace detail + /// Return the elements of this attribute as a value of type 'T'. template auto ElementsAttr::value_begin() const -> DefaultValueCheckT> { diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td index 30b3ea7ca09a..45295e874f3b 100644 --- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td +++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td @@ -158,27 +158,6 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> { ]; string ElementsAttrInterfaceAccessors = [{ - /// Return the attribute value at the given index. The index is expected to - /// refer to a valid element. - Attribute getValue(ArrayRef index) const { - return getValue(index); - } - - /// Return the value of type 'T' at the given index, where 'T' corresponds - /// to an Attribute type. - template - std::enable_if_t::value && - std::is_base_of::value> - getValue(ArrayRef index) const { - return getValue(index).template dyn_cast_or_null(); - } - - /// Return the value of type 'T' at the given index. - template - T getValue(ArrayRef index) const { - return getFlatValue(getFlattenedIndex(index)); - } - /// Return the number of elements held by this attribute. int64_t size() const { return getNumElements(); } @@ -281,6 +260,14 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> { // Value Iteration //===------------------------------------------------------------------===// + /// The iterator for the given element type T. + template + using iterator = decltype(std::declval().template value_begin()); + /// The iterator range over the given element T. + template + using iterator_range = + decltype(std::declval().template getValues()); + /// Return an iterator to the first element of this attribute as a value of /// type `T`. template @@ -292,11 +279,8 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> { template auto getValues() const { auto beginIt = $_attr.template value_begin(); - return llvm::make_range(beginIt, std::next(beginIt, size())); - } - /// Return the value at the given flattened index. - template T getFlatValue(uint64_t index) const { - return *std::next($_attr.template value_begin(), index); + return detail::ElementsAttrRange( + Attribute($_attr).getType(), beginIt, std::next(beginIt, size())); } }] # ElementsAttrInterfaceAccessors; @@ -304,7 +288,7 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> { template using iterator = detail::ElementsAttrIterator; template - using iterator_range = llvm::iterator_range>; + using iterator_range = detail::ElementsAttrRange>; //===------------------------------------------------------------------===// // Accessors @@ -329,8 +313,12 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> { uint64_t getFlattenedIndex(ArrayRef index) const { return getFlattenedIndex(*this, index); } - static uint64_t getFlattenedIndex(Attribute elementsAttr, + static uint64_t getFlattenedIndex(Type type, ArrayRef index); + static uint64_t getFlattenedIndex(Attribute elementsAttr, + ArrayRef index) { + return getFlattenedIndex(elementsAttr.getType(), index); + } /// Returns the number of elements held by this attribute. int64_t getNumElements() const { return getNumElements(*this); } @@ -350,13 +338,6 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> { !std::is_base_of::value, ResultT>; - /// Return the element of this attribute at the given index as a value of - /// type 'T'. - template - T getFlatValue(uint64_t index) const { - return *std::next(value_begin(), index); - } - /// Return the splat value for this attribute. This asserts that the /// attribute corresponds to a splat. template @@ -368,7 +349,7 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> { /// Return the elements of this attribute as a value of type 'T'. template DefaultValueCheckT> getValues() const { - return iterator_range(value_begin(), value_end()); + return {Attribute::getType(), value_begin(), value_end()}; } template DefaultValueCheckT> value_begin() const; @@ -384,12 +365,12 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> { llvm::mapped_iterator, T (*)(Attribute)>; template using DerivedAttrValueIteratorRange = - llvm::iterator_range>; + detail::ElementsAttrRange>; template > DerivedAttrValueIteratorRange getValues() const { auto castFn = [](Attribute attr) { return attr.template cast(); }; - return llvm::map_range(getValues(), - static_cast(castFn)); + return {Attribute::getType(), llvm::map_range(getValues(), + static_cast(castFn))}; } template > DerivedAttrValueIterator value_begin() const { @@ -407,8 +388,10 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> { /// return the iterable range. Otherwise, return llvm::None. template DefaultValueCheckT>> tryGetValues() const { - if (Optional> beginIt = try_value_begin()) - return iterator_range(*beginIt, value_end()); + if (Optional> beginIt = try_value_begin()) { + return iterator_range(Attribute::getType(), *beginIt, + value_end()); + } return llvm::None; } template @@ -418,10 +401,15 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> { /// return the iterable range. Otherwise, return llvm::None. template > Optional> tryGetValues() const { + auto values = tryGetValues(); + if (!values) + return llvm::None; + auto castFn = [](Attribute attr) { return attr.template cast(); }; - if (auto values = tryGetValues()) - return llvm::map_range(*values, static_cast(castFn)); - return llvm::None; + return DerivedAttrValueIteratorRange( + Attribute::getType(), + llvm::map_range(*values, static_cast(castFn)) + ); } template > Optional> try_value_begin() const { diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h index cc1d1d74615e..37da2eb9150b 100644 --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -61,8 +61,7 @@ protected: }; /// Type trait detector that checks if a given type T is a complex type. -template -struct is_complex_t : public std::false_type {}; +template struct is_complex_t : public std::false_type {}; template struct is_complex_t> : public std::true_type {}; } // namespace detail @@ -82,8 +81,7 @@ public: /// floating point type that can be used to access the underlying element /// types of a DenseElementsAttr. // TODO: Use std::disjunction when C++17 is supported. - template - struct is_valid_cpp_fp_type { + template struct is_valid_cpp_fp_type { /// The type is a valid floating point type if it is a builtin floating /// point type, or is a potentially user defined floating point type. The /// latter allows for supporting users that have custom types defined for @@ -219,6 +217,18 @@ public: // Iterators //===--------------------------------------------------------------------===// + /// The iterator range over the given iterator type T. + template + using iterator_range_impl = detail::ElementsAttrRange; + + /// The iterator for the given element type T. + template + using iterator = decltype(std::declval().template value_begin()); + /// The iterator range over the given element T. + template + using iterator_range = + decltype(std::declval().template getValues()); + /// A utility iterator that allows walking over the internal Attribute values /// of a DenseElementsAttr. class AttributeElementIterator @@ -358,22 +368,7 @@ public: !std::is_same::value, T>::type getSplatValue() const { - return getSplatValue().template cast(); - } - - /// Return the value at the given index. The 'index' is expected to refer to a - /// valid element. - Attribute getValue(ArrayRef index) const { - return getValue(index); - } - template - T getValue(ArrayRef index) const { - // Skip to the element corresponding to the flattened index. - return getFlatValue(ElementsAttr::getFlattenedIndex(*this, index)); - } - /// Return the value at the given flattened index. - template T getFlatValue(uint64_t index) const { - return *std::next(value_begin(), index); + return getSplatValue().template cast(); } /// Return the held element values as a range of integer or floating-point @@ -384,12 +379,12 @@ public: std::numeric_limits::is_integer) || is_valid_cpp_fp_type::value>::type; template > - llvm::iterator_range> getValues() const { + iterator_range_impl> getValues() const { assert(isValidIntOrFloat(sizeof(T), std::numeric_limits::is_integer, std::numeric_limits::is_signed)); const char *rawData = getRawData().data(); bool splat = isSplat(); - return {ElementIterator(rawData, splat, 0), + return {Attribute::getType(), ElementIterator(rawData, splat, 0), ElementIterator(rawData, splat, getNumElements())}; } template > @@ -413,12 +408,12 @@ public: is_valid_cpp_fp_type::value)>::type; template > - llvm::iterator_range> getValues() const { + iterator_range_impl> getValues() const { assert(isValidComplex(sizeof(T), std::numeric_limits::is_integer, std::numeric_limits::is_signed)); const char *rawData = getRawData().data(); bool splat = isSplat(); - return {ElementIterator(rawData, splat, 0), + return {Attribute::getType(), ElementIterator(rawData, splat, 0), ElementIterator(rawData, splat, getNumElements())}; } template ::value>::type; template > - llvm::iterator_range> getValues() const { + iterator_range_impl> getValues() const { auto stringRefs = getRawStringData(); const char *ptr = reinterpret_cast(stringRefs.data()); bool splat = isSplat(); - return {ElementIterator(ptr, splat, 0), + return {Attribute::getType(), ElementIterator(ptr, splat, 0), ElementIterator(ptr, splat, getNumElements())}; } template > @@ -464,8 +459,9 @@ public: using AttributeValueTemplateCheckT = typename std::enable_if::value>::type; template > - llvm::iterator_range getValues() const { - return {value_begin(), value_end()}; + iterator_range_impl getValues() const { + return {Attribute::getType(), value_begin(), + value_end()}; } template > AttributeElementIterator value_begin() const { @@ -486,10 +482,11 @@ public: using DerivedAttributeElementIterator = llvm::mapped_iterator; template > - llvm::iterator_range> getValues() const { + iterator_range_impl> getValues() const { auto castFn = [](Attribute attr) { return attr.template cast(); }; - return llvm::map_range(getValues(), - static_cast(castFn)); + return {Attribute::getType(), + llvm::map_range(getValues(), + static_cast(castFn))}; } template > DerivedAttributeElementIterator value_begin() const { @@ -508,9 +505,9 @@ public: using BoolValueTemplateCheckT = typename std::enable_if::value>::type; template > - llvm::iterator_range getValues() const { + iterator_range_impl getValues() const { assert(isValidBool() && "bool is not the value of this elements attribute"); - return {BoolElementIterator(*this, 0), + return {Attribute::getType(), BoolElementIterator(*this, 0), BoolElementIterator(*this, getNumElements())}; } template > @@ -530,9 +527,9 @@ public: using APIntValueTemplateCheckT = typename std::enable_if::value>::type; template > - llvm::iterator_range getValues() const { + iterator_range_impl getValues() const { assert(getElementType().isIntOrIndex() && "expected integral type"); - return {raw_int_begin(), raw_int_end()}; + return {Attribute::getType(), raw_int_begin(), raw_int_end()}; } template > IntElementIterator value_begin() const { @@ -551,7 +548,7 @@ public: using ComplexAPIntValueTemplateCheckT = typename std::enable_if< std::is_same>::value>::type; template > - llvm::iterator_range getValues() const { + iterator_range_impl getValues() const { return getComplexIntValues(); } template > @@ -569,7 +566,7 @@ public: using APFloatValueTemplateCheckT = typename std::enable_if::value>::type; template > - llvm::iterator_range getValues() const { + iterator_range_impl getValues() const { return getFloatValues(); } template > @@ -587,7 +584,7 @@ public: using ComplexAPFloatValueTemplateCheckT = typename std::enable_if< std::is_same>::value>::type; template > - llvm::iterator_range getValues() const { + iterator_range_impl getValues() const { return getComplexFloatValues(); } template > @@ -660,13 +657,13 @@ protected: IntElementIterator raw_int_end() const { return IntElementIterator(*this, getNumElements()); } - llvm::iterator_range getComplexIntValues() const; + iterator_range_impl getComplexIntValues() const; ComplexIntElementIterator complex_value_begin() const; ComplexIntElementIterator complex_value_end() const; - llvm::iterator_range getFloatValues() const; + iterator_range_impl getFloatValues() const; FloatElementIterator float_value_begin() const; FloatElementIterator float_value_end() const; - llvm::iterator_range + iterator_range_impl getComplexFloatValues() const; ComplexFloatElementIterator complex_float_value_begin() const; ComplexFloatElementIterator complex_float_value_end() const; @@ -872,8 +869,7 @@ public: //===----------------------------------------------------------------------===// template -auto SparseElementsAttr::getValues() const - -> llvm::iterator_range> { +auto SparseElementsAttr::value_begin() const -> iterator { auto zeroValue = getZeroValue(); auto valueIt = getValues().value_begin(); const std::vector flatSparseIndices(getFlattenedSparseIndices()); @@ -888,15 +884,7 @@ auto SparseElementsAttr::getValues() const // Otherwise, return the zero value. return zeroValue; }; - return llvm::map_range(llvm::seq(0, getNumElements()), mapFn); -} -template -auto SparseElementsAttr::value_begin() const -> iterator { - return getValues().begin(); -} -template -auto SparseElementsAttr::value_end() const -> iterator { - return getValues().end(); + return iterator(llvm::seq(0, getNumElements()).begin(), mapFn); } } // end namespace mlir. diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td index 01af84c421e9..c6631cd79fe5 100644 --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -174,9 +174,7 @@ def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr< "ArrayRef":$rawData); let extraClassDeclaration = [{ using DenseElementsAttr::empty; - using DenseElementsAttr::getFlatValue; using DenseElementsAttr::getNumElements; - using DenseElementsAttr::getValue; using DenseElementsAttr::getValues; using DenseElementsAttr::isSplat; using DenseElementsAttr::size; @@ -313,9 +311,7 @@ def Builtin_DenseStringElementsAttr : Builtin_Attr< ]; let extraClassDeclaration = [{ using DenseElementsAttr::empty; - using DenseElementsAttr::getFlatValue; using DenseElementsAttr::getNumElements; - using DenseElementsAttr::getValue; using DenseElementsAttr::getValues; using DenseElementsAttr::isSplat; using DenseElementsAttr::size; @@ -712,10 +708,6 @@ def Builtin_OpaqueElementsAttr : Builtin_Attr< let extraClassDeclaration = [{ using ValueType = StringRef; - /// Return the value at the given index. The 'index' is expected to refer to - /// a valid element. - Attribute getValue(ArrayRef index) const; - /// Decodes the attribute value using dialect-specific decoding hook. /// Returns false if decoding is successful. If not, returns true and leaves /// 'result' argument unspecified. @@ -802,6 +794,7 @@ def Builtin_SparseElementsAttr : Builtin_Attr< // String types. StringRef >; + using ElementsAttr::Trait::getValues; /// Provide a `value_begin_impl` to enable iteration within ElementsAttr. template @@ -817,13 +810,7 @@ def Builtin_SparseElementsAttr : Builtin_Attr< /// Return the values of this attribute in the form of the given type 'T'. /// 'T' may be any of Attribute, APInt, APFloat, c++ integer/float types, /// etc. - template llvm::iterator_range> getValues() const; template iterator value_begin() const; - template iterator value_end() const; - - /// Return the value of the element at the given index. The 'index' is - /// expected to refer to a valid element. - Attribute getValue(ArrayRef index) const; private: /// Get a zero APFloat for the given sparse attribute. diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index 3b15212e3002..8d6c4ccf6a8b 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -288,8 +288,9 @@ bool mlirAttributeIsAElements(MlirAttribute attr) { MlirAttribute mlirElementsAttrGetValue(MlirAttribute attr, intptr_t rank, uint64_t *idxs) { - return wrap(unwrap(attr).cast().getValue( - llvm::makeArrayRef(idxs, rank))); + return wrap(unwrap(attr) + .cast() + .getValues()[llvm::makeArrayRef(idxs, rank)]); } bool mlirElementsAttrIsValidIndex(MlirAttribute attr, intptr_t rank, @@ -482,7 +483,8 @@ bool mlirDenseElementsAttrIsSplat(MlirAttribute attr) { } MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getSplatValue()); + return wrap( + unwrap(attr).cast().getSplatValue()); } int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr) { return unwrap(attr).cast().getSplatValue(); @@ -520,36 +522,36 @@ MlirStringRef mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr) { // Indexed accessors. bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getFlatValue(pos); + return unwrap(attr).cast().getValues()[pos]; } int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getFlatValue(pos); + return unwrap(attr).cast().getValues()[pos]; } uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getFlatValue(pos); + return unwrap(attr).cast().getValues()[pos]; } int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getFlatValue(pos); + return unwrap(attr).cast().getValues()[pos]; } uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getFlatValue(pos); + return unwrap(attr).cast().getValues()[pos]; } int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getFlatValue(pos); + return unwrap(attr).cast().getValues()[pos]; } uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getFlatValue(pos); + return unwrap(attr).cast().getValues()[pos]; } float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getFlatValue(pos); + return unwrap(attr).cast().getValues()[pos]; } double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getFlatValue(pos); + return unwrap(attr).cast().getValues()[pos]; } MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr, intptr_t pos) { return wrap( - unwrap(attr).cast().getFlatValue(pos)); + unwrap(attr).cast().getValues()[pos]); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index 88eca4600a42..d55deb5ce84a 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -169,7 +169,7 @@ LogicalResult WorkGroupSizeConversion::matchAndRewrite( return failure(); auto workGroupSizeAttr = spirv::lookupLocalWorkGroupSize(op); - auto val = workGroupSizeAttr.getValue(index.getValue()); + auto val = workGroupSizeAttr.getValues()[index.getValue()]; auto convertedType = getTypeConverter()->convertType(op.getResult().getType()); if (!convertedType) diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 46c689d9b177..b4ea696f80d0 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -451,7 +451,7 @@ struct GlobalMemrefOpLowering // For scalar memrefs, the global variable created is of the element type, // so unpack the elements attribute to extract the value. if (type.getRank() == 0) - initialValue = elementsAttr.getValue({}); + initialValue = elementsAttr.getValues()[0]; } uint64_t alignment = global.alignment().getValueOr(0); diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 2dae2feb2f4c..1acb0a565dae 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2415,8 +2415,7 @@ LogicalResult AffineStoreOp::fold(ArrayRef cstOperands, // AffineMinMaxOpBase //===----------------------------------------------------------------------===// -template -static LogicalResult verifyAffineMinMaxOp(T op) { +template static LogicalResult verifyAffineMinMaxOp(T op) { // Verify that operand count matches affine map dimension and symbol count. if (op.getNumOperands() != op.map().getNumDims() + op.map().getNumSymbols()) return op.emitOpError( @@ -2424,8 +2423,7 @@ static LogicalResult verifyAffineMinMaxOp(T op) { return success(); } -template -static void printAffineMinMaxOp(OpAsmPrinter &p, T op) { +template static void printAffineMinMaxOp(OpAsmPrinter &p, T op) { p << ' ' << op->getAttr(T::getMapAttrName()); auto operands = op.getOperands(); unsigned numDims = op.map().getNumDims(); @@ -2532,8 +2530,7 @@ struct DeduplicateAffineMinMaxExpressions : public OpRewritePattern { /// /// %1 = affine.min affine_map< /// ()[s0, s1] -> (s0 + 4, s1 + 16, s1 * 8)> ()[%sym2, %sym1] -template -struct MergeAffineMinMaxOp : public OpRewritePattern { +template struct MergeAffineMinMaxOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(T affineOp, @@ -2890,19 +2887,19 @@ AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() { } AffineMap AffineParallelOp::getLowerBoundMap(unsigned pos) { + auto values = lowerBoundsGroups().getValues(); unsigned start = 0; for (unsigned i = 0; i < pos; ++i) - start += lowerBoundsGroups().getValue(i); - return lowerBoundsMap().getSliceMap( - start, lowerBoundsGroups().getValue(pos)); + start += values[i]; + return lowerBoundsMap().getSliceMap(start, values[pos]); } AffineMap AffineParallelOp::getUpperBoundMap(unsigned pos) { + auto values = upperBoundsGroups().getValues(); unsigned start = 0; for (unsigned i = 0; i < pos; ++i) - start += upperBoundsGroups().getValue(i); - return upperBoundsMap().getSliceMap( - start, upperBoundsGroups().getValue(pos)); + start += values[i]; + return upperBoundsMap().getSliceMap(start, values[pos]); } AffineValueMap AffineParallelOp::getLowerBoundsValueMap() { diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp index d9560bf9139d..ea4d7a69c063 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -163,8 +163,8 @@ constexpr const static unsigned kBitsInByte = 8; /// Returns the value that corresponds to named position `pos` from the /// attribute `attr` assuming it's a dense integer elements attribute. static unsigned extractPointerSpecValue(Attribute attr, DLEntryPos pos) { - return attr.cast().getValue( - static_cast(pos)); + return attr.cast() + .getValues()[static_cast(pos)]; } /// Returns the part of the data layout entry that corresponds to `pos` for the diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index ee5622d2662d..fa836ed9577a 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1184,7 +1184,7 @@ public: if (matchPattern(def, m_Constant(&splatAttr)) && splatAttr.isSplat() && splatAttr.getType().getElementType().isIntOrFloat()) { - constantAttr = splatAttr.getSplatValue(); + constantAttr = splatAttr.getSplatValue(); return true; } } @@ -1455,10 +1455,9 @@ public: bool isFloat = elementType.isa(); if (isFloat) { - SmallVector> - inputFpIterators; + SmallVector> inFpRanges; for (int i = 0; i < numInputs; ++i) - inputFpIterators.push_back(inputValues[i].getValues()); + inFpRanges.push_back(inputValues[i].getValues()); computeFnInputs.apFloats.resize(numInputs, APFloat(0.f)); @@ -1469,22 +1468,17 @@ public: computeRemappedLinearIndex(linearIndex); // Collect constant elements for all inputs at this loop iteration. - for (int i = 0; i < numInputs; ++i) { - computeFnInputs.apFloats[i] = - *(inputFpIterators[i].begin() + srcLinearIndices[i]); - } + for (int i = 0; i < numInputs; ++i) + computeFnInputs.apFloats[i] = inFpRanges[i][srcLinearIndices[i]]; // Invoke the computation to get the corresponding constant output // element. - APIntOrFloat outputs = computeFn(computeFnInputs); - - fpOutputValues[dstLinearIndex] = outputs.apFloat.getValue(); + fpOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apFloat; } } else { - SmallVector> - inputIntIterators; + SmallVector> inIntRanges; for (int i = 0; i < numInputs; ++i) - inputIntIterators.push_back(inputValues[i].getValues()); + inIntRanges.push_back(inputValues[i].getValues()); computeFnInputs.apInts.resize(numInputs); @@ -1495,25 +1489,19 @@ public: computeRemappedLinearIndex(linearIndex); // Collect constant elements for all inputs at this loop iteration. - for (int i = 0; i < numInputs; ++i) { - computeFnInputs.apInts[i] = - *(inputIntIterators[i].begin() + srcLinearIndices[i]); - } + for (int i = 0; i < numInputs; ++i) + computeFnInputs.apInts[i] = inIntRanges[i][srcLinearIndices[i]]; // Invoke the computation to get the corresponding constant output // element. - APIntOrFloat outputs = computeFn(computeFnInputs); - - intOutputValues[dstLinearIndex] = outputs.apInt.getValue(); + intOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apInt; } } - DenseIntOrFPElementsAttr outputAttr; - if (isFloat) { - outputAttr = DenseFPElementsAttr::get(outputType, fpOutputValues); - } else { - outputAttr = DenseIntElementsAttr::get(outputType, intOutputValues); - } + DenseElementsAttr outputAttr = + isFloat ? DenseElementsAttr::get(outputType, fpOutputValues) + : DenseElementsAttr::get(outputType, intOutputValues); + rewriter.replaceOpWithNewOp(genericOp, outputAttr); return success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 758aeecf380d..af3e528212f7 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -913,9 +913,9 @@ struct DownscaleSizeOneWindowed2DConvolution final loc, newOutputType, output, ioReshapeIndices); // We need to shrink the strides and dilations too. - auto stride = convOp.strides().getFlatValue(removeH ? 1 : 0); + auto stride = convOp.strides().getValues()[removeH ? 1 : 0]; auto stridesAttr = rewriter.getI64VectorAttr(stride); - auto dilation = convOp.dilations().getFlatValue(removeH ? 1 : 0); + auto dilation = convOp.dilations().getValues()[removeH ? 1 : 0]; auto dilationsAttr = rewriter.getI64VectorAttr(dilation); auto conv1DOp = rewriter.create( diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp index 437c762b5f55..c5125860d437 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -56,7 +56,7 @@ static Attribute extractCompositeElement(Attribute composite, if (auto vector = composite.dyn_cast()) { assert(indices.size() == 1 && "must have exactly one index for a vector"); - return vector.getValue({indices[0]}); + return vector.getValues()[indices[0]]; } if (auto array = composite.dyn_cast()) { diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 635d3c05a72e..27d60b5f02f3 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -1138,7 +1138,7 @@ OpFoldResult GetExtentOp::fold(ArrayRef operands) { return nullptr; if (dim.getValue() >= elements.getNumElements()) return nullptr; - return elements.getValue({(uint64_t)dim.getValue()}); + return elements.getValues()[(uint64_t)dim.getValue()]; } void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape, diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp index 91470be5e7a5..6bc2d7fd436d 100644 --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1304,13 +1304,14 @@ static void printSwitchOpCases( if (!caseValues) return; - for (int64_t i = 0, size = caseValues.size(); i < size; ++i) { + for (const auto &it : llvm::enumerate(caseValues.getValues())) { p << ','; p.printNewline(); p << " "; - p << caseValues.getValue(i).getLimitedValue(); + p << it.value().getLimitedValue(); p << ": "; - p.printSuccessorAndUseList(caseDestinations[i], caseOperands[i]); + p.printSuccessorAndUseList(caseDestinations[it.index()], + caseOperands[it.index()]); } p.printNewline(); } @@ -1353,9 +1354,9 @@ Block *SwitchOp::getSuccessorForOperands(ArrayRef operands) { SuccessorRange caseDests = getCaseDestinations(); if (auto value = operands.front().dyn_cast_or_null()) { - for (int64_t i = 0, size = getCaseValues()->size(); i < size; ++i) - if (value == caseValues->getValue(i)) - return caseDests[i]; + for (const auto &it : llvm::enumerate(caseValues->getValues())) + if (it.value() == value.getValue()) + return caseDests[it.index()]; return getDefaultDestination(); } return nullptr; @@ -1394,15 +1395,15 @@ dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) { auto caseValues = op.getCaseValues(); auto caseDests = op.getCaseDestinations(); - for (int64_t i = 0, size = caseValues->size(); i < size; ++i) { - if (caseDests[i] == op.getDefaultDestination() && - op.getCaseOperands(i) == op.getDefaultOperands()) { + for (const auto &it : llvm::enumerate(caseValues->getValues())) { + if (caseDests[it.index()] == op.getDefaultDestination() && + op.getCaseOperands(it.index()) == op.getDefaultOperands()) { requiresChange = true; continue; } - newCaseDestinations.push_back(caseDests[i]); - newCaseOperands.push_back(op.getCaseOperands(i)); - newCaseValues.push_back(caseValues->getValue(i)); + newCaseDestinations.push_back(caseDests[it.index()]); + newCaseOperands.push_back(op.getCaseOperands(it.index())); + newCaseValues.push_back(it.value()); } if (!requiresChange) @@ -1424,10 +1425,11 @@ dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) { static void foldSwitch(SwitchOp op, PatternRewriter &rewriter, APInt caseValue) { auto caseValues = op.getCaseValues(); - for (int64_t i = 0, size = caseValues->size(); i < size; ++i) { - if (caseValues->getValue(i) == caseValue) { - rewriter.replaceOpWithNewOp(op, op.getCaseDestinations()[i], - op.getCaseOperands(i)); + for (const auto &it : llvm::enumerate(caseValues->getValues())) { + if (it.value() == caseValue) { + rewriter.replaceOpWithNewOp( + op, op.getCaseDestinations()[it.index()], + op.getCaseOperands(it.index())); return; } } @@ -1551,22 +1553,16 @@ simplifySwitchFromSwitchOnSameCondition(SwitchOp op, return failure(); // Fold this switch to an unconditional branch. - APInt caseValue; - bool isDefault = true; SuccessorRange predDests = predSwitch.getCaseDestinations(); - Optional predCaseValues = predSwitch.getCaseValues(); - for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i) { - if (currentBlock == predDests[i]) { - caseValue = predCaseValues->getValue(i); - isDefault = false; - break; - } - } - if (isDefault) + auto it = llvm::find(predDests, currentBlock); + if (it != predDests.end()) { + Optional predCaseValues = predSwitch.getCaseValues(); + foldSwitch(op, rewriter, + predCaseValues->getValues()[it - predDests.begin()]); + } else { rewriter.replaceOpWithNewOp(op, op.getDefaultDestination(), op.getDefaultOperands()); - else - foldSwitch(op, rewriter, caseValue); + } return success(); } @@ -1613,7 +1609,7 @@ simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op, auto predCaseValues = predSwitch.getCaseValues(); for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i) if (currentBlock != predDests[i]) - caseValuesToRemove.insert(predCaseValues->getValue(i)); + caseValuesToRemove.insert(predCaseValues->getValues()[i]); SmallVector newCaseDestinations; SmallVector newCaseOperands; @@ -1622,14 +1618,14 @@ simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op, auto caseValues = op.getCaseValues(); auto caseDests = op.getCaseDestinations(); - for (int64_t i = 0, size = caseValues->size(); i < size; ++i) { - if (caseValuesToRemove.contains(caseValues->getValue(i))) { + for (const auto &it : llvm::enumerate(caseValues->getValues())) { + if (caseValuesToRemove.contains(it.value())) { requiresChange = true; continue; } - newCaseDestinations.push_back(caseDests[i]); - newCaseOperands.push_back(op.getCaseOperands(i)); - newCaseValues.push_back(caseValues->getValue(i)); + newCaseDestinations.push_back(caseDests[it.index()]); + newCaseOperands.push_back(op.getCaseOperands(it.index())); + newCaseValues.push_back(it.value()); } if (!requiresChange) diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 287289b4cd6d..1d8d8e2d5068 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -340,7 +340,7 @@ OpFoldResult ExtractOp::fold(ArrayRef operands) { // If this is a splat elements attribute, simply return the value. All of the // elements of a splat attribute are the same. if (auto splatTensor = tensor.dyn_cast()) - return splatTensor.getSplatValue(); + return splatTensor.getSplatValue(); // Otherwise, collect the constant indices into the tensor. SmallVector indices; @@ -353,7 +353,7 @@ OpFoldResult ExtractOp::fold(ArrayRef operands) { // If this is an elements attribute, query the value at the given indices. auto elementsAttr = tensor.dyn_cast(); if (elementsAttr && elementsAttr.isValidIndex(indices)) - return elementsAttr.getValue(indices); + return elementsAttr.getValues()[indices]; return {}; } @@ -440,7 +440,7 @@ OpFoldResult InsertOp::fold(ArrayRef operands) { Attribute dest = operands[1]; if (scalar && dest) if (auto splatDest = dest.dyn_cast()) - if (scalar == splatDest.getSplatValue()) + if (scalar == splatDest.getSplatValue()) return dest; return {}; } diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 85415d92bd1b..2a435476e5be 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -230,6 +230,7 @@ struct ConstantTransposeOptimization // Transpose the input constant. Because we don't know its rank in advance, // we need to loop over the range [0, element count) and delinearize the // index. + auto attrValues = inputValues.getValues(); for (int srcLinearIndex = 0; srcLinearIndex < numElements; ++srcLinearIndex) { SmallVector srcIndices(inputType.getRank(), 0); @@ -247,7 +248,7 @@ struct ConstantTransposeOptimization for (int dim = 1; dim < outputType.getRank(); ++dim) dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim]; - outputValues[dstLinearIndex] = inputValues.getValue(srcIndices); + outputValues[dstLinearIndex] = attrValues[srcIndices]; } rewriter.replaceOpWithNewOp( @@ -424,8 +425,7 @@ OpFoldResult TransposeOp::fold(ArrayRef operands) { // TOSA Operator Verifiers. //===----------------------------------------------------------------------===// -template -static LogicalResult verifyConvOp(T op) { +template static LogicalResult verifyConvOp(T op) { // All TOSA conv ops have an input() and weight(). auto inputType = op.input().getType().template dyn_cast(); auto weightType = op.weight().getType().template dyn_cast(); diff --git a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp index 96992c219aa0..fd289917c64c 100644 --- a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp +++ b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp @@ -56,15 +56,15 @@ bool ElementsAttr::isValidIndex(Attribute elementsAttr, return isValidIndex(elementsAttr.getType().cast(), index); } -uint64_t ElementsAttr::getFlattenedIndex(Attribute elementsAttr, - ArrayRef index) { - ShapedType type = elementsAttr.getType().cast(); - assert(isValidIndex(type, index) && "expected valid multi-dimensional index"); +uint64_t ElementsAttr::getFlattenedIndex(Type type, ArrayRef index) { + ShapedType shapeType = type.cast(); + assert(isValidIndex(shapeType, index) && + "expected valid multi-dimensional index"); // Reduce the provided multidimensional index into a flattended 1D row-major // index. - auto rank = type.getRank(); - auto shape = type.getShape(); + auto rank = shapeType.getRank(); + ArrayRef shape = shapeType.getShape(); uint64_t valueIndex = 0; uint64_t dimMultiplier = 1; for (int i = rank - 1; i >= 0; --i) { diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp index 41a8c46c0c6d..38c843026898 100644 --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -902,10 +902,10 @@ LLVM_ATTRIBUTE_UNUSED static bool isComplexOfIntType(Type type) { } auto DenseElementsAttr::getComplexIntValues() const - -> llvm::iterator_range { + -> iterator_range_impl { assert(isComplexOfIntType(getElementType()) && "expected complex integral type"); - return {ComplexIntElementIterator(*this, 0), + return {getType(), ComplexIntElementIterator(*this, 0), ComplexIntElementIterator(*this, getNumElements())}; } auto DenseElementsAttr::complex_value_begin() const @@ -923,10 +923,10 @@ auto DenseElementsAttr::complex_value_end() const -> ComplexIntElementIterator { /// Return the held element values as a range of APFloat. The element type of /// this attribute must be of float type. auto DenseElementsAttr::getFloatValues() const - -> llvm::iterator_range { + -> iterator_range_impl { auto elementType = getElementType().cast(); const auto &elementSemantics = elementType.getFloatSemantics(); - return {FloatElementIterator(elementSemantics, raw_int_begin()), + return {getType(), FloatElementIterator(elementSemantics, raw_int_begin()), FloatElementIterator(elementSemantics, raw_int_end())}; } auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator { @@ -939,11 +939,12 @@ auto DenseElementsAttr::float_value_end() const -> FloatElementIterator { } auto DenseElementsAttr::getComplexFloatValues() const - -> llvm::iterator_range { + -> iterator_range_impl { Type eltTy = getElementType().cast().getElementType(); assert(eltTy.isa() && "expected complex float type"); const auto &semantics = eltTy.cast().getFloatSemantics(); - return {{semantics, {*this, 0}}, + return {getType(), + {semantics, {*this, 0}}, {semantics, {*this, static_cast(getNumElements())}}}; } auto DenseElementsAttr::complex_float_value_begin() const @@ -1248,13 +1249,6 @@ bool DenseIntElementsAttr::classof(Attribute attr) { // OpaqueElementsAttr //===----------------------------------------------------------------------===// -/// Return the value at the given index. If index does not refer to a valid -/// element, then a null attribute is returned. -Attribute OpaqueElementsAttr::getValue(ArrayRef index) const { - assert(isValidIndex(index) && "expected valid multi-dimensional index"); - return Attribute(); -} - bool OpaqueElementsAttr::decode(ElementsAttr &result) { Dialect *dialect = getDialect().getDialect(); if (!dialect) @@ -1279,47 +1273,6 @@ OpaqueElementsAttr::verify(function_ref emitError, // SparseElementsAttr //===----------------------------------------------------------------------===// -/// Return the value of the element at the given index. -Attribute SparseElementsAttr::getValue(ArrayRef index) const { - assert(isValidIndex(index) && "expected valid multi-dimensional index"); - auto type = getType(); - - // The sparse indices are 64-bit integers, so we can reinterpret the raw data - // as a 1-D index array. - auto sparseIndices = getIndices(); - auto sparseIndexValues = sparseIndices.getValues(); - - // Check to see if the indices are a splat. - if (sparseIndices.isSplat()) { - // If the index is also not a splat of the index value, we know that the - // value is zero. - auto splatIndex = *sparseIndexValues.begin(); - if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; })) - return getZeroAttr(); - - // If the indices are a splat, we also expect the values to be a splat. - assert(getValues().isSplat() && "expected splat values"); - return getValues().getSplatValue(); - } - - // Build a mapping between known indices and the offset of the stored element. - llvm::SmallDenseMap, size_t> mappedIndices; - auto numSparseIndices = sparseIndices.getType().getDimSize(0); - size_t rank = type.getRank(); - for (size_t i = 0, e = numSparseIndices; i != e; ++i) - mappedIndices.try_emplace( - {&*std::next(sparseIndexValues.begin(), i * rank), rank}, i); - - // Look for the provided index key within the mapped indices. If the provided - // index is not found, then return a zero attribute. - auto it = mappedIndices.find(index); - if (it == mappedIndices.end()) - return getZeroAttr(); - - // Otherwise, return the held sparse value element. - return getValues().getValue(it->second); -} - /// Get a zero APFloat for the given sparse attribute. APFloat SparseElementsAttr::getZeroAPFloat() const { auto eltType = getElementType().cast(); diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp index 9676553aeaa5..67c9ccbaec5b 100644 --- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp +++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp @@ -71,7 +71,7 @@ int64_t ShapeAdaptor::getDimSize(int index) const { return t.cast().getDimSize(index); if (auto attr = val.dyn_cast()) return attr.cast() - .getFlatValue(index) + .getValues()[index] .getSExtValue(); auto *stc = val.get(); return stc->getDims()[index]; diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp index 864bba19b2ca..6c10e61dcc85 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -386,14 +386,12 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, return success(); } if (auto condbrOp = dyn_cast(opInst)) { - auto weights = condbrOp.getBranchWeights(); llvm::MDNode *branchWeights = nullptr; - if (weights) { + if (auto weights = condbrOp.getBranchWeights()) { // Map weight attributes to LLVM metadata. - auto trueWeight = - weights.getValue().getValue(0).cast().getInt(); - auto falseWeight = - weights.getValue().getValue(1).cast().getInt(); + auto weightValues = weights->getValues(); + auto trueWeight = weightValues[0].getSExtValue(); + auto falseWeight = weightValues[1].getSExtValue(); branchWeights = llvm::MDBuilder(moduleTranslation.getLLVMContext()) .createBranchWeights(static_cast(trueWeight), diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index 24eea6c31711..5701b44a2a9f 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -706,11 +706,12 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType, if (shapedType.getRank() == dim) { if (auto attr = valueAttr.dyn_cast()) { return attr.getType().getElementType().isInteger(1) - ? prepareConstantBool(loc, attr.getValue(index)) - : prepareConstantInt(loc, attr.getValue(index)); + ? prepareConstantBool(loc, attr.getValues()[index]) + : prepareConstantInt(loc, + attr.getValues()[index]); } if (auto attr = valueAttr.dyn_cast()) { - return prepareConstantFp(loc, attr.getValue(index)); + return prepareConstantFp(loc, attr.getValues()[index]); } return 0; } diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml index b1b807701676..b8edbea19a40 100644 --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml @@ -154,9 +154,9 @@ structured_op: !LinalgStructuredOpConfig # ODS-NEXT: LogicalResult verifyIndexingMapRequiredAttributes(); # IMPL: getSymbolBindings(Test2Op self) -# IMPL: cst2 = self.strides().getValue({ 0 }); +# IMPL: cst2 = self.strides().getValues()[0]; # IMPL-NEXT: getAffineConstantExpr(cst2, context) -# IMPL: cst3 = self.strides().getValue({ 1 }); +# IMPL: cst3 = self.strides().getValues()[1]; # IMPL-NEXT: getAffineConstantExpr(cst3, context) # IMPL: Test2Op::indexing_maps() diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp index bcf3616f8b0a..507713e25678 100644 --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -142,8 +142,7 @@ namespace yaml { /// Top-level type containing op metadata and one of a concrete op type. /// Currently, the only defined op type is `structured_op` (maps to /// `LinalgStructuredOpConfig`). -template <> -struct MappingTraits { +template <> struct MappingTraits { static void mapping(IO &io, LinalgOpConfig &info) { io.mapOptional("metadata", info.metadata); io.mapOptional("structured_op", info.structuredOp); @@ -156,8 +155,7 @@ struct MappingTraits { /// - List of indexing maps (see `LinalgIndexingMaps`). /// - Iterator types (see `LinalgIteratorTypeDef`). /// - List of scalar level assignment (see `ScalarAssign`). -template <> -struct MappingTraits { +template <> struct MappingTraits { static void mapping(IO &io, LinalgStructuredOpConfig &info) { io.mapRequired("args", info.args); io.mapRequired("indexing_maps", info.indexingMaps); @@ -180,8 +178,7 @@ struct MappingTraits { /// attribute symbols. During op creation these symbols are replaced by the /// corresponding `name` attribute values. Only attribute arguments have /// an `attribute_map`. -template <> -struct MappingTraits { +template <> struct MappingTraits { static void mapping(IO &io, LinalgOperandDef &info) { io.mapRequired("name", info.name); io.mapRequired("usage", info.usage); @@ -192,8 +189,7 @@ struct MappingTraits { }; /// Usage enum for a named argument. -template <> -struct ScalarEnumerationTraits { +template <> struct ScalarEnumerationTraits { static void enumeration(IO &io, LinalgOperandDefUsage &value) { io.enumCase(value, "InputOperand", LinalgOperandDefUsage::input); io.enumCase(value, "OutputOperand", LinalgOperandDefUsage::output); @@ -202,8 +198,7 @@ struct ScalarEnumerationTraits { }; /// Iterator type enum. -template <> -struct ScalarEnumerationTraits { +template <> struct ScalarEnumerationTraits { static void enumeration(IO &io, LinalgIteratorTypeDef &value) { io.enumCase(value, "parallel", LinalgIteratorTypeDef::parallel); io.enumCase(value, "reduction", LinalgIteratorTypeDef::reduction); @@ -211,8 +206,7 @@ struct ScalarEnumerationTraits { }; /// Metadata about the op (name, C++ name, and documentation). -template <> -struct MappingTraits { +template <> struct MappingTraits { static void mapping(IO &io, LinalgOpMetadata &info) { io.mapRequired("name", info.name); io.mapRequired("cpp_class_name", info.cppClassName); @@ -226,8 +220,7 @@ struct MappingTraits { /// some symbols that bind to attributes of the op. Each indexing map must /// be normalized over the same list of dimensions, and its symbols must /// match the symbols for argument shapes. -template <> -struct MappingTraits { +template <> struct MappingTraits { static void mapping(IO &io, LinalgIndexingMapsConfig &info) { io.mapOptional("static_indexing_maps", info.staticIndexingMaps); } @@ -237,8 +230,7 @@ struct MappingTraits { /// - The `arg` name must match a named output. /// - The `value` is a scalar expression for computing the value to /// assign (see `ScalarExpression`). -template <> -struct MappingTraits { +template <> struct MappingTraits { static void mapping(IO &io, ScalarAssign &info) { io.mapRequired("arg", info.arg); io.mapRequired("value", info.value); @@ -250,8 +242,7 @@ struct MappingTraits { /// - `scalar_apply`: Result of evaluating a named function (see /// `ScalarApply`). /// - `symbolic_cast`: Cast to a symbolic TypeVar bound elsewhere. -template <> -struct MappingTraits { +template <> struct MappingTraits { static void mapping(IO &io, ScalarExpression &info) { io.mapOptional("scalar_arg", info.arg); io.mapOptional("scalar_const", info.constant); @@ -266,16 +257,14 @@ struct MappingTraits { /// functions include: /// - `add(lhs, rhs)` /// - `mul(lhs, rhs)` -template <> -struct MappingTraits { +template <> struct MappingTraits { static void mapping(IO &io, ScalarApply &info) { io.mapRequired("fn_name", info.fnName); io.mapRequired("operands", info.operands); } }; -template <> -struct MappingTraits { +template <> struct MappingTraits { static void mapping(IO &io, ScalarSymbolicCast &info) { io.mapRequired("type_var", info.typeVar); io.mapRequired("operands", info.operands); @@ -285,8 +274,7 @@ struct MappingTraits { /// Helper mapping which accesses an AffineMapAttr as a serialized string of /// the same. -template <> -struct ScalarTraits { +template <> struct ScalarTraits { static void output(const SerializedAffineMap &value, void *rawYamlContext, raw_ostream &out) { assert(value.affineMapAttr); @@ -726,7 +714,7 @@ static SmallVector getSymbolBindings({0} self) { // {1}: Symbol position // {2}: Attribute index static const char structuredOpAccessAttrFormat[] = R"FMT( -int64_t cst{1} = self.{0}().getValue({ {2} }); +int64_t cst{1} = self.{0}().getValues()[{2}]; exprs.push_back(getAffineConstantExpr(cst{1}, context)); )FMT"; // Update all symbol bindings mapped to an attribute. diff --git a/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp b/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp index 33c6360a4e90..5125413a6c11 100644 --- a/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp +++ b/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp @@ -113,7 +113,8 @@ TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) { EXPECT_TRUE(returnedValue.isa()); // Check Elements attribute element value is expected. - auto firstValue = returnedValue.cast().getValue({0, 0}); + auto firstValue = + returnedValue.cast().getValues()[{0, 0}]; EXPECT_EQ(firstValue.cast().getInt(), 5); } @@ -138,7 +139,8 @@ TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) { EXPECT_TRUE(returnedValue.isa()); // Check Elements attribute element value is expected. - auto firstValue = returnedValue.cast().getValue({0, 0}); + auto firstValue = + returnedValue.cast().getValues()[{0, 0}]; EXPECT_EQ(firstValue.cast().getInt(), 5); } @@ -162,7 +164,8 @@ TEST(QuantizationUtilsTest, convertRankedSparseAttrUniform) { EXPECT_TRUE(returnedValue.isa()); // Check Elements attribute element value is expected. - auto firstValue = returnedValue.cast().getValue({0, 0}); + auto firstValue = + returnedValue.cast().getValues()[{0, 0}]; EXPECT_EQ(firstValue.cast().getInt(), 5); } diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp index aaff61e7d5f9..19b57fa754df 100644 --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -202,7 +202,7 @@ TEST(DenseScalarTest, ExtractZeroRankElement) { RankedTensorType shape = RankedTensorType::get({}, intTy); auto attr = DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue})); - EXPECT_TRUE(attr.getValue({0}) == value); + EXPECT_TRUE(attr.getValues()[0] == value); } TEST(SparseElementsAttrTest, GetZero) { @@ -238,15 +238,15 @@ TEST(SparseElementsAttrTest, GetZero) { // Only index (0, 0) contains an element, others are supposed to return // the zero/empty value. - auto zeroIntValue = sparseInt.getValue({1, 1}); + auto zeroIntValue = sparseInt.getValues()[{1, 1}]; EXPECT_EQ(zeroIntValue.cast().getInt(), 0); EXPECT_TRUE(zeroIntValue.getType() == intTy); - auto zeroFloatValue = sparseFloat.getValue({1, 1}); + auto zeroFloatValue = sparseFloat.getValues()[{1, 1}]; EXPECT_EQ(zeroFloatValue.cast().getValueAsDouble(), 0.0f); EXPECT_TRUE(zeroFloatValue.getType() == floatTy); - auto zeroStringValue = sparseString.getValue({1, 1}); + auto zeroStringValue = sparseString.getValues()[{1, 1}]; EXPECT_TRUE(zeroStringValue.cast().getValue().empty()); EXPECT_TRUE(zeroStringValue.getType() == stringTy); }