llvm-project/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
Groverkss a5a598be44 [MLIR][Presburger] Use PresburgerSpace in constructors
This patch modifies IntegerPolyhedron, IntegerRelation, PresburgerRelation,
PresburgerSet, PWMAFunction, constructors to take PresburgerSpace instead of
dimensions. This allows information present in PresburgerSpace to be carried
better and allows for a general interface.

Reviewed By: arjunp

Differential Revision: https://reviews.llvm.org/D122842
2022-04-01 15:07:26 +05:30

138 lines
4.3 KiB
C++

//===- PresburgerSpace.cpp - MLIR PresburgerSpace Class -------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/Presburger/PresburgerSpace.h"
#include <algorithm>
#include <cassert>
using namespace mlir;
using namespace presburger;
unsigned PresburgerSpace::getNumIdKind(IdKind kind) const {
if (kind == IdKind::Domain)
return getNumDomainIds();
if (kind == IdKind::Range)
return getNumRangeIds();
if (kind == IdKind::Symbol)
return getNumSymbolIds();
if (kind == IdKind::Local)
return numLocals;
llvm_unreachable("IdKind does not exist!");
}
unsigned PresburgerSpace::getIdKindOffset(IdKind kind) const {
if (kind == IdKind::Domain)
return 0;
if (kind == IdKind::Range)
return getNumDomainIds();
if (kind == IdKind::Symbol)
return getNumDimIds();
if (kind == IdKind::Local)
return getNumDimAndSymbolIds();
llvm_unreachable("IdKind does not exist!");
}
unsigned PresburgerSpace::getIdKindEnd(IdKind kind) const {
return getIdKindOffset(kind) + getNumIdKind(kind);
}
unsigned PresburgerSpace::getIdKindOverlap(IdKind kind, unsigned idStart,
unsigned idLimit) const {
unsigned idRangeStart = getIdKindOffset(kind);
unsigned idRangeEnd = getIdKindEnd(kind);
// Compute number of elements in intersection of the ranges [idStart, idLimit)
// and [idRangeStart, idRangeEnd).
unsigned overlapStart = std::max(idStart, idRangeStart);
unsigned overlapEnd = std::min(idLimit, idRangeEnd);
if (overlapStart > overlapEnd)
return 0;
return overlapEnd - overlapStart;
}
IdKind PresburgerSpace::getIdKindAt(unsigned pos) const {
assert(pos < getNumIds() && "`pos` should represent a valid id position");
if (pos < getIdKindEnd(IdKind::Domain))
return IdKind::Domain;
if (pos < getIdKindEnd(IdKind::Range))
return IdKind::Range;
if (pos < getIdKindEnd(IdKind::Symbol))
return IdKind::Symbol;
if (pos < getIdKindEnd(IdKind::Local))
return IdKind::Local;
llvm_unreachable("`pos` should represent a valid id position");
}
unsigned PresburgerSpace::insertId(IdKind kind, unsigned pos, unsigned num) {
assert(pos <= getNumIdKind(kind));
unsigned absolutePos = getIdKindOffset(kind) + pos;
if (kind == IdKind::Domain)
numDomain += num;
else if (kind == IdKind::Range)
numRange += num;
else if (kind == IdKind::Symbol)
numSymbols += num;
else
numLocals += num;
return absolutePos;
}
void PresburgerSpace::removeIdRange(IdKind kind, unsigned idStart,
unsigned idLimit) {
assert(idLimit <= getNumIdKind(kind) && "invalid id limit");
if (idStart >= idLimit)
return;
unsigned numIdsEliminated = idLimit - idStart;
if (kind == IdKind::Domain)
numDomain -= numIdsEliminated;
else if (kind == IdKind::Range)
numRange -= numIdsEliminated;
else if (kind == IdKind::Symbol)
numSymbols -= numIdsEliminated;
else
numLocals -= numIdsEliminated;
}
void PresburgerSpace::truncateIdKind(IdKind kind, unsigned num) {
unsigned curNum = getNumIdKind(kind);
assert(num <= curNum && "Can't truncate to more ids!");
removeIdRange(kind, num, curNum);
}
bool PresburgerSpace::isSpaceCompatible(const PresburgerSpace &other) const {
return getNumDomainIds() == other.getNumDomainIds() &&
getNumRangeIds() == other.getNumRangeIds() &&
getNumSymbolIds() == other.getNumSymbolIds();
}
bool PresburgerSpace::isSpaceEqual(const PresburgerSpace &other) const {
return isSpaceCompatible(other) && getNumLocalIds() == other.getNumLocalIds();
}
void PresburgerSpace::setDimSymbolSeparation(unsigned newSymbolCount) {
assert(newSymbolCount <= getNumDimAndSymbolIds() &&
"invalid separation position");
numRange = numRange + numSymbols - newSymbolCount;
numSymbols = newSymbolCount;
}
void PresburgerSpace::print(llvm::raw_ostream &os) const {
os << "Domain: " << getNumDomainIds() << ", "
<< "Range: " << getNumRangeIds() << ", "
<< "Symbols: " << getNumSymbolIds() << ", "
<< "Locals: " << getNumLocalIds() << "\n";
}
void PresburgerSpace::dump() const { print(llvm::errs()); }