
Moves iteration lattice/merger code into new SparseTensor/Utils directory. A follow-up CL will add lattice/merger unit tests. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D104757
139 lines
4.0 KiB
C++
139 lines
4.0 KiB
C++
//===- Merger.cpp - Implementation of iteration lattices ------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
|
|
|
|
namespace mlir {
|
|
namespace sparse_tensor {
|
|
|
|
unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) {
|
|
unsigned e = tensorExps.size();
|
|
tensorExps.push_back(TensorExp(k, e0, e1, v));
|
|
return e;
|
|
}
|
|
|
|
unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) {
|
|
assert(t < numTensors && i < numLoops);
|
|
unsigned p = latPoints.size();
|
|
latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
|
|
return p;
|
|
}
|
|
|
|
unsigned Merger::addSet() {
|
|
unsigned s = latSets.size();
|
|
latSets.emplace_back(SmallVector<unsigned, 16>());
|
|
return s;
|
|
}
|
|
|
|
unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1) {
|
|
unsigned p = latPoints.size();
|
|
llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits);
|
|
nb |= latPoints[p1].bits;
|
|
unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp);
|
|
latPoints.push_back(LatPoint(nb, e));
|
|
return p;
|
|
}
|
|
|
|
unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1) {
|
|
unsigned s = addSet();
|
|
for (unsigned p0 : latSets[s0])
|
|
for (unsigned p1 : latSets[s1])
|
|
latSets[s].push_back(conjLatPoint(kind, p0, p1));
|
|
return s;
|
|
}
|
|
|
|
unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) {
|
|
unsigned s = takeConj(kind, s0, s1);
|
|
for (unsigned p : latSets[s0])
|
|
latSets[s].push_back(p);
|
|
for (unsigned p : latSets[s1])
|
|
latSets[s].push_back(p);
|
|
return s;
|
|
}
|
|
|
|
unsigned Merger::optimizeSet(unsigned s0) {
|
|
unsigned s = addSet();
|
|
assert(latSets[s0].size() != 0);
|
|
unsigned p0 = latSets[s0][0];
|
|
for (unsigned p1 : latSets[s0]) {
|
|
bool add = true;
|
|
if (p0 != p1) {
|
|
// Is this a straightforward copy?
|
|
unsigned e = latPoints[p1].exp;
|
|
if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor)
|
|
continue;
|
|
// Conjunction already covered?
|
|
for (unsigned p2 : latSets[s]) {
|
|
assert(!latGT(p1, p2)); // Lj => Li would be bad
|
|
if (onlyDenseDiff(p2, p1)) {
|
|
add = false;
|
|
break;
|
|
}
|
|
}
|
|
assert(!add || latGT(p0, p1));
|
|
}
|
|
if (add)
|
|
latSets[s].push_back(p1);
|
|
}
|
|
for (unsigned p : latSets[s])
|
|
latPoints[p].simple = simplifyCond(s, p);
|
|
return s;
|
|
}
|
|
|
|
llvm::BitVector Merger::simplifyCond(unsigned s, unsigned p0) {
|
|
// First determine if this lattice point is a *singleton*, i.e.,
|
|
// the last point in a lattice, no other is less than this one.
|
|
bool isSingleton = true;
|
|
for (unsigned p1 : latSets[s]) {
|
|
if (p0 != p1 && latGT(p0, p1)) {
|
|
isSingleton = false;
|
|
break;
|
|
}
|
|
}
|
|
// Now apply the two basic rules.
|
|
llvm::BitVector simple = latPoints[p0].bits;
|
|
bool reset = isSingleton && hasAnyDimOf(simple, Dim::kSparse);
|
|
for (unsigned b = 0, be = simple.size(); b < be; b++) {
|
|
if (simple[b] && !isDim(b, Dim::kSparse)) {
|
|
if (reset)
|
|
simple.reset(b);
|
|
reset = true;
|
|
}
|
|
}
|
|
return simple;
|
|
}
|
|
|
|
bool Merger::latGT(unsigned i, unsigned j) const {
|
|
const llvm::BitVector &bitsi = latPoints[i].bits;
|
|
const llvm::BitVector &bitsj = latPoints[j].bits;
|
|
assert(bitsi.size() == bitsj.size());
|
|
if (bitsi.count() > bitsj.count()) {
|
|
for (unsigned b = 0, be = bitsj.size(); b < be; b++)
|
|
if (bitsj[b] && !bitsi[b])
|
|
return false;
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
|
|
llvm::BitVector tmp = latPoints[j].bits;
|
|
tmp ^= latPoints[i].bits;
|
|
return !hasAnyDimOf(tmp, Dim::kSparse);
|
|
}
|
|
|
|
bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
|
|
for (unsigned b = 0, be = bits.size(); b < be; b++)
|
|
if (bits[b] && isDim(b, d))
|
|
return true;
|
|
return false;
|
|
}
|
|
|
|
} // namespace sparse_tensor
|
|
} // namespace mlir
|