[mlir][sparse] add merger support on Batch LevelType. (#83186)

This commit is contained in:
Peiming Liu 2024-02-27 13:18:43 -08:00 committed by GitHub
parent f7a9966468
commit d82e93e7f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 59 additions and 28 deletions

View File

@ -333,16 +333,28 @@ public:
return lvlBits & static_cast<uint64_t>(p);
}
/// Check if the `LevelType` is considered to be sparse.
constexpr bool hasSparseSemantic() const {
return isa<LevelFormat::Compressed, LevelFormat::Singleton,
LevelFormat::LooseCompressed, LevelFormat::NOutOfM>();
}
/// Check if the `LevelType` is considered to be dense-like.
constexpr bool hasDenseSemantic() const {
return isa<LevelFormat::Dense, LevelFormat::Batch>();
}
/// Check if the `LevelType` needs positions array.
constexpr bool isWithPosLT() const {
return isa<LevelFormat::Compressed>() ||
isa<LevelFormat::LooseCompressed>();
assert(!isa<LevelFormat::Undef>());
return isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>();
}
/// Check if the `LevelType` needs coordinates array.
constexpr bool isWithCrdLT() const {
assert(!isa<LevelFormat::Undef>());
// All sparse levels has coordinate array.
return !isa<LevelFormat::Dense, LevelFormat::Batch>();
return hasSparseSemantic();
}
std::string toMLIRString() const {

View File

@ -509,8 +509,7 @@ public:
bool isSparseLvlWithNonTrivialIdxExp(TensorLoopId b) const {
if (isLvlWithNonTrivialIdxExp(b)) {
auto lt = getLoopDependentLevelType(b);
return isCompressedLT(lt) || isSingletonLT(lt) ||
isLooseCompressedLT(lt) || isNOutOfMLT(lt);
return lt.hasSparseSemantic();
}
return false;
}

View File

@ -476,7 +476,7 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
// Starts resetting from a dense level, so that the first bit (if kept)
// is not undefined level-type.
for (unsigned b = 0; b < be; b++) {
if (simple[b] && isDenseLT(getLvlType(TensorLoopId{b}))) {
if (simple[b] && getLvlType(TensorLoopId{b}).hasDenseSemantic()) {
offset = be - b - 1; // relative to the end
break;
}
@ -489,8 +489,7 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
// Slice on dense level has `locate` property as well, and can be optimized.
if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) {
const auto lt = getLvlType(b);
if (!isCompressedLT(lt) && !isSingletonLT(lt) &&
!isLooseCompressedLT(lt) && !isNOutOfMLT(lt)) {
if (!lt.hasSparseSemantic()) {
if (reset)
simple.reset(b);
reset = true;
@ -670,8 +669,7 @@ bool Merger::isSingleCondition(TensorId t, ExprId e) const {
bool Merger::hasAnySparse(const BitVector &bits) const {
for (TensorLoopId b : bits.set_bits()) {
const auto lt = getLvlType(b);
if (isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) ||
isNOutOfMLT(lt))
if (lt.hasSparseSemantic())
return true;
}
return hasSparseIdxReduction(bits);

View File

@ -120,7 +120,8 @@ static Match synZeroMatch() { return Match(); }
FOREVERY_BINOP(IMPL_BINOP_PATTERN)
#undef IMPL_BINOP_PATTERN
class MergerTestBase : public ::testing::Test {
// Parameterize LevelFormat to test both Dense and Batch LevelFormat.
class MergerTestBase : public ::testing::TestWithParam<LevelFormat> {
protected:
MergerTestBase(unsigned numTensors, unsigned numLoops)
: merger(numTensors, numLoops, /*maxRank=*/numLoops) {
@ -317,10 +318,14 @@ protected:
// Tensor 1: sparse input vector.
merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Compressed);
// Tensor 2: dense output vector.
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Dense);
merger.setLevelAndType(tid(2), lid(0), 0, GetParam());
}
};
INSTANTIATE_TEST_SUITE_P(Test3T1L, MergerTest3T1L,
::testing::Values(LevelFormat::Dense,
LevelFormat::Batch));
/// Four tensors (three inputs, one output); and a single loop.
class MergerTest4T1L : public MergerTestBase {
protected:
@ -333,10 +338,14 @@ protected:
// Tensor 2: sparse input vector
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Compressed);
// Tensor 3: dense output vector
merger.setLevelAndType(tid(3), lid(0), 0, LevelFormat::Dense);
merger.setLevelAndType(tid(3), lid(0), 0, GetParam());
}
};
INSTANTIATE_TEST_SUITE_P(Test4T1L, MergerTest4T1L,
::testing::Values(LevelFormat::Dense,
LevelFormat::Batch));
///
/// Tests with both sparse and dense input.
///
@ -349,12 +358,16 @@ protected:
// Tensor 0: sparse input vector.
merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed);
// Tensor 1: dense input vector.
merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Dense);
merger.setLevelAndType(tid(1), lid(0), 0, GetParam());
// Tensor 2: dense output vector.
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Dense);
merger.setLevelAndType(tid(2), lid(0), 0, GetParam());
}
};
INSTANTIATE_TEST_SUITE_P(Test3T1LD, MergerTest3T1LD,
::testing::Values(LevelFormat::Dense,
LevelFormat::Batch));
///
/// Tests with both undef and dense input.
///
@ -367,14 +380,18 @@ protected:
// Tensor 0: undef input vector.
merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Undef);
// Tensor 1: dense input vector.
merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Dense);
merger.setLevelAndType(tid(1), lid(0), 0, GetParam());
// Tensor 2: undef input vector.
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Undef);
// Tensor 3: dense output vector.
merger.setLevelAndType(tid(3), lid(0), 0, LevelFormat::Dense);
merger.setLevelAndType(tid(3), lid(0), 0, GetParam());
}
};
INSTANTIATE_TEST_SUITE_P(Test4T1LU, MergerTest4T1LU,
::testing::Values(LevelFormat::Dense,
LevelFormat::Batch));
///
/// Tests with operation on sparse output.
///
@ -395,6 +412,11 @@ protected:
}
};
// This testsuite does not use any dense-like format, just one of {Dense, Batch}
// is enough.
INSTANTIATE_TEST_SUITE_P(Test3T1LSo, MergerTest3T1LSo,
::testing::Values(LevelFormat::Dense));
} // namespace
/// Vector multiplication (conjunction) of 3 vectors, i.e.;
@ -409,7 +431,7 @@ protected:
/// lat( i_01_D / (tensor_0 * tensor_1 * tensor2) )
/// }
#define IMPL_MERGER_TEST_CONJ_CONJ_UNDEF(CONJ1, CONJ2) \
TEST_F(MergerTest4T1LU, vector_##CONJ1##_##CONJ2) { \
TEST_P(MergerTest4T1LU, vector_##CONJ1##_##CONJ2) { \
const auto em = CONJ1##Expr(tensor(0), tensor(1)); \
const auto e = CONJ2##Expr(em, tensor(2)); \
const auto l0 = lid(0); \
@ -443,7 +465,7 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_UNDEF)
/// lat( i_03_U / (tensor_0 * tensor_1 * output_tensor2) )
/// }
#define IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT(CONJ1, CONJ2) \
TEST_F(MergerTest3T1LSo, vector_##CONJ1##_##CONJ2) { \
TEST_P(MergerTest3T1LSo, vector_##CONJ1##_##CONJ2) { \
const auto em = CONJ1##Expr(tensor(0), tensor(1)); \
const auto e = CONJ2##Expr(em, tensor(2)); \
const auto l0 = lid(0); \
@ -482,7 +504,7 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT)
/// lat( i_01 / tensor_1 )
/// }
#define IMPL_MERGER_TEST_DISJ(OP, UNUSED) \
TEST_F(MergerTest3T1L, vector_##OP) { \
TEST_P(MergerTest3T1L, vector_##OP) { \
const auto e = OP##Expr(tensor(0), tensor(1)); \
const auto l0 = lid(0); \
const auto t0 = tid(0); \
@ -514,7 +536,7 @@ FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_DISJ)
/// lat( i_00 i_01 / (tensor_0 * tensor_1) )
/// }
#define IMPL_MERGER_TEST_CONJ(OP, UNUSED) \
TEST_F(MergerTest3T1L, vector_##OP) { \
TEST_P(MergerTest3T1L, vector_##OP) { \
const auto e = OP##Expr(tensor(0), tensor(1)); \
const auto l0 = lid(0); \
const auto t0 = tid(0); \
@ -544,7 +566,7 @@ FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_CONJ)
/// lat( i_02 / tensor_2 )
/// }
#define IMPL_MERGER_TEST_CONJ_DISJ(CONJ, DISJ) \
TEST_F(MergerTest4T1L, vector_##CONJ##_##DISJ) { \
TEST_P(MergerTest4T1L, vector_##CONJ##_##DISJ) { \
const auto em = CONJ##Expr(tensor(0), tensor(1)); \
const auto e = DISJ##Expr(em, tensor(2)); \
const auto l0 = lid(0); \
@ -587,7 +609,7 @@ FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(IMPL_MERGER_TEST_CONJ_DISJ)
/// lat( i_00 / tensor_0 )
/// }
#define IMPL_MERGER_TEST_DISJ_DISJ(DISJ1, DISJ2) \
TEST_F(MergerTest4T1L, Vector_##DISJ1##_##DISJ2) { \
TEST_P(MergerTest4T1L, Vector_##DISJ1##_##DISJ2) { \
const auto em = DISJ1##Expr(tensor(0), tensor(1)); \
const auto e = DISJ2##Expr(em, tensor(2)); \
const auto l0 = lid(0); \
@ -636,7 +658,7 @@ FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(IMPL_MERGER_TEST_DISJ_DISJ)
/// lat( i_00 i_01 i_02 / tensor_0 * tensor_1 * tensor_2 )
/// }
#define IMPL_MERGER_TEST_CONJ_CONJ(CONJ1, CONJ2) \
TEST_F(MergerTest4T1L, vector_##CONJ1##_##CONJ2) { \
TEST_P(MergerTest4T1L, vector_##CONJ1##_##CONJ2) { \
const auto em = CONJ1##Expr(tensor(0), tensor(1)); \
const auto e = CONJ2##Expr(em, tensor(2)); \
const auto l0 = lid(0); \
@ -675,7 +697,7 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ)
/// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense diff
/// with lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ).
#define IMPL_MERGER_TEST_OPTIMIZED_DISJ(OP, UNUSED) \
TEST_F(MergerTest3T1LD, vector_opted_##OP) { \
TEST_P(MergerTest3T1LD, vector_opted_##OP) { \
const auto e = OP##Expr(tensor(0), tensor(1)); \
const auto l0 = lid(0); \
const auto t0 = tid(0); \
@ -711,7 +733,7 @@ FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_DISJ)
/// }
/// since i_01 is a dense dimension.
#define IMPL_MERGER_TEST_OPTIMIZED_CONJ(OP, UNUSED) \
TEST_F(MergerTest3T1LD, vector_opted_##OP) { \
TEST_P(MergerTest3T1LD, vector_opted_##OP) { \
const auto e = OP##Expr(tensor(0), tensor(1)); \
const auto l0 = lid(0); \
const auto t0 = tid(0); \
@ -746,7 +768,7 @@ FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_CONJ)
/// lat( i_00 / tensor_0 cmp 0 )
/// lat( i_01 / 0 cmp tensor_1 )
/// }
TEST_F(MergerTest3T1L, vector_cmp) {
TEST_P(MergerTest3T1L, vector_cmp) {
const auto e = cmpiExpr(tensor(0), tensor(1));
const auto l0 = lid(0);
const auto t0 = tid(0);
@ -784,7 +806,7 @@ TEST_F(MergerTest3T1L, vector_cmp) {
///
/// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense diff
/// with lat( i_00 i_01 / (sparse_tensor_0 cmp dense_tensor_1) ).
TEST_F(MergerTest3T1LD, vector_cmp) {
TEST_P(MergerTest3T1LD, vector_cmp) {
const auto e = cmpiExpr(tensor(0), tensor(1));
const auto l0 = lid(0);
const auto t0 = tid(0);