Skip to content

Commit

Permalink
[MLIR] Introduce new C bindings to differentiate between discardable …
Browse files Browse the repository at this point in the history
…and inherent attributes (#66332)

This is part of the transition toward properly splitting the two groups.
This only introduces new C APIs, the Python bindings are unaffected. No
API is removed.
  • Loading branch information
joker-eph authored Sep 26, 2023
1 parent 5746407 commit 7675f54
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 22 deletions.
52 changes: 52 additions & 0 deletions mlir/include/mlir-c/IR.h
Original file line number Diff line number Diff line change
Expand Up @@ -576,25 +576,77 @@ MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumSuccessors(MlirOperation op);
MLIR_CAPI_EXPORTED MlirBlock mlirOperationGetSuccessor(MlirOperation op,
intptr_t pos);

/// Returns true if this operation defines an inherent attribute with this name.
/// Note: the attribute can be optional, so
/// `mlirOperationGetInherentAttributeByName` can still return a null attribute.
MLIR_CAPI_EXPORTED bool
mlirOperationHasInherentAttributeByName(MlirOperation op, MlirStringRef name);

/// Returns an inherent attribute attached to the operation given its name.
MLIR_CAPI_EXPORTED MlirAttribute
mlirOperationGetInherentAttributeByName(MlirOperation op, MlirStringRef name);

/// Sets an inherent attribute by name, replacing the existing if it exists.
/// This has no effect if "name" does not match an inherent attribute.
MLIR_CAPI_EXPORTED void
mlirOperationSetInherentAttributeByName(MlirOperation op, MlirStringRef name,
MlirAttribute attr);

/// Returns the number of discardable attributes attached to the operation.
MLIR_CAPI_EXPORTED intptr_t
mlirOperationGetNumDiscardableAttributes(MlirOperation op);

/// Return `pos`-th discardable attribute of the operation.
MLIR_CAPI_EXPORTED MlirNamedAttribute
mlirOperationGetDiscardableAttribute(MlirOperation op, intptr_t pos);

/// Returns a discardable attribute attached to the operation given its name.
MLIR_CAPI_EXPORTED MlirAttribute mlirOperationGetDiscardableAttributeByName(
MlirOperation op, MlirStringRef name);

/// Sets a discardable attribute by name, replacing the existing if it exists or
/// adding a new one otherwise. The new `attr` Attribute is not allowed to be
/// null, use `mlirOperationRemoveDiscardableAttributeByName` to remove an
/// Attribute instead.
MLIR_CAPI_EXPORTED void
mlirOperationSetDiscardableAttributeByName(MlirOperation op, MlirStringRef name,
MlirAttribute attr);

/// Removes a discardable attribute by name. Returns false if the attribute was
/// not found and true if removed.
MLIR_CAPI_EXPORTED bool
mlirOperationRemoveDiscardableAttributeByName(MlirOperation op,
MlirStringRef name);

/// Returns the number of attributes attached to the operation.
/// Deprecated, please use `mlirOperationGetNumInherentAttributes` or
/// `mlirOperationGetNumDiscardableAttributes`.
MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumAttributes(MlirOperation op);

/// Return `pos`-th attribute of the operation.
/// Deprecated, please use `mlirOperationGetInherentAttribute` or
/// `mlirOperationGetDiscardableAttribute`.
MLIR_CAPI_EXPORTED MlirNamedAttribute
mlirOperationGetAttribute(MlirOperation op, intptr_t pos);

/// Returns an attribute attached to the operation given its name.
/// Deprecated, please use `mlirOperationGetInherentAttributeByName` or
/// `mlirOperationGetDiscardableAttributeByName`.
MLIR_CAPI_EXPORTED MlirAttribute
mlirOperationGetAttributeByName(MlirOperation op, MlirStringRef name);

/// Sets an attribute by name, replacing the existing if it exists or
/// adding a new one otherwise.
/// Deprecated, please use `mlirOperationSetInherentAttributeByName` or
/// `mlirOperationSetDiscardableAttributeByName`.
MLIR_CAPI_EXPORTED void mlirOperationSetAttributeByName(MlirOperation op,
MlirStringRef name,
MlirAttribute attr);

/// Removes an attribute by name. Returns false if the attribute was not found
/// and true if removed.
/// Deprecated, please use `mlirOperationRemoveInherentAttributeByName` or
/// `mlirOperationRemoveDiscardableAttributeByName`.
MLIR_CAPI_EXPORTED bool mlirOperationRemoveAttributeByName(MlirOperation op,
MlirStringRef name);

Expand Down
17 changes: 17 additions & 0 deletions mlir/include/mlir/IR/Operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,23 @@ class alignas(8) Operation final
if (attributes.set(name, value) != value)
attrs = attributes.getDictionary(getContext());
}
void setDiscardableAttr(StringRef name, Attribute value) {
setDiscardableAttr(StringAttr::get(getContext(), name), value);
}

