llvm-project/mlir/test/lib/Dialect/Test/TestDialect.cpp
Twice 2b9ad865f7
[MLIR] Support dynamic traits in DynamicDialect (#177735)
Unlike Interfaces, Traits in MLIR are static: they are defined via CRTP
templates and used as base classes of an `Op`, which makes them
difficult to attach to an op dynamically.

However, in IRDL and the Python bindings, we define operations
dynamically through `DynamicDialect`, which means the traditional static
traits cannot be applied to them. Traits are important, for example,
they are how MLIR marks an op as a terminator or a non-terminator.

If `DynamicDialect` does not support traits, then even though we can
define an op with regions, we cannot define new terminators or mark an
op as a non-terminator. This makes `DynamicDialect` very limited in
region-related scenarios.

In this PR, we introduce a `DynamicOpTrait` type that “dynamizes”
`OpTrait`, enabling traits to be attached to ops in `DynamicDialect`.
The key design goal is that existing checks in the MLIR codebase such as
`op->hasTrait<XXX>()` work seamlessly on ops defined by
`DynamicOpDefinition`, without requiring any changes.

Note that currently only two traits `IsTerminator` and `NoTerminator`
are supported in this PR.

This PR aims to lay the groundwork for adding support for traits in IRDL
and python bindings (and maybe other bindings) in the future.

Related to #158066.
2026-01-24 16:14:07 +08:00

498 lines
18 KiB
C++

//===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
#include "TestOps.h"
#include "TestTypes.h"
#include "mlir/Bytecode/BytecodeImplementation.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/ExtensibleDialect.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/ODSSupport.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/STLFunctionalExtras.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Base64.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Interfaces/FoldInterfaces.h"
#include "mlir/Reducer/ReductionPatternInterface.h"
#include "mlir/Transforms/InliningUtils.h"
#include <cstdint>
#include <memory>
#include <numeric>
#include <optional>
// Include this before the using namespace lines below to test that we don't
// have namespace dependencies.
#include "TestOpsDialect.cpp.inc"
using namespace mlir;
using namespace test;
//===----------------------------------------------------------------------===//
// PropertiesWithCustomPrint
//===----------------------------------------------------------------------===//
LogicalResult
test::setPropertiesFromAttribute(PropertiesWithCustomPrint &prop,
Attribute attr,
function_ref<InFlightDiagnostic()> emitError) {
DictionaryAttr dict = dyn_cast<DictionaryAttr>(attr);
if (!dict) {
emitError() << "expected DictionaryAttr to set TestProperties";
return failure();
}
auto label = dict.getAs<mlir::StringAttr>("label");
if (!label) {
emitError() << "expected StringAttr for key `label`";
return failure();
}
auto valueAttr = dict.getAs<IntegerAttr>("value");
if (!valueAttr) {
emitError() << "expected IntegerAttr for key `value`";
return failure();
}
prop.label = std::make_shared<std::string>(label.getValue());
prop.value = valueAttr.getValue().getSExtValue();
return success();
}
DictionaryAttr
test::getPropertiesAsAttribute(MLIRContext *ctx,
const PropertiesWithCustomPrint &prop) {
SmallVector<NamedAttribute> attrs;
Builder b{ctx};
attrs.push_back(b.getNamedAttr("label", b.getStringAttr(*prop.label)));
attrs.push_back(b.getNamedAttr("value", b.getI32IntegerAttr(prop.value)));
return b.getDictionaryAttr(attrs);
}
llvm::hash_code test::computeHash(const PropertiesWithCustomPrint &prop) {
return llvm::hash_combine(prop.value, StringRef(*prop.label));
}
void test::customPrintProperties(OpAsmPrinter &p,
const PropertiesWithCustomPrint &prop) {
p.printKeywordOrString(*prop.label);
p << " is " << prop.value;
}
ParseResult test::customParseProperties(OpAsmParser &parser,
PropertiesWithCustomPrint &prop) {
std::string label;
if (parser.parseKeywordOrString(&label) || parser.parseKeyword("is") ||
parser.parseInteger(prop.value))
return failure();
prop.label = std::make_shared<std::string>(std::move(label));
return success();
}
//===----------------------------------------------------------------------===//
// MyPropStruct
//===----------------------------------------------------------------------===//
Attribute MyPropStruct::asAttribute(MLIRContext *ctx) const {
return StringAttr::get(ctx, content);
}
LogicalResult
MyPropStruct::setFromAttr(MyPropStruct &prop, Attribute attr,
function_ref<InFlightDiagnostic()> emitError) {
StringAttr strAttr = dyn_cast<StringAttr>(attr);
if (!strAttr) {
emitError() << "Expect StringAttr but got " << attr;
return failure();
}
prop.content = strAttr.getValue();
return success();
}
llvm::hash_code MyPropStruct::hash() const {
return hash_value(StringRef(content));
}
LogicalResult test::readFromMlirBytecode(DialectBytecodeReader &reader,
MyPropStruct &prop) {
StringRef str;
if (failed(reader.readString(str)))
return failure();
prop.content = str.str();
return success();
}
void test::writeToMlirBytecode(DialectBytecodeWriter &writer,
MyPropStruct &prop) {
writer.writeOwnedString(prop.content);
}
//===----------------------------------------------------------------------===//
// VersionedProperties
//===----------------------------------------------------------------------===//
LogicalResult
test::setPropertiesFromAttribute(VersionedProperties &prop, Attribute attr,
function_ref<InFlightDiagnostic()> emitError) {
DictionaryAttr dict = dyn_cast<DictionaryAttr>(attr);
if (!dict) {
emitError() << "expected DictionaryAttr to set VersionedProperties";
return failure();
}
auto value1Attr = dict.getAs<IntegerAttr>("value1");
if (!value1Attr) {
emitError() << "expected IntegerAttr for key `value1`";
return failure();
}
auto value2Attr = dict.getAs<IntegerAttr>("value2");
if (!value2Attr) {
emitError() << "expected IntegerAttr for key `value2`";
return failure();
}
prop.value1 = value1Attr.getValue().getSExtValue();
prop.value2 = value2Attr.getValue().getSExtValue();
return success();
}
DictionaryAttr test::getPropertiesAsAttribute(MLIRContext *ctx,
const VersionedProperties &prop) {
SmallVector<NamedAttribute> attrs;
Builder b{ctx};
attrs.push_back(b.getNamedAttr("value1", b.getI32IntegerAttr(prop.value1)));
attrs.push_back(b.getNamedAttr("value2", b.getI32IntegerAttr(prop.value2)));
return b.getDictionaryAttr(attrs);
}
llvm::hash_code test::computeHash(const VersionedProperties &prop) {
return llvm::hash_combine(prop.value1, prop.value2);
}
void test::customPrintProperties(OpAsmPrinter &p,
const VersionedProperties &prop) {
p << prop.value1 << " | " << prop.value2;
}
ParseResult test::customParseProperties(OpAsmParser &parser,
VersionedProperties &prop) {
if (parser.parseInteger(prop.value1) || parser.parseVerticalBar() ||
parser.parseInteger(prop.value2))
return failure();
return success();
}
//===----------------------------------------------------------------------===//
// Bytecode Support
//===----------------------------------------------------------------------===//
LogicalResult test::readFromMlirBytecode(DialectBytecodeReader &reader,
MutableArrayRef<int64_t> prop) {
uint64_t size;
if (failed(reader.readVarInt(size)))
return failure();
if (size != prop.size())
return reader.emitError("array size mismach when reading properties: ")
<< size << " vs expected " << prop.size();
for (auto &elt : prop) {
uint64_t value;
if (failed(reader.readVarInt(value)))
return failure();
elt = value;
}
return success();
}
void test::writeToMlirBytecode(DialectBytecodeWriter &writer,
ArrayRef<int64_t> prop) {
writer.writeVarInt(prop.size());
for (auto elt : prop)
writer.writeVarInt(elt);
}
//===----------------------------------------------------------------------===//
// Dynamic operations
//===----------------------------------------------------------------------===//
static std::unique_ptr<DynamicOpDefinition>
getDynamicGenericOp(TestDialect *dialect) {
return DynamicOpDefinition::get(
"dynamic_generic", dialect, [](Operation *op) { return success(); },
[](Operation *op) { return success(); });
}
static std::unique_ptr<DynamicOpDefinition>
getDynamicTerminatorOp(TestDialect *dialect) {
auto def = DynamicOpDefinition::get(
"dynamic_terminator", dialect, [](Operation *op) { return success(); },
[](Operation *op) { return success(); });
def->addTrait(std::make_unique<DynamicOpTraits::IsTerminator>());
return def;
}
static std::unique_ptr<DynamicOpDefinition>
getDynamicNoTerminatorOp(TestDialect *dialect) {
auto def = DynamicOpDefinition::get(
"dynamic_noterminator", dialect, [](Operation *op) { return success(); },
[](Operation *op) { return success(); });
def->addTrait(std::make_unique<DynamicOpTraits::NoTerminator>());
return def;
}
static std::unique_ptr<DynamicOpDefinition>
getDynamicOneOperandTwoResultsOp(TestDialect *dialect) {
return DynamicOpDefinition::get(
"dynamic_one_operand_two_results", dialect,
[](Operation *op) {
if (op->getNumOperands() != 1) {
op->emitOpError()
<< "expected 1 operand, but had " << op->getNumOperands();
return failure();
}
if (op->getNumResults() != 2) {
op->emitOpError()
<< "expected 2 results, but had " << op->getNumResults();
return failure();
}
return success();
},
[](Operation *op) { return success(); });
}
static std::unique_ptr<DynamicOpDefinition>
getDynamicCustomParserPrinterOp(TestDialect *dialect) {
auto verifier = [](Operation *op) {
if (op->getNumOperands() == 0 && op->getNumResults() == 0)
return success();
op->emitError() << "operation should have no operands and no results";
return failure();
};
auto regionVerifier = [](Operation *op) { return success(); };
auto parser = [](OpAsmParser &parser, OperationState &state) {
return parser.parseKeyword("custom_keyword");
};
auto printer = [](Operation *op, OpAsmPrinter &printer, llvm::StringRef) {
printer << op->getName() << " custom_keyword";
};
return DynamicOpDefinition::get("dynamic_custom_parser_printer", dialect,
verifier, regionVerifier, parser, printer);
}
//===----------------------------------------------------------------------===//
// TestDialect
//===----------------------------------------------------------------------===//
void test::registerTestDialect(DialectRegistry &registry) {
registry.insert<TestDialect>();
}
void test::testSideEffectOpGetEffect(
Operation *op,
SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
&effects) {
auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
if (!effectsAttr)
return;
effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
}
// This is the implementation of a dialect fallback for `TestEffectOpInterface`.
struct TestOpEffectInterfaceFallback
: public TestEffectOpInterface::FallbackModel<
TestOpEffectInterfaceFallback> {
static bool classof(Operation *op) {
bool isSupportedOp =
op->getName().getStringRef() == "test.unregistered_side_effect_op";
assert(isSupportedOp && "Unexpected dispatch");
return isSupportedOp;
}
void
getEffects(Operation *op,
SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
&effects) const {
testSideEffectOpGetEffect(op, effects);
}
};
void TestDialect::initialize() {
registerAttributes();
registerTypes();
registerOpsSyntax();
addOperations<ManualCppOpWithFold>();
registerTestDialectOperations(this);
registerDynamicOp(getDynamicGenericOp(this));
registerDynamicOp(getDynamicTerminatorOp(this));
registerDynamicOp(getDynamicNoTerminatorOp(this));
registerDynamicOp(getDynamicOneOperandTwoResultsOp(this));
registerDynamicOp(getDynamicCustomParserPrinterOp(this));
registerInterfaces();
allowUnknownOperations();
// Instantiate our fallback op interface that we'll use on specific
// unregistered op.
fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback;
}
TestDialect::~TestDialect() {
delete static_cast<TestOpEffectInterfaceFallback *>(
fallbackEffectOpInterfaces);
}
Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
Type type, Location loc) {
return TestOpConstant::create(builder, loc, type, value);
}
void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
OperationName opName) {
if (opName.getIdentifier() == "test.unregistered_side_effect_op" &&
typeID == TypeID::get<TestEffectOpInterface>())
return fallbackEffectOpInterfaces;
return nullptr;
}
LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
NamedAttribute namedAttr) {
if (namedAttr.getName() == "test.invalid_attr")
return op->emitError() << "invalid to use 'test.invalid_attr'";
return success();
}
LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
unsigned regionIndex,
unsigned argIndex,
NamedAttribute namedAttr) {
if (namedAttr.getName() == "test.invalid_attr")
return op->emitError() << "invalid to use 'test.invalid_attr'";
return success();
}
LogicalResult
TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
unsigned resultIndex,
NamedAttribute namedAttr) {
if (namedAttr.getName() == "test.invalid_attr")
return op->emitError() << "invalid to use 'test.invalid_attr'";
return success();
}
std::optional<Dialect::ParseOpHook>
TestDialect::getParseOperationHook(StringRef opName) const {
if (opName == "test.dialect_custom_printer") {
return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
return parser.parseKeyword("custom_format");
}};
}
if (opName == "test.dialect_custom_format_fallback") {
return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
return parser.parseKeyword("custom_format_fallback");
}};
}
if (opName == "test.dialect_custom_printer.with.dot") {
return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
return ParseResult::success();
}};
}
return std::nullopt;
}
llvm::unique_function<void(Operation *, OpAsmPrinter &)>
TestDialect::getOperationPrinter(Operation *op) const {
StringRef opName = op->getName().getStringRef();
if (opName == "test.dialect_custom_printer") {
return [](Operation *op, OpAsmPrinter &printer) {
printer.getStream() << " custom_format";
};
}
if (opName == "test.dialect_custom_format_fallback") {
return [](Operation *op, OpAsmPrinter &printer) {
printer.getStream() << " custom_format_fallback";
};
}
return {};
}
static LogicalResult
dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,
PatternRewriter &rewriter) {
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, rewriter.getI32IntegerAttr(42));
return success();
}
void TestDialect::getCanonicalizationPatterns(
RewritePatternSet &results) const {
results.add(&dialectCanonicalizationPattern);
}
//===----------------------------------------------------------------------===//
// TestCallWithSegmentsOp
//===----------------------------------------------------------------------===//
// The op `test.call_with_segments` models a call-like operation whose operands
// are divided into 3 variadic segments: `prefix`, `args`, and `suffix`.
// Only the middle segment represents the actual call arguments. The op uses
// the AttrSizedOperandSegments trait, so we can derive segment boundaries from
// the generated `operandSegmentSizes` attribute. We provide custom helpers to
// expose the logical call arguments as both a read-only range and a mutable
// range bound to the proper segment so that insertion/erasure updates the
// attribute automatically.
// Segment layout indices in the DenseI32ArrayAttr: [prefix, args, suffix].
static constexpr unsigned kTestCallWithSegmentsArgsSegIndex = 1;
Operation::operand_range CallWithSegmentsOp::getArgOperands() {
// Leverage generated getters for segment sizes: slice between prefix and
// suffix using current operand list.
return getOperation()->getOperands().slice(getPrefix().size(),
getArgs().size());
}
MutableOperandRange CallWithSegmentsOp::getArgOperandsMutable() {
Operation *op = getOperation();
// Obtain the canonical segment size attribute name for this op.
auto segName =
CallWithSegmentsOp::getOperandSegmentSizesAttrName(op->getName());
auto sizesAttr = op->getAttrOfType<DenseI32ArrayAttr>(segName);
assert(sizesAttr && "missing operandSegmentSizes attribute on op");
// Compute the start and length of the args segment from the prefix size and
// args size stored in the attribute.
auto sizes = sizesAttr.asArrayRef();
unsigned start = static_cast<unsigned>(sizes[0]); // prefix size
unsigned len = static_cast<unsigned>(sizes[1]); // args size
NamedAttribute segNamed(segName, sizesAttr);
MutableOperandRange::OperandSegment binding{kTestCallWithSegmentsArgsSegIndex,
segNamed};
return MutableOperandRange(op, start, len, {binding});
}