Groverkss a5a598be44 [MLIR][Presburger] Use PresburgerSpace in constructors
This patch modifies IntegerPolyhedron, IntegerRelation, PresburgerRelation,
PresburgerSet, PWMAFunction, constructors to take PresburgerSpace instead of
dimensions. This allows information present in PresburgerSpace to be carried
better and allows for a general interface.

Reviewed By: arjunp

Differential Revision: https://reviews.llvm.org/D122842
2022-04-01 15:07:26 +05:30

192 lines
7.5 KiB
C++

//===- PWMAFunction.cpp - MLIR PWMAFunction Class -------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/Presburger/PWMAFunction.h"
#include "mlir/Analysis/Presburger/Simplex.h"
using namespace mlir;
using namespace presburger;
// Return the result of subtracting the two given vectors pointwise.
// The vectors must be of the same size.
// e.g., [3, 4, 6] - [2, 5, 1] = [1, -1, 5].
static SmallVector<int64_t, 8> subtract(ArrayRef<int64_t> vecA,
ArrayRef<int64_t> vecB) {
assert(vecA.size() == vecB.size() &&
"Cannot subtract vectors of differing lengths!");
SmallVector<int64_t, 8> result;
result.reserve(vecA.size());
for (unsigned i = 0, e = vecA.size(); i < e; ++i)
result.push_back(vecA[i] - vecB[i]);
return result;
}
PresburgerSet PWMAFunction::getDomain() const {
PresburgerSet domain = PresburgerSet::getEmpty(getSpace());
for (const MultiAffineFunction &piece : pieces)
domain.unionInPlace(piece.getDomain());
return domain;
}
Optional<SmallVector<int64_t, 8>>
MultiAffineFunction::valueAt(ArrayRef<int64_t> point) const {
assert(point.size() == getNumDimAndSymbolIds() &&
"Point has incorrect dimensionality!");
Optional<SmallVector<int64_t, 8>> maybeLocalValues =
getDomain().containsPointNoLocal(point);
if (!maybeLocalValues)
return {};
// The point lies in the domain, so we need to compute the output value.
SmallVector<int64_t, 8> pointHomogenous{llvm::to_vector(point)};
// The given point didn't include the values of locals which the output is a
// function of; we have computed one possible set of values and use them
// here. The function is not allowed to have local ids that take more than
// one possible value.
pointHomogenous.append(*maybeLocalValues);
// The matrix `output` has an affine expression in the ith row, corresponding
// to the expression for the ith value in the output vector. The last column
// of the matrix contains the constant term. Let v be the input point with
// a 1 appended at the end. We can see that output * v gives the desired
// output vector.
pointHomogenous.push_back(1);
SmallVector<int64_t, 8> result =
output.postMultiplyWithColumn(pointHomogenous);
assert(result.size() == getNumOutputs());
return result;
}
Optional<SmallVector<int64_t, 8>>
PWMAFunction::valueAt(ArrayRef<int64_t> point) const {
assert(point.size() == getNumInputs() &&
"Point has incorrect dimensionality!");
for (const MultiAffineFunction &piece : pieces)
if (Optional<SmallVector<int64_t, 8>> output = piece.valueAt(point))
return output;
return {};
}
void MultiAffineFunction::print(raw_ostream &os) const {
os << "Domain:";
IntegerPolyhedron::print(os);
os << "Output:\n";
output.print(os);
os << "\n";
}
void MultiAffineFunction::dump() const { print(llvm::errs()); }
bool MultiAffineFunction::isEqual(const MultiAffineFunction &other) const {
return isSpaceCompatible(other) && getDomain().isEqual(other.getDomain()) &&
isEqualWhereDomainsOverlap(other);
}
unsigned MultiAffineFunction::insertId(IdKind kind, unsigned pos,
unsigned num) {
assert((kind != IdKind::Domain || num == 0) &&
"Domain has to be zero in a set");
unsigned absolutePos = getIdKindOffset(kind) + pos;
output.insertColumns(absolutePos, num);
return IntegerPolyhedron::insertId(kind, pos, num);
}
void MultiAffineFunction::swapId(unsigned posA, unsigned posB) {
output.swapColumns(posA, posB);
IntegerPolyhedron::swapId(posA, posB);
}
void MultiAffineFunction::removeIdRange(IdKind kind, unsigned idStart,
unsigned idLimit) {
output.removeColumns(idStart + getIdKindOffset(kind), idLimit - idStart);
IntegerPolyhedron::removeIdRange(kind, idStart, idLimit);
}
void MultiAffineFunction::eliminateRedundantLocalId(unsigned posA,
unsigned posB) {
unsigned localOffset = getIdKindOffset(IdKind::Local);
output.addToColumn(localOffset + posB, localOffset + posA, /*scale=*/1);
IntegerPolyhedron::eliminateRedundantLocalId(posA, posB);
}
bool MultiAffineFunction::isEqualWhereDomainsOverlap(
MultiAffineFunction other) const {
if (!isSpaceCompatible(other))
return false;
// `commonFunc` has the same output as `this`.
MultiAffineFunction commonFunc = *this;
// After this merge, `commonFunc` and `other` have the same local ids; they
// are merged.
commonFunc.mergeLocalIds(other);
// After this, the domain of `commonFunc` will be the intersection of the
// domains of `this` and `other`.
commonFunc.IntegerPolyhedron::append(other);
// `commonDomainMatching` contains the subset of the common domain
// where the outputs of `this` and `other` match.
//
// We want to add constraints equating the outputs of `this` and `other`.
// However, `this` may have difference local ids from `other`, whereas we
// need both to have the same locals. Accordingly, we use `commonFunc.output`
// in place of `this->output`, since `commonFunc` has the same output but also
// has its locals merged.
IntegerPolyhedron commonDomainMatching = commonFunc.getDomain();
for (unsigned row = 0, e = getNumOutputs(); row < e; ++row)
commonDomainMatching.addEquality(
subtract(commonFunc.output.getRow(row), other.output.getRow(row)));
// If the whole common domain is a subset of commonDomainMatching, then they
// are equal and the two functions match on the whole common domain.
return commonFunc.getDomain().isSubsetOf(commonDomainMatching);
}
/// Two PWMAFunctions are equal if they have the same dimensionalities,
/// the same domain, and take the same value at every point in the domain.
bool PWMAFunction::isEqual(const PWMAFunction &other) const {
if (!isSpaceCompatible(other))
return false;
if (!this->getDomain().isEqual(other.getDomain()))
return false;
// Check if, whenever the domains of a piece of `this` and a piece of `other`
// overlap, they take the same output value. If `this` and `other` have the
// same domain (checked above), then this check passes iff the two functions
// have the same output at every point in the domain.
for (const MultiAffineFunction &aPiece : this->pieces)
for (const MultiAffineFunction &bPiece : other.pieces)
if (!aPiece.isEqualWhereDomainsOverlap(bPiece))
return false;
return true;
}
void PWMAFunction::addPiece(const MultiAffineFunction &piece) {
assert(piece.isSpaceCompatible(*this) &&
"Piece to be added is not compatible with this PWMAFunction!");
assert(piece.isConsistent() && "Piece is internally inconsistent!");
assert(this->getDomain()
.intersect(PresburgerSet(piece.getDomain()))
.isIntegerEmpty() &&
"New piece's domain overlaps with that of existing pieces!");
pieces.push_back(piece);
}
void PWMAFunction::addPiece(const IntegerPolyhedron &domain,
const Matrix &output) {
addPiece(MultiAffineFunction(domain, output));
}
void PWMAFunction::print(raw_ostream &os) const {
os << pieces.size() << " pieces:\n";
for (const MultiAffineFunction &piece : pieces)
piece.print(os);
}
void PWMAFunction::dump() const { print(llvm::errs()); }