/// Remove the discardable attribute with the specified name if it exists.
/// Return the attribute that was erased, or nullptr if there was no attribute
/// with such name.
Attribute removeDiscardableAttr(StringAttr name) {
NamedAttrList attributes(attrs);
Attribute removedAttr = attributes.erase(name);
if (removedAttr)
attrs = attributes.getDictionary(getContext());
return removedAttr;
}
Attribute removeDiscardableAttr(StringRef name) {
return removeDiscardableAttr(StringAttr::get(getContext(), name));
}

/// Return all of the discardable attributes on this operation.
ArrayRef<NamedAttribute> getDiscardableAttrs() { return attrs.getValue(); }
Expand Down
47 changes: 47 additions & 0 deletions mlir/lib/CAPI/IR/IR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,53 @@ MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) {
return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos)));
}

MLIR_CAPI_EXPORTED bool
mlirOperationHasInherentAttributeByName(MlirOperation op, MlirStringRef name) {
std::optional<Attribute> attr = unwrap(op)->getInherentAttr(unwrap(name));
return attr.has_value();
}

MlirAttribute mlirOperationGetInherentAttributeByName(MlirOperation op,
MlirStringRef name) {
std::optional<Attribute> attr = unwrap(op)->getInherentAttr(unwrap(name));
if (attr.has_value())
return wrap(*attr);
return {};
}

void mlirOperationSetInherentAttributeByName(MlirOperation op,
MlirStringRef name,
MlirAttribute attr) {
unwrap(op)->setInherentAttr(
StringAttr::get(unwrap(op)->getContext(), unwrap(name)), unwrap(attr));
}

intptr_t mlirOperationGetNumDiscardableAttributes(MlirOperation op) {
return static_cast<intptr_t>(unwrap(op)->getDiscardableAttrs().size());
}

MlirNamedAttribute mlirOperationGetDiscardableAttribute(MlirOperation op,
intptr_t pos) {
NamedAttribute attr = unwrap(op)->getDiscardableAttrs()[pos];
return MlirNamedAttribute{wrap(attr.getName()), wrap(attr.getValue())};
}

MlirAttribute mlirOperationGetDiscardableAttributeByName(MlirOperation op,
MlirStringRef name) {
return wrap(unwrap(op)->getDiscardableAttr(unwrap(name)));
}

void mlirOperationSetDiscardableAttributeByName(MlirOperation op,
MlirStringRef name,
MlirAttribute attr) {
unwrap(op)->setDiscardableAttr(unwrap(name), unwrap(attr));
}

bool mlirOperationRemoveDiscardableAttributeByName(MlirOperation op,
MlirStringRef name) {
return !!unwrap(op)->removeDiscardableAttr(unwrap(name));
}

intptr_t mlirOperationGetNumAttributes(MlirOperation op) {
return static_cast<intptr_t>(unwrap(op)->getAttrs().size());
}
Expand Down
43 changes: 21 additions & 22 deletions mlir/test/CAPI/ir.c
Original file line number Diff line number Diff line change
Expand Up @@ -407,24 +407,23 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
fprintf(stderr, "\n");
// CHECK: Terminator: func.return

