[mlir][sparse] support type conversion from SoA COO to memrefs. (#82398)
This commit is contained in:
parent
a9b5753220
commit
f740366fa6
@ -303,9 +303,9 @@ public:
|
||||
}
|
||||
|
||||
/// Check if the `LevelType` is in the `LevelFormat`.
|
||||
template <LevelFormat fmt>
|
||||
template <LevelFormat... fmt>
|
||||
constexpr bool isa() const {
|
||||
return getLvlFmt() == fmt;
|
||||
return (... || (getLvlFmt() == fmt)) || false;
|
||||
}
|
||||
|
||||
/// Check if the `LevelType` has the properties
|
||||
|
||||
@ -18,6 +18,18 @@
|
||||
namespace mlir {
|
||||
namespace sparse_tensor {
|
||||
|
||||
/// A simple structure that encodes a range of levels in the sparse tensors that
|
||||
/// forms a COO segment.
|
||||
struct COOSegment {
|
||||
std::pair<Level, Level> lvlRange; // [low, high)
|
||||
bool isSoA;
|
||||
|
||||
bool isSegmentStart(Level l) const { return l == lvlRange.first; }
|
||||
bool inSegment(Level l) const {
|
||||
return l >= lvlRange.first && l < lvlRange.second;
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// A wrapper around `RankedTensorType`, which has three goals:
|
||||
///
|
||||
@ -330,6 +342,9 @@ public:
|
||||
/// Returns [un]ordered COO type for this sparse tensor type.
|
||||
RankedTensorType getCOOType(bool ordered) const;
|
||||
|
||||
/// Returns a list of COO segments in the sparse tensor types.
|
||||
SmallVector<COOSegment> getCOOSegments() const;
|
||||
|
||||
private:
|
||||
// These two must be const, to ensure coherence of the memoized fields.
|
||||
const RankedTensorType rtp;
|
||||
|
||||
@ -74,11 +74,12 @@ void StorageLayout::foreachField(
|
||||
callback) const {
|
||||
const auto lvlTypes = enc.getLvlTypes();
|
||||
const Level lvlRank = enc.getLvlRank();
|
||||
const Level cooStart = SparseTensorType(enc).getCOOStart();
|
||||
const Level end = cooStart == lvlRank ? cooStart : cooStart + 1;
|
||||
SmallVector<COOSegment> cooSegs = SparseTensorType(enc).getCOOSegments();
|
||||
FieldIndex fieldIdx = kDataFieldStartingIdx;
|
||||
|
||||
ArrayRef cooSegsRef = cooSegs;
|
||||
// Per-level storage.
|
||||
for (Level l = 0; l < end; l++) {
|
||||
for (Level l = 0; l < lvlRank; /*l += 1 or l += AoSCooLen*/) {
|
||||
const auto lt = lvlTypes[l];
|
||||
if (isWithPosLT(lt)) {
|
||||
if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, lt)))
|
||||
@ -88,6 +89,21 @@ void StorageLayout::foreachField(
|
||||
if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, lt)))
|
||||
return;
|
||||
}
|
||||
if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) {
|
||||
if (!cooSegsRef.front().isSoA) {
|
||||
// AoS COO, all singletons are fused into one memrefs. Skips the entire
|
||||
// COO segement.
|
||||
l = cooSegsRef.front().lvlRange.second;
|
||||
} else {
|
||||
// SoA COO, each singleton level has one memref.
|
||||
l++;
|
||||
}
|
||||
// Expire handled COO segment.
|
||||
cooSegsRef = cooSegsRef.drop_front();
|
||||
} else {
|
||||
// Non COO levels.
|
||||
l++;
|
||||
}
|
||||
}
|
||||
// The values array.
|
||||
if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
|
||||
@ -796,13 +812,46 @@ bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl,
|
||||
}
|
||||
|
||||
Level mlir::sparse_tensor::SparseTensorType::getCOOStart() const {
|
||||
if (hasEncoding() && lvlRank > 1)
|
||||
for (Level l = 0; l < lvlRank - 1; l++)
|
||||
if (isCOOType(l, /*isUnique=*/false))
|
||||
return l;
|
||||
SmallVector<COOSegment> coo = getCOOSegments();
|
||||
if (!coo.empty()) {
|
||||
assert(coo.size() == 1);
|
||||
return coo.front().lvlRange.first;
|
||||
}
|
||||
return lvlRank;
|
||||
}
|
||||
|
||||
SmallVector<COOSegment>
|
||||
mlir::sparse_tensor::SparseTensorType::getCOOSegments() const {
|
||||
SmallVector<COOSegment> ret;
|
||||
if (!hasEncoding() || lvlRank <= 1)
|
||||
return ret;
|
||||
|
||||
ArrayRef<LevelType> lts = getLvlTypes();
|
||||
Level l = 0;
|
||||
while (l < lvlRank) {
|
||||
auto lt = lts[l];
|
||||
if (lt.isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>()) {
|
||||
auto cur = lts.begin() + l;
|
||||
auto end = std::find_if(cur + 1, lts.end(), [](LevelType lt) {
|
||||
return !lt.isa<LevelFormat::Singleton>();
|
||||
});
|
||||
unsigned cooLen = std::distance(cur, end);
|
||||
if (cooLen > 1) {
|
||||
// To support mixed SoA/AoS COO, we should break the segment when the
|
||||
// storage scheme changes, for now we faithfully assume that all
|
||||
// consecutive singleton levels have the same storage format as verified
|
||||
// STEA.
|
||||
ret.push_back(COOSegment{std::make_pair(l, l + cooLen),
|
||||
lts[l + 1].isa<LevelPropNonDefault::SoA>()});
|
||||
}
|
||||
l += cooLen;
|
||||
} else {
|
||||
l++;
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
RankedTensorType
|
||||
mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const {
|
||||
SmallVector<LevelType> lvlTypes;
|
||||
|
||||
@ -48,6 +48,10 @@
|
||||
map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton)
|
||||
}>
|
||||
|
||||
#SoACOO = #sparse_tensor.encoding<{
|
||||
map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa))
|
||||
}>
|
||||
|
||||
#CooPNo = #sparse_tensor.encoding<{
|
||||
map = (d0, d1) -> (d1 : compressed(nonunique), d0 : singleton(nonordered))
|
||||
}>
|
||||
@ -67,6 +71,28 @@ func.func @sparse_nop(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #Spa
|
||||
return %arg0 : tensor<?xf64, #SparseVector>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @sparse_nop_aos_coo(
|
||||
// CHECK-SAME: %[[POS:.*0]]: memref<?xindex>,
|
||||
// CHECK-SAME: %[[AoS_CRD:.*1]]: memref<?xindex>,
|
||||
// CHECK-SAME: %[[VAL:.*]]: memref<?xf64>,
|
||||
// CHECK-SAME: %[[A3:.*]]: !sparse_tensor.storage_specifier
|
||||
// CHECK: return %[[POS]], %[[AoS_CRD]], %[[VAL]], %[[A3]]
|
||||
func.func @sparse_nop_aos_coo(%arg0: tensor<?x?xf64, #Coo>) -> tensor<?x?xf64, #Coo> {
|
||||
return %arg0 : tensor<?x?xf64, #Coo>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @sparse_nop_soa_coo(
|
||||
// CHECK-SAME: %[[POS:.*0]]: memref<?xindex>,
|
||||
// CHECK-SAME: %[[SoA_CRD_0:.*1]]: memref<?xindex>,
|
||||
// CHECK-SAME: %[[SoA_CRD_1:.*2]]: memref<?xindex>,
|
||||
// CHECK-SAME: %[[VAL:.*]]: memref<?xf64>,
|
||||
// CHECK-SAME: %[[A3:.*]]: !sparse_tensor.storage_specifier
|
||||
// CHECK: return %[[POS]], %[[SoA_CRD_0]], %[[SoA_CRD_1]], %[[VAL]], %[[A3]]
|
||||
func.func @sparse_nop_soa_coo(%arg0: tensor<?x?xf64, #SoACOO>) -> tensor<?x?xf64, #SoACOO> {
|
||||
return %arg0 : tensor<?x?xf64, #SoACOO>
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: func @sparse_nop_multi_ret(
|
||||
// CHECK-SAME: %[[A0:.*0]]: memref<?xi32>,
|
||||
// CHECK-SAME: %[[A1:.*1]]: memref<?xi64>,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user