[MLIR][LLVM] add metadata attrs and llvm.named_metadata op (#186703)

This PR adds some LLVM metadata attributes and an `llvm.named_metadata`
container op (similar to `llvm.module_flags`) for those attributes.

Summary:

- Add MLIR attributes modeling LLVM IR metadata: `#llvm.md_string`,
`#llvm.md_const`, `#llvm.md_func`, and `#llvm.md_node`;
- Add `llvm.named_metadata` container op for module-level named metadata
nodes;
  - Add MLIR-to-LLVM-IR translation for the new attributes and op;
- Add C API functions (`mlirLLVMMDStringAttrGet`,
`mlirLLVMMDNodeAttrGet`, etc.);
- Add Python bindings (`llvm.MDStringAttr`, `llvm.MDConstantAttr`,
`llvm.MDFuncAttr`, `llvm.MDNodeAttr`, `llvm.FunctionType`).
This commit is contained in:
Maksim Levental 2026-03-16 16:33:53 -07:00 committed by GitHub
parent 57568c288d
commit e3d7bf290d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 736 additions and 1 deletions

View File

@ -63,6 +63,12 @@ mlirLLVMFunctionTypeGet(MlirType resultType, intptr_t nArgumentTypes,
MLIR_CAPI_EXPORTED MlirStringRef mlirLLVMFunctionTypeGetName(void);
/// Returns `true` if the type is an LLVM dialect function type.
MLIR_CAPI_EXPORTED bool mlirTypeIsALLVMFunctionType(MlirType type);
/// Returns the TypeID of an LLVM function type.
MLIR_CAPI_EXPORTED MlirTypeID mlirLLVMFunctionTypeGetTypeID(void);
/// Returns the number of input types.
MLIR_CAPI_EXPORTED intptr_t mlirLLVMFunctionTypeGetNumInputs(MlirType type);
@ -70,6 +76,9 @@ MLIR_CAPI_EXPORTED intptr_t mlirLLVMFunctionTypeGetNumInputs(MlirType type);
MLIR_CAPI_EXPORTED MlirType mlirLLVMFunctionTypeGetInput(MlirType type,
intptr_t pos);
/// Returns `true` if the function type is variadic.
MLIR_CAPI_EXPORTED bool mlirLLVMFunctionTypeIsVarArg(MlirType type);
/// Returns the return type of the function type.
MLIR_CAPI_EXPORTED MlirType mlirLLVMFunctionTypeGetReturnType(MlirType type);
@ -190,6 +199,7 @@ enum MlirLLVMCConv {
MlirLLVMCConvAMDGPU_Gfx = 100,
MlirLLVMCConvM68k_INTR = 101,
};
typedef enum MlirLLVMCConv MlirLLVMCConv;
/// Creates a LLVM CConv attribute.
@ -205,6 +215,7 @@ enum MlirLLVMComdat {
MlirLLVMComdatNoDeduplicate = 3,
MlirLLVMComdatSameSize = 4,
};
typedef enum MlirLLVMComdat MlirLLVMComdat;
/// Creates a LLVM Comdat attribute.
@ -226,6 +237,7 @@ enum MlirLLVMLinkage {
MlirLLVMLinkageExternWeak = 9,
MlirLLVMLinkageCommon = 10,
};
typedef enum MlirLLVMLinkage MlirLLVMLinkage;
/// Creates a LLVM Linkage attribute.
@ -274,6 +286,7 @@ enum MlirLLVMTypeEncoding {
MlirLLVMTypeEncodingLoUser = 0x80,
MlirLLVMTypeEncodingHiUser = 0xff,
};
typedef enum MlirLLVMTypeEncoding MlirLLVMTypeEncoding;
/// Creates a LLVM DIBasicType attribute.
@ -337,6 +350,7 @@ enum MlirLLVMDIEmissionKind {
MlirLLVMDIEmissionKindLineTablesOnly = 2,
MlirLLVMDIEmissionKindDebugDirectivesOnly = 3,
};
typedef enum MlirLLVMDIEmissionKind MlirLLVMDIEmissionKind;
enum MlirLLVMDINameTableKind {
@ -345,6 +359,7 @@ enum MlirLLVMDINameTableKind {
MlirLLVMDINameTableKindNone = 2,
MlirLLVMDINameTableKindApple = 3,
};
typedef enum MlirLLVMDINameTableKind MlirLLVMDINameTableKind;
/// Creates a LLVM DICompileUnit attribute.
@ -456,6 +471,69 @@ MLIR_CAPI_EXPORTED MlirStringRef mlirLLVMDIImportedEntityAttrGetName(void);
MLIR_CAPI_EXPORTED MlirAttribute
mlirLLVMDIModuleAttrGetScope(MlirAttribute diModule);
//===----------------------------------------------------------------------===//
// Metadata Attributes
//===----------------------------------------------------------------------===//
/// Creates an LLVM MDStringAttr.
MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMMDStringAttrGet(MlirContext ctx,
MlirStringRef value);
/// Returns `true` if the attribute is an LLVM MDStringAttr.
MLIR_CAPI_EXPORTED bool mlirLLVMAttrIsAMDStringAttr(MlirAttribute attr);
/// Returns the TypeID of MDStringAttr.
MLIR_CAPI_EXPORTED MlirTypeID mlirLLVMMDStringAttrGetTypeID(void);
/// Returns the string value of an LLVM MDStringAttr.
MLIR_CAPI_EXPORTED MlirStringRef
mlirLLVMMDStringAttrGetValue(MlirAttribute attr);
/// Creates an LLVM MDConstantAttr wrapping an attribute.
MLIR_CAPI_EXPORTED MlirAttribute
mlirLLVMMDConstantAttrGet(MlirContext ctx, MlirAttribute valueAttr);
/// Returns `true` if the attribute is an LLVM MDConstantAttr.
MLIR_CAPI_EXPORTED bool mlirLLVMAttrIsAMDConstantAttr(MlirAttribute attr);
/// Returns the TypeID of MDConstantAttr.
MLIR_CAPI_EXPORTED MlirTypeID mlirLLVMMDConstantAttrGetTypeID(void);
/// Returns the attribute value of an LLVM MDConstantAttr.
MLIR_CAPI_EXPORTED MlirAttribute
mlirLLVMMDConstantAttrGetValue(MlirAttribute attr);
/// Creates an LLVM MDFuncAttr referencing a function symbol.
MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMMDFuncAttrGet(MlirContext ctx,
MlirAttribute name);
/// Returns `true` if the attribute is an LLVM MDFuncAttr.
MLIR_CAPI_EXPORTED bool mlirLLVMAttrIsAMDFuncAttr(MlirAttribute attr);
/// Returns the TypeID of MDFuncAttr.
MLIR_CAPI_EXPORTED MlirTypeID mlirLLVMMDFuncAttrGetTypeID(void);
/// Returns the symbol name of an LLVM MDFuncAttr.
MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMMDFuncAttrGetName(MlirAttribute attr);
/// Creates an LLVM MDNodeAttr.
MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMMDNodeAttrGet(
MlirContext ctx, intptr_t nOperands, MlirAttribute const *operands);
/// Returns `true` if the attribute is an LLVM MDNodeAttr.
MLIR_CAPI_EXPORTED bool mlirLLVMAttrIsAMDNodeAttr(MlirAttribute attr);
/// Returns the TypeID of MDNodeAttr.
MLIR_CAPI_EXPORTED MlirTypeID mlirLLVMMDNodeAttrGetTypeID(void);
/// Returns the number of operands in an LLVM MDNodeAttr.
MLIR_CAPI_EXPORTED intptr_t
mlirLLVMMDNodeAttrGetNumOperands(MlirAttribute attr);
/// Returns the operand at the given index of an LLVM MDNodeAttr.
MLIR_CAPI_EXPORTED MlirAttribute
mlirLLVMMDNodeAttrGetOperand(MlirAttribute attr, intptr_t index);
#ifdef __cplusplus
}
#endif

View File

@ -1686,4 +1686,75 @@ def UWTableKindAttr : LLVM_Attr<"UWTableKind", "uwtableKind"> {
let assemblyFormat = "`<` $uwtableKind `>`";
}
//===----------------------------------------------------------------------===//
// Metadata Attributes
//===----------------------------------------------------------------------===//
//
// These attributes model LLVM IR metadata nodes (llvm::Metadata and its
// subclasses). They can be nested to form arbitrary metadata trees and are
// translated to their LLVM IR counterparts during MLIR-to-LLVM-IR conversion.
def LLVM_MDStringAttr : LLVM_Attr<"MDString", "md_string"> {
let summary = "LLVM metadata string";
let description = [{
Wraps a string as an LLVM metadata node, corresponding to
`llvm::MDString` in LLVM IR.
Example:
```mlir
#llvm.md_string<"foo.buffer">
```
}];
let parameters = (ins "StringAttr":$value);
let assemblyFormat = "`<` $value `>`";
}
def LLVM_MDConstantAttr : LLVM_Attr<"MDConstant", "md_const"> {
let summary = "LLVM constant-as-metadata";
let description = [{
Wraps an attribute as an LLVM metadata node, corresponding to
`llvm::ConstantAsMetadata` wrapping a `llvm::Constant*` in LLVM IR.
Currently, only integers/IntegerAttrs supported.
Example:
```mlir
#llvm.md_const<42 : i32>
```
}];
let parameters = (ins "Attribute":$value);
let assemblyFormat = "`<` $value `>`";
}
def LLVM_MDFuncAttr : LLVM_Attr<"MDFunc", "md_func"> {
let summary = "LLVM function-as-metadata";
let description = [{
References a function (or global) symbol as LLVM metadata, corresponding
to `llvm::ValueAsMetadata::get(function)` in LLVM IR.
Example:
```mlir
#llvm.md_func<@my_kernel>
```
}];
let parameters = (ins "FlatSymbolRefAttr":$name);
let assemblyFormat = "`<` $name `>`";
}
def LLVM_MDNodeAttr : LLVM_Attr<"MDNode", "md_node"> {
let summary = "LLVM metadata node";
let description = [{
Represents an LLVM metadata node. The operands
can be any combination of metadata attributes: `#llvm.md_string`,
`#llvm.md_const`, `#llvm.md_func`, or nested `#llvm.md_node`.
Example:
```mlir
#llvm.md_node<#llvm.md_const<0 : i32>, #llvm.md_string<"foo.buffer">>
#llvm.md_node<>
```
}];
let parameters = (ins OptionalArrayRefParameter<"Attribute">:$operands);
let assemblyFormat = "`<` (`>`) : ($operands^ `>`)?";
}
#endif // LLVMIR_ATTRDEFS

View File

@ -2576,4 +2576,53 @@ def LLVM_ModuleFlagsOp
let hasVerifier = 1;
}
//===--------------------------------------------------------------------===//
// NamedMetadataOp
//===--------------------------------------------------------------------===//
def LLVM_NamedMetadataOp
: LLVM_Op<"named_metadata"> {
let summary = "Module-level named metadata";
let description = [{
Represents an LLVM named metadata node (`llvm::NamedMDNode`). Named
metadata nodes are module-level metadata that associate a name string
with a list of metadata nodes. Each operand must be an `#llvm.md_node`.
Note: cyclic metadata graphs are not supported. Because metadata attributes
are represented as MLIR attributes (which form a tree), there is no way to
express a metadata node that directly or transitively references itself.
LLVM IR permits such cycles (e.g. `!0 = !{!0}`), but they cannot be
represented here and will not round-trip through this op.
Example:
```mlir
llvm.named_metadata "foo.version" [
#llvm.md_node<#llvm.md_const<2 : i32>,
#llvm.md_const<9 : i32>,
#llvm.md_const<0 : i32>
>
]
llvm.named_metadata "foo.kernel" [
#llvm.md_node<
#llvm.md_func<@my_kernel>,
#llvm.md_node<>,
#llvm.md_node<
#llvm.md_node<#llvm.md_const<0 : i32>,
#llvm.md_string<"foo.buffer">
>
>
>
]
```
}];
let arguments = (ins StrAttr:$metadata_name, ArrayAttr:$nodes);
let assemblyFormat = [{
$metadata_name $nodes attr-dict
}];
let llvmBuilder = [{
convertNamedMetadataOp($metadata_name, $nodes, builder, moduleTranslation);
}];
}
#endif // LLVMIR_OPS

View File

@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include <string>
#include <vector>
#include "mlir-c/Dialect/LLVM.h"
#include "mlir-c/IR.h"
@ -222,10 +223,175 @@ struct PointerType : PyConcreteType<PointerType> {
}
};
//===--------------------------------------------------------------------===//
// FunctionType
//===--------------------------------------------------------------------===//
struct FunctionType : PyConcreteType<FunctionType> {
static constexpr IsAFunctionTy isaFunction = mlirTypeIsALLVMFunctionType;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirLLVMFunctionTypeGetTypeID;
static constexpr const char *pyClassName = "FunctionType";
static inline const MlirStringRef name = mlirLLVMFunctionTypeGetName();
using Base::Base;
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](PyType &resultType, const std::vector<PyType> &argumentTypes,
bool isVarArg) {
std::vector<MlirType> argTypes(argumentTypes.size());
std::copy(argumentTypes.begin(), argumentTypes.end(),
argTypes.begin());
return FunctionType(
resultType.getContext(),
mlirLLVMFunctionTypeGet(resultType, argTypes.size(),
argTypes.data(), isVarArg));
},
"result_type"_a, "argument_types"_a, nb::kw_only(),
"is_var_arg"_a = false);
c.def_prop_ro("return_type", [](const FunctionType &type) {
return mlirLLVMFunctionTypeGetReturnType(type);
});
c.def_prop_ro("num_inputs", [](const FunctionType &type) {
return mlirLLVMFunctionTypeGetNumInputs(type);
});
c.def_prop_ro("inputs", [](const FunctionType &type) {
nb::list inputs;
for (intptr_t i = 0, e = mlirLLVMFunctionTypeGetNumInputs(type); i < e;
++i) {
inputs.append(mlirLLVMFunctionTypeGetInput(type, i));
}
return inputs;
});
c.def_prop_ro("is_var_arg", [](const FunctionType &type) {
return mlirLLVMFunctionTypeIsVarArg(type);
});
}
};
//===--------------------------------------------------------------------===//
// Metadata Attributes
//===--------------------------------------------------------------------===//
struct MDStringAttr : PyConcreteAttribute<MDStringAttr> {
static constexpr IsAFunctionTy isaFunction = mlirLLVMAttrIsAMDStringAttr;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirLLVMMDStringAttrGetTypeID;
static constexpr const char *pyClassName = "MDStringAttr";
using Base::Base;
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](const std::string &value, DefaultingPyMlirContext context) {
return MDStringAttr(
context->getRef(),
mlirLLVMMDStringAttrGet(
context.get()->get(),
mlirStringRefCreate(value.data(), value.size())));
},
"value"_a, nb::kw_only(), "context"_a = nb::none());
c.def_prop_ro("value", [](const MDStringAttr &self) {
MlirStringRef ref = mlirLLVMMDStringAttrGetValue(self);
return nb::str(ref.data, ref.length);
});
}
};
struct MDConstantAttr : PyConcreteAttribute<MDConstantAttr> {
static constexpr IsAFunctionTy isaFunction = mlirLLVMAttrIsAMDConstantAttr;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirLLVMMDConstantAttrGetTypeID;
static constexpr const char *pyClassName = "MDConstantAttr";
using Base::Base;
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](PyAttribute &valueAttr, DefaultingPyMlirContext context) {
return MDConstantAttr(
context->getRef(),
mlirLLVMMDConstantAttrGet(context.get()->get(), valueAttr));
},
"value"_a, nb::kw_only(), "context"_a = nb::none());
c.def_prop_ro("value", [](const MDConstantAttr &self) {
return mlirLLVMMDConstantAttrGetValue(self);
});
}
};
struct MDFuncAttr : PyConcreteAttribute<MDFuncAttr> {
static constexpr IsAFunctionTy isaFunction = mlirLLVMAttrIsAMDFuncAttr;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirLLVMMDFuncAttrGetTypeID;
static constexpr const char *pyClassName = "MDFuncAttr";
using Base::Base;
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](const std::string &name, DefaultingPyMlirContext context) {
MlirAttribute symRef = mlirFlatSymbolRefAttrGet(
context.get()->get(),
mlirStringRefCreate(name.data(), name.size()));
return MDFuncAttr(
context->getRef(),
mlirLLVMMDFuncAttrGet(context.get()->get(), symRef));
},
"name"_a, nb::kw_only(), "context"_a = nb::none());
c.def_prop_ro("name", [](const MDFuncAttr &self) {
MlirAttribute symRef = mlirLLVMMDFuncAttrGetName(self);
MlirStringRef ref = mlirFlatSymbolRefAttrGetValue(symRef);
return nb::str(ref.data, ref.length);
});
}
};
struct MDNodeAttr : PyConcreteAttribute<MDNodeAttr> {
static constexpr IsAFunctionTy isaFunction = mlirLLVMAttrIsAMDNodeAttr;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirLLVMMDNodeAttrGetTypeID;
static constexpr const char *pyClassName = "MDNodeAttr";
using Base::Base;
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](const std::vector<PyAttribute> &operands,
DefaultingPyMlirContext context) {
std::vector<MlirAttribute> operands_(operands.size());
std::copy(operands.begin(), operands.end(), operands_.begin());
return MDNodeAttr(context->getRef(),
mlirLLVMMDNodeAttrGet(context.get()->get(),
operands_.size(),
operands_.data()));
},
"operands"_a, nb::kw_only(), "context"_a = nb::none());
c.def_prop_ro("num_operands", [](const MDNodeAttr &self) {
return mlirLLVMMDNodeAttrGetNumOperands(self);
});
c.def("__getitem__", [](const MDNodeAttr &self, intptr_t index) {
intptr_t n = mlirLLVMMDNodeAttrGetNumOperands(self);
if (index < 0 || index >= n)
throw nb::index_error("MDNodeAttr operand index out of range");
return mlirLLVMMDNodeAttrGetOperand(self, index);
});
c.def("__len__", [](const MDNodeAttr &self) {
return mlirLLVMMDNodeAttrGetNumOperands(self);
});
}
};
static void populateDialectLLVMSubmodule(nanobind::module_ &m) {
StructType::bind(m);
ArrayType::bind(m);
PointerType::bind(m);
FunctionType::bind(m);
MDStringAttr::bind(m);
MDConstantAttr::bind(m);
MDFuncAttr::bind(m);
MDNodeAttr::bind(m);
m.def(
"translate_module_to_llvmir",

View File

@ -85,6 +85,14 @@ MlirStringRef mlirLLVMFunctionTypeGetName(void) {
return wrap(LLVMFunctionType::name);
}
bool mlirTypeIsALLVMFunctionType(MlirType type) {
return isa<LLVM::LLVMFunctionType>(unwrap(type));
}
MlirTypeID mlirLLVMFunctionTypeGetTypeID(void) {
return wrap(LLVM::LLVMFunctionType::getTypeID());
}
intptr_t mlirLLVMFunctionTypeGetNumInputs(MlirType type) {
return llvm::cast<LLVM::LLVMFunctionType>(unwrap(type)).getNumParams();
}
@ -99,6 +107,10 @@ MlirType mlirLLVMFunctionTypeGetReturnType(MlirType type) {
return wrap(llvm::cast<LLVM::LLVMFunctionType>(unwrap(type)).getReturnType());
}
bool mlirLLVMFunctionTypeIsVarArg(MlirType type) {
return llvm::cast<LLVM::LLVMFunctionType>(unwrap(type)).isVarArg();
}
bool mlirTypeIsALLVMStructType(MlirType type) {
return isa<LLVM::LLVMStructType>(unwrap(type));
}
@ -523,3 +535,82 @@ MlirAttribute mlirLLVMDIAnnotationAttrGet(MlirContext ctx, MlirAttribute name,
MlirStringRef mlirLLVMDIAnnotationAttrGetName(void) {
return wrap(DIAnnotationAttr::name);
}
//===----------------------------------------------------------------------===//
// Metadata Attributes
//===----------------------------------------------------------------------===//
MlirAttribute mlirLLVMMDStringAttrGet(MlirContext ctx, MlirStringRef value) {
return wrap(MDStringAttr::get(unwrap(ctx),
StringAttr::get(unwrap(ctx), unwrap(value))));
}
bool mlirLLVMAttrIsAMDStringAttr(MlirAttribute attr) {
return isa<MDStringAttr>(unwrap(attr));
}
MlirTypeID mlirLLVMMDStringAttrGetTypeID(void) {
return wrap(MDStringAttr::getTypeID());
}
MlirStringRef mlirLLVMMDStringAttrGetValue(MlirAttribute attr) {
return wrap(cast<MDStringAttr>(unwrap(attr)).getValue().getValue());
}
MlirAttribute mlirLLVMMDConstantAttrGet(MlirContext ctx,
MlirAttribute valueAttr) {
return wrap(MDConstantAttr::get(unwrap(ctx), unwrap(valueAttr)));
}
bool mlirLLVMAttrIsAMDConstantAttr(MlirAttribute attr) {
return isa<MDConstantAttr>(unwrap(attr));
}
MlirTypeID mlirLLVMMDConstantAttrGetTypeID(void) {
return wrap(MDConstantAttr::getTypeID());
}
MlirAttribute mlirLLVMMDConstantAttrGetValue(MlirAttribute attr) {
return wrap((Attribute)cast<MDConstantAttr>(unwrap(attr)).getValue());
}
MlirAttribute mlirLLVMMDFuncAttrGet(MlirContext ctx, MlirAttribute name) {
return wrap(
MDFuncAttr::get(unwrap(ctx), cast<FlatSymbolRefAttr>(unwrap(name))));
}
bool mlirLLVMAttrIsAMDFuncAttr(MlirAttribute attr) {
return isa<MDFuncAttr>(unwrap(attr));
}
MlirTypeID mlirLLVMMDFuncAttrGetTypeID(void) {
return wrap(MDFuncAttr::getTypeID());
}
MlirAttribute mlirLLVMMDFuncAttrGetName(MlirAttribute attr) {
return wrap((Attribute)cast<MDFuncAttr>(unwrap(attr)).getName());
}
MlirAttribute mlirLLVMMDNodeAttrGet(MlirContext ctx, intptr_t nOperands,
MlirAttribute const *operands) {
SmallVector<Attribute> attrStorage;
attrStorage.reserve(nOperands);
return wrap(MDNodeAttr::get(unwrap(ctx),
unwrapList(nOperands, operands, attrStorage)));
}
bool mlirLLVMAttrIsAMDNodeAttr(MlirAttribute attr) {
return isa<MDNodeAttr>(unwrap(attr));
}
MlirTypeID mlirLLVMMDNodeAttrGetTypeID(void) {
return wrap(MDNodeAttr::getTypeID());
}
intptr_t mlirLLVMMDNodeAttrGetNumOperands(MlirAttribute attr) {
return cast<MDNodeAttr>(unwrap(attr)).getOperands().size();
}
MlirAttribute mlirLLVMMDNodeAttrGetOperand(MlirAttribute attr, intptr_t index) {
return wrap(cast<MDNodeAttr>(unwrap(attr)).getOperands()[index]);
}

View File

@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Operation.h"
#include "mlir/Interfaces/CallInterfaces.h"
@ -215,6 +216,54 @@ convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder,
return success();
}
/// Recursively converts an MLIR metadata attribute to an LLVM metadata node.
static llvm::Metadata *
convertMetadataAttr(Attribute attr, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
return llvm::TypeSwitch<Attribute, llvm::Metadata *>(attr)
.Case<LLVM::MDStringAttr>([&](auto a) -> llvm::Metadata * {
return llvm::MDString::get(builder.getContext(),
a.getValue().getValue());
})
.Case<LLVM::MDConstantAttr>([&](auto a) -> llvm::Metadata * {
IntegerAttr intAttr = llvm::dyn_cast<IntegerAttr>(a.getValue());
if (!intAttr)
return nullptr;
return llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(
llvm::Type::getIntNTy(builder.getContext(),
intAttr.getType().getIntOrFloatBitWidth()),
intAttr.getValue()));
})
.Case<LLVM::MDFuncAttr>([&](auto a) -> llvm::Metadata * {
if (llvm::Function *fn =
moduleTranslation.lookupFunction(a.getName().getValue()))
return llvm::ValueAsMetadata::get(fn);
return nullptr;
})
.Case<LLVM::MDNodeAttr>([&](auto a) -> llvm::Metadata * {
SmallVector<llvm::Metadata *> operands;
for (Attribute op : a.getOperands())
operands.push_back(
convertMetadataAttr(op, builder, moduleTranslation));
return llvm::MDNode::get(builder.getContext(), operands);
})
.Default([](auto) -> llvm::Metadata * { return nullptr; });
}
static void convertNamedMetadataOp(StringRef metadataName, ArrayAttr nodes,
llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
llvm::NamedMDNode *namedMD =
llvmModule->getOrInsertNamedMetadata(metadataName);
for (Attribute nodeAttr : nodes) {
llvm::Metadata *md =
convertMetadataAttr(nodeAttr, builder, moduleTranslation);
if (auto *mdNode = llvm::dyn_cast_or_null<llvm::MDNode>(md))
namedMD->addOperand(mdNode);
}
}
static void convertLinkerOptionsOp(ArrayAttr options,
llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {

View File

@ -6,7 +6,7 @@ from ._llvm_ops_gen import *
from ._llvm_ops_gen import _Dialect
from ._llvm_enum_gen import *
from .._mlir_libs._mlirDialectsLLVM import *
from ..ir import Value
from ..ir import Value, IntegerType, IntegerAttr
from ._ods_common import get_op_result_or_op_results as _get_op_result_or_op_results
@ -14,3 +14,16 @@ def mlir_constant(value, *, loc=None, ip=None) -> Value:
return _get_op_result_or_op_results(
ConstantOp(res=value.type, value=value, loc=loc, ip=ip)
)
def md_const(val, *, width=32, context=None):
if not isinstance(val, int):
raise NotImplementedError(
f"{val=} not supported; only integers currently supported."
)
i_type = IntegerType.get_signless(width, context=context)
return MDConstantAttr.get(IntegerAttr.get(i_type, val), context=context)
def md_str(s, *, context=None):
return MDStringAttr.get(s, context=context)

View File

@ -1117,3 +1117,39 @@ llvm.func @escapedtypename() {
%1 = llvm.alloca %0 x !llvm.struct<"bucket<string, double, '\\b'>::Iterator", (ptr, i64, i64)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
llvm.return
}
// Metadata attributes and llvm.named_metadata op.
llvm.func @md_kernel() {
llvm.return
}
// CHECK: llvm.named_metadata "foo.version" [#llvm.md_node<#llvm.md_const<1 : i32>, #llvm.md_const<0 : i32>, #llvm.md_const<0 : i32>>]
llvm.named_metadata "foo.version" [
#llvm.md_node<
#llvm.md_const<1 : i32>,
#llvm.md_const<0 : i32>,
#llvm.md_const<0 : i32>
>
]
// CHECK: llvm.named_metadata "foo.language" [#llvm.md_node<#llvm.md_string<"Bar">, #llvm.md_const<1 : i32>, #llvm.md_const<2 : i32>>]
llvm.named_metadata "foo.language" [
#llvm.md_node<
#llvm.md_string<"Bar">,
#llvm.md_const<1 : i32>,
#llvm.md_const<2 : i32>
>
]
// CHECK: llvm.named_metadata "foo.kernel" [#llvm.md_node<#llvm.md_func<@md_kernel>, #llvm.md_node<>, #llvm.md_node<#llvm.md_const<0 : i32>, #llvm.md_string<"foo.buffer">>>]
llvm.named_metadata "foo.kernel" [
#llvm.md_node<
#llvm.md_func<@md_kernel>,
#llvm.md_node<>,
#llvm.md_node<
#llvm.md_const<0 : i32>,
#llvm.md_string<"foo.buffer">
>
>
]

View File

@ -0,0 +1,45 @@
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
// Tests LLVM named metadata translation with deeply nested metadata trees.
// CHECK: !foo.version = !{![[VERSION:[0-9]+]]}
// CHECK: !foo.language_version = !{![[LANG:[0-9]+]]}
// CHECK: !foo.kernel = !{![[KERNEL:[0-9]+]]}
llvm.func @my_kernel() {
llvm.return
}
llvm.named_metadata "foo.version" [
#llvm.md_node<#llvm.md_const<1 : i32>,
#llvm.md_const<0 : i32>,
#llvm.md_const<0 : i32>>
]
// CHECK-DAG: ![[VERSION]] = !{i32 1, i32 0, i32 0}
llvm.named_metadata "foo.language_version" [
#llvm.md_node<#llvm.md_string<"Bar">,
#llvm.md_const<1 : i32>,
#llvm.md_const<2 : i32>,
#llvm.md_const<3 : i32>>
]
// CHECK-DAG: ![[LANG]] = !{!"Bar", i32 1, i32 2, i32 3}
#buf0 = #llvm.md_node<
#llvm.md_const<0 : i32>, #llvm.md_string<"foo.buffer">,
#llvm.md_string<"foo.idx">, #llvm.md_const<0 : i32>,
#llvm.md_const<1 : i32>, #llvm.md_string<"foo.read">,
#llvm.md_string<"foo.address_space">, #llvm.md_const<1 : i32>,
#llvm.md_string<"foo.size">, #llvm.md_const<4 : i32>,
#llvm.md_string<"foo.align_size">, #llvm.md_const<4 : i32>>
// CHECK-DAG: ![[A0:[0-9]+]] = !{i32 0, !"foo.buffer", !"foo.idx", i32 0, i32 1, !"foo.read", !"foo.address_space", i32 1, !"foo.size", i32 4, !"foo.align_size", i32 4}
llvm.named_metadata "foo.kernel" [
#llvm.md_node<
#llvm.md_func<@my_kernel>,
#llvm.md_node<>,
#llvm.md_node<#buf0>>
]
// CHECK-DAG: ![[KERNEL]] = !{ptr @my_kernel, ![[EMPTY:[0-9]+]], ![[ARGS:[0-9]+]]}
// CHECK-DAG: ![[EMPTY]] = !{}
// CHECK-DAG: ![[ARGS]] = !{![[A0]]}

View File

@ -215,3 +215,140 @@ def testTranslateToLLVMIR():
# CHECK: ret i64 %3
# CHECK: }
print(llvm.translate_module_to_llvmir(module.operation))
# CHECK-LABEL: testMetadataAttrs
@constructAndPrintInModule
def testMetadataAttrs():
# MDStringAttr
md_str = llvm.MDStringAttr.get("foo.buffer")
# CHECK: #llvm.md_string<"foo.buffer">
print(md_str)
assert md_str.value == "foo.buffer"
# MDConstantAttr
i32 = IntegerType.get_signless(32)
md_const = llvm.MDConstantAttr.get(IntegerAttr.get(i32, 42))
# CHECK: #llvm.md_const<42 : i32>
print(md_const)
# MDFuncAttr
md_func = llvm.MDFuncAttr.get("my_kernel")
# CHECK: #llvm.md_func<@my_kernel>
print(md_func)
assert md_func.name == "my_kernel"
# MDNodeAttr - empty
md_empty = llvm.MDNodeAttr.get([])
# CHECK: #llvm.md_node<>
print(md_empty)
assert len(md_empty) == 0
# MDNodeAttr - with operands
md_node = llvm.MDNodeAttr.get([md_const, md_str])
# CHECK: #llvm.md_node<#llvm.md_const<42 : i32>, #llvm.md_string<"foo.buffer">>
print(md_node)
assert len(md_node) == 2
# MDNodeAttr - __getitem__
# CHECK: #llvm.md_const<42 : i32>
print(md_node[0])
# CHECK: #llvm.md_string<"foo.buffer">
print(md_node[1])
assert str(md_node[0]) == str(md_const)
assert str(md_node[1]) == str(md_str)
# MDNodeAttr - nested
md_nested = llvm.MDNodeAttr.get([md_node, md_empty])
# CHECK: #llvm.md_node<#llvm.md_node<#llvm.md_const<42 : i32>, #llvm.md_string<"foo.buffer">>, #llvm.md_node<>>
print(md_nested)
assert len(md_nested) == 2
# CHECK-LABEL: testNamedMetadata
@constructAndPrintInModule
def testNamedMetadata():
void = Type.parse("!llvm.void")
func_ty = llvm.FunctionType.get(void, [])
llvm.LLVMFuncOp("my_kernel", TypeAttr.get(func_ty))
# CHECK-LABEL: llvm.func @my_kernel()
llvm.NamedMetadataOp(
metadata_name="foo.version",
nodes=ArrayAttr.get(
[
llvm.MDNodeAttr.get(
[llvm.md_const(1), llvm.md_const(0), llvm.md_const(0)]
)
]
),
)
# CHECK: llvm.named_metadata "foo.version" [#llvm.md_node<#llvm.md_const<1 : i32>, #llvm.md_const<0 : i32>, #llvm.md_const<0 : i32>>]
llvm.NamedMetadataOp(
metadata_name="foo.language_version",
nodes=ArrayAttr.get(
[
llvm.MDNodeAttr.get(
[
llvm.md_str("Bar"),
llvm.md_const(1),
llvm.md_const(2),
llvm.md_const(3),
]
)
]
),
)
# CHECK: llvm.named_metadata "foo.language_version" [#llvm.md_node<#llvm.md_string<"Bar">, #llvm.md_const<1 : i32>, #llvm.md_const<2 : i32>, #llvm.md_const<3 : i32>>]
buf0 = llvm.MDNodeAttr.get(
[
llvm.md_const(0),
llvm.md_str("foo.buffer"),
llvm.md_str("foo.idx"),
llvm.md_const(0),
llvm.md_const(1),
llvm.md_str("foo.read"),
llvm.md_str("foo.address_space"),
llvm.md_const(1),
llvm.md_str("foo.size"),
llvm.md_const(4),
llvm.md_str("foo.align_size"),
llvm.md_const(4),
]
)
llvm.NamedMetadataOp(
metadata_name="foo.kernel",
nodes=ArrayAttr.get(
[
llvm.MDNodeAttr.get(
[
llvm.MDFuncAttr.get("my_kernel"),
llvm.MDNodeAttr.get([]),
buf0,
]
)
]
),
)
# CHECK: llvm.named_metadata "foo.kernel" [
# CHECK-SAME: #llvm.md_node<
# CHECK-SAME: #llvm.md_func<@my_kernel>,
# CHECK-SAME: #llvm.md_node<>,
# CHECK-SAME: #llvm.md_node<
# CHECK-SAME: #llvm.md_const<0 : i32>,
# CHECK-SAME: #llvm.md_string<"foo.buffer">,
# CHECK-SAME: #llvm.md_string<"foo.idx">,
# CHECK-SAME: #llvm.md_const<0 : i32>,
# CHECK-SAME: #llvm.md_const<1 : i32>,
# CHECK-SAME: #llvm.md_string<"foo.read">,
# CHECK-SAME: #llvm.md_string<"foo.address_space">,
# CHECK-SAME: #llvm.md_const<1 : i32>,
# CHECK-SAME: #llvm.md_string<"foo.size">,
# CHECK-SAME: #llvm.md_const<4 : i32>,
# CHECK-SAME: #llvm.md_string<"foo.align_size">,
# CHECK-SAME: #llvm.md_const<4 : i32>>
# CHECK-SAME: >]