[mlir][sparse] provide an AoS "view" into sparse runtime support lib (#87116)

Note that even though the sparse runtime support lib always uses SoA
storage for COO storage (and provides correct codegen by means of views
into this storage), in some rare cases we need the true physical SoA
storage as a coordinate buffer. This PR provides that functionality by
means of a (costly) coordinate buffer call.

Since this is currently only used for testing/debugging by means of the
sparse_tensor.print method, this solution is acceptable. If we ever want
a performing version of this, we should truly support AoS storage of COO
in addition to the SoA used right now.
This commit is contained in:
Aart Bik 2024-03-29 15:30:36 -07:00 committed by GitHub
parent 038e66fe59
commit dc4cfdbb8f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 152 additions and 23 deletions

View File

@ -143,6 +143,12 @@ public:
MLIR_SPARSETENSOR_FOREVERY_FIXED_O(DECL_GETCOORDINATES)
#undef DECL_GETCOORDINATES
/// Gets coordinates-overhead storage buffer for the given level.
#define DECL_GETCOORDINATESBUFFER(INAME, C) \
virtual void getCoordinatesBuffer(std::vector<C> **, uint64_t);
MLIR_SPARSETENSOR_FOREVERY_FIXED_O(DECL_GETCOORDINATESBUFFER)
#undef DECL_GETCOORDINATESBUFFER
/// Gets primary storage.
#define DECL_GETVALUES(VNAME, V) virtual void getValues(std::vector<V> **);
MLIR_SPARSETENSOR_FOREVERY_V(DECL_GETVALUES)
@ -251,6 +257,31 @@ public:
assert(lvl < getLvlRank());
*out = &coordinates[lvl];
}
void getCoordinatesBuffer(std::vector<C> **out, uint64_t lvl) final {
assert(out && "Received nullptr for out parameter");
assert(lvl < getLvlRank());
// Note that the sparse tensor support library always stores COO in SoA
// format, even when AoS is requested. This is never an issue, since all
// actual code/library generation requests "views" into the coordinate
// storage for the individual levels, which is trivially provided for
// both AoS and SoA (as well as all the other storage formats). The only
// exception is when the buffer version of coordinate storage is requested
// (currently only for printing). In that case, we do the following
// potentially expensive transformation to provide that view. If this
// operation becomes more common beyond debugging, we should consider
// implementing proper AoS in the support library as well.
uint64_t lvlRank = getLvlRank();
uint64_t nnz = values.size();
crdBuffer.clear();
crdBuffer.reserve(nnz * (lvlRank - lvl));
for (uint64_t i = 0; i < nnz; i++) {
for (uint64_t l = lvl; l < lvlRank; l++) {
assert(i < coordinates[l].size());
crdBuffer.push_back(coordinates[l][i]);
}
}
*out = &crdBuffer;
}
void getValues(std::vector<V> **out) final {
assert(out && "Received nullptr for out parameter");
*out = &values;
@ -529,10 +560,14 @@ private:
return -1u;
}
// Sparse tensor storage components.
std::vector<std::vector<P>> positions;
std::vector<std::vector<C>> coordinates;
std::vector<V> values;
// Auxiliary data structures.
std::vector<uint64_t> lvlCursor;
std::vector<C> crdBuffer; // just for AoS view
};
//===----------------------------------------------------------------------===//

View File

@ -77,6 +77,14 @@ MLIR_SPARSETENSOR_FOREVERY_O(DECL_SPARSEPOSITIONS)
MLIR_SPARSETENSOR_FOREVERY_O(DECL_SPARSECOORDINATES)
#undef DECL_SPARSECOORDINATES
/// Tensor-storage method to obtain direct access to the coordinates array
/// buffer for the given level (provides an AoS view into the library).
#define DECL_SPARSECOORDINATES(CNAME, C) \
MLIR_CRUNNERUTILS_EXPORT void _mlir_ciface_sparseCoordinatesBuffer##CNAME( \
StridedMemRefType<C, 1> *out, void *tensor, index_type lvl);
MLIR_SPARSETENSOR_FOREVERY_O(DECL_SPARSECOORDINATES)
#undef DECL_SPARSECOORDINATES
/// Tensor-storage method to insert elements in lexicographical
/// level-coordinate order.
#define DECL_LEXINSERT(VNAME, V) \

View File

