MaheshRavishankar fc093f1361
[mlir][Interfaces] Add interface methods to allow reifying single result/single dim of result. (#162924)
Current implementation of `reifyResultShapes` forces all
implementations to return all dimensions of all results. This can be
wasteful when you only require dimensions of one result, or a single
dimension of a result. Further this also creates issues with using
patterns to resolve the `tensor.dim` and `memref.dim` operations since
the extra operations created result in the pattern rewriter entering
an infinite loop (eventually breaking out of the loop due to the
iteration limit on the pattern rewriter). This is demonstrated by some
of the test cases added here that hit this limit when using
`--resolve-shaped-type-result-dims` and
`--resolve-ranked-shaped-type-result-dims`. To resolve this issue the
interface should allow for creating just the operations needed. This
change is the first step in resolving this.

The original implementation was done with the restriction in mind that
it might not always be possible to compute dimension of a single
result or one dimension of a single result in all cases. To account
for such cases, two additional interface methods are added

- `reifyShapeOfResult` (which allows reifying dimensions of
  just one result), has a default implementation that calls
  `reifyResultShapes` and returns the dimensions of a single result.
- `reifyDimOfResult` (which allows reifying a single dimension of a
  single result) has a default implementation that calls
  `reifyDimOfResult` and returns the value for the dimension of the
  result (which in turn for the default case would call
  `reifyDimOfResult`).

While this change sets up the interface, ideally most operations will
implement the `refiyDimOfResult` when possible. For almost all
operations in tree this is true. Subsequent commits will change those
incrementally.

Some of the tests added here that check that the default
implementations for the above method work as expected, also end up
hitting the pattern rewriter limit when using
`--resolve-ranked-shaped-type-result-dims`/
`--resolve-ranked-shaped-type-result-dims`. For testing purposes, a
flag is added to these passes that ignore the error returned by the
pattern application (this flag is left on by default to maintain
current state).

Changes required downstream to integrate this change

1. In operation definitions in .td files, for those operations that
   implement the `ReifyRankedShapedTypeOpInterface`.

```
def <op-name> : Op<..., [...,
    DeclareOpInterfaceMethods[ReifyRankedShapedTypeOpInterface]]>
```

should be changed to

```
def <op-name> : Op<..., [...,
    DeclareOpInterfaceMethods[ReifyRankedShapedTypeOpInterface, [
        "reifyResultShapes"]]]>
```

---------

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
2025-11-10 09:01:01 -08:00

1640 lines
62 KiB
C++

//===- TestOpDefs.cpp - MLIR Test Dialect Operation Hooks -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
#include "TestOps.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
using namespace mlir;
using namespace test;
//===----------------------------------------------------------------------===//
// OverridenSymbolVisibilityOp
//===----------------------------------------------------------------------===//
SymbolTable::Visibility OverriddenSymbolVisibilityOp::getVisibility() {
return SymbolTable::Visibility::Private;
}
static StringLiteral getVisibilityString(SymbolTable::Visibility visibility) {
switch (visibility) {
case SymbolTable::Visibility::Private:
return "private";
case SymbolTable::Visibility::Nested:
return "nested";
case SymbolTable::Visibility::Public:
return "public";
}
}
void OverriddenSymbolVisibilityOp::setVisibility(
SymbolTable::Visibility visibility) {
emitOpError("cannot change visibility of symbol to ")
<< getVisibilityString(visibility);
}
//===----------------------------------------------------------------------===//
// TestBranchOp
//===----------------------------------------------------------------------===//
SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) {
assert(index == 0 && "invalid successor index");
return SuccessorOperands(getTargetOperandsMutable());
}
//===----------------------------------------------------------------------===//
// TestProducingBranchOp
//===----------------------------------------------------------------------===//
SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) {
assert(index <= 1 && "invalid successor index");
if (index == 1)
return SuccessorOperands(getFirstOperandsMutable());
return SuccessorOperands(getSecondOperandsMutable());
}
//===----------------------------------------------------------------------===//
// TestInternalBranchOp
//===----------------------------------------------------------------------===//
SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) {
assert(index <= 1 && "invalid successor index");
if (index == 0)
return SuccessorOperands(0, getSuccessOperandsMutable());
return SuccessorOperands(1, getErrorOperandsMutable());
}
//===----------------------------------------------------------------------===//
// TestCallOp
//===----------------------------------------------------------------------===//
LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// Check that the callee attribute was specified.
auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
if (!fnAttr)
return emitOpError("requires a 'callee' symbol reference attribute");
if (!symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(*this, fnAttr))
return emitOpError() << "'" << fnAttr.getValue()
<< "' does not reference a valid function";
return success();
}
//===----------------------------------------------------------------------===//
// FoldToCallOp
//===----------------------------------------------------------------------===//
namespace {
struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
LogicalResult matchAndRewrite(FoldToCallOp op,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(),
op.getCalleeAttr(), ValueRange());
return success();
}
};
} // namespace
void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldToCallOpPattern>(context);
}
//===----------------------------------------------------------------------===//
// IsolatedRegionOp - test parsing passthrough operands
//===----------------------------------------------------------------------===//
ParseResult IsolatedRegionOp::parse(OpAsmParser &parser,
OperationState &result) {
// Parse the input operand.
OpAsmParser::Argument argInfo;
argInfo.type = parser.getBuilder().getIndexType();
if (parser.parseOperand(argInfo.ssaName) ||
parser.resolveOperand(argInfo.ssaName, argInfo.type, result.operands))
return failure();
// Parse the body region, and reuse the operand info as the argument info.
Region *body = result.addRegion();
return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/true);
}
void IsolatedRegionOp::print(OpAsmPrinter &p) {
p << ' ';
p.printOperand(getOperand());
p.shadowRegionArgs(getRegion(), getOperand());
p << ' ';
p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
}
//===----------------------------------------------------------------------===//
// SSACFGRegionOp
//===----------------------------------------------------------------------===//
RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
return RegionKind::SSACFG;
}
//===----------------------------------------------------------------------===//
// GraphRegionOp
//===----------------------------------------------------------------------===//
RegionKind GraphRegionOp::getRegionKind(unsigned index) {
return RegionKind::Graph;
}
//===----------------------------------------------------------------------===//
// IsolatedGraphRegionOp
//===----------------------------------------------------------------------===//
RegionKind IsolatedGraphRegionOp::getRegionKind(unsigned index) {
return RegionKind::Graph;
}
//===----------------------------------------------------------------------===//
// AffineScopeOp
//===----------------------------------------------------------------------===//
ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) {
// Parse the body region, and reuse the operand info as the argument info.
Region *body = result.addRegion();
return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
}
void AffineScopeOp::print(OpAsmPrinter &p) {
p << " ";
p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
}
//===----------------------------------------------------------------------===//
// TestRemoveOpWithInnerOps
//===----------------------------------------------------------------------===//
namespace {
struct TestRemoveOpWithInnerOps
: public OpRewritePattern<TestOpWithRegionPattern> {
using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
PatternRewriter &rewriter) const override {
rewriter.eraseOp(op);
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// TestOpWithRegionPattern
//===----------------------------------------------------------------------===//
void TestOpWithRegionPattern::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<TestRemoveOpWithInnerOps>(context);
}
//===----------------------------------------------------------------------===//
// TestOpWithRegionFold
//===----------------------------------------------------------------------===//
OpFoldResult TestOpWithRegionFold::fold(FoldAdaptor adaptor) {
return getOperand();
}
//===----------------------------------------------------------------------===//
// TestOpConstant
//===----------------------------------------------------------------------===//
OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) { return getValue(); }
//===----------------------------------------------------------------------===//
// TestOpWithVariadicResultsAndFolder
//===----------------------------------------------------------------------===//
LogicalResult TestOpWithVariadicResultsAndFolder::fold(
FoldAdaptor adaptor, SmallVectorImpl<OpFoldResult> &results) {
for (Value input : this->getOperands()) {
results.push_back(input);
}
return success();
}
//===----------------------------------------------------------------------===//
// TestOpInPlaceFold
//===----------------------------------------------------------------------===//
OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) {
// Exercise the fact that an operation created with createOrFold should be
// allowed to access its parent block.
assert(getOperation()->getBlock() &&
"expected that operation is not unlinked");
if (adaptor.getOp() && !getProperties().attr) {
// The folder adds "attr" if not present.
getProperties().attr = dyn_cast_or_null<IntegerAttr>(adaptor.getOp());
return getResult();
}
return {};
}
//===----------------------------------------------------------------------===//
// OpWithInferTypeInterfaceOp
//===----------------------------------------------------------------------===//
LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
MLIRContext *, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (operands[0].getType() != operands[1].getType()) {
return emitOptionalError(location, "operand type mismatch ",
operands[0].getType(), " vs ",
operands[1].getType());
}
inferredReturnTypes.assign({operands[0].getType()});
return success();
}
//===----------------------------------------------------------------------===//
// OpWithShapedTypeInferTypeInterfaceOp
//===----------------------------------------------------------------------===//
LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
MLIRContext *context, std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
// Create return type consisting of the last element of the first operand.
auto operandType = operands.front().getType();
auto sval = dyn_cast<ShapedType>(operandType);
if (!sval)
return emitOptionalError(location, "only shaped type operands allowed");
int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic;
auto type = IntegerType::get(context, 17);
Attribute encoding;
if (auto rankedTy = dyn_cast<RankedTensorType>(sval))
encoding = rankedTy.getEncoding();
inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding));
return success();
}
LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
OpBuilder &builder, ValueRange operands,
llvm::SmallVectorImpl<Value> &shapes) {
shapes = SmallVector<Value, 1>{
builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
return success();
}
//===----------------------------------------------------------------------===//
// OpWithResultShapeInterfaceOp
//===----------------------------------------------------------------------===//
LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
OpBuilder &builder, ValueRange operands,
llvm::SmallVectorImpl<Value> &shapes) {
Location loc = getLoc();
shapes.reserve(operands.size());
for (Value operand : llvm::reverse(operands)) {
auto rank = cast<RankedTensorType>(operand.getType()).getRank();
auto currShape = llvm::to_vector<4>(
llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value {
return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
}));
shapes.push_back(tensor::FromElementsOp::create(
builder, getLoc(),
RankedTensorType::get({rank}, builder.getIndexType()), currShape));
}
return success();
}
//===----------------------------------------------------------------------===//
// ReifyShapedTypeUsingReifyResultShapesOp
//===----------------------------------------------------------------------===//
LogicalResult ReifyShapedTypeUsingReifyResultShapesOp::reifyResultShapes(
OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
Location loc = getLoc();
shapes.reserve(getNumOperands());
for (Value operand : llvm::reverse(getOperands())) {
auto tensorType = cast<RankedTensorType>(operand.getType());
auto currShape = llvm::to_vector<4>(llvm::map_range(
llvm::seq<int64_t>(0, tensorType.getRank()),
[&](int64_t dim) -> OpFoldResult {
return tensorType.isDynamicDim(dim)
? static_cast<OpFoldResult>(
builder.createOrFold<tensor::DimOp>(loc, operand,
dim))
: static_cast<OpFoldResult>(
builder.getIndexAttr(tensorType.getDimSize(dim)));
}));
shapes.emplace_back(std::move(currShape));
}
return success();
}
//===----------------------------------------------------------------------===//
// ReifyShapedTypeUsingReifyShapeOfResultOp
//===----------------------------------------------------------------------===//
LogicalResult ReifyShapedTypeUsingReifyShapeOfResultOp::reifyResultShapes(
OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
return failure();
}
FailureOr<SmallVector<OpFoldResult>>
ReifyShapedTypeUsingReifyShapeOfResultOp::reifyShapeOfResult(OpBuilder &builder,
int resultIndex) {
Location loc = getLoc();
Value sourceOperand = getOperand(getNumOperands() - 1 - resultIndex);
SmallVector<OpFoldResult> shape =
tensor::getMixedSizes(builder, loc, sourceOperand);
return shape;
}
//===----------------------------------------------------------------------===//
// ReifyShapedTypeUsingReifyDimOfResultOp
//===----------------------------------------------------------------------===//
LogicalResult ReifyShapedTypeUsingReifyDimOfResultOp::reifyResultShapes(
OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
return failure();
}
FailureOr<SmallVector<OpFoldResult>>
ReifyShapedTypeUsingReifyDimOfResultOp::reifyShapeOfResult(OpBuilder &builder,
int resultIndex) {
return failure();
}
FailureOr<OpFoldResult>
ReifyShapedTypeUsingReifyDimOfResultOp::reifyDimOfResult(OpBuilder &builder,
int resultIndex,
int dim) {
Location loc = getLoc();
Value sourceOperand = getOperand(getNumOperands() - 1 - resultIndex);
OpFoldResult shape = tensor::getMixedSize(builder, loc, sourceOperand, dim);
return shape;
}
//===----------------------------------------------------------------------===//
// UnreifableResultShapesOp
//===----------------------------------------------------------------------===//
LogicalResult UnreifiableResultShapesOp::reifyResultShapes(
OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
Location loc = getLoc();
shapes.resize(1);
shapes[0] = {tensor::getMixedSize(builder, loc, getOperand(), 0),
OpFoldResult()};
return success();
}
//===----------------------------------------------------------------------===//
// UnreifableResultShapeOp
//===----------------------------------------------------------------------===//
LogicalResult UnreifiableResultShapeOp::reifyResultShapes(
OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
return failure();
}
FailureOr<SmallVector<OpFoldResult>>
UnreifiableResultShapeOp::reifyShapeOfResult(OpBuilder &builder,
int resultIndex) {
SmallVector<OpFoldResult> shape = {
tensor::getMixedSize(builder, getLoc(), getOperand(), 0), OpFoldResult()};
return shape;
}
//===----------------------------------------------------------------------===//
// UnreifableResultShapeOp
//===----------------------------------------------------------------------===//
LogicalResult UnreifiableDimOfResultShapeOp::reifyResultShapes(
OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
return failure();
}
FailureOr<SmallVector<OpFoldResult>>
UnreifiableDimOfResultShapeOp::reifyShapeOfResult(OpBuilder &builder,
int resultIndex) {
return failure();
}
FailureOr<OpFoldResult>
UnreifiableDimOfResultShapeOp::reifyDimOfResult(OpBuilder &builder,
int resultIndex, int dim) {
if (dim == 0)
return tensor::getMixedSize(builder, getLoc(), getOperand(), 0);
return failure();
}
//===----------------------------------------------------------------------===//
// SideEffectOp
//===----------------------------------------------------------------------===//
namespace {
/// A test resource for side effects.
struct TestResource : public SideEffects::Resource::Base<TestResource> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource)
StringRef getName() final { return "<Test>"; }
};
} // namespace
void SideEffectOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
// Check for an effects attribute on the op instance.
ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
if (!effectsAttr)
return;
for (Attribute element : effectsAttr) {
DictionaryAttr effectElement = cast<DictionaryAttr>(element);
// Get the specific memory effect.
MemoryEffects::Effect *effect =
StringSwitch<MemoryEffects::Effect *>(
cast<StringAttr>(effectElement.get("effect")).getValue())
.Case("allocate", MemoryEffects::Allocate::get())
.Case("free", MemoryEffects::Free::get())
.Case("read", MemoryEffects::Read::get())
.Case("write", MemoryEffects::Write::get());
// Check for a non-default resource to use.
SideEffects::Resource *resource = SideEffects::DefaultResource::get();
if (effectElement.get("test_resource"))
resource = TestResource::get();
// Check for a result to affect.
if (effectElement.get("on_result"))
effects.emplace_back(effect, getOperation()->getOpResults()[0], resource);
else if (Attribute ref = effectElement.get("on_reference"))
effects.emplace_back(effect, cast<SymbolRefAttr>(ref), resource);
else
effects.emplace_back(effect, resource);
}
}
void SideEffectOp::getEffects(
SmallVectorImpl<TestEffects::EffectInstance> &effects) {
testSideEffectOpGetEffect(getOperation(), effects);
}
void SideEffectWithRegionOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
// Check for an effects attribute on the op instance.
ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
if (!effectsAttr)
return;
for (Attribute element : effectsAttr) {
DictionaryAttr effectElement = cast<DictionaryAttr>(element);
// Get the specific memory effect.
MemoryEffects::Effect *effect =
StringSwitch<MemoryEffects::Effect *>(
cast<StringAttr>(effectElement.get("effect")).getValue())
.Case("allocate", MemoryEffects::Allocate::get())
.Case("free", MemoryEffects::Free::get())
.Case("read", MemoryEffects::Read::get())
.Case("write", MemoryEffects::Write::get());
// Check for a non-default resource to use.
SideEffects::Resource *resource = SideEffects::DefaultResource::get();
if (effectElement.get("test_resource"))
resource = TestResource::get();
// Check for a result to affect.
if (effectElement.get("on_result"))
effects.emplace_back(effect, getOperation()->getOpResults()[0], resource);
else if (effectElement.get("on_operand"))
effects.emplace_back(effect, &getOperation()->getOpOperands()[0],
resource);
else if (effectElement.get("on_argument"))
effects.emplace_back(effect, getOperation()->getRegion(0).getArgument(0),
resource);
else if (Attribute ref = effectElement.get("on_reference"))
effects.emplace_back(effect, cast<SymbolRefAttr>(ref), resource);
else
effects.emplace_back(effect, resource);
}
}
void SideEffectWithRegionOp::getEffects(
SmallVectorImpl<TestEffects::EffectInstance> &effects) {
testSideEffectOpGetEffect(getOperation(), effects);
}
//===----------------------------------------------------------------------===//
// StringAttrPrettyNameOp
//===----------------------------------------------------------------------===//
// This op has fancy handling of its SSA result name.
ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser,
OperationState &result) {
// Add the result types.
for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
result.addTypes(parser.getBuilder().getIntegerType(32));
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
return failure();
// If the attribute dictionary contains no 'names' attribute, infer it from
// the SSA name (if specified).
bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
return attr.getName() == "names";
});
// If there was no name specified, check to see if there was a useful name
// specified in the asm file.
if (hadNames || parser.getNumResults() == 0)
return success();
SmallVector<StringRef, 4> names;
auto *context = result.getContext();
for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
auto resultName = parser.getResultName(i);
StringRef nameStr;
if (!resultName.first.empty() && !isdigit(resultName.first[0]))
nameStr = resultName.first;
names.push_back(nameStr);
}
auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
result.attributes.push_back({StringAttr::get(context, "names"), namesAttr});
return success();
}
void StringAttrPrettyNameOp::print(OpAsmPrinter &p) {
// Note that we only need to print the "name" attribute if the asmprinter
// result name disagrees with it. This can happen in strange cases, e.g.
// when there are conflicts.
bool namesDisagree = getNames().size() != getNumResults();
SmallString<32> resultNameStr;
for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) {
resultNameStr.clear();
llvm::raw_svector_ostream tmpStream(resultNameStr);
p.printOperand(getResult(i), tmpStream);
auto expectedName = dyn_cast<StringAttr>(getNames()[i]);
if (!expectedName ||
tmpStream.str().drop_front() != expectedName.getValue()) {
namesDisagree = true;
}
}
if (namesDisagree)
p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
else
p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"});
}
// We set the SSA name in the asm syntax to the contents of the name
// attribute.
void StringAttrPrettyNameOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
auto value = getNames();
for (size_t i = 0, e = value.size(); i != e; ++i)
if (auto str = dyn_cast<StringAttr>(value[i]))
if (!str.getValue().empty())
setNameFn(getResult(i), str.getValue());
}
//===----------------------------------------------------------------------===//
// CustomResultsNameOp
//===----------------------------------------------------------------------===//
void CustomResultsNameOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
ArrayAttr value = getNames();
for (size_t i = 0, e = value.size(); i != e; ++i)
if (auto str = dyn_cast<StringAttr>(value[i]))
if (!str.empty())
setNameFn(getResult(i), str.getValue());
}
//===----------------------------------------------------------------------===//
// ResultNameFromTypeOp
//===----------------------------------------------------------------------===//
void ResultNameFromTypeOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
auto result = getResult();
auto setResultNameFn = [&](::llvm::StringRef name) {
setNameFn(result, name);
};
auto opAsmTypeInterface =
::mlir::cast<::mlir::OpAsmTypeInterface>(result.getType());
opAsmTypeInterface.getAsmName(setResultNameFn);
}
//===----------------------------------------------------------------------===//
// BlockArgumentNameFromTypeOp
//===----------------------------------------------------------------------===//
void BlockArgumentNameFromTypeOp::getAsmBlockArgumentNames(
::mlir::Region &region, ::mlir::OpAsmSetValueNameFn setNameFn) {
for (auto &block : region) {
for (auto arg : block.getArguments()) {
if (auto opAsmTypeInterface =
::mlir::dyn_cast<::mlir::OpAsmTypeInterface>(arg.getType())) {
auto setArgNameFn = [&](StringRef name) { setNameFn(arg, name); };
opAsmTypeInterface.getAsmName(setArgNameFn);
}
}
}
}
//===----------------------------------------------------------------------===//
// ResultTypeWithTraitOp
//===----------------------------------------------------------------------===//
LogicalResult ResultTypeWithTraitOp::verify() {
if ((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>())
return success();
return emitError("result type should have trait 'TestTypeTrait'");
}
//===----------------------------------------------------------------------===//
// AttrWithTraitOp
//===----------------------------------------------------------------------===//
LogicalResult AttrWithTraitOp::verify() {
if (getAttr().hasTrait<AttributeTrait::TestAttrTrait>())
return success();
return emitError("'attr' attribute should have trait 'TestAttrTrait'");
}
//===----------------------------------------------------------------------===//
// RegionIfOp
//===----------------------------------------------------------------------===//
void RegionIfOp::print(OpAsmPrinter &p) {
p << " ";
p.printOperands(getOperands());
p << ": " << getOperandTypes();
p.printArrowTypeList(getResultTypes());
p << " then ";
p.printRegion(getThenRegion(),
/*printEntryBlockArgs=*/true,
/*printBlockTerminators=*/true);
p << " else ";
p.printRegion(getElseRegion(),
/*printEntryBlockArgs=*/true,
/*printBlockTerminators=*/true);
p << " join ";
p.printRegion(getJoinRegion(),
/*printEntryBlockArgs=*/true,
/*printBlockTerminators=*/true);
}
ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfos;
SmallVector<Type, 2> operandTypes;
result.regions.reserve(3);
Region *thenRegion = result.addRegion();
Region *elseRegion = result.addRegion();
Region *joinRegion = result.addRegion();
// Parse operand, type and arrow type lists.
if (parser.parseOperandList(operandInfos) ||
parser.parseColonTypeList(operandTypes) ||
parser.parseArrowTypeList(result.types))
return failure();
// Parse all attached regions.
if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
return failure();
return parser.resolveOperands(operandInfos, operandTypes,
parser.getCurrentLocation(), result.operands);
}
OperandRange RegionIfOp::getEntrySuccessorOperands(RegionSuccessor successor) {
assert(llvm::is_contained({&getThenRegion(), &getElseRegion()},
successor.getSuccessor()) &&
"invalid region index");
return getOperands();
}
void RegionIfOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
// We always branch to the join region.
if (!point.isParent()) {
if (point.getTerminatorPredecessorOrNull()->getParentRegion() !=
&getJoinRegion())
regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
else
regions.push_back(RegionSuccessor(getOperation(), getResults()));
return;
}
// The then and else regions are the entry regions of this op.
regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs()));
regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs()));
}
void RegionIfOp::getRegionInvocationBounds(
ArrayRef<Attribute> operands,
SmallVectorImpl<InvocationBounds> &invocationBounds) {
// Each region is invoked at most once.
invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1});
}
//===----------------------------------------------------------------------===//
// AnyCondOp
//===----------------------------------------------------------------------===//
void AnyCondOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {
// The parent op branches into the only region, and the region branches back
// to the parent op.
if (point.isParent())
regions.emplace_back(&getRegion());
else
regions.emplace_back(getOperation(), getResults());
}
void AnyCondOp::getRegionInvocationBounds(
ArrayRef<Attribute> operands,
SmallVectorImpl<InvocationBounds> &invocationBounds) {
invocationBounds.emplace_back(1, 1);
}
//===----------------------------------------------------------------------===//
// SingleBlockImplicitTerminatorOp
//===----------------------------------------------------------------------===//
/// Testing the correctness of some traits.
static_assert(
llvm::is_detected<OpTrait::has_implicit_terminator_t,
SingleBlockImplicitTerminatorOp>::value,
"has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
static_assert(OpTrait::hasSingleBlockImplicitTerminator<
SingleBlockImplicitTerminatorOp>::value,
"hasSingleBlockImplicitTerminator does not match "
"SingleBlockImplicitTerminatorOp");
//===----------------------------------------------------------------------===//
// SingleNoTerminatorCustomAsmOp
//===----------------------------------------------------------------------===//
ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser,
OperationState &state) {
Region *body = state.addRegion();
if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
return failure();
return success();
}
void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) {
printer.printRegion(
getRegion(), /*printEntryBlockArgs=*/false,
// This op has a single block without terminators. But explicitly mark
// as not printing block terminators for testing.
/*printBlockTerminators=*/false);
}
//===----------------------------------------------------------------------===//
// TestVerifiersOp
//===----------------------------------------------------------------------===//
LogicalResult TestVerifiersOp::verify() {
if (!getRegion().hasOneBlock())
return emitOpError("`hasOneBlock` trait hasn't been verified");
Operation *definingOp = getInput().getDefiningOp();
if (definingOp && failed(mlir::verify(definingOp)))
return emitOpError("operand hasn't been verified");
// Avoid using `emitRemark(msg)` since that will trigger an infinite verifier
// loop.
mlir::emitRemark(getLoc(), "success run of verifier");
return success();
}
LogicalResult TestVerifiersOp::verifyRegions() {
if (!getRegion().hasOneBlock())
return emitOpError("`hasOneBlock` trait hasn't been verified");
for (Block &block : getRegion())
for (Operation &op : block)
if (failed(mlir::verify(&op)))
return emitOpError("nested op hasn't been verified");
// Avoid using `emitRemark(msg)` since that will trigger an infinite verifier
// loop.
mlir::emitRemark(getLoc(), "success run of region verifier");
return success();
}
//===----------------------------------------------------------------------===//
// Test InferIntRangeInterface
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// TestWithBoundsOp
//===----------------------------------------------------------------------===//
void TestWithBoundsOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRanges) {
setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()});
}
//===----------------------------------------------------------------------===//
// TestWithBoundsRegionOp
//===----------------------------------------------------------------------===//
ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser,
OperationState &result) {
if (parser.parseOptionalAttrDict(result.attributes))
return failure();
// Parse the input argument
OpAsmParser::Argument argInfo;
if (failed(parser.parseArgument(argInfo, true)))
return failure();
// Parse the body region, and reuse the operand info as the argument info.
Region *body = result.addRegion();
return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/false);
}
void TestWithBoundsRegionOp::print(OpAsmPrinter &p) {
p.printOptionalAttrDict((*this)->getAttrs());
p << ' ';
p.printRegionArgument(getRegion().getArgument(0), /*argAttrs=*/{},
/*omitType=*/false);
p << ' ';
p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
}
void TestWithBoundsRegionOp::inferResultRanges(
ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
Value arg = getRegion().getArgument(0);
setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()});
}
//===----------------------------------------------------------------------===//
// TestIncrementOp
//===----------------------------------------------------------------------===//
void TestIncrementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRanges) {
const ConstantIntRanges &range = argRanges[0];
APInt one(range.umin().getBitWidth(), 1);
setResultRanges(getResult(),
{range.umin().uadd_sat(one), range.umax().uadd_sat(one),
range.smin().sadd_sat(one), range.smax().sadd_sat(one)});
}
//===----------------------------------------------------------------------===//
// TestReflectBoundsOp
//===----------------------------------------------------------------------===//
void TestReflectBoundsOp::inferResultRanges(
ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
const ConstantIntRanges &range = argRanges[0];
MLIRContext *ctx = getContext();
Builder b(ctx);
Type sIntTy, uIntTy;
// For plain `IntegerType`s, we can derive the appropriate signed and unsigned
// Types for the Attributes.
Type type = getElementTypeOrSelf(getType());
if (auto intTy = llvm::dyn_cast<IntegerType>(type)) {
unsigned bitwidth = intTy.getWidth();
sIntTy = b.getIntegerType(bitwidth, /*isSigned=*/true);
uIntTy = b.getIntegerType(bitwidth, /*isSigned=*/false);
} else {
sIntTy = uIntTy = type;
}
setUminAttr(b.getIntegerAttr(uIntTy, range.umin()));
setUmaxAttr(b.getIntegerAttr(uIntTy, range.umax()));
setSminAttr(b.getIntegerAttr(sIntTy, range.smin()));
setSmaxAttr(b.getIntegerAttr(sIntTy, range.smax()));
setResultRanges(getResult(), range);
}
//===----------------------------------------------------------------------===//
// ConversionFuncOp
//===----------------------------------------------------------------------===//
ParseResult ConversionFuncOp::parse(OpAsmParser &parser,
OperationState &result) {
auto buildFuncType =
[](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
function_interface_impl::VariadicFlag,
std::string &) { return builder.getFunctionType(argTypes, results); };
return function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false,
getFunctionTypeAttrName(result.name), buildFuncType,
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
}
void ConversionFuncOp::print(OpAsmPrinter &p) {
function_interface_impl::printFunctionOp(
p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
getArgAttrsAttrName(), getResAttrsAttrName());
}
//===----------------------------------------------------------------------===//
// TestValueWithBoundsOp
//===----------------------------------------------------------------------===//
void TestValueWithBoundsOp::populateBoundsForIndexValue(
Value v, ValueBoundsConstraintSet &cstr) {
cstr.bound(v) >= getMin().getSExtValue();
cstr.bound(v) <= getMax().getSExtValue();
}
//===----------------------------------------------------------------------===//
// ReifyBoundOp
//===----------------------------------------------------------------------===//
mlir::presburger::BoundType ReifyBoundOp::getBoundType() {
if (getType() == "EQ")
return mlir::presburger::BoundType::EQ;
if (getType() == "LB")
return mlir::presburger::BoundType::LB;
if (getType() == "UB")
return mlir::presburger::BoundType::UB;
llvm_unreachable("invalid bound type");
}
LogicalResult ReifyBoundOp::verify() {
if (isa<ShapedType>(getVar().getType())) {
if (!getDim().has_value())
return emitOpError("expected 'dim' attribute for shaped type variable");
} else if (getVar().getType().isIndex()) {
if (getDim().has_value())
return emitOpError("unexpected 'dim' attribute for index variable");
} else {
return emitOpError("expected index-typed variable or shape type variable");
}
if (getConstant() && getScalable())
return emitOpError("'scalable' and 'constant' are mutually exlusive");
if (getScalable() != getVscaleMin().has_value())
return emitOpError("expected 'vscale_min' if and only if 'scalable'");
if (getScalable() != getVscaleMax().has_value())
return emitOpError("expected 'vscale_min' if and only if 'scalable'");
return success();
}
ValueBoundsConstraintSet::Variable ReifyBoundOp::getVariable() {
if (getDim().has_value())
return ValueBoundsConstraintSet::Variable(getVar(), *getDim());
return ValueBoundsConstraintSet::Variable(getVar());
}
//===----------------------------------------------------------------------===//
// CompareOp
//===----------------------------------------------------------------------===//
ValueBoundsConstraintSet::ComparisonOperator
CompareOp::getComparisonOperator() {
if (getCmp() == "EQ")
return ValueBoundsConstraintSet::ComparisonOperator::EQ;
if (getCmp() == "LT")
return ValueBoundsConstraintSet::ComparisonOperator::LT;
if (getCmp() == "LE")
return ValueBoundsConstraintSet::ComparisonOperator::LE;
if (getCmp() == "GT")
return ValueBoundsConstraintSet::ComparisonOperator::GT;
if (getCmp() == "GE")
return ValueBoundsConstraintSet::ComparisonOperator::GE;
llvm_unreachable("invalid comparison operator");
}
mlir::ValueBoundsConstraintSet::Variable CompareOp::getLhs() {
if (!getLhsMap())
return ValueBoundsConstraintSet::Variable(getVarOperands()[0]);
SmallVector<Value> mapOperands(
getVarOperands().slice(0, getLhsMap()->getNumInputs()));
return ValueBoundsConstraintSet::Variable(*getLhsMap(), mapOperands);
}
mlir::ValueBoundsConstraintSet::Variable CompareOp::getRhs() {
int64_t rhsOperandsBegin = getLhsMap() ? getLhsMap()->getNumInputs() : 1;
if (!getRhsMap())
return ValueBoundsConstraintSet::Variable(
getVarOperands()[rhsOperandsBegin]);
SmallVector<Value> mapOperands(
getVarOperands().slice(rhsOperandsBegin, getRhsMap()->getNumInputs()));
return ValueBoundsConstraintSet::Variable(*getRhsMap(), mapOperands);
}
LogicalResult CompareOp::verify() {
if (getCompose() && (getLhsMap() || getRhsMap()))
return emitOpError(
"'compose' not supported when 'lhs_map' or 'rhs_map' is present");
int64_t expectedNumOperands = getLhsMap() ? getLhsMap()->getNumInputs() : 1;
expectedNumOperands += getRhsMap() ? getRhsMap()->getNumInputs() : 1;
if (getVarOperands().size() != size_t(expectedNumOperands))
return emitOpError("expected ")
<< expectedNumOperands << " operands, but got "
<< getVarOperands().size();
return success();
}
//===----------------------------------------------------------------------===//
// TestOpInPlaceSelfFold
//===----------------------------------------------------------------------===//
OpFoldResult TestOpInPlaceSelfFold::fold(FoldAdaptor adaptor) {
if (!getFolded()) {
// The folder adds the "folded" if not present.
setFolded(true);
return getResult();
}
return {};
}
//===----------------------------------------------------------------------===//
// TestOpFoldWithFoldAdaptor
//===----------------------------------------------------------------------===//
OpFoldResult TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor) {
int64_t sum = 0;
if (auto value = dyn_cast_or_null<IntegerAttr>(adaptor.getOp()))
sum += value.getValue().getSExtValue();
for (Attribute attr : adaptor.getVariadic())
if (auto value = dyn_cast_or_null<IntegerAttr>(attr))
sum += 2 * value.getValue().getSExtValue();
for (ArrayRef<Attribute> attrs : adaptor.getVarOfVar())
for (Attribute attr : attrs)
if (auto value = dyn_cast_or_null<IntegerAttr>(attr))
sum += 3 * value.getValue().getSExtValue();
sum += 4 * std::distance(adaptor.getBody().begin(), adaptor.getBody().end());
return IntegerAttr::get(getType(), sum);
}
//===----------------------------------------------------------------------===//
// OpWithInferTypeAdaptorInterfaceOp
//===----------------------------------------------------------------------===//
LogicalResult OpWithInferTypeAdaptorInterfaceOp::inferReturnTypes(
MLIRContext *, std::optional<Location> location,
OpWithInferTypeAdaptorInterfaceOp::Adaptor adaptor,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (adaptor.getX().getType() != adaptor.getY().getType()) {
return emitOptionalError(location, "operand type mismatch ",
adaptor.getX().getType(), " vs ",
adaptor.getY().getType());
}
inferredReturnTypes.assign({adaptor.getX().getType()});
return success();
}
//===----------------------------------------------------------------------===//
// OpWithRefineTypeInterfaceOp
//===----------------------------------------------------------------------===//
// TODO: We should be able to only define either inferReturnType or
// refineReturnType, currently only refineReturnType can be omitted.
LogicalResult OpWithRefineTypeInterfaceOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &returnTypes) {
returnTypes.clear();
return OpWithRefineTypeInterfaceOp::refineReturnTypes(
context, location, operands, attributes, properties, regions,
returnTypes);
}
LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes(
MLIRContext *, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &returnTypes) {
if (operands[0].getType() != operands[1].getType()) {
return emitOptionalError(location, "operand type mismatch ",
operands[0].getType(), " vs ",
operands[1].getType());
}
// TODO: Add helper to make this more concise to write.
if (returnTypes.empty())
returnTypes.resize(1, nullptr);
if (returnTypes[0] && returnTypes[0] != operands[0].getType())
return emitOptionalError(location,
"required first operand and result to match");
returnTypes[0] = operands[0].getType();
return success();
}
//===----------------------------------------------------------------------===//
// TilingNoDpsOp
//===----------------------------------------------------------------------===//
SmallVector<Range> TilingNoDpsOp::getIterationDomain(OpBuilder &builder) {
return {};
}
SmallVector<utils::IteratorType> TilingNoDpsOp::getLoopIteratorTypes() {
return {};
}
FailureOr<TilingResult>
TilingNoDpsOp::getTiledImplementation(OpBuilder &builder,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
return failure();
}
LogicalResult TilingNoDpsOp::getResultTilePosition(
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) {
return failure();
}
//===----------------------------------------------------------------------===//
// OpWithShapedTypeInferTypeAdaptorInterfaceOp
//===----------------------------------------------------------------------===//
LogicalResult
OpWithShapedTypeInferTypeAdaptorInterfaceOp::inferReturnTypeComponents(
MLIRContext *context, std::optional<Location> location,
OpWithShapedTypeInferTypeAdaptorInterfaceOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
// Create return type consisting of the last element of the first operand.
auto operandType = adaptor.getOperand1().getType();
auto sval = dyn_cast<ShapedType>(operandType);
if (!sval)
return emitOptionalError(location, "only shaped type operands allowed");
int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic;
auto type = IntegerType::get(context, 17);
Attribute encoding;
if (auto rankedTy = dyn_cast<RankedTensorType>(sval))
encoding = rankedTy.getEncoding();
inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding));
return success();
}
LogicalResult
OpWithShapedTypeInferTypeAdaptorInterfaceOp::reifyReturnTypeShapes(
OpBuilder &builder, ValueRange operands,
llvm::SmallVectorImpl<Value> &shapes) {
shapes = SmallVector<Value, 1>{
builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
return success();
}
//===----------------------------------------------------------------------===//
// TestOpWithPropertiesAndInferredType
//===----------------------------------------------------------------------===//
LogicalResult TestOpWithPropertiesAndInferredType::inferReturnTypes(
MLIRContext *context, std::optional<Location>, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
Adaptor adaptor(operands, attributes, properties, regions);
inferredReturnTypes.push_back(IntegerType::get(
context, adaptor.getLhs() + adaptor.getProperties().rhs));
return success();
}
//===----------------------------------------------------------------------===//
// LoopBlockOp
//===----------------------------------------------------------------------===//
void LoopBlockOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
regions.emplace_back(&getBody(), getBody().getArguments());
if (point.isParent())
return;
regions.emplace_back(getOperation(), getOperation()->getResults());
}
OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionSuccessor successor) {
assert(successor.getSuccessor() == &getBody());
return MutableOperandRange(getInitMutable());
}
//===----------------------------------------------------------------------===//
// LoopBlockTerminatorOp
//===----------------------------------------------------------------------===//
MutableOperandRange
LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionSuccessor successor) {
if (successor.isParent())
return getExitArgMutable();
return getNextIterArgMutable();
}
//===----------------------------------------------------------------------===//
// SwitchWithNoBreakOp
//===----------------------------------------------------------------------===//
void TestNoTerminatorOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {}
//===----------------------------------------------------------------------===//
// Test InferIntRangeInterface
//===----------------------------------------------------------------------===//
OpFoldResult ManualCppOpWithFold::fold(ArrayRef<Attribute> attributes) {
// Just a simple fold for testing purposes that reads an operands constant
// value and returns it.
if (!attributes.empty())
return attributes.front();
return nullptr;
}
//===----------------------------------------------------------------------===//
// Tensor/Buffer Ops
//===----------------------------------------------------------------------===//
void ReadBufferOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
// The buffer operand is read.
effects.emplace_back(MemoryEffects::Read::get(), &getBufferMutable(),
SideEffects::DefaultResource::get());
// The buffer contents are dumped.
effects.emplace_back(MemoryEffects::Write::get(),
SideEffects::DefaultResource::get());
}
//===----------------------------------------------------------------------===//
// Test Dataflow
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// TestCallAndStoreOp
//===----------------------------------------------------------------------===//
CallInterfaceCallable TestCallAndStoreOp::getCallableForCallee() {
return getCallee();
}
void TestCallAndStoreOp::setCalleeFromCallable(CallInterfaceCallable callee) {
setCalleeAttr(cast<SymbolRefAttr>(callee));
}
Operation::operand_range TestCallAndStoreOp::getArgOperands() {
return getCalleeOperands();
}
MutableOperandRange TestCallAndStoreOp::getArgOperandsMutable() {
return getCalleeOperandsMutable();
}
//===----------------------------------------------------------------------===//
// TestCallOnDeviceOp
//===----------------------------------------------------------------------===//
CallInterfaceCallable TestCallOnDeviceOp::getCallableForCallee() {
return getCallee();
}
void TestCallOnDeviceOp::setCalleeFromCallable(CallInterfaceCallable callee) {
setCalleeAttr(cast<SymbolRefAttr>(callee));
}
Operation::operand_range TestCallOnDeviceOp::getArgOperands() {
return getForwardedOperands();
}
MutableOperandRange TestCallOnDeviceOp::getArgOperandsMutable() {
return getForwardedOperandsMutable();
}
//===----------------------------------------------------------------------===//
// TestStoreWithARegion
//===----------------------------------------------------------------------===//
void TestStoreWithARegion::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
if (point.isParent())
regions.emplace_back(&getBody(), getBody().front().getArguments());
else
regions.emplace_back(getOperation(), getOperation()->getResults());
}
//===----------------------------------------------------------------------===//
// TestStoreWithALoopRegion
//===----------------------------------------------------------------------===//
void TestStoreWithALoopRegion::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
// Both the operation itself and the region may be branching into the body or
// back into the operation itself. It is possible for the operation not to
// enter the body.
regions.emplace_back(
RegionSuccessor(&getBody(), getBody().front().getArguments()));
regions.emplace_back(getOperation(), getOperation()->getResults());
}
//===----------------------------------------------------------------------===//
// TestVersionedOpA
//===----------------------------------------------------------------------===//
LogicalResult
TestVersionedOpA::readProperties(mlir::DialectBytecodeReader &reader,
mlir::OperationState &state) {
auto &prop = state.getOrAddProperties<Properties>();
if (mlir::failed(reader.readAttribute(prop.dims)))
return mlir::failure();
// Check if we have a version. If not, assume we are parsing the current
// version.
auto maybeVersion = reader.getDialectVersion<test::TestDialect>();
if (succeeded(maybeVersion)) {
// If version is less than 2.0, there is no additional attribute to parse.
// We can materialize missing properties post parsing before verification.
const auto *version =
reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
if ((version->major_ < 2)) {
return success();
}
}
if (mlir::failed(reader.readAttribute(prop.modifier)))
return mlir::failure();
return mlir::success();
}
void TestVersionedOpA::writeProperties(mlir::DialectBytecodeWriter &writer) {
auto &prop = getProperties();
writer.writeAttribute(prop.dims);
auto maybeVersion = writer.getDialectVersion<test::TestDialect>();
if (succeeded(maybeVersion)) {
// If version is less than 2.0, there is no additional attribute to write.
const auto *version =
reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
if ((version->major_ < 2)) {
llvm::outs() << "downgrading op properties...\n";
return;
}
}
writer.writeAttribute(prop.modifier);
}
//===----------------------------------------------------------------------===//
// TestOpWithVersionedProperties
//===----------------------------------------------------------------------===//
llvm::LogicalResult TestOpWithVersionedProperties::readFromMlirBytecode(
mlir::DialectBytecodeReader &reader, test::VersionedProperties &prop) {
uint64_t value1, value2 = 0;
if (failed(reader.readVarInt(value1)))
return failure();
// Check if we have a version. If not, assume we are parsing the current
// version.
auto maybeVersion = reader.getDialectVersion<test::TestDialect>();
bool needToParseAnotherInt = true;
if (succeeded(maybeVersion)) {
// If version is less than 2.0, there is no additional attribute to parse.
// We can materialize missing properties post parsing before verification.
const auto *version =
reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
if ((version->major_ < 2))
needToParseAnotherInt = false;
}
if (needToParseAnotherInt && failed(reader.readVarInt(value2)))
return failure();
prop.value1 = value1;
prop.value2 = value2;
return success();
}
void TestOpWithVersionedProperties::writeToMlirBytecode(
mlir::DialectBytecodeWriter &writer,
const test::VersionedProperties &prop) {
writer.writeVarInt(prop.value1);
writer.writeVarInt(prop.value2);
}
//===----------------------------------------------------------------------===//
// TestMultiSlotAlloca
//===----------------------------------------------------------------------===//
llvm::SmallVector<MemorySlot> TestMultiSlotAlloca::getPromotableSlots() {
SmallVector<MemorySlot> slots;
for (Value result : getResults()) {
slots.push_back(MemorySlot{
result, cast<MemRefType>(result.getType()).getElementType()});
}
return slots;
}
Value TestMultiSlotAlloca::getDefaultValue(const MemorySlot &slot,
OpBuilder &builder) {
return TestOpConstant::create(builder, getLoc(), slot.elemType,
builder.getI32IntegerAttr(42));
}
void TestMultiSlotAlloca::handleBlockArgument(const MemorySlot &slot,
BlockArgument argument,
OpBuilder &builder) {
// Not relevant for testing.
}
/// Creates a new TestMultiSlotAlloca operation, just without the `slot`.
static std::optional<TestMultiSlotAlloca>
createNewMultiAllocaWithoutSlot(const MemorySlot &slot, OpBuilder &builder,
TestMultiSlotAlloca oldOp) {
if (oldOp.getNumResults() == 1) {
oldOp.erase();
return std::nullopt;
}
SmallVector<Type> newTypes;
SmallVector<Value> remainingValues;
for (Value oldResult : oldOp.getResults()) {
if (oldResult == slot.ptr)
continue;
remainingValues.push_back(oldResult);
newTypes.push_back(oldResult.getType());
}
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPoint(oldOp);
auto replacement =
TestMultiSlotAlloca::create(builder, oldOp->getLoc(), newTypes);
for (auto [oldResult, newResult] :
llvm::zip_equal(remainingValues, replacement.getResults()))
oldResult.replaceAllUsesWith(newResult);
oldOp.erase();
return replacement;
}
std::optional<PromotableAllocationOpInterface>
TestMultiSlotAlloca::handlePromotionComplete(const MemorySlot &slot,
Value defaultValue,
OpBuilder &builder) {
if (defaultValue && defaultValue.use_empty())
defaultValue.getDefiningOp()->erase();
return createNewMultiAllocaWithoutSlot(slot, builder, *this);
}
SmallVector<DestructurableMemorySlot>
TestMultiSlotAlloca::getDestructurableSlots() {
SmallVector<DestructurableMemorySlot> slots;
for (Value result : getResults()) {
auto memrefType = cast<MemRefType>(result.getType());
auto destructurable = dyn_cast<DestructurableTypeInterface>(memrefType);
if (!destructurable)
continue;
std::optional<DenseMap<Attribute, Type>> destructuredType =
destructurable.getSubelementIndexMap();
if (!destructuredType)
continue;
slots.emplace_back(
DestructurableMemorySlot{{result, memrefType}, *destructuredType});
}
return slots;
}
DenseMap<Attribute, MemorySlot> TestMultiSlotAlloca::destructure(
const DestructurableMemorySlot &slot,
const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) {
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointAfter(*this);
DenseMap<Attribute, MemorySlot> slotMap;
for (Attribute usedIndex : usedIndices) {
Type elemType = slot.subelementTypes.lookup(usedIndex);
MemRefType elemPtr = MemRefType::get({}, elemType);
auto subAlloca = TestMultiSlotAlloca::create(builder, getLoc(), elemPtr);
newAllocators.push_back(subAlloca);
slotMap.try_emplace<MemorySlot>(usedIndex,
{subAlloca.getResult(0), elemType});
}
return slotMap;
}
std::optional<DestructurableAllocationOpInterface>
TestMultiSlotAlloca::handleDestructuringComplete(
const DestructurableMemorySlot &slot, OpBuilder &builder) {
return createNewMultiAllocaWithoutSlot(slot, builder, *this);
}
namespace {
/// Returns test dialect's memref layout for test dialect's tensor encoding when
/// applicable.
MemRefLayoutAttrInterface
getMemRefLayoutForTensorEncoding(RankedTensorType tensorType) {
if (auto encoding =
dyn_cast<test::TestTensorEncodingAttr>(tensorType.getEncoding())) {
return cast<MemRefLayoutAttrInterface>(test::TestMemRefLayoutAttr::get(
tensorType.getContext(), encoding.getDummy()));
}
return {};
}
/// Auxiliary bufferization function for test and builtin tensors.
bufferization::BufferLikeType
convertTensorToBuffer(mlir::Operation *op,
const bufferization::BufferizationOptions &options,
bufferization::TensorLikeType tensorLike) {
auto buffer =
*tensorLike.getBufferType(options, [&]() { return op->emitError(); });
if (auto memref = dyn_cast<MemRefType>(buffer)) {
// Note: For the sake of testing, we want to ensure that encoding -> layout
// bufferization happens. This is currently achieved manually.
auto layout =
getMemRefLayoutForTensorEncoding(cast<RankedTensorType>(tensorLike));
return cast<bufferization::BufferLikeType>(
MemRefType::get(memref.getShape(), memref.getElementType(), layout,
memref.getMemorySpace()));
}
return buffer;
}
} // namespace
::mlir::LogicalResult test::TestDummyTensorOp::bufferize(
::mlir::RewriterBase &rewriter,
const ::mlir::bufferization::BufferizationOptions &options,
::mlir::bufferization::BufferizationState &state) {
auto buffer =
mlir::bufferization::getBuffer(rewriter, getInput(), options, state);
if (mlir::failed(buffer))
return failure();
const auto outType = getOutput().getType();
const auto bufferizedOutType =
convertTensorToBuffer(getOperation(), options, outType);
// replace op with memref analogy
auto dummyMemrefOp = test::TestDummyMemrefOp::create(
rewriter, getLoc(), bufferizedOutType, *buffer);
mlir::bufferization::replaceOpWithBufferizedValues(rewriter, getOperation(),
dummyMemrefOp.getResult());
return mlir::success();
}
::mlir::LogicalResult test::TestCreateTensorOp::bufferize(
::mlir::RewriterBase &rewriter,
const ::mlir::bufferization::BufferizationOptions &options,
::mlir::bufferization::BufferizationState &state) {
// Note: mlir::bufferization::getBufferType() would internally call
// TestCreateTensorOp::getBufferType()
const auto bufferizedOutType =
mlir::bufferization::getBufferType(getOutput(), options, state);
if (mlir::failed(bufferizedOutType))
return failure();
// replace op with memref analogy
auto createMemrefOp =
test::TestCreateMemrefOp::create(rewriter, getLoc(), *bufferizedOutType);
mlir::bufferization::replaceOpWithBufferizedValues(
rewriter, getOperation(), createMemrefOp.getResult());
return mlir::success();
}
mlir::FailureOr<mlir::bufferization::BufferLikeType>
test::TestCreateTensorOp::getBufferType(
mlir::Value value, const mlir::bufferization::BufferizationOptions &options,
const mlir::bufferization::BufferizationState &,
llvm::SmallVector<::mlir::Value> &) {
const auto type = dyn_cast<bufferization::TensorLikeType>(value.getType());
if (type == nullptr)
return failure();
return convertTensorToBuffer(getOperation(), options, type);
}