[mlir][sparse] add merger support on Batch LevelType. (#83186)

This commit is contained in:
Peiming Liu 2024-02-27 13:18:43 -08:00 committed by GitHub
parent f7a9966468
commit d82e93e7f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 59 additions and 28 deletions

View File

@ -333,16 +333,28 @@ public:
return lvlBits & static_cast<uint64_t>(p); return lvlBits & static_cast<uint64_t>(p);
} }
/// Check if the `LevelType` is considered to be sparse.
constexpr bool hasSparseSemantic() const {
return isa<LevelFormat::Compressed, LevelFormat::Singleton,
LevelFormat::LooseCompressed, LevelFormat::NOutOfM>();
}
/// Check if the `LevelType` is considered to be dense-like.
constexpr bool hasDenseSemantic() const {
return isa<LevelFormat::Dense, LevelFormat::Batch>();
}
/// Check if the `LevelType` needs positions array. /// Check if the `LevelType` needs positions array.
constexpr bool isWithPosLT() const { constexpr bool isWithPosLT() const {
return isa<LevelFormat::Compressed>() || assert(!isa<LevelFormat::Undef>());
isa<LevelFormat::LooseCompressed>(); return isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>();
} }
/// Check if the `LevelType` needs coordinates array. /// Check if the `LevelType` needs coordinates array.
constexpr bool isWithCrdLT() const { constexpr bool isWithCrdLT() const {
assert(!isa<LevelFormat::Undef>());
// All sparse levels has coordinate array. // All sparse levels has coordinate array.
return !isa<LevelFormat::Dense, LevelFormat::Batch>(); return hasSparseSemantic();
} }
std::string toMLIRString() const { std::string toMLIRString() const {

View File

@ -509,8 +509,7 @@ public:
bool isSparseLvlWithNonTrivialIdxExp(TensorLoopId b) const { bool isSparseLvlWithNonTrivialIdxExp(TensorLoopId b) const {
if (isLvlWithNonTrivialIdxExp(b)) { if (isLvlWithNonTrivialIdxExp(b)) {
auto lt = getLoopDependentLevelType(b); auto lt = getLoopDependentLevelType(b);
return isCompressedLT(lt) || isSingletonLT(lt) || return lt.hasSparseSemantic();
isLooseCompressedLT(lt) || isNOutOfMLT(lt);
} }
return false; return false;
} }

View File

@ -476,7 +476,7 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
// Starts resetting from a dense level, so that the first bit (if kept) // Starts resetting from a dense level, so that the first bit (if kept)
// is not undefined level-type. // is not undefined level-type.
for (unsigned b = 0; b < be; b++) { for (unsigned b = 0; b < be; b++) {
if (simple[b] && isDenseLT(getLvlType(TensorLoopId{b}))) { if (simple[b] && getLvlType(TensorLoopId{b}).hasDenseSemantic()) {
offset = be - b - 1; // relative to the end offset = be - b - 1; // relative to the end
break; break;
} }
@ -489,8 +489,7 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
// Slice on dense level has `locate` property as well, and can be optimized. // Slice on dense level has `locate` property as well, and can be optimized.
if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) { if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) {
const auto lt = getLvlType(b); const auto lt = getLvlType(b);
if (!isCompressedLT(lt) && !isSingletonLT(lt) && if (!lt.hasSparseSemantic()) {
!isLooseCompressedLT(lt) && !isNOutOfMLT(lt)) {
if (reset) if (reset)
simple.reset(b); simple.reset(b);
reset = true; reset = true;
@ -670,8 +669,7 @@ bool Merger::isSingleCondition(TensorId t, ExprId e) const {
bool Merger::hasAnySparse(const BitVector &bits) const { bool Merger::hasAnySparse(const BitVector &bits) const {
for (TensorLoopId b : bits.set_bits()) { for (TensorLoopId b : bits.set_bits()) {
const auto lt = getLvlType(b); const auto lt = getLvlType(b);
if (isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) || if (lt.hasSparseSemantic())
isNOutOfMLT(lt))
return true; return true;
} }
return hasSparseIdxReduction(bits); return hasSparseIdxReduction(bits);

View File

@ -120,7 +120,8 @@ static Match synZeroMatch() { return Match(); }
FOREVERY_BINOP(IMPL_BINOP_PATTERN) FOREVERY_BINOP(IMPL_BINOP_PATTERN)
#undef IMPL_BINOP_PATTERN #undef IMPL_BINOP_PATTERN
class MergerTestBase : public ::testing::Test { // Parameterize LevelFormat to test both Dense and Batch LevelFormat.
class MergerTestBase : public ::testing::TestWithParam<LevelFormat> {
protected: protected:
MergerTestBase(unsigned numTensors, unsigned numLoops) MergerTestBase(unsigned numTensors, unsigned numLoops)
: merger(numTensors, numLoops, /*maxRank=*/numLoops) { : merger(numTensors, numLoops, /*maxRank=*/numLoops) {
@ -317,10 +318,14 @@ protected:
// Tensor 1: sparse input vector. // Tensor 1: sparse input vector.
merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Compressed); merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Compressed);
// Tensor 2: dense output vector. // Tensor 2: dense output vector.
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Dense); merger.setLevelAndType(tid(2), lid(0), 0, GetParam());
} }
}; };
INSTANTIATE_TEST_SUITE_P(Test3T1L, MergerTest3T1L,
::testing::Values(LevelFormat::Dense,
LevelFormat::Batch));
/// Four tensors (three inputs, one output); and a single loop. /// Four tensors (three inputs, one output); and a single loop.
class MergerTest4T1L : public MergerTestBase { class MergerTest4T1L : public MergerTestBase {
protected: protected:
@ -333,10 +338,14 @@ protected:
// Tensor 2: sparse input vector // Tensor 2: sparse input vector
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Compressed); merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Compressed);
// Tensor 3: dense output vector // Tensor 3: dense output vector
merger.setLevelAndType(tid(3), lid(0), 0, LevelFormat::Dense); merger.setLevelAndType(tid(3), lid(0), 0, GetParam());
} }
}; };
INSTANTIATE_TEST_SUITE_P(Test4T1L, MergerTest4T1L,
::testing::Values(LevelFormat::Dense,
LevelFormat::Batch));
/// ///
/// Tests with both sparse and dense input. /// Tests with both sparse and dense input.
/// ///
@ -349,12 +358,16 @@ protected:
// Tensor 0: sparse input vector. // Tensor 0: sparse input vector.
merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed); merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed);
// Tensor 1: dense input vector. // Tensor 1: dense input vector.
merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Dense); merger.setLevelAndType(tid(1), lid(0), 0, GetParam());
// Tensor 2: dense output vector. // Tensor 2: dense output vector.
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Dense); merger.setLevelAndType(tid(2), lid(0), 0, GetParam());
} }
}; };
INSTANTIATE_TEST_SUITE_P(Test3T1LD, MergerTest3T1LD,
::testing::Values(LevelFormat::Dense,
LevelFormat::Batch));
/// ///
/// Tests with both undef and dense input. /// Tests with both undef and dense input.
/// ///
@ -367,14 +380,18 @@ protected:
// Tensor 0: undef input vector. // Tensor 0: undef input vector.
merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Undef); merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Undef);
// Tensor 1: dense input vector. // Tensor 1: dense input vector.
merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Dense); merger.setLevelAndType(tid(1), lid(0), 0, GetParam());
// Tensor 2: undef input vector. // Tensor 2: undef input vector.
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Undef); merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Undef);
// Tensor 3: dense output vector. // Tensor 3: dense output vector.
merger.setLevelAndType(tid(3), lid(0), 0, LevelFormat::Dense); merger.setLevelAndType(tid(3), lid(0), 0, GetParam());
} }
}; };
INSTANTIATE_TEST_SUITE_P(Test4T1LU, MergerTest4T1LU,
::testing::Values(LevelFormat::Dense,
LevelFormat::Batch));
/// ///
/// Tests with operation on sparse output. /// Tests with operation on sparse output.
/// ///
@ -395,6 +412,11 @@ protected:
} }
}; };
// This testsuite does not use any dense-like format, just one of {Dense, Batch}
// is enough.
INSTANTIATE_TEST_SUITE_P(Test3T1LSo, MergerTest3T1LSo,
::testing::Values(LevelFormat::Dense));
} // namespace } // namespace
/// Vector multiplication (conjunction) of 3 vectors, i.e.; /// Vector multiplication (conjunction) of 3 vectors, i.e.;
@ -409,7 +431,7 @@ protected:
/// lat( i_01_D / (tensor_0 * tensor_1 * tensor2) ) /// lat( i_01_D / (tensor_0 * tensor_1 * tensor2) )
/// } /// }
#define IMPL_MERGER_TEST_CONJ_CONJ_UNDEF(CONJ1, CONJ2) \ #define IMPL_MERGER_TEST_CONJ_CONJ_UNDEF(CONJ1, CONJ2) \
TEST_F(MergerTest4T1LU, vector_##CONJ1##_##CONJ2) { \ TEST_P(MergerTest4T1LU, vector_##CONJ1##_##CONJ2) { \
const auto em = CONJ1##Expr(tensor(0), tensor(1)); \ const auto em = CONJ1##Expr(tensor(0), tensor(1)); \
const auto e = CONJ2##Expr(em, tensor(2)); \ const auto e = CONJ2##Expr(em, tensor(2)); \
const auto l0 = lid(0); \ const auto l0 = lid(0); \
@ -443,7 +465,7 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_UNDEF)
/// lat( i_03_U / (tensor_0 * tensor_1 * output_tensor2) ) /// lat( i_03_U / (tensor_0 * tensor_1 * output_tensor2) )
/// } /// }
#define IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT(CONJ1, CONJ2) \ #define IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT(CONJ1, CONJ2) \
TEST_F(MergerTest3T1LSo, vector_##CONJ1##_##CONJ2) { \ TEST_P(MergerTest3T1LSo, vector_##CONJ1##_##CONJ2) { \
const auto em = CONJ1##Expr(tensor(0), tensor(1)); \ const auto em = CONJ1##Expr(tensor(0), tensor(1)); \
const auto e = CONJ2##Expr(em, tensor(2)); \ const auto e = CONJ2##Expr(em, tensor(2)); \
const auto l0 = lid(0); \ const auto l0 = lid(0); \
@ -482,7 +504,7 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT)
/// lat( i_01 / tensor_1 ) /// lat( i_01 / tensor_1 )
/// } /// }
#define IMPL_MERGER_TEST_DISJ(OP, UNUSED) \ #define IMPL_MERGER_TEST_DISJ(OP, UNUSED) \
TEST_F(MergerTest3T1L, vector_##OP) { \ TEST_P(MergerTest3T1L, vector_##OP) { \
const auto e = OP##Expr(tensor(0), tensor(1)); \ const auto e = OP##Expr(tensor(0), tensor(1)); \
const auto l0 = lid(0); \ const auto l0 = lid(0); \
const auto t0 = tid(0); \ const auto t0 = tid(0); \
@ -514,7 +536,7 @@ FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_DISJ)
/// lat( i_00 i_01 / (tensor_0 * tensor_1) ) /// lat( i_00 i_01 / (tensor_0 * tensor_1) )
/// } /// }
#define IMPL_MERGER_TEST_CONJ(OP, UNUSED) \ #define IMPL_MERGER_TEST_CONJ(OP, UNUSED) \
TEST_F(MergerTest3T1L, vector_##OP) { \ TEST_P(MergerTest3T1L, vector_##OP) { \
const auto e = OP##Expr(tensor(0), tensor(1)); \ const auto e = OP##Expr(tensor(0), tensor(1)); \
const auto l0 = lid(0); \ const auto l0 = lid(0); \
const auto t0 = tid(0); \ const auto t0 = tid(0); \
@ -544,7 +566,7 @@ FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_CONJ)
/// lat( i_02 / tensor_2 ) /// lat( i_02 / tensor_2 )
/// } /// }
#define IMPL_MERGER_TEST_CONJ_DISJ(CONJ, DISJ) \ #define IMPL_MERGER_TEST_CONJ_DISJ(CONJ, DISJ) \
TEST_F(MergerTest4T1L, vector_##CONJ##_##DISJ) { \ TEST_P(MergerTest4T1L, vector_##CONJ##_##DISJ) { \
const auto em = CONJ##Expr(tensor(0), tensor(1)); \ const auto em = CONJ##Expr(tensor(0), tensor(1)); \
const auto e = DISJ##Expr(em, tensor(2)); \ const auto e = DISJ##Expr(em, tensor(2)); \
const auto l0 = lid(0); \ const auto l0 = lid(0); \
@ -587,7 +609,7 @@ FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(IMPL_MERGER_TEST_CONJ_DISJ)
/// lat( i_00 / tensor_0 ) /// lat( i_00 / tensor_0 )
/// } /// }
#define IMPL_MERGER_TEST_DISJ_DISJ(DISJ1, DISJ2) \ #define IMPL_MERGER_TEST_DISJ_DISJ(DISJ1, DISJ2) \
TEST_F(MergerTest4T1L, Vector_##DISJ1##_##DISJ2) { \ TEST_P(MergerTest4T1L, Vector_##DISJ1##_##DISJ2) { \
const auto em = DISJ1##Expr(tensor(0), tensor(1)); \ const auto em = DISJ1##Expr(tensor(0), tensor(1)); \
const auto e = DISJ2##Expr(em, tensor(2)); \ const auto e = DISJ2##Expr(em, tensor(2)); \
const auto l0 = lid(0); \ const auto l0 = lid(0); \
@ -636,7 +658,7 @@ FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(IMPL_MERGER_TEST_DISJ_DISJ)
/// lat( i_00 i_01 i_02 / tensor_0 * tensor_1 * tensor_2 ) /// lat( i_00 i_01 i_02 / tensor_0 * tensor_1 * tensor_2 )
/// } /// }
#define IMPL_MERGER_TEST_CONJ_CONJ(CONJ1, CONJ2) \ #define IMPL_MERGER_TEST_CONJ_CONJ(CONJ1, CONJ2) \
TEST_F(MergerTest4T1L, vector_##CONJ1##_##CONJ2) { \ TEST_P(MergerTest4T1L, vector_##CONJ1##_##CONJ2) { \
const auto em = CONJ1##Expr(tensor(0), tensor(1)); \ const auto em = CONJ1##Expr(tensor(0), tensor(1)); \
const auto e = CONJ2##Expr(em, tensor(2)); \ const auto e = CONJ2##Expr(em, tensor(2)); \
const auto l0 = lid(0); \ const auto l0 = lid(0); \
@ -675,7 +697,7 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ)
/// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense diff /// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense diff
/// with lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ). /// with lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ).
#define IMPL_MERGER_TEST_OPTIMIZED_DISJ(OP, UNUSED) \ #define IMPL_MERGER_TEST_OPTIMIZED_DISJ(OP, UNUSED) \
TEST_F(MergerTest3T1LD, vector_opted_##OP) { \ TEST_P(MergerTest3T1LD, vector_opted_##OP) { \
const auto e = OP##Expr(tensor(0), tensor(1)); \ const auto e = OP##Expr(tensor(0), tensor(1)); \
const auto l0 = lid(0); \ const auto l0 = lid(0); \
const auto t0 = tid(0); \ const auto t0 = tid(0); \
@ -711,7 +733,7 @@ FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_DISJ)
/// } /// }
/// since i_01 is a dense dimension. /// since i_01 is a dense dimension.
#define IMPL_MERGER_TEST_OPTIMIZED_CONJ(OP, UNUSED) \ #define IMPL_MERGER_TEST_OPTIMIZED_CONJ(OP, UNUSED) \
TEST_F(MergerTest3T1LD, vector_opted_##OP) { \ TEST_P(MergerTest3T1LD, vector_opted_##OP) { \
const auto e = OP##Expr(tensor(0), tensor(1)); \ const auto e = OP##Expr(tensor(0), tensor(1)); \
const auto l0 = lid(0); \ const auto l0 = lid(0); \
const auto t0 = tid(0); \ const auto t0 = tid(0); \
@ -746,7 +768,7 @@ FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_CONJ)
/// lat( i_00 / tensor_0 cmp 0 ) /// lat( i_00 / tensor_0 cmp 0 )
/// lat( i_01 / 0 cmp tensor_1 ) /// lat( i_01 / 0 cmp tensor_1 )
/// } /// }
TEST_F(MergerTest3T1L, vector_cmp) { TEST_P(MergerTest3T1L, vector_cmp) {
const auto e = cmpiExpr(tensor(0), tensor(1)); const auto e = cmpiExpr(tensor(0), tensor(1));
const auto l0 = lid(0); const auto l0 = lid(0);
const auto t0 = tid(0); const auto t0 = tid(0);
@ -784,7 +806,7 @@ TEST_F(MergerTest3T1L, vector_cmp) {
/// ///
/// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense diff /// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense diff
/// with lat( i_00 i_01 / (sparse_tensor_0 cmp dense_tensor_1) ). /// with lat( i_00 i_01 / (sparse_tensor_0 cmp dense_tensor_1) ).
TEST_F(MergerTest3T1LD, vector_cmp) { TEST_P(MergerTest3T1LD, vector_cmp) {
const auto e = cmpiExpr(tensor(0), tensor(1)); const auto e = cmpiExpr(tensor(0), tensor(1));
const auto l0 = lid(0); const auto l0 = lid(0);
const auto t0 = tid(0); const auto t0 = tid(0);