@ -275,7 +275,7 @@ static Value genPositionsCall(OpBuilder &builder, Location loc,
.getResult(0);
}
/// Generates a call to obtain the coordindates array.
/// Generates a call to obtain the coordinates array.
static Value genCoordinatesCall(OpBuilder &builder, Location loc,
SparseTensorType stt, Value ptr, Level l) {
Type crdTp = stt.getCrdType();
@ -287,6 +287,20 @@ static Value genCoordinatesCall(OpBuilder &builder, Location loc,
.getResult(0);
}
/// Generates a call to obtain the coordinates array (AoS view).
static Value genCoordinatesBufferCall(OpBuilder &builder, Location loc,
SparseTensorType stt, Value ptr,
Level l) {
Type crdTp = stt.getCrdType();
auto resTp = MemRefType::get({ShapedType::kDynamic}, crdTp);
Value lvl = constantIndex(builder, loc, l);
SmallString<25> name{"sparseCoordinatesBuffer",
overheadTypeFunctionSuffix(crdTp)};
return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
EmitCInterface::On)
.getResult(0);
}
//===----------------------------------------------------------------------===//
// Conversion rules.
//===----------------------------------------------------------------------===//
@ -518,13 +532,35 @@ public:
LogicalResult
matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
const Location loc = op.getLoc();
auto stt = getSparseTensorType(op.getTensor());
auto crds = genCoordinatesCall(rewriter, op.getLoc(), stt,
adaptor.getTensor(), op.getLevel());
auto crds = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
op.getLevel());
// Cast the MemRef type to the type expected by the users, though these
// two types should be compatible at runtime.
if (op.getType() != crds.getType())
crds = rewriter.create<memref::CastOp>(op.getLoc(), op.getType(), crds);
crds = rewriter.create<memref::CastOp>(loc, op.getType(), crds);
rewriter.replaceOp(op, crds);
return success();
}
};
/// Sparse conversion rule for coordinate accesses (AoS style).
class SparseToCoordinatesBufferConverter
: public OpConversionPattern<ToCoordinatesBufferOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(ToCoordinatesBufferOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
const Location loc = op.getLoc();
auto stt = getSparseTensorType(op.getTensor());
auto crds = genCoordinatesBufferCall(
rewriter, loc, stt, adaptor.getTensor(), stt.getAoSCOOStart());
// Cast the MemRef type to the type expected by the users, though these
// two types should be compatible at runtime.
if (op.getType() != crds.getType())
crds = rewriter.create<memref::CastOp>(loc, op.getType(), crds);
rewriter.replaceOp(op, crds);
return success();
}
@ -878,10 +914,10 @@ void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
SparseTensorAllocConverter, SparseTensorEmptyConverter,
SparseTensorDeallocConverter, SparseTensorReorderCOOConverter,
SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter,
SparseTensorToValuesConverter, SparseNumberOfEntriesConverter,
SparseTensorLoadConverter, SparseTensorInsertConverter,
SparseTensorExpandConverter, SparseTensorCompressConverter,
SparseTensorAssembleConverter, SparseTensorDisassembleConverter,
SparseHasRuntimeLibraryConverter>(typeConverter,
patterns.getContext());
SparseToCoordinatesBufferConverter, SparseTensorToValuesConverter,
SparseNumberOfEntriesConverter, SparseTensorLoadConverter,
SparseTensorInsertConverter, SparseTensorExpandConverter,
SparseTensorCompressConverter, SparseTensorAssembleConverter,
SparseTensorDisassembleConverter, SparseHasRuntimeLibraryConverter>(
typeConverter, patterns.getContext());
}

View File

@ -648,7 +648,9 @@ public:
loc, lvl, vector::PrintPunctuation::NoPunctuation);
rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
Value crd = nullptr;
// TODO: eliminates ToCoordinateBufferOp!
// For COO AoS storage, we want to print a single, linear view of
// the full coordinate storage at this level. For any other storage,
// we show the coordinate storage for every indivual level.
if (stt.getAoSCOOStart() == l)
crd = rewriter.create<ToCoordinatesBufferOp>(loc, tensor);
else

View File

