[mlir][sparse] add merger support on Batch LevelType. (#83186)
This commit is contained in:
parent
f7a9966468
commit
d82e93e7f1
@ -333,16 +333,28 @@ public:
|
||||
return lvlBits & static_cast<uint64_t>(p);
|
||||
}
|
||||
|
||||
/// Check if the `LevelType` is considered to be sparse.
|
||||
constexpr bool hasSparseSemantic() const {
|
||||
return isa<LevelFormat::Compressed, LevelFormat::Singleton,
|
||||
LevelFormat::LooseCompressed, LevelFormat::NOutOfM>();
|
||||
}
|
||||
|
||||
/// Check if the `LevelType` is considered to be dense-like.
|
||||
constexpr bool hasDenseSemantic() const {
|
||||
return isa<LevelFormat::Dense, LevelFormat::Batch>();
|
||||
}
|
||||
|
||||
/// Check if the `LevelType` needs positions array.
|
||||
constexpr bool isWithPosLT() const {
|
||||
return isa<LevelFormat::Compressed>() ||
|
||||
isa<LevelFormat::LooseCompressed>();
|
||||
assert(!isa<LevelFormat::Undef>());
|
||||
return isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>();
|
||||
}
|
||||
|
||||
/// Check if the `LevelType` needs coordinates array.
|
||||
constexpr bool isWithCrdLT() const {
|
||||
assert(!isa<LevelFormat::Undef>());
|
||||
// All sparse levels has coordinate array.
|
||||
return !isa<LevelFormat::Dense, LevelFormat::Batch>();
|
||||
return hasSparseSemantic();
|
||||
}
|
||||
|
||||
std::string toMLIRString() const {
|
||||
|
@ -509,8 +509,7 @@ public:
|
||||
bool isSparseLvlWithNonTrivialIdxExp(TensorLoopId b) const {
|
||||
if (isLvlWithNonTrivialIdxExp(b)) {
|
||||
auto lt = getLoopDependentLevelType(b);
|
||||
return isCompressedLT(lt) || isSingletonLT(lt) ||
|
||||
isLooseCompressedLT(lt) || isNOutOfMLT(lt);
|
||||
return lt.hasSparseSemantic();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
@ -476,7 +476,7 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
|
||||
// Starts resetting from a dense level, so that the first bit (if kept)
|
||||
// is not undefined level-type.
|
||||
for (unsigned b = 0; b < be; b++) {
|
||||
if (simple[b] && isDenseLT(getLvlType(TensorLoopId{b}))) {
|
||||
if (simple[b] && getLvlType(TensorLoopId{b}).hasDenseSemantic()) {
|
||||
offset = be - b - 1; // relative to the end
|
||||
break;
|
||||
}
|
||||
@ -489,8 +489,7 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
|
||||
// Slice on dense level has `locate` property as well, and can be optimized.
|
||||
if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) {
|
||||
const auto lt = getLvlType(b);
|
||||
if (!isCompressedLT(lt) && !isSingletonLT(lt) &&
|
||||
!isLooseCompressedLT(lt) && !isNOutOfMLT(lt)) {
|
||||
if (!lt.hasSparseSemantic()) {
|
||||
if (reset)
|
||||
simple.reset(b);
|
||||
reset = true;
|
||||
@ -670,8 +669,7 @@ bool Merger::isSingleCondition(TensorId t, ExprId e) const {
|
||||
bool Merger::hasAnySparse(const BitVector &bits) const {
|
||||
for (TensorLoopId b : bits.set_bits()) {
|
||||
const auto lt = getLvlType(b);
|
||||
if (isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) ||
|
||||
isNOutOfMLT(lt))
|
||||
if (lt.hasSparseSemantic())
|
||||
return true;
|
||||
}
|
||||
return hasSparseIdxReduction(bits);
|
||||
|
@ -120,7 +120,8 @@ static Match synZeroMatch() { return Match(); }
|
||||
FOREVERY_BINOP(IMPL_BINOP_PATTERN)
|
||||
#undef IMPL_BINOP_PATTERN
|
||||
|
||||
class MergerTestBase : public ::testing::Test {
|
||||
// Parameterize LevelFormat to test both Dense and Batch LevelFormat.
|
||||
class MergerTestBase : public ::testing::TestWithParam<LevelFormat> {
|
||||
protected:
|
||||
MergerTestBase(unsigned numTensors, unsigned numLoops)
|
||||
: merger(numTensors, numLoops, /*maxRank=*/numLoops) {
|
||||
@ -317,10 +318,14 @@ protected:
|
||||
// Tensor 1: sparse input vector.
|
||||
merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Compressed);
|
||||
// Tensor 2: dense output vector.
|
||||
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Dense);
|
||||
merger.setLevelAndType(tid(2), lid(0), 0, GetParam());
|
||||
}
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(Test3T1L, MergerTest3T1L,
|
||||
::testing::Values(LevelFormat::Dense,
|
||||
LevelFormat::Batch));
|
||||
|
||||
/// Four tensors (three inputs, one output); and a single loop.
|
||||
class MergerTest4T1L : public MergerTestBase {
|
||||
protected:
|
||||
@ -333,10 +338,14 @@ protected:
|
||||
// Tensor 2: sparse input vector
|
||||
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Compressed);
|
||||
// Tensor 3: dense output vector
|
||||
merger.setLevelAndType(tid(3), lid(0), 0, LevelFormat::Dense);
|
||||
merger.setLevelAndType(tid(3), lid(0), 0, GetParam());
|
||||
}
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(Test4T1L, MergerTest4T1L,
|
||||
::testing::Values(LevelFormat::Dense,
|
||||
LevelFormat::Batch));
|
||||
|
||||
///
|
||||
/// Tests with both sparse and dense input.
|
||||
///
|
||||
@ -349,12 +358,16 @@ protected:
|
||||
// Tensor 0: sparse input vector.
|
||||
merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed);
|
||||
// Tensor 1: dense input vector.
|
||||
merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Dense);
|
||||
merger.setLevelAndType(tid(1), lid(0), 0, GetParam());
|
||||
// Tensor 2: dense output vector.
|
||||
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Dense);
|
||||
merger.setLevelAndType(tid(2), lid(0), 0, GetParam());
|
||||
}
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(Test3T1LD, MergerTest3T1LD,
|
||||
::testing::Values(LevelFormat::Dense,
|
||||
LevelFormat::Batch));
|
||||
|
||||
///
|
||||
/// Tests with both undef and dense input.
|
||||
///
|
||||
@ -367,14 +380,18 @@ protected:
|
||||
// Tensor 0: undef input vector.
|
||||
merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Undef);
|
||||
// Tensor 1: dense input vector.
|
||||
merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Dense);
|
||||
merger.setLevelAndType(tid(1), lid(0), 0, GetParam());
|
||||
// Tensor 2: undef input vector.
|
||||
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Undef);
|
||||
// Tensor 3: dense output vector.
|
||||
merger.setLevelAndType(tid(3), lid(0), 0, LevelFormat::Dense);
|
||||
merger.setLevelAndType(tid(3), lid(0), 0, GetParam());
|
||||
}
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(Test4T1LU, MergerTest4T1LU,
|
||||
::testing::Values(LevelFormat::Dense,
|
||||
LevelFormat::Batch));
|
||||
|
||||
///
|
||||
/// Tests with operation on sparse output.
|
||||
///
|
||||
@ -395,6 +412,11 @@ protected:
|
||||
}
|
||||
};
|
||||
|
||||
// This testsuite does not use any dense-like format, just one of {Dense, Batch}
|
||||
// is enough.
|
||||
INSTANTIATE_TEST_SUITE_P(Test3T1LSo, MergerTest3T1LSo,
|
||||
::testing::Values(LevelFormat::Dense));
|
||||
|
||||
} // namespace
|
||||
|
||||
/// Vector multiplication (conjunction) of 3 vectors, i.e.;
|
||||
@ -409,7 +431,7 @@ protected:
|
||||
/// lat( i_01_D / (tensor_0 * tensor_1 * tensor2) )
|
||||
/// }
|
||||
#define IMPL_MERGER_TEST_CONJ_CONJ_UNDEF(CONJ1, CONJ2) \
|
||||
TEST_F(MergerTest4T1LU, vector_##CONJ1##_##CONJ2) { \
|
||||
TEST_P(MergerTest4T1LU, vector_##CONJ1##_##CONJ2) { \
|
||||
const auto em = CONJ1##Expr(tensor(0), tensor(1)); \
|
||||
const auto e = CONJ2##Expr(em, tensor(2)); \
|
||||
const auto l0 = lid(0); \
|
||||
@ -443,7 +465,7 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_UNDEF)
|
||||
/// lat( i_03_U / (tensor_0 * tensor_1 * output_tensor2) )
|
||||
/// }
|
||||
#define IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT(CONJ1, CONJ2) \
|
||||
TEST_F(MergerTest3T1LSo, vector_##CONJ1##_##CONJ2) { \
|
||||
TEST_P(MergerTest3T1LSo, vector_##CONJ1##_##CONJ2) { \
|
||||
const auto em = CONJ1##Expr(tensor(0), tensor(1)); \
|
||||
const auto e = CONJ2##Expr(em, tensor(2)); \
|
||||
const auto l0 = lid(0); \
|
||||
@ -482,7 +504,7 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT)
|
||||
/// lat( i_01 / tensor_1 )
|
||||
/// }
|
||||
#define IMPL_MERGER_TEST_DISJ(OP, UNUSED) \
|
||||
TEST_F(MergerTest3T1L, vector_##OP) { \
|
||||
TEST_P(MergerTest3T1L, vector_##OP) { \
|
||||
const auto e = OP##Expr(tensor(0), tensor(1)); \
|
||||
const auto l0 = lid(0); \
|
||||
const auto t0 = tid(0); \
|
||||
@ -514,7 +536,7 @@ FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_DISJ)
|
||||
/// lat( i_00 i_01 / (tensor_0 * tensor_1) )
|
||||
/// }
|
||||
#define IMPL_MERGER_TEST_CONJ(OP, UNUSED) \
|
||||
TEST_F(MergerTest3T1L, vector_##OP) { \
|
||||
TEST_P(MergerTest3T1L, vector_##OP) { \
|
||||
const auto e = OP##Expr(tensor(0), tensor(1)); \
|
||||
const auto l0 = lid(0); \
|
||||
const auto t0 = tid(0); \
|
||||
@ -544,7 +566,7 @@ FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_CONJ)
|
||||
/// lat( i_02 / tensor_2 )
|
||||
/// }
|
||||
#define IMPL_MERGER_TEST_CONJ_DISJ(CONJ, DISJ) \
|
||||
TEST_F(MergerTest4T1L, vector_##CONJ##_##DISJ) { \
|
||||
TEST_P(MergerTest4T1L, vector_##CONJ##_##DISJ) { \
|
||||
const auto em = CONJ##Expr(tensor(0), tensor(1)); \
|
||||
const auto e = DISJ##Expr(em, tensor(2)); \
|
||||
const auto l0 = lid(0); \
|
||||
@ -587,7 +609,7 @@ FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(IMPL_MERGER_TEST_CONJ_DISJ)
|
||||
/// lat( i_00 / tensor_0 )
|
||||
/// }
|
||||
#define IMPL_MERGER_TEST_DISJ_DISJ(DISJ1, DISJ2) \
|
||||
TEST_F(MergerTest4T1L, Vector_##DISJ1##_##DISJ2) { \
|
||||
TEST_P(MergerTest4T1L, Vector_##DISJ1##_##DISJ2) { \
|
||||
const auto em = DISJ1##Expr(tensor(0), tensor(1)); \
|
||||
const auto e = DISJ2##Expr(em, tensor(2)); \
|
||||
const auto l0 = lid(0); \
|
||||
@ -636,7 +658,7 @@ FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(IMPL_MERGER_TEST_DISJ_DISJ)
|
||||
/// lat( i_00 i_01 i_02 / tensor_0 * tensor_1 * tensor_2 )
|
||||
/// }
|
||||
#define IMPL_MERGER_TEST_CONJ_CONJ(CONJ1, CONJ2) \
|
||||
TEST_F(MergerTest4T1L, vector_##CONJ1##_##CONJ2) { \
|
||||
TEST_P(MergerTest4T1L, vector_##CONJ1##_##CONJ2) { \
|
||||
const auto em = CONJ1##Expr(tensor(0), tensor(1)); \
|
||||
const auto e = CONJ2##Expr(em, tensor(2)); \
|
||||
const auto l0 = lid(0); \
|
||||
@ -675,7 +697,7 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ)
|
||||
/// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense diff
|
||||
/// with lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ).
|
||||
#define IMPL_MERGER_TEST_OPTIMIZED_DISJ(OP, UNUSED) \
|
||||
TEST_F(MergerTest3T1LD, vector_opted_##OP) { \
|
||||
TEST_P(MergerTest3T1LD, vector_opted_##OP) { \
|
||||
const auto e = OP##Expr(tensor(0), tensor(1)); \
|
||||
const auto l0 = lid(0); \
|
||||
const auto t0 = tid(0); \
|
||||
@ -711,7 +733,7 @@ FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_DISJ)
|
||||
/// }
|
||||
/// since i_01 is a dense dimension.
|
||||
#define IMPL_MERGER_TEST_OPTIMIZED_CONJ(OP, UNUSED) \
|
||||
TEST_F(MergerTest3T1LD, vector_opted_##OP) { \
|
||||
TEST_P(MergerTest3T1LD, vector_opted_##OP) { \
|
||||
const auto e = OP##Expr(tensor(0), tensor(1)); \
|
||||
const auto l0 = lid(0); \
|
||||
const auto t0 = tid(0); \
|
||||
@ -746,7 +768,7 @@ FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_CONJ)
|
||||
/// lat( i_00 / tensor_0 cmp 0 )
|
||||
/// lat( i_01 / 0 cmp tensor_1 )
|
||||
/// }
|
||||
TEST_F(MergerTest3T1L, vector_cmp) {
|
||||
TEST_P(MergerTest3T1L, vector_cmp) {
|
||||
const auto e = cmpiExpr(tensor(0), tensor(1));
|
||||
const auto l0 = lid(0);
|
||||
const auto t0 = tid(0);
|
||||
@ -784,7 +806,7 @@ TEST_F(MergerTest3T1L, vector_cmp) {
|
||||
///
|
||||
/// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense diff
|
||||
/// with lat( i_00 i_01 / (sparse_tensor_0 cmp dense_tensor_1) ).
|
||||
TEST_F(MergerTest3T1LD, vector_cmp) {
|
||||
TEST_P(MergerTest3T1LD, vector_cmp) {
|
||||
const auto e = cmpiExpr(tensor(0), tensor(1));
|
||||
const auto l0 = lid(0);
|
||||
const auto t0 = tid(0);
|
||||
|
Loading…
x
Reference in New Issue
Block a user