[MLIR] Introduce new C bindings to differentiate between discardable and inherent attributes (#66332)

This is part of the transition toward properly splitting the two groups.
This only introduces new C APIs, the Python bindings are unaffected. No
API is removed.
This commit is contained in:
Mehdi Amini 2023-09-26 01:53:17 -07:00 committed by GitHub
parent 5746407a78
commit 7675f541f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 136 additions and 21 deletions

View File

@ -576,25 +576,77 @@ MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumSuccessors(MlirOperation op);
MLIR_CAPI_EXPORTED MlirBlock mlirOperationGetSuccessor(MlirOperation op,
intptr_t pos);
/// Returns true if this operation defines an inherent attribute with this name.
/// Note: the attribute can be optional, so
/// `mlirOperationGetInherentAttributeByName` can still return a null attribute.
MLIR_CAPI_EXPORTED bool
mlirOperationHasInherentAttributeByName(MlirOperation op, MlirStringRef name);
/// Returns an inherent attribute attached to the operation given its name.
MLIR_CAPI_EXPORTED MlirAttribute
mlirOperationGetInherentAttributeByName(MlirOperation op, MlirStringRef name);
/// Sets an inherent attribute by name, replacing the existing if it exists.
/// This has no effect if "name" does not match an inherent attribute.
MLIR_CAPI_EXPORTED void
mlirOperationSetInherentAttributeByName(MlirOperation op, MlirStringRef name,
MlirAttribute attr);
/// Returns the number of discardable attributes attached to the operation.
MLIR_CAPI_EXPORTED intptr_t
mlirOperationGetNumDiscardableAttributes(MlirOperation op);
/// Return `pos`-th discardable attribute of the operation.
MLIR_CAPI_EXPORTED MlirNamedAttribute
mlirOperationGetDiscardableAttribute(MlirOperation op, intptr_t pos);
/// Returns a discardable attribute attached to the operation given its name.
MLIR_CAPI_EXPORTED MlirAttribute mlirOperationGetDiscardableAttributeByName(
MlirOperation op, MlirStringRef name);
/// Sets a discardable attribute by name, replacing the existing if it exists or
/// adding a new one otherwise. The new `attr` Attribute is not allowed to be
/// null, use `mlirOperationRemoveDiscardableAttributeByName` to remove an
/// Attribute instead.
MLIR_CAPI_EXPORTED void
mlirOperationSetDiscardableAttributeByName(MlirOperation op, MlirStringRef name,
MlirAttribute attr);
/// Removes a discardable attribute by name. Returns false if the attribute was
/// not found and true if removed.
MLIR_CAPI_EXPORTED bool
mlirOperationRemoveDiscardableAttributeByName(MlirOperation op,
MlirStringRef name);
/// Returns the number of attributes attached to the operation.
/// Deprecated, please use `mlirOperationGetNumInherentAttributes` or
/// `mlirOperationGetNumDiscardableAttributes`.
MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumAttributes(MlirOperation op);
/// Return `pos`-th attribute of the operation.
/// Deprecated, please use `mlirOperationGetInherentAttribute` or
/// `mlirOperationGetDiscardableAttribute`.
MLIR_CAPI_EXPORTED MlirNamedAttribute
mlirOperationGetAttribute(MlirOperation op, intptr_t pos);
/// Returns an attribute attached to the operation given its name.
/// Deprecated, please use `mlirOperationGetInherentAttributeByName` or
/// `mlirOperationGetDiscardableAttributeByName`.
MLIR_CAPI_EXPORTED MlirAttribute
mlirOperationGetAttributeByName(MlirOperation op, MlirStringRef name);
/// Sets an attribute by name, replacing the existing if it exists or
/// adding a new one otherwise.
/// Deprecated, please use `mlirOperationSetInherentAttributeByName` or
/// `mlirOperationSetDiscardableAttributeByName`.
MLIR_CAPI_EXPORTED void mlirOperationSetAttributeByName(MlirOperation op,
MlirStringRef name,
MlirAttribute attr);
/// Removes an attribute by name. Returns false if the attribute was not found
/// and true if removed.
/// Deprecated, please use `mlirOperationRemoveInherentAttributeByName` or
/// `mlirOperationRemoveDiscardableAttributeByName`.
MLIR_CAPI_EXPORTED bool mlirOperationRemoveAttributeByName(MlirOperation op,
MlirStringRef name);

View File

@ -457,6 +457,23 @@ public:
if (attributes.set(name, value) != value)
attrs = attributes.getDictionary(getContext());
}
void setDiscardableAttr(StringRef name, Attribute value) {
setDiscardableAttr(StringAttr::get(getContext(), name), value);
}
/// Remove the discardable attribute with the specified name if it exists.
/// Return the attribute that was erased, or nullptr if there was no attribute
/// with such name.
Attribute removeDiscardableAttr(StringAttr name) {
NamedAttrList attributes(attrs);
Attribute removedAttr = attributes.erase(name);
if (removedAttr)
attrs = attributes.getDictionary(getContext());
return removedAttr;
}
Attribute removeDiscardableAttr(StringRef name) {
return removeDiscardableAttr(StringAttr::get(getContext(), name));
}
/// Return all of the discardable attributes on this operation.
ArrayRef<NamedAttribute> getDiscardableAttrs() { return attrs.getValue(); }

View File