@ -68,6 +68,14 @@ MLIR_SPARSETENSOR_FOREVERY_FIXED_O(IMPL_GETPOSITIONS)
MLIR_SPARSETENSOR_FOREVERY_FIXED_O(IMPL_GETCOORDINATES)
#undef IMPL_GETCOORDINATES
#define IMPL_GETCOORDINATESBUFFER(CNAME, C) \
void SparseTensorStorageBase::getCoordinatesBuffer(std::vector<C> **, \
uint64_t) { \
FATAL_PIV("getCoordinatesBuffer" #CNAME); \
}
MLIR_SPARSETENSOR_FOREVERY_FIXED_O(IMPL_GETCOORDINATESBUFFER)
#undef IMPL_GETCOORDINATESBUFFER
#define IMPL_GETVALUES(VNAME, V) \
void SparseTensorStorageBase::getValues(std::vector<V> **) { \
FATAL_PIV("getValues" #VNAME); \

View File

@ -311,6 +311,7 @@ MLIR_SPARSETENSOR_FOREVERY_V(IMPL_SPARSEVALUES)
assert(v); \
aliasIntoMemref(v->size(), v->data(), *ref); \
}
#define IMPL_SPARSEPOSITIONS(PNAME, P) \
IMPL_GETOVERHEAD(sparsePositions##PNAME, P, getPositions)
MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSEPOSITIONS)
@ -320,6 +321,12 @@ MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSEPOSITIONS)
IMPL_GETOVERHEAD(sparseCoordinates##CNAME, C, getCoordinates)
MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSECOORDINATES)
#undef IMPL_SPARSECOORDINATES
#define IMPL_SPARSECOORDINATESBUFFER(CNAME, C) \
IMPL_GETOVERHEAD(sparseCoordinatesBuffer##CNAME, C, getCoordinatesBuffer)
MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSECOORDINATESBUFFER)
#undef IMPL_SPARSECOORDINATESBUFFER
#undef IMPL_GETOVERHEAD
#define IMPL_LEXINSERT(VNAME, V) \

View File

@ -120,6 +120,14 @@
)
}>
#COOAoS = #sparse_tensor.encoding<{
map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton)
}>
#COOSoA = #sparse_tensor.encoding<{
map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa))
}>
module {
//
@ -161,6 +169,8 @@ module {
%h = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #BSCC>
%i = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #BSR0>
%j = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #BSC0>
%AoS = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #COOAoS>
%SoA = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #COOSoA>
// CHECK-NEXT: ---- Sparse Tensor ----
// CHECK-NEXT: nse = 5
@ -274,19 +284,42 @@ module {
// CHECK-NEXT: ----
sparse_tensor.print %j : tensor<4x8xi32, #BSC0>
// CHECK-NEXT: ---- Sparse Tensor ----
// CHECK-NEXT: nse = 5
// CHECK-NEXT: dim = ( 4, 8 )
// CHECK-NEXT: lvl = ( 4, 8 )
// CHECK-NEXT: pos[0] : ( 0, 5,
// CHECK-NEXT: crd[0] : ( 0, 0, 0, 2, 3, 2, 3, 3, 3, 5,
// CHECK-NEXT: values : ( 1, 2, 3, 4, 5,
// CHECK-NEXT: ----
sparse_tensor.print %AoS : tensor<4x8xi32, #COOAoS>
// CHECK-NEXT: ---- Sparse Tensor ----
// CHECK-NEXT: nse = 5
// CHECK-NEXT: dim = ( 4, 8 )
// CHECK-NEXT: lvl = ( 4, 8 )
// CHECK-NEXT: pos[0] : ( 0, 5,
// CHECK-NEXT: crd[0] : ( 0, 0, 3, 3, 3,
// CHECK-NEXT: crd[1] : ( 0, 2, 2, 3, 5,
// CHECK-NEXT: values : ( 1, 2, 3, 4, 5,
// CHECK-NEXT: ----
sparse_tensor.print %SoA : tensor<4x8xi32, #COOSoA>
// Release the resources.
bufferization.dealloc_tensor %XO : tensor<4x8xi32, #AllDense>
bufferization.dealloc_tensor %XT : tensor<4x8xi32, #AllDenseT>
bufferization.dealloc_tensor %a : tensor<4x8xi32, #CSR>
bufferization.dealloc_tensor %b : tensor<4x8xi32, #DCSR>
bufferization.dealloc_tensor %c : tensor<4x8xi32, #CSC>
bufferization.dealloc_tensor %d : tensor<4x8xi32, #DCSC>
bufferization.dealloc_tensor %e : tensor<4x8xi32, #BSR>
bufferization.dealloc_tensor %f : tensor<4x8xi32, #BSRC>
bufferization.dealloc_tensor %g : tensor<4x8xi32, #BSC>
bufferization.dealloc_tensor %h : tensor<4x8xi32, #BSCC>
bufferization.dealloc_tensor %i : tensor<4x8xi32, #BSR0>
bufferization.dealloc_tensor %j : tensor<4x8xi32, #BSC0>
bufferization.dealloc_tensor %XO : tensor<4x8xi32, #AllDense>
bufferization.dealloc_tensor %XT : tensor<4x8xi32, #AllDenseT>
bufferization.dealloc_tensor %a : tensor<4x8xi32, #CSR>
bufferization.dealloc_tensor %b : tensor<4x8xi32, #DCSR>
bufferization.dealloc_tensor %c : tensor<4x8xi32, #CSC>
bufferization.dealloc_tensor %d : tensor<4x8xi32, #DCSC>
bufferization.dealloc_tensor %e : tensor<4x8xi32, #BSR>
bufferization.dealloc_tensor %f : tensor<4x8xi32, #BSRC>
bufferization.dealloc_tensor %g : tensor<4x8xi32, #BSC>
bufferization.dealloc_tensor %h : tensor<4x8xi32, #BSCC>
bufferization.dealloc_tensor %i : tensor<4x8xi32, #BSR0>
bufferization.dealloc_tensor %j : tensor<4x8xi32, #BSC0>
bufferization.dealloc_tensor %AoS : tensor<4x8xi32, #COOAoS>
bufferization.dealloc_tensor %SoA : tensor<4x8xi32, #COOSoA>
return
}