[mlir] Refactor ElementsAttr's value access API
There are several aspects of the API that either aren't easy to use, or are deceptively easy to do the wrong thing. The main change of this commit is to remove all of the `getValue<T>`/`getFlatValue<T>` from ElementsAttr and instead provide operator[] methods on the ranges returned by `getValues<T>`. This provides a much more convenient API for the value ranges. It also removes the easy-to-be-inefficient nature of getValue/getFlatValue, which under the hood would construct a new range for the type `T`. Constructing a range is not necessarily cheap in all cases, and could lead to very poor performance if used within a loop; i.e. if you were to naively write something like: ``` DenseElementsAttr attr = ...; for (int i = 0; i < size; ++i) { // We are internally rebuilding the APFloat value range on each iteration!! APFloat it = attr.getFlatValue<APFloat>(i); } ``` Differential Revision: https://reviews.llvm.org/D113229
This commit is contained in:
parent
4a0c89a6cf
commit
ae40d62541
@ -227,6 +227,33 @@ private:
|
||||
ElementsAttrIndexer indexer;
|
||||
ptrdiff_t index;
|
||||
};
|
||||
|
||||
/// This class provides iterator utilities for an ElementsAttr range.
|
||||
template <typename IteratorT>
|
||||
class ElementsAttrRange : public llvm::iterator_range<IteratorT> {
|
||||
public:
|
||||
using reference = typename IteratorT::reference;
|
||||
|
||||
ElementsAttrRange(Type shapeType,
|
||||
const llvm::iterator_range<IteratorT> &range)
|
||||
: llvm::iterator_range<IteratorT>(range), shapeType(shapeType) {}
|
||||
ElementsAttrRange(Type shapeType, IteratorT beginIt, IteratorT endIt)
|
||||
: ElementsAttrRange(shapeType, llvm::make_range(beginIt, endIt)) {}
|
||||
|
||||
/// Return the value at the given index.
|
||||
reference operator[](ArrayRef<uint64_t> index) const;
|
||||
reference operator[](uint64_t index) const {
|
||||
return *std::next(this->begin(), index);
|
||||
}
|
||||
|
||||
/// Return the size of this range.
|
||||
size_t size() const { return llvm::size(*this); }
|
||||
|
||||
private:
|
||||
/// The shaped type of the parent ElementsAttr.
|
||||
Type shapeType;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -256,6 +283,16 @@ verifyAffineMapAsLayout(AffineMap m, ArrayRef<int64_t> shape,
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace mlir {
|
||||
namespace detail {
|
||||
/// Return the value at the given index.
|
||||
template <typename IteratorT>
|
||||
auto ElementsAttrRange<IteratorT>::operator[](ArrayRef<uint64_t> index) const
|
||||
-> reference {
|
||||
// Skip to the element corresponding to the flattened index.
|
||||
return (*this)[ElementsAttr::getFlattenedIndex(shapeType, index)];
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
/// Return the elements of this attribute as a value of type 'T'.
|
||||
template <typename T>
|
||||
auto ElementsAttr::value_begin() const -> DefaultValueCheckT<T, iterator<T>> {
|
||||
|
@ -158,27 +158,6 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
|
||||
];
|
||||
|
||||
string ElementsAttrInterfaceAccessors = [{
|
||||
/// Return the attribute value at the given index. The index is expected to
|
||||
/// refer to a valid element.
|
||||
Attribute getValue(ArrayRef<uint64_t> index) const {
|
||||
return getValue<Attribute>(index);
|
||||
}
|
||||
|
||||
/// Return the value of type 'T' at the given index, where 'T' corresponds
|
||||
/// to an Attribute type.
|
||||
template <typename T>
|
||||
std::enable_if_t<!std::is_same<T, ::mlir::Attribute>::value &&
|
||||
std::is_base_of<T, ::mlir::Attribute>::value>
|
||||
getValue(ArrayRef<uint64_t> index) const {
|
||||
return getValue(index).template dyn_cast_or_null<T>();
|
||||
}
|
||||
|
||||
/// Return the value of type 'T' at the given index.
|
||||
template <typename T>
|
||||
T getValue(ArrayRef<uint64_t> index) const {
|
||||
return getFlatValue<T>(getFlattenedIndex(index));
|
||||
}
|
||||
|
||||
/// Return the number of elements held by this attribute.
|
||||
int64_t size() const { return getNumElements(); }
|
||||
|
||||
@ -281,6 +260,14 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
|
||||
// Value Iteration
|
||||
//===------------------------------------------------------------------===//
|
||||
|
||||
/// The iterator for the given element type T.
|
||||
template <typename T, typename AttrT = ConcreteAttr>
|
||||
using iterator = decltype(std::declval<AttrT>().template value_begin<T>());
|
||||
/// The iterator range over the given element T.
|
||||
template <typename T, typename AttrT = ConcreteAttr>
|
||||
using iterator_range =
|
||||
decltype(std::declval<AttrT>().template getValues<T>());
|
||||
|
||||
/// Return an iterator to the first element of this attribute as a value of
|
||||
/// type `T`.
|
||||
template <typename T>
|
||||
@ -292,11 +279,8 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
|
||||
template <typename T>
|
||||
auto getValues() const {
|
||||
auto beginIt = $_attr.template value_begin<T>();
|
||||
return llvm::make_range(beginIt, std::next(beginIt, size()));
|
||||
}
|
||||
/// Return the value at the given flattened index.
|
||||
template <typename T> T getFlatValue(uint64_t index) const {
|
||||
return *std::next($_attr.template value_begin<T>(), index);
|
||||
return detail::ElementsAttrRange<decltype(beginIt)>(
|
||||
Attribute($_attr).getType(), beginIt, std::next(beginIt, size()));
|
||||
}
|
||||
}] # ElementsAttrInterfaceAccessors;
|
||||
|
||||
@ -304,7 +288,7 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
|
||||
template <typename T>
|
||||
using iterator = detail::ElementsAttrIterator<T>;
|
||||
template <typename T>
|
||||
using iterator_range = llvm::iterator_range<iterator<T>>;
|
||||
using iterator_range = detail::ElementsAttrRange<iterator<T>>;
|
||||
|
||||
//===------------------------------------------------------------------===//
|
||||
// Accessors
|
||||
@ -329,8 +313,12 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
|
||||
uint64_t getFlattenedIndex(ArrayRef<uint64_t> index) const {
|
||||
return getFlattenedIndex(*this, index);
|
||||
}
|
||||
static uint64_t getFlattenedIndex(Attribute elementsAttr,
|
||||
static uint64_t getFlattenedIndex(Type type,
|
||||
ArrayRef<uint64_t> index);
|
||||
static uint64_t getFlattenedIndex(Attribute elementsAttr,
|
||||
ArrayRef<uint64_t> index) {
|
||||
return getFlattenedIndex(elementsAttr.getType(), index);
|
||||
}
|
||||
|
||||
/// Returns the number of elements held by this attribute.
|
||||
int64_t getNumElements() const { return getNumElements(*this); }
|
||||
@ -350,13 +338,6 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
|
||||
!std::is_base_of<Attribute, T>::value,
|
||||
ResultT>;
|
||||
|
||||
/// Return the element of this attribute at the given index as a value of
|
||||
/// type 'T'.
|
||||
template <typename T>
|
||||
T getFlatValue(uint64_t index) const {
|
||||
return *std::next(value_begin<T>(), index);
|
||||
}
|
||||
|
||||
/// Return the splat value for this attribute. This asserts that the
|
||||
/// attribute corresponds to a splat.
|
||||
template <typename T>
|
||||
@ -368,7 +349,7 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
|
||||
/// Return the elements of this attribute as a value of type 'T'.
|
||||
template <typename T>
|
||||
DefaultValueCheckT<T, iterator_range<T>> getValues() const {
|
||||
return iterator_range<T>(value_begin<T>(), value_end<T>());
|
||||
return {Attribute::getType(), value_begin<T>(), value_end<T>()};
|
||||
}
|
||||
template <typename T>
|
||||
DefaultValueCheckT<T, iterator<T>> value_begin() const;
|
||||
@ -384,12 +365,12 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
|
||||
llvm::mapped_iterator<iterator<Attribute>, T (*)(Attribute)>;
|
||||
template <typename T>
|
||||
using DerivedAttrValueIteratorRange =
|
||||
llvm::iterator_range<DerivedAttrValueIterator<T>>;
|
||||
detail::ElementsAttrRange<DerivedAttrValueIterator<T>>;
|
||||
template <typename T, typename = DerivedAttrValueCheckT<T>>
|
||||
DerivedAttrValueIteratorRange<T> getValues() const {
|
||||
auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
|
||||
return llvm::map_range(getValues<Attribute>(),
|
||||
static_cast<T (*)(Attribute)>(castFn));
|
||||
return {Attribute::getType(), llvm::map_range(getValues<Attribute>(),
|
||||
static_cast<T (*)(Attribute)>(castFn))};
|
||||
}
|
||||
template <typename T, typename = DerivedAttrValueCheckT<T>>
|
||||
DerivedAttrValueIterator<T> value_begin() const {
|
||||
@ -407,8 +388,10 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
|
||||
/// return the iterable range. Otherwise, return llvm::None.
|
||||
template <typename T>
|
||||
DefaultValueCheckT<T, Optional<iterator_range<T>>> tryGetValues() const {
|
||||
if (Optional<iterator<T>> beginIt = try_value_begin<T>())
|
||||
return iterator_range<T>(*beginIt, value_end<T>());
|
||||
if (Optional<iterator<T>> beginIt = try_value_begin<T>()) {
|
||||
return iterator_range<T>(Attribute::getType(), *beginIt,
|
||||
value_end<T>());
|
||||
}
|
||||
return llvm::None;
|
||||
}
|
||||
template <typename T>
|
||||
@ -418,10 +401,15 @@ def ElementsAttrInterface : AttrInterface<"ElementsAttr"> {
|
||||
/// return the iterable range. Otherwise, return llvm::None.
|
||||
template <typename T, typename = DerivedAttrValueCheckT<T>>
|
||||
Optional<DerivedAttrValueIteratorRange<T>> tryGetValues() const {
|
||||
auto values = tryGetValues<Attribute>();
|
||||
if (!values)
|
||||
return llvm::None;
|
||||
|
||||
auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
|
||||
if (auto values = tryGetValues<Attribute>())
|
||||
return llvm::map_range(*values, static_cast<T (*)(Attribute)>(castFn));
|
||||
return llvm::None;
|
||||
return DerivedAttrValueIteratorRange<T>(
|
||||
Attribute::getType(),
|
||||
llvm::map_range(*values, static_cast<T (*)(Attribute)>(castFn))
|
||||
);
|
||||
}
|
||||
template <typename T, typename = DerivedAttrValueCheckT<T>>
|
||||
Optional<DerivedAttrValueIterator<T>> try_value_begin() const {
|
||||
|
@ -61,8 +61,7 @@ protected:
|
||||
};
|
||||
|
||||
/// Type trait detector that checks if a given type T is a complex type.
|
||||
template <typename T>
|
||||
struct is_complex_t : public std::false_type {};
|
||||
template <typename T> struct is_complex_t : public std::false_type {};
|
||||
template <typename T>
|
||||
struct is_complex_t<std::complex<T>> : public std::true_type {};
|
||||
} // namespace detail
|
||||
@ -82,8 +81,7 @@ public:
|
||||
/// floating point type that can be used to access the underlying element
|
||||
/// types of a DenseElementsAttr.
|
||||
// TODO: Use std::disjunction when C++17 is supported.
|
||||
template <typename T>
|
||||
struct is_valid_cpp_fp_type {
|
||||
template <typename T> struct is_valid_cpp_fp_type {
|
||||
/// The type is a valid floating point type if it is a builtin floating
|
||||
/// point type, or is a potentially user defined floating point type. The
|
||||
/// latter allows for supporting users that have custom types defined for
|
||||
@ -219,6 +217,18 @@ public:
|
||||
// Iterators
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// The iterator range over the given iterator type T.
|
||||
template <typename IteratorT>
|
||||
using iterator_range_impl = detail::ElementsAttrRange<IteratorT>;
|
||||
|
||||
/// The iterator for the given element type T.
|
||||
template <typename T, typename AttrT = DenseElementsAttr>
|
||||
using iterator = decltype(std::declval<AttrT>().template value_begin<T>());
|
||||
/// The iterator range over the given element T.
|
||||
template <typename T, typename AttrT = DenseElementsAttr>
|
||||
using iterator_range =
|
||||
decltype(std::declval<AttrT>().template getValues<T>());
|
||||
|
||||
/// A utility iterator that allows walking over the internal Attribute values
|
||||
/// of a DenseElementsAttr.
|
||||
class AttributeElementIterator
|
||||
@ -358,22 +368,7 @@ public:
|
||||
!std::is_same<Attribute, T>::value,
|
||||
T>::type
|
||||
getSplatValue() const {
|
||||
return getSplatValue().template cast<T>();
|
||||
}
|
||||
|
||||
/// Return the value at the given index. The 'index' is expected to refer to a
|
||||
/// valid element.
|
||||
Attribute getValue(ArrayRef<uint64_t> index) const {
|
||||
return getValue<Attribute>(index);
|
||||
}
|
||||
template <typename T>
|
||||
T getValue(ArrayRef<uint64_t> index) const {
|
||||
// Skip to the element corresponding to the flattened index.
|
||||
return getFlatValue<T>(ElementsAttr::getFlattenedIndex(*this, index));
|
||||
}
|
||||
/// Return the value at the given flattened index.
|
||||
template <typename T> T getFlatValue(uint64_t index) const {
|
||||
return *std::next(value_begin<T>(), index);
|
||||
return getSplatValue<Attribute>().template cast<T>();
|
||||
}
|
||||
|
||||
/// Return the held element values as a range of integer or floating-point
|
||||
@ -384,12 +379,12 @@ public:
|
||||
std::numeric_limits<T>::is_integer) ||
|
||||
is_valid_cpp_fp_type<T>::value>::type;
|
||||
template <typename T, typename = IntFloatValueTemplateCheckT<T>>
|
||||
llvm::iterator_range<ElementIterator<T>> getValues() const {
|
||||
iterator_range_impl<ElementIterator<T>> getValues() const {
|
||||
assert(isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer,
|
||||
std::numeric_limits<T>::is_signed));
|
||||
const char *rawData = getRawData().data();
|
||||
bool splat = isSplat();
|
||||
return {ElementIterator<T>(rawData, splat, 0),
|
||||
return {Attribute::getType(), ElementIterator<T>(rawData, splat, 0),
|
||||
ElementIterator<T>(rawData, splat, getNumElements())};
|
||||
}
|
||||
template <typename T, typename = IntFloatValueTemplateCheckT<T>>
|
||||
@ -413,12 +408,12 @@ public:
|
||||
is_valid_cpp_fp_type<ElementT>::value)>::type;
|
||||
template <typename T, typename ElementT = typename T::value_type,
|
||||
typename = ComplexValueTemplateCheckT<T, ElementT>>
|
||||
llvm::iterator_range<ElementIterator<T>> getValues() const {
|
||||
iterator_range_impl<ElementIterator<T>> getValues() const {
|
||||
assert(isValidComplex(sizeof(T), std::numeric_limits<ElementT>::is_integer,
|
||||
std::numeric_limits<ElementT>::is_signed));
|
||||
const char *rawData = getRawData().data();
|
||||
bool splat = isSplat();
|
||||
return {ElementIterator<T>(rawData, splat, 0),
|
||||
return {Attribute::getType(), ElementIterator<T>(rawData, splat, 0),
|
||||
ElementIterator<T>(rawData, splat, getNumElements())};
|
||||
}
|
||||
template <typename T, typename ElementT = typename T::value_type,
|
||||
@ -441,11 +436,11 @@ public:
|
||||
using StringRefValueTemplateCheckT =
|
||||
typename std::enable_if<std::is_same<T, StringRef>::value>::type;
|
||||
template <typename T, typename = StringRefValueTemplateCheckT<T>>
|
||||
llvm::iterator_range<ElementIterator<StringRef>> getValues() const {
|
||||
iterator_range_impl<ElementIterator<StringRef>> getValues() const {
|
||||
auto stringRefs = getRawStringData();
|
||||
const char *ptr = reinterpret_cast<const char *>(stringRefs.data());
|
||||
bool splat = isSplat();
|
||||
return {ElementIterator<StringRef>(ptr, splat, 0),
|
||||
return {Attribute::getType(), ElementIterator<StringRef>(ptr, splat, 0),
|
||||
ElementIterator<StringRef>(ptr, splat, getNumElements())};
|
||||
}
|
||||
template <typename T, typename = StringRefValueTemplateCheckT<T>>
|
||||
@ -464,8 +459,9 @@ public:
|
||||
using AttributeValueTemplateCheckT =
|
||||
typename std::enable_if<std::is_same<T, Attribute>::value>::type;
|
||||
template <typename T, typename = AttributeValueTemplateCheckT<T>>
|
||||
llvm::iterator_range<AttributeElementIterator> getValues() const {
|
||||
return {value_begin<Attribute>(), value_end<Attribute>()};
|
||||
iterator_range_impl<AttributeElementIterator> getValues() const {
|
||||
return {Attribute::getType(), value_begin<Attribute>(),
|
||||
value_end<Attribute>()};
|
||||
}
|
||||
template <typename T, typename = AttributeValueTemplateCheckT<T>>
|
||||
AttributeElementIterator value_begin() const {
|
||||
@ -486,10 +482,11 @@ public:
|
||||
using DerivedAttributeElementIterator =
|
||||
llvm::mapped_iterator<AttributeElementIterator, T (*)(Attribute)>;
|
||||
template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
|
||||
llvm::iterator_range<DerivedAttributeElementIterator<T>> getValues() const {
|
||||
iterator_range_impl<DerivedAttributeElementIterator<T>> getValues() const {
|
||||
auto castFn = [](Attribute attr) { return attr.template cast<T>(); };
|
||||
return llvm::map_range(getValues<Attribute>(),
|
||||
static_cast<T (*)(Attribute)>(castFn));
|
||||
return {Attribute::getType(),
|
||||
llvm::map_range(getValues<Attribute>(),
|
||||
static_cast<T (*)(Attribute)>(castFn))};
|
||||
}
|
||||
template <typename T, typename = DerivedAttrValueTemplateCheckT<T>>
|
||||
DerivedAttributeElementIterator<T> value_begin() const {
|
||||
@ -508,9 +505,9 @@ public:
|
||||
using BoolValueTemplateCheckT =
|
||||
typename std::enable_if<std::is_same<T, bool>::value>::type;
|
||||
template <typename T, typename = BoolValueTemplateCheckT<T>>
|
||||
llvm::iterator_range<BoolElementIterator> getValues() const {
|
||||
iterator_range_impl<BoolElementIterator> getValues() const {
|
||||
assert(isValidBool() && "bool is not the value of this elements attribute");
|
||||
return {BoolElementIterator(*this, 0),
|
||||
return {Attribute::getType(), BoolElementIterator(*this, 0),
|
||||
BoolElementIterator(*this, getNumElements())};
|
||||
}
|
||||
template <typename T, typename = BoolValueTemplateCheckT<T>>
|
||||
@ -530,9 +527,9 @@ public:
|
||||
using APIntValueTemplateCheckT =
|
||||
typename std::enable_if<std::is_same<T, APInt>::value>::type;
|
||||
template <typename T, typename = APIntValueTemplateCheckT<T>>
|
||||
llvm::iterator_range<IntElementIterator> getValues() const {
|
||||
iterator_range_impl<IntElementIterator> getValues() const {
|
||||
assert(getElementType().isIntOrIndex() && "expected integral type");
|
||||
return {raw_int_begin(), raw_int_end()};
|
||||
return {Attribute::getType(), raw_int_begin(), raw_int_end()};
|
||||
}
|
||||
template <typename T, typename = APIntValueTemplateCheckT<T>>
|
||||
IntElementIterator value_begin() const {
|
||||
@ -551,7 +548,7 @@ public:
|
||||
using ComplexAPIntValueTemplateCheckT = typename std::enable_if<
|
||||
std::is_same<T, std::complex<APInt>>::value>::type;
|
||||
template <typename T, typename = ComplexAPIntValueTemplateCheckT<T>>
|
||||
llvm::iterator_range<ComplexIntElementIterator> getValues() const {
|
||||
iterator_range_impl<ComplexIntElementIterator> getValues() const {
|
||||
return getComplexIntValues();
|
||||
}
|
||||
template <typename T, typename = ComplexAPIntValueTemplateCheckT<T>>
|
||||
@ -569,7 +566,7 @@ public:
|
||||
using APFloatValueTemplateCheckT =
|
||||
typename std::enable_if<std::is_same<T, APFloat>::value>::type;
|
||||
template <typename T, typename = APFloatValueTemplateCheckT<T>>
|
||||
llvm::iterator_range<FloatElementIterator> getValues() const {
|
||||
iterator_range_impl<FloatElementIterator> getValues() const {
|
||||
return getFloatValues();
|
||||
}
|
||||
template <typename T, typename = APFloatValueTemplateCheckT<T>>
|
||||
@ -587,7 +584,7 @@ public:
|
||||
using ComplexAPFloatValueTemplateCheckT = typename std::enable_if<
|
||||
std::is_same<T, std::complex<APFloat>>::value>::type;
|
||||
template <typename T, typename = ComplexAPFloatValueTemplateCheckT<T>>
|
||||
llvm::iterator_range<ComplexFloatElementIterator> getValues() const {
|
||||
iterator_range_impl<ComplexFloatElementIterator> getValues() const {
|
||||
return getComplexFloatValues();
|
||||
}
|
||||
template <typename T, typename = ComplexAPFloatValueTemplateCheckT<T>>
|
||||
@ -660,13 +657,13 @@ protected:
|
||||
IntElementIterator raw_int_end() const {
|
||||
return IntElementIterator(*this, getNumElements());
|
||||
}
|
||||
llvm::iterator_range<ComplexIntElementIterator> getComplexIntValues() const;
|
||||
iterator_range_impl<ComplexIntElementIterator> getComplexIntValues() const;
|
||||
ComplexIntElementIterator complex_value_begin() const;
|
||||
ComplexIntElementIterator complex_value_end() const;
|
||||
llvm::iterator_range<FloatElementIterator> getFloatValues() const;
|
||||
iterator_range_impl<FloatElementIterator> getFloatValues() const;
|
||||
FloatElementIterator float_value_begin() const;
|
||||
FloatElementIterator float_value_end() const;
|
||||
llvm::iterator_range<ComplexFloatElementIterator>
|
||||
iterator_range_impl<ComplexFloatElementIterator>
|
||||
getComplexFloatValues() const;
|
||||
ComplexFloatElementIterator complex_float_value_begin() const;
|
||||
ComplexFloatElementIterator complex_float_value_end() const;
|
||||
@ -872,8 +869,7 @@ public:
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
template <typename T>
|
||||
auto SparseElementsAttr::getValues() const
|
||||
-> llvm::iterator_range<iterator<T>> {
|
||||
auto SparseElementsAttr::value_begin() const -> iterator<T> {
|
||||
auto zeroValue = getZeroValue<T>();
|
||||
auto valueIt = getValues().value_begin<T>();
|
||||
const std::vector<ptrdiff_t> flatSparseIndices(getFlattenedSparseIndices());
|
||||
@ -888,15 +884,7 @@ auto SparseElementsAttr::getValues() const
|
||||
// Otherwise, return the zero value.
|
||||
return zeroValue;
|
||||
};
|
||||
return llvm::map_range(llvm::seq<ptrdiff_t>(0, getNumElements()), mapFn);
|
||||
}
|
||||
template <typename T>
|
||||
auto SparseElementsAttr::value_begin() const -> iterator<T> {
|
||||
return getValues<T>().begin();
|
||||
}
|
||||
template <typename T>
|
||||
auto SparseElementsAttr::value_end() const -> iterator<T> {
|
||||
return getValues<T>().end();
|
||||
return iterator<T>(llvm::seq<ptrdiff_t>(0, getNumElements()).begin(), mapFn);
|
||||
}
|
||||
} // end namespace mlir.
|
||||
|
||||
|
@ -174,9 +174,7 @@ def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr<
|
||||
"ArrayRef<char>":$rawData);
|
||||
let extraClassDeclaration = [{
|
||||
using DenseElementsAttr::empty;
|
||||
using DenseElementsAttr::getFlatValue;
|
||||
using DenseElementsAttr::getNumElements;
|
||||
using DenseElementsAttr::getValue;
|
||||
using DenseElementsAttr::getValues;
|
||||
using DenseElementsAttr::isSplat;
|
||||
using DenseElementsAttr::size;
|
||||
@ -313,9 +311,7 @@ def Builtin_DenseStringElementsAttr : Builtin_Attr<
|
||||
];
|
||||
let extraClassDeclaration = [{
|
||||
using DenseElementsAttr::empty;
|
||||
using DenseElementsAttr::getFlatValue;
|
||||
using DenseElementsAttr::getNumElements;
|
||||
using DenseElementsAttr::getValue;
|
||||
using DenseElementsAttr::getValues;
|
||||
using DenseElementsAttr::isSplat;
|
||||
using DenseElementsAttr::size;
|
||||
@ -712,10 +708,6 @@ def Builtin_OpaqueElementsAttr : Builtin_Attr<
|
||||
let extraClassDeclaration = [{
|
||||
using ValueType = StringRef;
|
||||
|
||||
/// Return the value at the given index. The 'index' is expected to refer to
|
||||
/// a valid element.
|
||||
Attribute getValue(ArrayRef<uint64_t> index) const;
|
||||
|
||||
/// Decodes the attribute value using dialect-specific decoding hook.
|
||||
/// Returns false if decoding is successful. If not, returns true and leaves
|
||||
/// 'result' argument unspecified.
|
||||
@ -802,6 +794,7 @@ def Builtin_SparseElementsAttr : Builtin_Attr<
|
||||
// String types.
|
||||
StringRef
|
||||
>;
|
||||
using ElementsAttr::Trait<SparseElementsAttr>::getValues;
|
||||
|
||||
/// Provide a `value_begin_impl` to enable iteration within ElementsAttr.
|
||||
template <typename T>
|
||||
@ -817,13 +810,7 @@ def Builtin_SparseElementsAttr : Builtin_Attr<
|
||||
/// Return the values of this attribute in the form of the given type 'T'.
|
||||
/// 'T' may be any of Attribute, APInt, APFloat, c++ integer/float types,
|
||||
/// etc.
|
||||
template <typename T> llvm::iterator_range<iterator<T>> getValues() const;
|
||||
template <typename T> iterator<T> value_begin() const;
|
||||
template <typename T> iterator<T> value_end() const;
|
||||
|
||||
/// Return the value of the element at the given index. The 'index' is
|
||||
/// expected to refer to a valid element.
|
||||
Attribute getValue(ArrayRef<uint64_t> index) const;
|
||||
|
||||
private:
|
||||
/// Get a zero APFloat for the given sparse attribute.
|
||||
|
@ -288,8 +288,9 @@ bool mlirAttributeIsAElements(MlirAttribute attr) {
|
||||
|
||||
MlirAttribute mlirElementsAttrGetValue(MlirAttribute attr, intptr_t rank,
|
||||
uint64_t *idxs) {
|
||||
return wrap(unwrap(attr).cast<ElementsAttr>().getValue(
|
||||
llvm::makeArrayRef(idxs, rank)));
|
||||
return wrap(unwrap(attr)
|
||||
.cast<ElementsAttr>()
|
||||
.getValues<Attribute>()[llvm::makeArrayRef(idxs, rank)]);
|
||||
}
|
||||
|
||||
bool mlirElementsAttrIsValidIndex(MlirAttribute attr, intptr_t rank,
|
||||
@ -482,7 +483,8 @@ bool mlirDenseElementsAttrIsSplat(MlirAttribute attr) {
|
||||
}
|
||||
|
||||
MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr) {
|
||||
return wrap(unwrap(attr).cast<DenseElementsAttr>().getSplatValue());
|
||||
return wrap(
|
||||
unwrap(attr).cast<DenseElementsAttr>().getSplatValue<Attribute>());
|
||||
}
|
||||
int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr) {
|
||||
return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<bool>();
|
||||
@ -520,36 +522,36 @@ MlirStringRef mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr) {
|
||||
// Indexed accessors.
|
||||
|
||||
bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) {
|
||||
return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<bool>(pos);
|
||||
return unwrap(attr).cast<DenseElementsAttr>().getValues<bool>()[pos];
|
||||
}
|
||||
int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, intptr_t pos) {
|
||||
return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<int8_t>(pos);
|
||||
return unwrap(attr).cast<DenseElementsAttr>().getValues<int8_t>()[pos];
|
||||
}
|
||||
uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos) {
|
||||
return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<uint8_t>(pos);
|
||||
return unwrap(attr).cast<DenseElementsAttr>().getValues<uint8_t>()[pos];
|
||||
}
|
||||
int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) {
|
||||
return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<int32_t>(pos);
|
||||
return unwrap(attr).cast<DenseElementsAttr>().getValues<int32_t>()[pos];
|
||||
}
|
||||
uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos) {
|
||||
return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<uint32_t>(pos);
|
||||
return unwrap(attr).cast<DenseElementsAttr>().getValues<uint32_t>()[pos];
|
||||
}
|
||||
int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) {
|
||||
return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<int64_t>(pos);
|
||||
return unwrap(attr).cast<DenseElementsAttr>().getValues<int64_t>()[pos];
|
||||
}
|
||||
uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) {
|
||||
return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<uint64_t>(pos);
|
||||
return unwrap(attr).cast<DenseElementsAttr>().getValues<uint64_t>()[pos];
|
||||
}
|
||||
float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) {
|
||||
return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<float>(pos);
|
||||
return unwrap(attr).cast<DenseElementsAttr>().getValues<float>()[pos];
|
||||
}
|
||||
double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) {
|
||||
return unwrap(attr).cast<DenseElementsAttr>().getFlatValue<double>(pos);
|
||||
return unwrap(attr).cast<DenseElementsAttr>().getValues<double>()[pos];
|
||||
}
|
||||
MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr,
|
||||
intptr_t pos) {
|
||||
return wrap(
|
||||
unwrap(attr).cast<DenseElementsAttr>().getFlatValue<StringRef>(pos));
|
||||
unwrap(attr).cast<DenseElementsAttr>().getValues<StringRef>()[pos]);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -169,7 +169,7 @@ LogicalResult WorkGroupSizeConversion::matchAndRewrite(
|
||||
return failure();
|
||||
|
||||
auto workGroupSizeAttr = spirv::lookupLocalWorkGroupSize(op);
|
||||
auto val = workGroupSizeAttr.getValue<int32_t>(index.getValue());
|
||||
auto val = workGroupSizeAttr.getValues<int32_t>()[index.getValue()];
|
||||
auto convertedType =
|
||||
getTypeConverter()->convertType(op.getResult().getType());
|
||||
if (!convertedType)
|
||||
|
@ -451,7 +451,7 @@ struct GlobalMemrefOpLowering
|
||||
// For scalar memrefs, the global variable created is of the element type,
|
||||
// so unpack the elements attribute to extract the value.
|
||||
if (type.getRank() == 0)
|
||||
initialValue = elementsAttr.getValue({});
|
||||
initialValue = elementsAttr.getValues<Attribute>()[0];
|
||||
}
|
||||
|
||||
uint64_t alignment = global.alignment().getValueOr(0);
|
||||
|
@ -2415,8 +2415,7 @@ LogicalResult AffineStoreOp::fold(ArrayRef<Attribute> cstOperands,
|
||||
// AffineMinMaxOpBase
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
template <typename T>
|
||||
static LogicalResult verifyAffineMinMaxOp(T op) {
|
||||
template <typename T> static LogicalResult verifyAffineMinMaxOp(T op) {
|
||||
// Verify that operand count matches affine map dimension and symbol count.
|
||||
if (op.getNumOperands() != op.map().getNumDims() + op.map().getNumSymbols())
|
||||
return op.emitOpError(
|
||||
@ -2424,8 +2423,7 @@ static LogicalResult verifyAffineMinMaxOp(T op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void printAffineMinMaxOp(OpAsmPrinter &p, T op) {
|
||||
template <typename T> static void printAffineMinMaxOp(OpAsmPrinter &p, T op) {
|
||||
p << ' ' << op->getAttr(T::getMapAttrName());
|
||||
auto operands = op.getOperands();
|
||||
unsigned numDims = op.map().getNumDims();
|
||||
@ -2532,8 +2530,7 @@ struct DeduplicateAffineMinMaxExpressions : public OpRewritePattern<T> {
|
||||
///
|
||||
/// %1 = affine.min affine_map<
|
||||
/// ()[s0, s1] -> (s0 + 4, s1 + 16, s1 * 8)> ()[%sym2, %sym1]
|
||||
template <typename T>
|
||||
struct MergeAffineMinMaxOp : public OpRewritePattern<T> {
|
||||
template <typename T> struct MergeAffineMinMaxOp : public OpRewritePattern<T> {
|
||||
using OpRewritePattern<T>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(T affineOp,
|
||||
@ -2890,19 +2887,19 @@ AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
|
||||
}
|
||||
|
||||
AffineMap AffineParallelOp::getLowerBoundMap(unsigned pos) {
|
||||
auto values = lowerBoundsGroups().getValues<int32_t>();
|
||||
unsigned start = 0;
|
||||
for (unsigned i = 0; i < pos; ++i)
|
||||
start += lowerBoundsGroups().getValue<int32_t>(i);
|
||||
return lowerBoundsMap().getSliceMap(
|
||||
start, lowerBoundsGroups().getValue<int32_t>(pos));
|
||||
start += values[i];
|
||||
return lowerBoundsMap().getSliceMap(start, values[pos]);
|
||||
}
|
||||
|
||||
AffineMap AffineParallelOp::getUpperBoundMap(unsigned pos) {
|
||||
auto values = upperBoundsGroups().getValues<int32_t>();
|
||||
unsigned start = 0;
|
||||
for (unsigned i = 0; i < pos; ++i)
|
||||
start += upperBoundsGroups().getValue<int32_t>(i);
|
||||
return upperBoundsMap().getSliceMap(
|
||||
start, upperBoundsGroups().getValue<int32_t>(pos));
|
||||
start += values[i];
|
||||
return upperBoundsMap().getSliceMap(start, values[pos]);
|
||||
}
|
||||
|
||||
AffineValueMap AffineParallelOp::getLowerBoundsValueMap() {
|
||||
|
@ -163,8 +163,8 @@ constexpr const static unsigned kBitsInByte = 8;
|
||||
/// Returns the value that corresponds to named position `pos` from the
|
||||
/// attribute `attr` assuming it's a dense integer elements attribute.
|
||||
static unsigned extractPointerSpecValue(Attribute attr, DLEntryPos pos) {
|
||||
return attr.cast<DenseIntElementsAttr>().getValue<unsigned>(
|
||||
static_cast<unsigned>(pos));
|
||||
return attr.cast<DenseIntElementsAttr>()
|
||||
.getValues<unsigned>()[static_cast<unsigned>(pos)];
|
||||
}
|
||||
|
||||
/// Returns the part of the data layout entry that corresponds to `pos` for the
|
||||
|
@ -1184,7 +1184,7 @@ public:
|
||||
if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
|
||||
splatAttr.isSplat() &&
|
||||
splatAttr.getType().getElementType().isIntOrFloat()) {
|
||||
constantAttr = splatAttr.getSplatValue();
|
||||
constantAttr = splatAttr.getSplatValue<Attribute>();
|
||||
return true;
|
||||
}
|
||||
}
|
||||
@ -1455,10 +1455,9 @@ public:
|
||||
|
||||
bool isFloat = elementType.isa<FloatType>();
|
||||
if (isFloat) {
|
||||
SmallVector<iterator_range<DenseElementsAttr::FloatElementIterator>>
|
||||
inputFpIterators;
|
||||
SmallVector<DenseElementsAttr::iterator_range<APFloat>> inFpRanges;
|
||||
for (int i = 0; i < numInputs; ++i)
|
||||
inputFpIterators.push_back(inputValues[i].getValues<APFloat>());
|
||||
inFpRanges.push_back(inputValues[i].getValues<APFloat>());
|
||||
|
||||
computeFnInputs.apFloats.resize(numInputs, APFloat(0.f));
|
||||
|
||||
@ -1469,22 +1468,17 @@ public:
|
||||
computeRemappedLinearIndex(linearIndex);
|
||||
|
||||
// Collect constant elements for all inputs at this loop iteration.
|
||||
for (int i = 0; i < numInputs; ++i) {
|
||||
computeFnInputs.apFloats[i] =
|
||||
*(inputFpIterators[i].begin() + srcLinearIndices[i]);
|
||||
}
|
||||
for (int i = 0; i < numInputs; ++i)
|
||||
computeFnInputs.apFloats[i] = inFpRanges[i][srcLinearIndices[i]];
|
||||
|
||||
// Invoke the computation to get the corresponding constant output
|
||||
// element.
|
||||
APIntOrFloat outputs = computeFn(computeFnInputs);
|
||||
|
||||
fpOutputValues[dstLinearIndex] = outputs.apFloat.getValue();
|
||||
fpOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apFloat;
|
||||
}
|
||||
} else {
|
||||
SmallVector<iterator_range<DenseElementsAttr::IntElementIterator>>
|
||||
inputIntIterators;
|
||||
SmallVector<DenseElementsAttr::iterator_range<APInt>> inIntRanges;
|
||||
for (int i = 0; i < numInputs; ++i)
|
||||
inputIntIterators.push_back(inputValues[i].getValues<APInt>());
|
||||
inIntRanges.push_back(inputValues[i].getValues<APInt>());
|
||||
|
||||
computeFnInputs.apInts.resize(numInputs);
|
||||
|
||||
@ -1495,25 +1489,19 @@ public:
|
||||
computeRemappedLinearIndex(linearIndex);
|
||||
|
||||
// Collect constant elements for all inputs at this loop iteration.
|
||||
for (int i = 0; i < numInputs; ++i) {
|
||||
computeFnInputs.apInts[i] =
|
||||
*(inputIntIterators[i].begin() + srcLinearIndices[i]);
|
||||
}
|
||||
for (int i = 0; i < numInputs; ++i)
|
||||
computeFnInputs.apInts[i] = inIntRanges[i][srcLinearIndices[i]];
|
||||
|
||||
// Invoke the computation to get the corresponding constant output
|
||||
// element.
|
||||
APIntOrFloat outputs = computeFn(computeFnInputs);
|
||||
|
||||
intOutputValues[dstLinearIndex] = outputs.apInt.getValue();
|
||||
intOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apInt;
|
||||
}
|
||||
}
|
||||
|
||||
DenseIntOrFPElementsAttr outputAttr;
|
||||
if (isFloat) {
|
||||
outputAttr = DenseFPElementsAttr::get(outputType, fpOutputValues);
|
||||
} else {
|
||||
outputAttr = DenseIntElementsAttr::get(outputType, intOutputValues);
|
||||
}
|
||||
DenseElementsAttr outputAttr =
|
||||
isFloat ? DenseElementsAttr::get(outputType, fpOutputValues)
|
||||
: DenseElementsAttr::get(outputType, intOutputValues);
|
||||
|
||||
rewriter.replaceOpWithNewOp<ConstantOp>(genericOp, outputAttr);
|
||||
return success();
|
||||
}
|
||||
|
@ -913,9 +913,9 @@ struct DownscaleSizeOneWindowed2DConvolution final
|
||||
loc, newOutputType, output, ioReshapeIndices);
|
||||
|
||||
// We need to shrink the strides and dilations too.
|
||||
auto stride = convOp.strides().getFlatValue<int64_t>(removeH ? 1 : 0);
|
||||
auto stride = convOp.strides().getValues<int64_t>()[removeH ? 1 : 0];
|
||||
auto stridesAttr = rewriter.getI64VectorAttr(stride);
|
||||
auto dilation = convOp.dilations().getFlatValue<int64_t>(removeH ? 1 : 0);
|
||||
auto dilation = convOp.dilations().getValues<int64_t>()[removeH ? 1 : 0];
|
||||
auto dilationsAttr = rewriter.getI64VectorAttr(dilation);
|
||||
|
||||
auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>(
|
||||
|
@ -56,7 +56,7 @@ static Attribute extractCompositeElement(Attribute composite,
|
||||
|
||||
if (auto vector = composite.dyn_cast<ElementsAttr>()) {
|
||||
assert(indices.size() == 1 && "must have exactly one index for a vector");
|
||||
return vector.getValue({indices[0]});
|
||||
return vector.getValues<Attribute>()[indices[0]];
|
||||
}
|
||||
|
||||
if (auto array = composite.dyn_cast<ArrayAttr>()) {
|
||||
|
@ -1138,7 +1138,7 @@ OpFoldResult GetExtentOp::fold(ArrayRef<Attribute> operands) {
|
||||
return nullptr;
|
||||
if (dim.getValue() >= elements.getNumElements())
|
||||
return nullptr;
|
||||
return elements.getValue({(uint64_t)dim.getValue()});
|
||||
return elements.getValues<Attribute>()[(uint64_t)dim.getValue()];
|
||||
}
|
||||
|
||||
void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
|
||||
|
@ -1304,13 +1304,14 @@ static void printSwitchOpCases(
|
||||
if (!caseValues)
|
||||
return;
|
||||
|
||||
for (int64_t i = 0, size = caseValues.size(); i < size; ++i) {
|
||||
for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) {
|
||||
p << ',';
|
||||
p.printNewline();
|
||||
p << " ";
|
||||
p << caseValues.getValue<APInt>(i).getLimitedValue();
|
||||
p << it.value().getLimitedValue();
|
||||
p << ": ";
|
||||
p.printSuccessorAndUseList(caseDestinations[i], caseOperands[i]);
|
||||
p.printSuccessorAndUseList(caseDestinations[it.index()],
|
||||
caseOperands[it.index()]);
|
||||
}
|
||||
p.printNewline();
|
||||
}
|
||||
@ -1353,9 +1354,9 @@ Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
|
||||
|
||||
SuccessorRange caseDests = getCaseDestinations();
|
||||
if (auto value = operands.front().dyn_cast_or_null<IntegerAttr>()) {
|
||||
for (int64_t i = 0, size = getCaseValues()->size(); i < size; ++i)
|
||||
if (value == caseValues->getValue<IntegerAttr>(i))
|
||||
return caseDests[i];
|
||||
for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>()))
|
||||
if (it.value() == value.getValue())
|
||||
return caseDests[it.index()];
|
||||
return getDefaultDestination();
|
||||
}
|
||||
return nullptr;
|
||||
@ -1394,15 +1395,15 @@ dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) {
|
||||
auto caseValues = op.getCaseValues();
|
||||
auto caseDests = op.getCaseDestinations();
|
||||
|
||||
for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
|
||||
if (caseDests[i] == op.getDefaultDestination() &&
|
||||
op.getCaseOperands(i) == op.getDefaultOperands()) {
|
||||
for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
|
||||
if (caseDests[it.index()] == op.getDefaultDestination() &&
|
||||
op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
|
||||
requiresChange = true;
|
||||
continue;
|
||||
}
|
||||
newCaseDestinations.push_back(caseDests[i]);
|
||||
newCaseOperands.push_back(op.getCaseOperands(i));
|
||||
newCaseValues.push_back(caseValues->getValue<APInt>(i));
|
||||
newCaseDestinations.push_back(caseDests[it.index()]);
|
||||
newCaseOperands.push_back(op.getCaseOperands(it.index()));
|
||||
newCaseValues.push_back(it.value());
|
||||
}
|
||||
|
||||
if (!requiresChange)
|
||||
@ -1424,10 +1425,11 @@ dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) {
|
||||
static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
|
||||
APInt caseValue) {
|
||||
auto caseValues = op.getCaseValues();
|
||||
for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
|
||||
if (caseValues->getValue<APInt>(i) == caseValue) {
|
||||
rewriter.replaceOpWithNewOp<BranchOp>(op, op.getCaseDestinations()[i],
|
||||
op.getCaseOperands(i));
|
||||
for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
|
||||
if (it.value() == caseValue) {
|
||||
rewriter.replaceOpWithNewOp<BranchOp>(
|
||||
op, op.getCaseDestinations()[it.index()],
|
||||
op.getCaseOperands(it.index()));
|
||||
return;
|
||||
}
|
||||
}
|
||||
@ -1551,22 +1553,16 @@ simplifySwitchFromSwitchOnSameCondition(SwitchOp op,
|
||||
return failure();
|
||||
|
||||
// Fold this switch to an unconditional branch.
|
||||
APInt caseValue;
|
||||
bool isDefault = true;
|
||||
SuccessorRange predDests = predSwitch.getCaseDestinations();
|
||||
Optional<DenseIntElementsAttr> predCaseValues = predSwitch.getCaseValues();
|
||||
for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i) {
|
||||
if (currentBlock == predDests[i]) {
|
||||
caseValue = predCaseValues->getValue<APInt>(i);
|
||||
isDefault = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (isDefault)
|
||||
auto it = llvm::find(predDests, currentBlock);
|
||||
if (it != predDests.end()) {
|
||||
Optional<DenseIntElementsAttr> predCaseValues = predSwitch.getCaseValues();
|
||||
foldSwitch(op, rewriter,
|
||||
predCaseValues->getValues<APInt>()[it - predDests.begin()]);
|
||||
} else {
|
||||
rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
|
||||
op.getDefaultOperands());
|
||||
else
|
||||
foldSwitch(op, rewriter, caseValue);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
@ -1613,7 +1609,7 @@ simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,
|
||||
auto predCaseValues = predSwitch.getCaseValues();
|
||||
for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
|
||||
if (currentBlock != predDests[i])
|
||||
caseValuesToRemove.insert(predCaseValues->getValue<APInt>(i));
|
||||
caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
|
||||
|
||||
SmallVector<Block *> newCaseDestinations;
|
||||
SmallVector<ValueRange> newCaseOperands;
|
||||
@ -1622,14 +1618,14 @@ simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,
|
||||
|
||||
auto caseValues = op.getCaseValues();
|
||||
auto caseDests = op.getCaseDestinations();
|
||||
for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
|
||||
if (caseValuesToRemove.contains(caseValues->getValue<APInt>(i))) {
|
||||
for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
|
||||
if (caseValuesToRemove.contains(it.value())) {
|
||||
requiresChange = true;
|
||||
continue;
|
||||
}
|
||||
newCaseDestinations.push_back(caseDests[i]);
|
||||
newCaseOperands.push_back(op.getCaseOperands(i));
|
||||
newCaseValues.push_back(caseValues->getValue<APInt>(i));
|
||||
newCaseDestinations.push_back(caseDests[it.index()]);
|
||||
newCaseOperands.push_back(op.getCaseOperands(it.index()));
|
||||
newCaseValues.push_back(it.value());
|
||||
}
|
||||
|
||||
if (!requiresChange)
|
||||
|
@ -340,7 +340,7 @@ OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
|
||||
// If this is a splat elements attribute, simply return the value. All of the
|
||||
// elements of a splat attribute are the same.
|
||||
if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
|
||||
return splatTensor.getSplatValue();
|
||||
return splatTensor.getSplatValue<Attribute>();
|
||||
|
||||
// Otherwise, collect the constant indices into the tensor.
|
||||
SmallVector<uint64_t, 8> indices;
|
||||
@ -353,7 +353,7 @@ OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
|
||||
// If this is an elements attribute, query the value at the given indices.
|
||||
auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
|
||||
if (elementsAttr && elementsAttr.isValidIndex(indices))
|
||||
return elementsAttr.getValue(indices);
|
||||
return elementsAttr.getValues<Attribute>()[indices];
|
||||
return {};
|
||||
}
|
||||
|
||||
@ -440,7 +440,7 @@ OpFoldResult InsertOp::fold(ArrayRef<Attribute> operands) {
|
||||
Attribute dest = operands[1];
|
||||
if (scalar && dest)
|
||||
if (auto splatDest = dest.dyn_cast<SplatElementsAttr>())
|
||||
if (scalar == splatDest.getSplatValue())
|
||||
if (scalar == splatDest.getSplatValue<Attribute>())
|
||||
return dest;
|
||||
return {};
|
||||
}
|
||||
|
@ -230,6 +230,7 @@ struct ConstantTransposeOptimization
|
||||
// Transpose the input constant. Because we don't know its rank in advance,
|
||||
// we need to loop over the range [0, element count) and delinearize the
|
||||
// index.
|
||||
auto attrValues = inputValues.getValues<Attribute>();
|
||||
for (int srcLinearIndex = 0; srcLinearIndex < numElements;
|
||||
++srcLinearIndex) {
|
||||
SmallVector<uint64_t, 6> srcIndices(inputType.getRank(), 0);
|
||||
@ -247,7 +248,7 @@ struct ConstantTransposeOptimization
|
||||
for (int dim = 1; dim < outputType.getRank(); ++dim)
|
||||
dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
|
||||
|
||||
outputValues[dstLinearIndex] = inputValues.getValue(srcIndices);
|
||||
outputValues[dstLinearIndex] = attrValues[srcIndices];
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<tosa::ConstOp>(
|
||||
@ -424,8 +425,7 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
|
||||
// TOSA Operator Verifiers.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
template <typename T>
|
||||
static LogicalResult verifyConvOp(T op) {
|
||||
template <typename T> static LogicalResult verifyConvOp(T op) {
|
||||
// All TOSA conv ops have an input() and weight().
|
||||
auto inputType = op.input().getType().template dyn_cast<RankedTensorType>();
|
||||
auto weightType = op.weight().getType().template dyn_cast<RankedTensorType>();
|
||||
|
@ -56,15 +56,15 @@ bool ElementsAttr::isValidIndex(Attribute elementsAttr,
|
||||
return isValidIndex(elementsAttr.getType().cast<ShapedType>(), index);
|
||||
}
|
||||
|
||||
uint64_t ElementsAttr::getFlattenedIndex(Attribute elementsAttr,
|
||||
ArrayRef<uint64_t> index) {
|
||||
ShapedType type = elementsAttr.getType().cast<ShapedType>();
|
||||
assert(isValidIndex(type, index) && "expected valid multi-dimensional index");
|
||||
uint64_t ElementsAttr::getFlattenedIndex(Type type, ArrayRef<uint64_t> index) {
|
||||
ShapedType shapeType = type.cast<ShapedType>();
|
||||
assert(isValidIndex(shapeType, index) &&
|
||||
"expected valid multi-dimensional index");
|
||||
|
||||
// Reduce the provided multidimensional index into a flattended 1D row-major
|
||||
// index.
|
||||
auto rank = type.getRank();
|
||||
auto shape = type.getShape();
|
||||
auto rank = shapeType.getRank();
|
||||
ArrayRef<int64_t> shape = shapeType.getShape();
|
||||
uint64_t valueIndex = 0;
|
||||
uint64_t dimMultiplier = 1;
|
||||
for (int i = rank - 1; i >= 0; --i) {
|
||||
|
@ -902,10 +902,10 @@ LLVM_ATTRIBUTE_UNUSED static bool isComplexOfIntType(Type type) {
|
||||
}
|
||||
|
||||
auto DenseElementsAttr::getComplexIntValues() const
|
||||
-> llvm::iterator_range<ComplexIntElementIterator> {
|
||||
-> iterator_range_impl<ComplexIntElementIterator> {
|
||||
assert(isComplexOfIntType(getElementType()) &&
|
||||
"expected complex integral type");
|
||||
return {ComplexIntElementIterator(*this, 0),
|
||||
return {getType(), ComplexIntElementIterator(*this, 0),
|
||||
ComplexIntElementIterator(*this, getNumElements())};
|
||||
}
|
||||
auto DenseElementsAttr::complex_value_begin() const
|
||||
@ -923,10 +923,10 @@ auto DenseElementsAttr::complex_value_end() const -> ComplexIntElementIterator {
|
||||
/// Return the held element values as a range of APFloat. The element type of
|
||||
/// this attribute must be of float type.
|
||||
auto DenseElementsAttr::getFloatValues() const
|
||||
-> llvm::iterator_range<FloatElementIterator> {
|
||||
-> iterator_range_impl<FloatElementIterator> {
|
||||
auto elementType = getElementType().cast<FloatType>();
|
||||
const auto &elementSemantics = elementType.getFloatSemantics();
|
||||
return {FloatElementIterator(elementSemantics, raw_int_begin()),
|
||||
return {getType(), FloatElementIterator(elementSemantics, raw_int_begin()),
|
||||
FloatElementIterator(elementSemantics, raw_int_end())};
|
||||
}
|
||||
auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
|
||||
@ -939,11 +939,12 @@ auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
|
||||
}
|
||||
|
||||
auto DenseElementsAttr::getComplexFloatValues() const
|
||||
-> llvm::iterator_range<ComplexFloatElementIterator> {
|
||||
-> iterator_range_impl<ComplexFloatElementIterator> {
|
||||
Type eltTy = getElementType().cast<ComplexType>().getElementType();
|
||||
assert(eltTy.isa<FloatType>() && "expected complex float type");
|
||||
const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics();
|
||||
return {{semantics, {*this, 0}},
|
||||
return {getType(),
|
||||
{semantics, {*this, 0}},
|
||||
{semantics, {*this, static_cast<size_t>(getNumElements())}}};
|
||||
}
|
||||
auto DenseElementsAttr::complex_float_value_begin() const
|
||||
@ -1248,13 +1249,6 @@ bool DenseIntElementsAttr::classof(Attribute attr) {
|
||||
// OpaqueElementsAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Return the value at the given index. If index does not refer to a valid
|
||||
/// element, then a null attribute is returned.
|
||||
Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const {
|
||||
assert(isValidIndex(index) && "expected valid multi-dimensional index");
|
||||
return Attribute();
|
||||
}
|
||||
|
||||
bool OpaqueElementsAttr::decode(ElementsAttr &result) {
|
||||
Dialect *dialect = getDialect().getDialect();
|
||||
if (!dialect)
|
||||
@ -1279,47 +1273,6 @@ OpaqueElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
|
||||
// SparseElementsAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Return the value of the element at the given index.
|
||||
Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
|
||||
assert(isValidIndex(index) && "expected valid multi-dimensional index");
|
||||
auto type = getType();
|
||||
|
||||
// The sparse indices are 64-bit integers, so we can reinterpret the raw data
|
||||
// as a 1-D index array.
|
||||
auto sparseIndices = getIndices();
|
||||
auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
|
||||
|
||||
// Check to see if the indices are a splat.
|
||||
if (sparseIndices.isSplat()) {
|
||||
// If the index is also not a splat of the index value, we know that the
|
||||
// value is zero.
|
||||
auto splatIndex = *sparseIndexValues.begin();
|
||||
if (llvm::any_of(index, [=](uint64_t i) { return i != splatIndex; }))
|
||||
return getZeroAttr();
|
||||
|
||||
// If the indices are a splat, we also expect the values to be a splat.
|
||||
assert(getValues().isSplat() && "expected splat values");
|
||||
return getValues().getSplatValue();
|
||||
}
|
||||
|
||||
// Build a mapping between known indices and the offset of the stored element.
|
||||
llvm::SmallDenseMap<llvm::ArrayRef<uint64_t>, size_t> mappedIndices;
|
||||
auto numSparseIndices = sparseIndices.getType().getDimSize(0);
|
||||
size_t rank = type.getRank();
|
||||
for (size_t i = 0, e = numSparseIndices; i != e; ++i)
|
||||
mappedIndices.try_emplace(
|
||||
{&*std::next(sparseIndexValues.begin(), i * rank), rank}, i);
|
||||
|
||||
// Look for the provided index key within the mapped indices. If the provided
|
||||
// index is not found, then return a zero attribute.
|
||||
auto it = mappedIndices.find(index);
|
||||
if (it == mappedIndices.end())
|
||||
return getZeroAttr();
|
||||
|
||||
// Otherwise, return the held sparse value element.
|
||||
return getValues().getValue(it->second);
|
||||
}
|
||||
|
||||
/// Get a zero APFloat for the given sparse attribute.
|
||||
APFloat SparseElementsAttr::getZeroAPFloat() const {
|
||||
auto eltType = getElementType().cast<FloatType>();
|
||||
|
@ -71,7 +71,7 @@ int64_t ShapeAdaptor::getDimSize(int index) const {
|
||||
return t.cast<ShapedType>().getDimSize(index);
|
||||
if (auto attr = val.dyn_cast<Attribute>())
|
||||
return attr.cast<DenseIntElementsAttr>()
|
||||
.getFlatValue<APInt>(index)
|
||||
.getValues<APInt>()[index]
|
||||
.getSExtValue();
|
||||
auto *stc = val.get<ShapedTypeComponents *>();
|
||||
return stc->getDims()[index];
|
||||
|
@ -386,14 +386,12 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
|
||||
return success();
|
||||
}
|
||||
if (auto condbrOp = dyn_cast<LLVM::CondBrOp>(opInst)) {
|
||||
auto weights = condbrOp.getBranchWeights();
|
||||
llvm::MDNode *branchWeights = nullptr;
|
||||
if (weights) {
|
||||
if (auto weights = condbrOp.getBranchWeights()) {
|
||||
// Map weight attributes to LLVM metadata.
|
||||
auto trueWeight =
|
||||
weights.getValue().getValue(0).cast<IntegerAttr>().getInt();
|
||||
auto falseWeight =
|
||||
weights.getValue().getValue(1).cast<IntegerAttr>().getInt();
|
||||
auto weightValues = weights->getValues<APInt>();
|
||||
auto trueWeight = weightValues[0].getSExtValue();
|
||||
auto falseWeight = weightValues[1].getSExtValue();
|
||||
branchWeights =
|
||||
llvm::MDBuilder(moduleTranslation.getLLVMContext())
|
||||
.createBranchWeights(static_cast<uint32_t>(trueWeight),
|
||||
|
@ -706,11 +706,12 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
|
||||
if (shapedType.getRank() == dim) {
|
||||
if (auto attr = valueAttr.dyn_cast<DenseIntElementsAttr>()) {
|
||||
return attr.getType().getElementType().isInteger(1)
|
||||
? prepareConstantBool(loc, attr.getValue<BoolAttr>(index))
|
||||
: prepareConstantInt(loc, attr.getValue<IntegerAttr>(index));
|
||||
? prepareConstantBool(loc, attr.getValues<BoolAttr>()[index])
|
||||
: prepareConstantInt(loc,
|
||||
attr.getValues<IntegerAttr>()[index]);
|
||||
}
|
||||
if (auto attr = valueAttr.dyn_cast<DenseFPElementsAttr>()) {
|
||||
return prepareConstantFp(loc, attr.getValue<FloatAttr>(index));
|
||||
return prepareConstantFp(loc, attr.getValues<FloatAttr>()[index]);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
@ -154,9 +154,9 @@ structured_op: !LinalgStructuredOpConfig
|
||||
# ODS-NEXT: LogicalResult verifyIndexingMapRequiredAttributes();
|
||||
|
||||
# IMPL: getSymbolBindings(Test2Op self)
|
||||
# IMPL: cst2 = self.strides().getValue<int64_t>({ 0 });
|
||||
# IMPL: cst2 = self.strides().getValues<int64_t>()[0];
|
||||
# IMPL-NEXT: getAffineConstantExpr(cst2, context)
|
||||
# IMPL: cst3 = self.strides().getValue<int64_t>({ 1 });
|
||||
# IMPL: cst3 = self.strides().getValues<int64_t>()[1];
|
||||
# IMPL-NEXT: getAffineConstantExpr(cst3, context)
|
||||
|
||||
# IMPL: Test2Op::indexing_maps()
|
||||
|
@ -142,8 +142,7 @@ namespace yaml {
|
||||
/// Top-level type containing op metadata and one of a concrete op type.
|
||||
/// Currently, the only defined op type is `structured_op` (maps to
|
||||
/// `LinalgStructuredOpConfig`).
|
||||
template <>
|
||||
struct MappingTraits<LinalgOpConfig> {
|
||||
template <> struct MappingTraits<LinalgOpConfig> {
|
||||
static void mapping(IO &io, LinalgOpConfig &info) {
|
||||
io.mapOptional("metadata", info.metadata);
|
||||
io.mapOptional("structured_op", info.structuredOp);
|
||||
@ -156,8 +155,7 @@ struct MappingTraits<LinalgOpConfig> {
|
||||
/// - List of indexing maps (see `LinalgIndexingMaps`).
|
||||
/// - Iterator types (see `LinalgIteratorTypeDef`).
|
||||
/// - List of scalar level assignment (see `ScalarAssign`).
|
||||
template <>
|
||||
struct MappingTraits<LinalgStructuredOpConfig> {
|
||||
template <> struct MappingTraits<LinalgStructuredOpConfig> {
|
||||
static void mapping(IO &io, LinalgStructuredOpConfig &info) {
|
||||
io.mapRequired("args", info.args);
|
||||
io.mapRequired("indexing_maps", info.indexingMaps);
|
||||
@ -180,8 +178,7 @@ struct MappingTraits<LinalgStructuredOpConfig> {
|
||||
/// attribute symbols. During op creation these symbols are replaced by the
|
||||
/// corresponding `name` attribute values. Only attribute arguments have
|
||||
/// an `attribute_map`.
|
||||
template <>
|
||||
struct MappingTraits<LinalgOperandDef> {
|
||||
template <> struct MappingTraits<LinalgOperandDef> {
|
||||
static void mapping(IO &io, LinalgOperandDef &info) {
|
||||
io.mapRequired("name", info.name);
|
||||
io.mapRequired("usage", info.usage);
|
||||
@ -192,8 +189,7 @@ struct MappingTraits<LinalgOperandDef> {
|
||||
};
|
||||
|
||||
/// Usage enum for a named argument.
|
||||
template <>
|
||||
struct ScalarEnumerationTraits<LinalgOperandDefUsage> {
|
||||
template <> struct ScalarEnumerationTraits<LinalgOperandDefUsage> {
|
||||
static void enumeration(IO &io, LinalgOperandDefUsage &value) {
|
||||
io.enumCase(value, "InputOperand", LinalgOperandDefUsage::input);
|
||||
io.enumCase(value, "OutputOperand", LinalgOperandDefUsage::output);
|
||||
@ -202,8 +198,7 @@ struct ScalarEnumerationTraits<LinalgOperandDefUsage> {
|
||||
};
|
||||
|
||||
/// Iterator type enum.
|
||||
template <>
|
||||
struct ScalarEnumerationTraits<LinalgIteratorTypeDef> {
|
||||
template <> struct ScalarEnumerationTraits<LinalgIteratorTypeDef> {
|
||||
static void enumeration(IO &io, LinalgIteratorTypeDef &value) {
|
||||
io.enumCase(value, "parallel", LinalgIteratorTypeDef::parallel);
|
||||
io.enumCase(value, "reduction", LinalgIteratorTypeDef::reduction);
|
||||
@ -211,8 +206,7 @@ struct ScalarEnumerationTraits<LinalgIteratorTypeDef> {
|
||||
};
|
||||
|
||||
/// Metadata about the op (name, C++ name, and documentation).
|
||||
template <>
|
||||
struct MappingTraits<LinalgOpMetadata> {
|
||||
template <> struct MappingTraits<LinalgOpMetadata> {
|
||||
static void mapping(IO &io, LinalgOpMetadata &info) {
|
||||
io.mapRequired("name", info.name);
|
||||
io.mapRequired("cpp_class_name", info.cppClassName);
|
||||
@ -226,8 +220,7 @@ struct MappingTraits<LinalgOpMetadata> {
|
||||
/// some symbols that bind to attributes of the op. Each indexing map must
|
||||
/// be normalized over the same list of dimensions, and its symbols must
|
||||
/// match the symbols for argument shapes.
|
||||
template <>
|
||||
struct MappingTraits<LinalgIndexingMapsConfig> {
|
||||
template <> struct MappingTraits<LinalgIndexingMapsConfig> {
|
||||
static void mapping(IO &io, LinalgIndexingMapsConfig &info) {
|
||||
io.mapOptional("static_indexing_maps", info.staticIndexingMaps);
|
||||
}
|
||||
@ -237,8 +230,7 @@ struct MappingTraits<LinalgIndexingMapsConfig> {
|
||||
/// - The `arg` name must match a named output.
|
||||
/// - The `value` is a scalar expression for computing the value to
|
||||
/// assign (see `ScalarExpression`).
|
||||
template <>
|
||||
struct MappingTraits<ScalarAssign> {
|
||||
template <> struct MappingTraits<ScalarAssign> {
|
||||
static void mapping(IO &io, ScalarAssign &info) {
|
||||
io.mapRequired("arg", info.arg);
|
||||
io.mapRequired("value", info.value);
|
||||
@ -250,8 +242,7 @@ struct MappingTraits<ScalarAssign> {
|
||||
/// - `scalar_apply`: Result of evaluating a named function (see
|
||||
/// `ScalarApply`).
|
||||
/// - `symbolic_cast`: Cast to a symbolic TypeVar bound elsewhere.
|
||||
template <>
|
||||
struct MappingTraits<ScalarExpression> {
|
||||
template <> struct MappingTraits<ScalarExpression> {
|
||||
static void mapping(IO &io, ScalarExpression &info) {
|
||||
io.mapOptional("scalar_arg", info.arg);
|
||||
io.mapOptional("scalar_const", info.constant);
|
||||
@ -266,16 +257,14 @@ struct MappingTraits<ScalarExpression> {
|
||||
/// functions include:
|
||||
/// - `add(lhs, rhs)`
|
||||
/// - `mul(lhs, rhs)`
|
||||
template <>
|
||||
struct MappingTraits<ScalarApply> {
|
||||
template <> struct MappingTraits<ScalarApply> {
|
||||
static void mapping(IO &io, ScalarApply &info) {
|
||||
io.mapRequired("fn_name", info.fnName);
|
||||
io.mapRequired("operands", info.operands);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MappingTraits<ScalarSymbolicCast> {
|
||||
template <> struct MappingTraits<ScalarSymbolicCast> {
|
||||
static void mapping(IO &io, ScalarSymbolicCast &info) {
|
||||
io.mapRequired("type_var", info.typeVar);
|
||||
io.mapRequired("operands", info.operands);
|
||||
@ -285,8 +274,7 @@ struct MappingTraits<ScalarSymbolicCast> {
|
||||
|
||||
/// Helper mapping which accesses an AffineMapAttr as a serialized string of
|
||||
/// the same.
|
||||
template <>
|
||||
struct ScalarTraits<SerializedAffineMap> {
|
||||
template <> struct ScalarTraits<SerializedAffineMap> {
|
||||
static void output(const SerializedAffineMap &value, void *rawYamlContext,
|
||||
raw_ostream &out) {
|
||||
assert(value.affineMapAttr);
|
||||
@ -726,7 +714,7 @@ static SmallVector<AffineExpr> getSymbolBindings({0} self) {
|
||||
// {1}: Symbol position
|
||||
// {2}: Attribute index
|
||||
static const char structuredOpAccessAttrFormat[] = R"FMT(
|
||||
int64_t cst{1} = self.{0}().getValue<int64_t>({ {2} });
|
||||
int64_t cst{1} = self.{0}().getValues<int64_t>()[{2}];
|
||||
exprs.push_back(getAffineConstantExpr(cst{1}, context));
|
||||
)FMT";
|
||||
// Update all symbol bindings mapped to an attribute.
|
||||
|
@ -113,7 +113,8 @@ TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) {
|
||||
EXPECT_TRUE(returnedValue.isa<DenseIntElementsAttr>());
|
||||
|
||||
// Check Elements attribute element value is expected.
|
||||
auto firstValue = returnedValue.cast<ElementsAttr>().getValue({0, 0});
|
||||
auto firstValue =
|
||||
returnedValue.cast<ElementsAttr>().getValues<Attribute>()[{0, 0}];
|
||||
EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5);
|
||||
}
|
||||
|
||||
@ -138,7 +139,8 @@ TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) {
|
||||
EXPECT_TRUE(returnedValue.isa<SplatElementsAttr>());
|
||||
|
||||
// Check Elements attribute element value is expected.
|
||||
auto firstValue = returnedValue.cast<ElementsAttr>().getValue({0, 0});
|
||||
auto firstValue =
|
||||
returnedValue.cast<ElementsAttr>().getValues<Attribute>()[{0, 0}];
|
||||
EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5);
|
||||
}
|
||||
|
||||
@ -162,7 +164,8 @@ TEST(QuantizationUtilsTest, convertRankedSparseAttrUniform) {
|
||||
EXPECT_TRUE(returnedValue.isa<SparseElementsAttr>());
|
||||
|
||||
// Check Elements attribute element value is expected.
|
||||
auto firstValue = returnedValue.cast<ElementsAttr>().getValue({0, 0});
|
||||
auto firstValue =
|
||||
returnedValue.cast<ElementsAttr>().getValues<Attribute>()[{0, 0}];
|
||||
EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5);
|
||||
}
|
||||
|
||||
|
@ -202,7 +202,7 @@ TEST(DenseScalarTest, ExtractZeroRankElement) {
|
||||
RankedTensorType shape = RankedTensorType::get({}, intTy);
|
||||
|
||||
auto attr = DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue}));
|
||||
EXPECT_TRUE(attr.getValue({0}) == value);
|
||||
EXPECT_TRUE(attr.getValues<Attribute>()[0] == value);
|
||||
}
|
||||
|
||||
TEST(SparseElementsAttrTest, GetZero) {
|
||||
@ -238,15 +238,15 @@ TEST(SparseElementsAttrTest, GetZero) {
|
||||
|
||||
// Only index (0, 0) contains an element, others are supposed to return
|
||||
// the zero/empty value.
|
||||
auto zeroIntValue = sparseInt.getValue({1, 1});
|
||||
auto zeroIntValue = sparseInt.getValues<Attribute>()[{1, 1}];
|
||||
EXPECT_EQ(zeroIntValue.cast<IntegerAttr>().getInt(), 0);
|
||||
EXPECT_TRUE(zeroIntValue.getType() == intTy);
|
||||
|
||||
auto zeroFloatValue = sparseFloat.getValue({1, 1});
|
||||
auto zeroFloatValue = sparseFloat.getValues<Attribute>()[{1, 1}];
|
||||
EXPECT_EQ(zeroFloatValue.cast<FloatAttr>().getValueAsDouble(), 0.0f);
|
||||
EXPECT_TRUE(zeroFloatValue.getType() == floatTy);
|
||||
|
||||
auto zeroStringValue = sparseString.getValue({1, 1});
|
||||
auto zeroStringValue = sparseString.getValues<Attribute>()[{1, 1}];
|
||||
EXPECT_TRUE(zeroStringValue.cast<StringAttr>().getValue().empty());
|
||||
EXPECT_TRUE(zeroStringValue.getType() == stringTy);
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user