@ -595,6 +595,53 @@ MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) {
return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos)));
}
MLIR_CAPI_EXPORTED bool
mlirOperationHasInherentAttributeByName(MlirOperation op, MlirStringRef name) {
std::optional<Attribute> attr = unwrap(op)->getInherentAttr(unwrap(name));
return attr.has_value();
}
MlirAttribute mlirOperationGetInherentAttributeByName(MlirOperation op,
MlirStringRef name) {
std::optional<Attribute> attr = unwrap(op)->getInherentAttr(unwrap(name));
if (attr.has_value())
return wrap(*attr);
return {};
}
void mlirOperationSetInherentAttributeByName(MlirOperation op,
MlirStringRef name,
MlirAttribute attr) {
unwrap(op)->setInherentAttr(
StringAttr::get(unwrap(op)->getContext(), unwrap(name)), unwrap(attr));
}
intptr_t mlirOperationGetNumDiscardableAttributes(MlirOperation op) {
return static_cast<intptr_t>(unwrap(op)->getDiscardableAttrs().size());
}
MlirNamedAttribute mlirOperationGetDiscardableAttribute(MlirOperation op,
intptr_t pos) {
NamedAttribute attr = unwrap(op)->getDiscardableAttrs()[pos];
return MlirNamedAttribute{wrap(attr.getName()), wrap(attr.getValue())};
}
MlirAttribute mlirOperationGetDiscardableAttributeByName(MlirOperation op,
MlirStringRef name) {
return wrap(unwrap(op)->getDiscardableAttr(unwrap(name)));
}
void mlirOperationSetDiscardableAttributeByName(MlirOperation op,
MlirStringRef name,
MlirAttribute attr) {
unwrap(op)->setDiscardableAttr(unwrap(name), unwrap(attr));
}
bool mlirOperationRemoveDiscardableAttributeByName(MlirOperation op,
MlirStringRef name) {
return !!unwrap(op)->removeDiscardableAttr(unwrap(name));
}
intptr_t mlirOperationGetNumAttributes(MlirOperation op) {
return static_cast<intptr_t>(unwrap(op)->getAttrs().size());
}

View File

@ -407,24 +407,23 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
fprintf(stderr, "\n");
// CHECK: Terminator: func.return
// Get the attribute by index.
MlirNamedAttribute namedAttr0 = mlirOperationGetAttribute(operation, 0);
fprintf(stderr, "Get attr 0: ");
mlirAttributePrint(namedAttr0.attribute, printToStderr, NULL);
fprintf(stderr, "\n");
// CHECK: Get attr 0: 0 : index
// Get the attribute by name.
bool hasValueAttr = mlirOperationHasInherentAttributeByName(
operation, mlirStringRefCreateFromCString("value"));
if (hasValueAttr)
// CHECK: Has attr "value"
fprintf(stderr, "Has attr \"value\"");
// Now re-get the attribute by name.
MlirAttribute attr0ByName = mlirOperationGetAttributeByName(
operation, mlirIdentifierStr(namedAttr0.name));
fprintf(stderr, "Get attr 0 by name: ");
mlirAttributePrint(attr0ByName, printToStderr, NULL);
MlirAttribute valueAttr0 = mlirOperationGetInherentAttributeByName(
operation, mlirStringRefCreateFromCString("value"));
fprintf(stderr, "Get attr \"value\": ");
mlirAttributePrint(valueAttr0, printToStderr, NULL);
fprintf(stderr, "\n");
// CHECK: Get attr 0 by name: 0 : index
// CHECK: Get attr "value": 0 : index
// Get a non-existing attribute and assert that it is null (sanity).
fprintf(stderr, "does_not_exist is null: %d\n",
mlirAttributeIsNull(mlirOperationGetAttributeByName(
mlirAttributeIsNull(mlirOperationGetDiscardableAttributeByName(
operation, mlirStringRefCreateFromCString("does_not_exist"))));
// CHECK: does_not_exist is null: 1
@ -443,9 +442,9 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
fprintf(stderr, "\n");
// CHECK: Result 0 type: index
// Set a custom attribute.
mlirOperationSetAttributeByName(operation,
mlirStringRefCreateFromCString("custom_attr"),
// Set a discardable attribute.
mlirOperationSetDiscardableAttributeByName(
operation, mlirStringRefCreateFromCString("custom_attr"),
mlirBoolAttrGet(ctx, 1));
fprintf(stderr, "Op with set attr: ");
mlirOperationPrint(operation, printToStderr, NULL);
@ -454,13 +453,13 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
// Remove the attribute.
fprintf(stderr, "Remove attr: %d\n",
mlirOperationRemoveAttributeByName(
mlirOperationRemoveDiscardableAttributeByName(
operation, mlirStringRefCreateFromCString("custom_attr")));
fprintf(stderr, "Remove attr again: %d\n",
mlirOperationRemoveAttributeByName(
mlirOperationRemoveDiscardableAttributeByName(
operation, mlirStringRefCreateFromCString("custom_attr")));
fprintf(stderr, "Removed attr is null: %d\n",
mlirAttributeIsNull(mlirOperationGetAttributeByName(
mlirAttributeIsNull(mlirOperationGetDiscardableAttributeByName(
operation, mlirStringRefCreateFromCString("custom_attr"))));
// CHECK: Remove attr: 1
// CHECK: Remove attr again: 0
@ -469,7 +468,7 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
// Add a large attribute to verify printing flags.
int64_t eltsShape[] = {4};
int32_t eltsData[] = {1, 2, 3, 4};
mlirOperationSetAttributeByName(
mlirOperationSetDiscardableAttributeByName(
operation, mlirStringRefCreateFromCString("elts"),
mlirDenseElementsAttrInt32Get(
mlirRankedTensorTypeGet(1, eltsShape, mlirIntegerTypeGet(ctx, 32),