Introduces a dataflow analysis for tracking offset, size, and stride
ranges of operations.
Inference of the metadata is accomplished through the implementation of
the interface
`InferStridedMetadataOpInterface`.
To keep the size of the patch small, this patch only implements the
interface for the
`memref.subview` operation. It's future work to add more operations.
Example:
```mlir
func.func @memref_subview(%arg0: memref<8x16x4xf32, strided<[64, 4, 1]>>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%0 = test.with_bounds {smax = 13 : index, smin = 11 : index, umax = 13 : index, umin = 11 : index} : index
%1 = test.with_bounds {smax = 7 : index, smin = 5 : index, umax = 7 : index, umin = 5 : index} : index
%subview = memref.subview %arg0[%c0, %c0, %c1] [%1, %0, %c2] [%c1, %c1, %c1] : memref<8x16x4xf32, strided<[64, 4, 1]>> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
return
}
```
Applying `mlir-opt --test-strided-metadata-range-analysis` prints:
```
Op: %subview = memref.subview %arg0[%c0, %c0, %c1] [%1, %0, %c2] [%c1, %c1, %c1] : memref<8x16x4xf32, strided<[64, 4, 1]>> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
result[0]: strided_metadata<offset = [{unsigned : [1, 1] signed : [1, 1]}], sizes = [{unsigned : [5, 7] signed : [5, 7]}, {unsigned : [11, 13] signed : [11, 13]}, {unsigned : [2, 2] signed : [2, 2]}], strides = [{unsigned : [64, 64] signed : [64, 64]}, {unsigned : [4, 4] signed : [4, 4]}, {unsigned : [1, 1] signed : [1, 1]}]>
```
---------
Signed-off-by: Fabian Mora <fabian.mora-cordero@amd.com>
128 lines
4.7 KiB
C++
128 lines
4.7 KiB
C++
//===- StridedMetadataRangeAnalysis.cpp - Integer range analysis --------*- C++
|
|
//-*-===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file defines the dataflow analysis class for integer range inference
|
|
// which is used in transformations over the `arith` dialect such as
|
|
// branch elimination or signed->unsigned rewriting
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Analysis/DataFlow/StridedMetadataRangeAnalysis.h"
|
|
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
|
|
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
|
#include "mlir/IR/Operation.h"
|
|
#include "mlir/IR/Value.h"
|
|
#include "mlir/Support/DebugStringHelper.h"
|
|
#include "llvm/Support/Debug.h"
|
|
#include "llvm/Support/DebugLog.h"
|
|
|
|
#define DEBUG_TYPE "strided-metadata-range-analysis"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::dataflow;
|
|
|
|
/// Get the entry state for a value. For any value that is not a ranked memref,
|
|
/// this function sets the metadata to a top state with no offsets, sizes, or
|
|
/// strides. For `memref` types, this function will use the metadata in the type
|
|
/// to try to deduce as much informaiton as possible.
|
|
static StridedMetadataRange getEntryStateImpl(Value v, int32_t indexBitwidth) {
|
|
// TODO: generalize this method with a type interface.
|
|
auto mTy = dyn_cast<BaseMemRefType>(v.getType());
|
|
|
|
// If not a memref or it's un-ranked, don't infer any metadata.
|
|
if (!mTy || !mTy.hasRank())
|
|
return StridedMetadataRange::getMaxRanges(indexBitwidth, 0, 0, 0);
|
|
|
|
// Get the top state.
|
|
auto metadata =
|
|
StridedMetadataRange::getMaxRanges(indexBitwidth, mTy.getRank());
|
|
|
|
// Compute the offset and strides.
|
|
int64_t offset;
|
|
SmallVector<int64_t> strides;
|
|
if (failed(cast<MemRefType>(mTy).getStridesAndOffset(strides, offset)))
|
|
return metadata;
|
|
|
|
// Refine the metadata if we know it from the type.
|
|
if (!ShapedType::isDynamic(offset)) {
|
|
metadata.getOffsets()[0] =
|
|
ConstantIntRanges::constant(APInt(indexBitwidth, offset));
|
|
}
|
|
for (auto &&[size, range] :
|
|
llvm::zip_equal(mTy.getShape(), metadata.getSizes())) {
|
|
if (ShapedType::isDynamic(size))
|
|
continue;
|
|
range = ConstantIntRanges::constant(APInt(indexBitwidth, size));
|
|
}
|
|
for (auto &&[stride, range] :
|
|
llvm::zip_equal(strides, metadata.getStrides())) {
|
|
if (ShapedType::isDynamic(stride))
|
|
continue;
|
|
range = ConstantIntRanges::constant(APInt(indexBitwidth, stride));
|
|
}
|
|
|
|
return metadata;
|
|
}
|
|
|
|
StridedMetadataRangeAnalysis::StridedMetadataRangeAnalysis(
|
|
DataFlowSolver &solver, int32_t indexBitwidth)
|
|
: SparseForwardDataFlowAnalysis(solver), indexBitwidth(indexBitwidth) {
|
|
assert(indexBitwidth > 0 && "invalid bitwidth");
|
|
}
|
|
|
|
void StridedMetadataRangeAnalysis::setToEntryState(
|
|
StridedMetadataRangeLattice *lattice) {
|
|
propagateIfChanged(lattice, lattice->join(getEntryStateImpl(
|
|
lattice->getAnchor(), indexBitwidth)));
|
|
}
|
|
|
|
LogicalResult StridedMetadataRangeAnalysis::visitOperation(
|
|
Operation *op, ArrayRef<const StridedMetadataRangeLattice *> operands,
|
|
ArrayRef<StridedMetadataRangeLattice *> results) {
|
|
auto inferrable = dyn_cast<InferStridedMetadataOpInterface>(op);
|
|
|
|
// Bail if we cannot reason about the op.
|
|
if (!inferrable) {
|
|
setAllToEntryStates(results);
|
|
return success();
|
|
}
|
|
|
|
LDBG() << "Inferring metadata for: "
|
|
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
|
|
|
|
// Helper function to retrieve int range values.
|
|
auto getIntRange = [&](Value value) -> IntegerValueRange {
|
|
auto lattice = getOrCreateFor<IntegerValueRangeLattice>(
|
|
getProgramPointAfter(op), value);
|
|
return lattice ? lattice->getValue() : IntegerValueRange();
|
|
};
|
|
|
|
// Convert the arguments lattices to a vector.
|
|
SmallVector<StridedMetadataRange> argRanges = llvm::map_to_vector(
|
|
operands, [](const StridedMetadataRangeLattice *lattice) {
|
|
return lattice->getValue();
|
|
});
|
|
|
|
// Callback to set metadata on a result.
|
|
auto joinCallback = [&](Value v, const StridedMetadataRange &md) {
|
|
auto result = cast<OpResult>(v);
|
|
assert(llvm::is_contained(op->getResults(), result));
|
|
LDBG() << "- Inferred metadata: " << md;
|
|
StridedMetadataRangeLattice *lattice = results[result.getResultNumber()];
|
|
ChangeResult changed = lattice->join(md);
|
|
LDBG() << "- Joined metadata: " << lattice->getValue();
|
|
propagateIfChanged(lattice, changed);
|
|
};
|
|
|
|
// Infer the metadata.
|
|
inferrable.inferStridedMetadataRanges(argRanges, getIntRange, joinCallback,
|
|
indexBitwidth);
|
|
return success();
|
|
}
|