diff --git a/mlir/include/mlir-c/ExtensibleDialect.h b/mlir/include/mlir-c/ExtensibleDialect.h new file mode 100644 index 000000000000..98457805f57c --- /dev/null +++ b/mlir/include/mlir-c/ExtensibleDialect.h @@ -0,0 +1,76 @@ +//===-- mlir-c/ExtensibleDialect.h - Extensible dialect APIs -----*- C -*-====// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header provides APIs for extensible dialects. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_EXTENSIBLEDIALECT_H +#define MLIR_C_EXTENSIBLEDIALECT_H + +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" + +#ifdef __cplusplus +extern "C" { +#endif + +//===----------------------------------------------------------------------===// +/// Opaque type declarations (see mlir-c/IR.h for more details). +//===----------------------------------------------------------------------===// + +#define DEFINE_C_API_STRUCT(name, storage) \ + struct name { \ + storage *ptr; \ + }; \ + typedef struct name name + +DEFINE_C_API_STRUCT(MlirDynamicOpTrait, void); + +/// Attach a dynamic op trait to the given operation name. +/// Note that the operation name must be modeled by dynamic dialect and must be +/// registered. +/// The ownership of the trait will be transferred to the operation name +/// after this call. +MLIR_CAPI_EXPORTED bool +mlirDynamicOpTraitAttach(MlirDynamicOpTrait dynamicOpTrait, + MlirStringRef opName, MlirContext context); + +/// Get the dynamic op trait that indicates the operation is a terminator. +MLIR_CAPI_EXPORTED MlirDynamicOpTrait mlirDynamicOpTraitGetIsTerminator(void); + +/// Get the dynamic op trait that indicates regions have no terminator. +MLIR_CAPI_EXPORTED MlirDynamicOpTrait mlirDynamicOpTraitGetNoTerminator(void); + +/// Destroy the dynamic op trait. +MLIR_CAPI_EXPORTED void +mlirDynamicOpTraitDestroy(MlirDynamicOpTrait dynamicOpTrait); + +typedef struct { + /// Optional constructor for the user data. + /// Set to nullptr to disable it. + void (*construct)(void *userData); + /// Optional destructor for the user data. + /// Set to nullptr to disable it. + void (*destruct)(void *userData); + /// The callback function to verify the operation. + MlirLogicalResult (*verifyTrait)(MlirOperation op, void *userData); + /// The callback function to verify the operation with access to regions. + MlirLogicalResult (*verifyRegionTrait)(MlirOperation op, void *userData); +} MlirDynamicOpTraitCallbacks; + +/// Create a custom dynamic op trait with the given type ID and callbacks. +MLIR_CAPI_EXPORTED MlirDynamicOpTrait mlirDynamicOpTraitCreate( + MlirTypeID typeID, MlirDynamicOpTraitCallbacks callbacks, void *userData); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_EXTENSIBLEDIALECT_H diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h index 4bb49e6bc245..e551a49bb34a 100644 --- a/mlir/include/mlir/Bindings/Python/IRCore.h +++ b/mlir/include/mlir/Bindings/Python/IRCore.h @@ -23,6 +23,7 @@ #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/Debug.h" #include "mlir-c/Diagnostics.h" +#include "mlir-c/ExtensibleDialect.h" #include "mlir-c/IR.h" #include "mlir-c/IntegerSet.h" #include "mlir-c/Support.h" @@ -1844,6 +1845,30 @@ private: PyOpAttributeMap attributes; }; +class MLIR_PYTHON_API_EXPORTED PyDynamicOpTrait { +public: + static bool attach(const nanobind::object &opName, + const nanobind::object &target, PyMlirContext &context); + + static void bind(nanobind::module_ &m); +}; + +namespace PyDynamicOpTraits { + +class MLIR_PYTHON_API_EXPORTED IsTerminator : public PyDynamicOpTrait { +public: + static bool attach(const nanobind::object &opName, PyMlirContext &context); + static void bind(nanobind::module_ &m); +}; + +class MLIR_PYTHON_API_EXPORTED NoTerminator : public PyDynamicOpTrait { +public: + static bool attach(const nanobind::object &opName, PyMlirContext &context); + static void bind(nanobind::module_ &m); +}; + +} // namespace PyDynamicOpTraits + MLIR_PYTHON_API_EXPORTED MlirValue getUniqueResult(MlirOperation operation); MLIR_PYTHON_API_EXPORTED void populateIRCore(nanobind::module_ &m); MLIR_PYTHON_API_EXPORTED void populateRoot(nanobind::module_ &m); diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index efe45d648824..6f03f334e34b 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -16,6 +16,7 @@ #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/Debug.h" #include "mlir-c/Diagnostics.h" +#include "mlir-c/ExtensibleDialect.h" #include "mlir-c/IR.h" #include "mlir-c/Support.h" @@ -2513,6 +2514,119 @@ void PyOpAdaptor::bind(nb::module_ &m) { "Returns the attributes of the adaptor."); } +static MlirLogicalResult verifyTraitByMethod(MlirOperation op, void *userData, + const char *methodName) { + nb::handle targetObj(static_cast(userData)); + if (!nb::hasattr(targetObj, methodName)) { + return mlirLogicalResultSuccess(); + } + PyMlirContextRef ctx = PyMlirContext::forContext(mlirOperationGetContext(op)); + nb::object opView = PyOperation::forOperation(ctx, op)->createOpView(); + bool success = nb::cast(targetObj.attr(methodName)(opView)); + return success ? mlirLogicalResultSuccess() : mlirLogicalResultFailure(); +}; + +static bool attachOpTrait(const nb::object &opName, MlirDynamicOpTrait trait, + PyMlirContext &context) { + std::string opNameStr; + if (opName.is_type()) { + opNameStr = nb::cast(opName.attr("OPERATION_NAME")); + } else if (nb::isinstance(opName)) { + opNameStr = nb::cast(opName); + } else { + throw nb::type_error("the root argument must be a type or a string"); + } + + return mlirDynamicOpTraitAttach( + trait, MlirStringRef{opNameStr.data(), opNameStr.size()}, context.get()); +} + +bool PyDynamicOpTrait::attach(const nb::object &opName, + const nb::object &target, + PyMlirContext &context) { + if (!nb::hasattr(target, "verify") && !nb::hasattr(target, "verify_region")) + throw nb::type_error( + "the target object must have at least one of 'verify' or " + "'verify_region' methods"); + + MlirDynamicOpTraitCallbacks callbacks; + callbacks.construct = [](void *userData) { + nb::handle(static_cast(userData)).inc_ref(); + }; + callbacks.destruct = [](void *userData) { + nb::handle(static_cast(userData)).dec_ref(); + }; + + callbacks.verifyTrait = [](MlirOperation op, + void *userData) -> MlirLogicalResult { + return verifyTraitByMethod(op, userData, "verify"); + }; + callbacks.verifyRegionTrait = [](MlirOperation op, + void *userData) -> MlirLogicalResult { + return verifyTraitByMethod(op, userData, "verify_region"); + }; + + constexpr const char *typeIDAttr = "_TYPE_ID"; + if (!nb::hasattr(target, typeIDAttr)) { + nb::setattr(target, typeIDAttr, + nb::cast(PyTypeID(PyGlobals::get().allocateTypeID()))); + } + MlirDynamicOpTrait trait = mlirDynamicOpTraitCreate( + nb::cast(target.attr(typeIDAttr)).get(), callbacks, + static_cast(target.ptr())); + return attachOpTrait(opName, trait, context); +} + +void PyDynamicOpTrait::bind(nb::module_ &m) { + nb::class_ cls(m, "DynamicOpTrait"); + cls.attr("attach") = classmethod( + [](const nb::object &cls, const nb::object &opName, nb::object target, + DefaultingPyMlirContext context) { + if (target.is_none()) + target = cls; + return PyDynamicOpTrait::attach(opName, target, *context.get()); + }, + nb::arg("cls"), nb::arg("op_name"), nb::arg("target").none() = nb::none(), + nb::arg("context").none() = nb::none(), + "Attach the dynamic op trait subclass to the given operation name."); +} + +bool PyDynamicOpTraits::IsTerminator::attach(const nb::object &opName, + PyMlirContext &context) { + MlirDynamicOpTrait trait = mlirDynamicOpTraitGetIsTerminator(); + return attachOpTrait(opName, trait, context); +} + +void PyDynamicOpTraits::IsTerminator::bind(nb::module_ &m) { + nb::class_ cls( + m, "IsTerminatorTrait"); + cls.attr("attach") = classmethod( + [](const nb::object &cls, const nb::object &opName, + DefaultingPyMlirContext context) { + return PyDynamicOpTraits::IsTerminator::attach(opName, *context.get()); + }, + "Attach IsTerminator trait to the given operation name.", nb::arg("cls"), + nb::arg("op_name"), nb::arg("context").none() = nb::none()); +} + +bool PyDynamicOpTraits::NoTerminator::attach(const nb::object &opName, + PyMlirContext &context) { + MlirDynamicOpTrait trait = mlirDynamicOpTraitGetNoTerminator(); + return attachOpTrait(opName, trait, context); +} + +void PyDynamicOpTraits::NoTerminator::bind(nb::module_ &m) { + nb::class_ cls( + m, "NoTerminatorTrait"); + cls.attr("attach") = classmethod( + [](const nb::object &cls, const nb::object &opName, + DefaultingPyMlirContext context) { + return PyDynamicOpTraits::NoTerminator::attach(opName, *context.get()); + }, + "Attach NoTerminator trait to the given operation name.", nb::arg("cls"), + nb::arg("op_name"), nb::arg("context").none() = nb::none()); +} + } // namespace MLIR_BINDINGS_PYTHON_DOMAIN } // namespace python } // namespace mlir @@ -4836,6 +4950,11 @@ void populateIRCore(nb::module_ &m) { // Attribute builder getter. PyAttrBuilderMap::bind(m); + + // Extensible Dialect + PyDynamicOpTrait::bind(m); + PyDynamicOpTraits::IsTerminator::bind(m); + PyDynamicOpTraits::NoTerminator::bind(m); } } // namespace MLIR_BINDINGS_PYTHON_DOMAIN } // namespace python diff --git a/mlir/lib/CAPI/IR/CMakeLists.txt b/mlir/lib/CAPI/IR/CMakeLists.txt index 36f28520d675..d78f9d9735aa 100644 --- a/mlir/lib/CAPI/IR/CMakeLists.txt +++ b/mlir/lib/CAPI/IR/CMakeLists.txt @@ -6,6 +6,7 @@ add_mlir_upstream_c_api_library(MLIRCAPIIR BuiltinTypes.cpp Diagnostics.cpp DialectHandle.cpp + ExtensibleDialect.cpp IntegerSet.cpp IR.cpp Pass.cpp diff --git a/mlir/lib/CAPI/IR/ExtensibleDialect.cpp b/mlir/lib/CAPI/IR/ExtensibleDialect.cpp new file mode 100644 index 000000000000..f3239d996a0e --- /dev/null +++ b/mlir/lib/CAPI/IR/ExtensibleDialect.cpp @@ -0,0 +1,87 @@ +//===- ExtensibleDialect - C API for MLIR Extensible Dialect --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/ExtensibleDialect.h" +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Support.h" +#include "mlir/IR/ExtensibleDialect.h" +#include "mlir/IR/OperationSupport.h" + +using namespace mlir; + +DEFINE_C_API_PTR_METHODS(MlirDynamicOpTrait, DynamicOpTrait) + +bool mlirDynamicOpTraitAttach(MlirDynamicOpTrait dynamicOpTrait, + MlirStringRef opName, MlirContext context) { + std::optional opNameFound = + RegisteredOperationName::lookup(unwrap(opName), unwrap(context)); + assert(opNameFound && "operation name must be registered in the context"); + + // The original getImpl() is protected, so we create a small helper struct + // here. + struct RegisteredOperationNameWithImpl : RegisteredOperationName { + Impl *getImpl() { return RegisteredOperationName::getImpl(); } + }; + OperationName::Impl *impl = + static_cast(*opNameFound).getImpl(); + + std::unique_ptr trait(unwrap(dynamicOpTrait)); + // TODO: we should check whether the `impl` is a DynamicOpDefinition here + // via llvm-style RTTI. + return static_cast(impl)->addTrait(std::move(trait)); +} + +MlirDynamicOpTrait mlirDynamicOpTraitGetIsTerminator() { + return wrap(new DynamicOpTraits::IsTerminator()); +} + +MlirDynamicOpTrait mlirDynamicOpTraitGetNoTerminator() { + return wrap(new DynamicOpTraits::NoTerminator()); +} + +void mlirDynamicOpTraitDestroy(MlirDynamicOpTrait dynamicOpTrait) { + delete unwrap(dynamicOpTrait); +} + +namespace mlir { + +class ExternalDynamicOpTrait : public DynamicOpTrait { +public: + ExternalDynamicOpTrait(TypeID typeID, MlirDynamicOpTraitCallbacks callbacks, + void *userData) + : typeID(typeID), callbacks(callbacks), userData(userData) { + if (callbacks.construct) + callbacks.construct(userData); + } + ~ExternalDynamicOpTrait() { + if (callbacks.destruct) + callbacks.destruct(userData); + } + + LogicalResult verifyTrait(Operation *op) const override { + return unwrap(callbacks.verifyTrait(wrap(op), userData)); + }; + LogicalResult verifyRegionTrait(Operation *op) const override { + return unwrap(callbacks.verifyRegionTrait(wrap(op), userData)); + }; + + TypeID getTypeID() const override { return typeID; }; + +private: + TypeID typeID; + MlirDynamicOpTraitCallbacks callbacks; + void *userData; +}; + +} // namespace mlir + +MlirDynamicOpTrait mlirDynamicOpTraitCreate( + MlirTypeID typeID, MlirDynamicOpTraitCallbacks callbacks, void *userData) { + return wrap( + new mlir::ExternalDynamicOpTrait(unwrap(typeID), callbacks, userData)); +} diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py index 30e705726756..d24a94bc8baf 100644 --- a/mlir/test/python/dialects/ext.py +++ b/mlir/test/python/dialects/ext.py @@ -350,22 +350,55 @@ def testExtDialectWithRegion(): class IfOp(TestRegion.Operation, name="if"): cond: Operand[IntegerType[1]] + result: Result[Any] then: Region else_: Region + class YieldOp(TestRegion.Operation, name="yield"): + value: Operand[Any] + + class NoTermOp(TestRegion.Operation, name="no_term"): + body: Region + with Context(), Location.unknown(): TestRegion.load() # CHECK: irdl.dialect @ext_region { - # CHECK: irdl.operation @if { + # CHECK: irdl.operation @if { # CHECK: %0 = irdl.is i1 # CHECK: irdl.operands(cond: %0) - # CHECK: %1 = irdl.region + # CHECK: %1 = irdl.any + # CHECK: irdl.results(result: %1) # CHECK: %2 = irdl.region - # CHECK: irdl.regions(then: %1, else_: %2) + # CHECK: %3 = irdl.region + # CHECK: irdl.regions(then: %2, else_: %3) + # CHECK: } + # CHECK: irdl.operation @yield { + # CHECK: %0 = irdl.any + # CHECK: irdl.operands(value: %0) + # CHECK: } + # CHECK: irdl.operation @no_term { + # CHECK: %0 = irdl.region + # CHECK: irdl.regions(body: %0) + # CHECK: } # CHECK: } print(TestRegion._mlir_module) - # CHECK: (self, /, cond, *, loc=None, ip=None) + IsTerminatorTrait.attach(YieldOp) + NoTerminatorTrait.attach(NoTermOp) + + class ParentIsIfTrait(DynamicOpTrait): + @staticmethod + def verify(op) -> bool: + if not isinstance(op.parent.opview, IfOp): + op.location.emit_error( + f"{op.name} should be put inside {IfOp.OPERATION_NAME}" + ) + return False + return True + + ParentIsIfTrait.attach(YieldOp) + + # CHECK: (self, /, result, cond, *, loc=None, ip=None) print(IfOp.__init__.__signature__) # CHECK: None None @@ -373,36 +406,44 @@ def testExtDialectWithRegion(): # CHECK: (2, True) print(IfOp._ODS_REGIONS) - from mlir.dialects import llvm - module = Module.create() with InsertionPoint(module.body): i1 = IntegerType.get_signless(1) i32 = IntegerType.get_signless(32) cond = arith.constant(i1, 1) - if_ = IfOp(cond) + if_ = IfOp(i32, cond) if_.then.blocks.append() if_.else_.blocks.append() with InsertionPoint(if_.then.blocks[0]): v = arith.constant(i32, 2) - llvm.unreachable() + YieldOp(v) with InsertionPoint(if_.else_.blocks[0]): v = arith.constant(i32, 3) - llvm.unreachable() + YieldOp(v) + + nt = NoTermOp() + nt.body.blocks.append() + + with InsertionPoint(nt.body.blocks[0]): + arith.constant(i32, 4) + # No terminator here assert module.operation.verify() # CHECK: module { # CHECK: %true = arith.constant true - # CHECK: "ext_region.if"(%true) ({ + # CHECK: %0 = "ext_region.if"(%true) ({ # CHECK: %c2_i32 = arith.constant 2 : i32 - # CHECK: llvm.unreachable + # CHECK: "ext_region.yield"(%c2_i32) : (i32) -> () # CHECK: }, { # CHECK: %c3_i32 = arith.constant 3 : i32 - # CHECK: llvm.unreachable - # CHECK: }) : (i1) -> () + # CHECK: "ext_region.yield"(%c3_i32) : (i32) -> () + # CHECK: }) : (i1) -> i32 + # CHECK: "ext_region.no_term"() ({ + # CHECK: %c4_i32 = arith.constant 4 : i32 + # CHECK: }) : () -> () # CHECK: } print(module) @@ -410,3 +451,41 @@ def testExtDialectWithRegion(): print(if_.then.blocks[0]) # CHECK: %c3_i32 = arith.constant 3 : i32 print(if_.else_.blocks[0]) + + # CHECK-LABEL: Testing violation cases + print("Testing violation cases:") + + module = Module.create() + with InsertionPoint(module.body): + i1 = IntegerType.get_signless(1) + i32 = IntegerType.get_signless(32) + cond = arith.constant(i1, 1) + + if_ = IfOp(i32, cond) + if_.then.blocks.append() + if_.else_.blocks.append() + + with InsertionPoint(if_.then.blocks[0]): + v = arith.constant(i32, 2) + + with InsertionPoint(if_.else_.blocks[0]): + v = arith.constant(i32, 3) + + try: + module.operation.verify() + except Exception as e: + # CHECK: Verification failed: + # CHECK: block with no terminator + print(e) + + module = Module.create() + with InsertionPoint(module.body): + v = arith.constant(i32, 2) + YieldOp(v) + + try: + module.operation.verify() + except Exception as e: + # CHECK: Verification failed: + # CHECK: ext_region.yield should be put inside ext_region.if + print(e)