From 7675f541f75baa20e8ec007cd625a837e89fc01f Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Tue, 26 Sep 2023 01:53:17 -0700 Subject: [PATCH] [MLIR] Introduce new C bindings to differentiate between discardable and inherent attributes (#66332) This is part of the transition toward properly splitting the two groups. This only introduces new C APIs, the Python bindings are unaffected. No API is removed. --- mlir/include/mlir-c/IR.h | 52 ++++++++++++++++++++++++++++++++ mlir/include/mlir/IR/Operation.h | 17 +++++++++++ mlir/lib/CAPI/IR/IR.cpp | 47 +++++++++++++++++++++++++++++ mlir/test/CAPI/ir.c | 41 ++++++++++++------------- 4 files changed, 136 insertions(+), 21 deletions(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 68eccab6dbac..a6408317db69 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -576,25 +576,77 @@ MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumSuccessors(MlirOperation op); MLIR_CAPI_EXPORTED MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos); +/// Returns true if this operation defines an inherent attribute with this name. +/// Note: the attribute can be optional, so +/// `mlirOperationGetInherentAttributeByName` can still return a null attribute. +MLIR_CAPI_EXPORTED bool +mlirOperationHasInherentAttributeByName(MlirOperation op, MlirStringRef name); + +/// Returns an inherent attribute attached to the operation given its name. +MLIR_CAPI_EXPORTED MlirAttribute +mlirOperationGetInherentAttributeByName(MlirOperation op, MlirStringRef name); + +/// Sets an inherent attribute by name, replacing the existing if it exists. +/// This has no effect if "name" does not match an inherent attribute. +MLIR_CAPI_EXPORTED void +mlirOperationSetInherentAttributeByName(MlirOperation op, MlirStringRef name, + MlirAttribute attr); + +/// Returns the number of discardable attributes attached to the operation. +MLIR_CAPI_EXPORTED intptr_t +mlirOperationGetNumDiscardableAttributes(MlirOperation op); + +/// Return `pos`-th discardable attribute of the operation. +MLIR_CAPI_EXPORTED MlirNamedAttribute +mlirOperationGetDiscardableAttribute(MlirOperation op, intptr_t pos); + +/// Returns a discardable attribute attached to the operation given its name. +MLIR_CAPI_EXPORTED MlirAttribute mlirOperationGetDiscardableAttributeByName( + MlirOperation op, MlirStringRef name); + +/// Sets a discardable attribute by name, replacing the existing if it exists or +/// adding a new one otherwise. The new `attr` Attribute is not allowed to be +/// null, use `mlirOperationRemoveDiscardableAttributeByName` to remove an +/// Attribute instead. +MLIR_CAPI_EXPORTED void +mlirOperationSetDiscardableAttributeByName(MlirOperation op, MlirStringRef name, + MlirAttribute attr); + +/// Removes a discardable attribute by name. Returns false if the attribute was +/// not found and true if removed. +MLIR_CAPI_EXPORTED bool +mlirOperationRemoveDiscardableAttributeByName(MlirOperation op, + MlirStringRef name); + /// Returns the number of attributes attached to the operation. +/// Deprecated, please use `mlirOperationGetNumInherentAttributes` or +/// `mlirOperationGetNumDiscardableAttributes`. MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumAttributes(MlirOperation op); /// Return `pos`-th attribute of the operation. +/// Deprecated, please use `mlirOperationGetInherentAttribute` or +/// `mlirOperationGetDiscardableAttribute`. MLIR_CAPI_EXPORTED MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos); /// Returns an attribute attached to the operation given its name. +/// Deprecated, please use `mlirOperationGetInherentAttributeByName` or +/// `mlirOperationGetDiscardableAttributeByName`. MLIR_CAPI_EXPORTED MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, MlirStringRef name); /// Sets an attribute by name, replacing the existing if it exists or /// adding a new one otherwise. +/// Deprecated, please use `mlirOperationSetInherentAttributeByName` or +/// `mlirOperationSetDiscardableAttributeByName`. MLIR_CAPI_EXPORTED void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name, MlirAttribute attr); /// Removes an attribute by name. Returns false if the attribute was not found /// and true if removed. +/// Deprecated, please use `mlirOperationRemoveInherentAttributeByName` or +/// `mlirOperationRemoveDiscardableAttributeByName`. MLIR_CAPI_EXPORTED bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name); diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index b815eaf8899d..35e9d31a6323 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -457,6 +457,23 @@ public: if (attributes.set(name, value) != value) attrs = attributes.getDictionary(getContext()); } + void setDiscardableAttr(StringRef name, Attribute value) { + setDiscardableAttr(StringAttr::get(getContext(), name), value); + } + + /// Remove the discardable attribute with the specified name if it exists. + /// Return the attribute that was erased, or nullptr if there was no attribute + /// with such name. + Attribute removeDiscardableAttr(StringAttr name) { + NamedAttrList attributes(attrs); + Attribute removedAttr = attributes.erase(name); + if (removedAttr) + attrs = attributes.getDictionary(getContext()); + return removedAttr; + } + Attribute removeDiscardableAttr(StringRef name) { + return removeDiscardableAttr(StringAttr::get(getContext(), name)); + } /// Return all of the discardable attributes on this operation. ArrayRef getDiscardableAttrs() { return attrs.getValue(); } diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 7f5c2aaee673..04b386b8268e 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -595,6 +595,53 @@ MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) { return wrap(unwrap(op)->getSuccessor(static_cast(pos))); } +MLIR_CAPI_EXPORTED bool +mlirOperationHasInherentAttributeByName(MlirOperation op, MlirStringRef name) { + std::optional attr = unwrap(op)->getInherentAttr(unwrap(name)); + return attr.has_value(); +} + +MlirAttribute mlirOperationGetInherentAttributeByName(MlirOperation op, + MlirStringRef name) { + std::optional attr = unwrap(op)->getInherentAttr(unwrap(name)); + if (attr.has_value()) + return wrap(*attr); + return {}; +} + +void mlirOperationSetInherentAttributeByName(MlirOperation op, + MlirStringRef name, + MlirAttribute attr) { + unwrap(op)->setInherentAttr( + StringAttr::get(unwrap(op)->getContext(), unwrap(name)), unwrap(attr)); +} + +intptr_t mlirOperationGetNumDiscardableAttributes(MlirOperation op) { + return static_cast(unwrap(op)->getDiscardableAttrs().size()); +} + +MlirNamedAttribute mlirOperationGetDiscardableAttribute(MlirOperation op, + intptr_t pos) { + NamedAttribute attr = unwrap(op)->getDiscardableAttrs()[pos]; + return MlirNamedAttribute{wrap(attr.getName()), wrap(attr.getValue())}; +} + +MlirAttribute mlirOperationGetDiscardableAttributeByName(MlirOperation op, + MlirStringRef name) { + return wrap(unwrap(op)->getDiscardableAttr(unwrap(name))); +} + +void mlirOperationSetDiscardableAttributeByName(MlirOperation op, + MlirStringRef name, + MlirAttribute attr) { + unwrap(op)->setDiscardableAttr(unwrap(name), unwrap(attr)); +} + +bool mlirOperationRemoveDiscardableAttributeByName(MlirOperation op, + MlirStringRef name) { + return !!unwrap(op)->removeDiscardableAttr(unwrap(name)); +} + intptr_t mlirOperationGetNumAttributes(MlirOperation op) { return static_cast(unwrap(op)->getAttrs().size()); } diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c index c031e61945d0..a181332e219d 100644 --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -407,24 +407,23 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) { fprintf(stderr, "\n"); // CHECK: Terminator: func.return - // Get the attribute by index. - MlirNamedAttribute namedAttr0 = mlirOperationGetAttribute(operation, 0); - fprintf(stderr, "Get attr 0: "); - mlirAttributePrint(namedAttr0.attribute, printToStderr, NULL); - fprintf(stderr, "\n"); - // CHECK: Get attr 0: 0 : index + // Get the attribute by name. + bool hasValueAttr = mlirOperationHasInherentAttributeByName( + operation, mlirStringRefCreateFromCString("value")); + if (hasValueAttr) + // CHECK: Has attr "value" + fprintf(stderr, "Has attr \"value\""); - // Now re-get the attribute by name. - MlirAttribute attr0ByName = mlirOperationGetAttributeByName( - operation, mlirIdentifierStr(namedAttr0.name)); - fprintf(stderr, "Get attr 0 by name: "); - mlirAttributePrint(attr0ByName, printToStderr, NULL); + MlirAttribute valueAttr0 = mlirOperationGetInherentAttributeByName( + operation, mlirStringRefCreateFromCString("value")); + fprintf(stderr, "Get attr \"value\": "); + mlirAttributePrint(valueAttr0, printToStderr, NULL); fprintf(stderr, "\n"); - // CHECK: Get attr 0 by name: 0 : index + // CHECK: Get attr "value": 0 : index // Get a non-existing attribute and assert that it is null (sanity). fprintf(stderr, "does_not_exist is null: %d\n", - mlirAttributeIsNull(mlirOperationGetAttributeByName( + mlirAttributeIsNull(mlirOperationGetDiscardableAttributeByName( operation, mlirStringRefCreateFromCString("does_not_exist")))); // CHECK: does_not_exist is null: 1 @@ -443,10 +442,10 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) { fprintf(stderr, "\n"); // CHECK: Result 0 type: index - // Set a custom attribute. - mlirOperationSetAttributeByName(operation, - mlirStringRefCreateFromCString("custom_attr"), - mlirBoolAttrGet(ctx, 1)); + // Set a discardable attribute. + mlirOperationSetDiscardableAttributeByName( + operation, mlirStringRefCreateFromCString("custom_attr"), + mlirBoolAttrGet(ctx, 1)); fprintf(stderr, "Op with set attr: "); mlirOperationPrint(operation, printToStderr, NULL); fprintf(stderr, "\n"); @@ -454,13 +453,13 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) { // Remove the attribute. fprintf(stderr, "Remove attr: %d\n", - mlirOperationRemoveAttributeByName( + mlirOperationRemoveDiscardableAttributeByName( operation, mlirStringRefCreateFromCString("custom_attr"))); fprintf(stderr, "Remove attr again: %d\n", - mlirOperationRemoveAttributeByName( + mlirOperationRemoveDiscardableAttributeByName( operation, mlirStringRefCreateFromCString("custom_attr"))); fprintf(stderr, "Removed attr is null: %d\n", - mlirAttributeIsNull(mlirOperationGetAttributeByName( + mlirAttributeIsNull(mlirOperationGetDiscardableAttributeByName( operation, mlirStringRefCreateFromCString("custom_attr")))); // CHECK: Remove attr: 1 // CHECK: Remove attr again: 0 @@ -469,7 +468,7 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) { // Add a large attribute to verify printing flags. int64_t eltsShape[] = {4}; int32_t eltsData[] = {1, 2, 3, 4}; - mlirOperationSetAttributeByName( + mlirOperationSetDiscardableAttributeByName( operation, mlirStringRefCreateFromCString("elts"), mlirDenseElementsAttrInt32Get( mlirRankedTensorTypeGet(1, eltsShape, mlirIntegerTypeGet(ctx, 32),