// Get the attribute by index.
MlirNamedAttribute namedAttr0 = mlirOperationGetAttribute(operation, 0);
fprintf(stderr, "Get attr 0: ");
mlirAttributePrint(namedAttr0.attribute, printToStderr, NULL);
// Get the attribute by name.
bool hasValueAttr = mlirOperationHasInherentAttributeByName(
operation, mlirStringRefCreateFromCString("value"));
if (hasValueAttr)
// CHECK: Has attr "value"
fprintf(stderr, "Has attr \"value\"");

MlirAttribute valueAttr0 = mlirOperationGetInherentAttributeByName(
operation, mlirStringRefCreateFromCString("value"));
fprintf(stderr, "Get attr \"value\": ");
mlirAttributePrint(valueAttr0, printToStderr, NULL);
fprintf(stderr, "\n");
// CHECK: Get attr 0: 0 : index

// Now re-get the attribute by name.
MlirAttribute attr0ByName = mlirOperationGetAttributeByName(
operation, mlirIdentifierStr(namedAttr0.name));
fprintf(stderr, "Get attr 0 by name: ");
mlirAttributePrint(attr0ByName, printToStderr, NULL);
fprintf(stderr, "\n");
// CHECK: Get attr 0 by name: 0 : index
// CHECK: Get attr "value": 0 : index

// Get a non-existing attribute and assert that it is null (sanity).
fprintf(stderr, "does_not_exist is null: %d\n",
mlirAttributeIsNull(mlirOperationGetAttributeByName(
mlirAttributeIsNull(mlirOperationGetDiscardableAttributeByName(
operation, mlirStringRefCreateFromCString("does_not_exist"))));
// CHECK: does_not_exist is null: 1

Expand All @@ -443,24 +442,24 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
fprintf(stderr, "\n");
// CHECK: Result 0 type: index

// Set a custom attribute.
mlirOperationSetAttributeByName(operation,
mlirStringRefCreateFromCString("custom_attr"),
mlirBoolAttrGet(ctx, 1));
// Set a discardable attribute.
mlirOperationSetDiscardableAttributeByName(
operation, mlirStringRefCreateFromCString("custom_attr"),
mlirBoolAttrGet(ctx, 1));
fprintf(stderr, "Op with set attr: ");
mlirOperationPrint(operation, printToStderr, NULL);
fprintf(stderr, "\n");
// CHECK: Op with set attr: {{.*}} {custom_attr = true}

// Remove the attribute.
fprintf(stderr, "Remove attr: %d\n",
mlirOperationRemoveAttributeByName(
mlirOperationRemoveDiscardableAttributeByName(
operation, mlirStringRefCreateFromCString("custom_attr")));
fprintf(stderr, "Remove attr again: %d\n",
mlirOperationRemoveAttributeByName(
mlirOperationRemoveDiscardableAttributeByName(
operation, mlirStringRefCreateFromCString("custom_attr")));
fprintf(stderr, "Removed attr is null: %d\n",
mlirAttributeIsNull(mlirOperationGetAttributeByName(
mlirAttributeIsNull(mlirOperationGetDiscardableAttributeByName(
operation, mlirStringRefCreateFromCString("custom_attr"))));
// CHECK: Remove attr: 1
// CHECK: Remove attr again: 0
Expand All @@ -469,7 +468,7 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
// Add a large attribute to verify printing flags.
int64_t eltsShape[] = {4};
int32_t eltsData[] = {1, 2, 3, 4};
mlirOperationSetAttributeByName(
mlirOperationSetDiscardableAttributeByName(
operation, mlirStringRefCreateFromCString("elts"),
mlirDenseElementsAttrInt32Get(
mlirRankedTensorTypeGet(1, eltsShape, mlirIntegerTypeGet(ctx, 32),
Expand Down

0 comments on commit 7675f54

Please sign in to comment.