
The MLIR classes Type/Attribute/Operation/Op/Value support cast/dyn_cast/isa/dyn_cast_or_null functionality through llvm's doCast functionality in addition to defining methods with the same name. This change begins the migration of uses of the method to the corresponding function call as has been decided as more consistent. Note that there still exist classes that only define methods directly, such as AffineExpr, and this does not include work currently to support a functional cast/isa call. Context: - https://mlir.llvm.org/deprecation/ at "Use the free function variants for dyn_cast/cast/isa/…" - Original discussion at https://discourse.llvm.org/t/preferred-casting-style-going-forward/68443 Implementation: This patch updates all remaining uses of the deprecated functionality in mlir/. This was done with clang-tidy as described below and further modifications to GPUBase.td and OpenMPOpsInterfaces.td. Steps are described per line, as comments are removed by git: 0. Retrieve the change from the following to build clang-tidy with an additional check: main...tpopp:llvm-project:tidy-cast-check 1. Build clang-tidy 2. Run clang-tidy over your entire codebase while disabling all checks and enabling the one relevant one. Run on all header files also. 3. Delete .inc files that were also modified, so the next build rebuilds them to a pure state. ``` ninja -C $BUILD_DIR clang-tidy run-clang-tidy -clang-tidy-binary=$BUILD_DIR/bin/clang-tidy -checks='-*,misc-cast-functions'\ -header-filter=mlir/ mlir/* -fix rm -rf $BUILD_DIR/tools/mlir/**/*.inc ``` Differential Revision: https://reviews.llvm.org/D151542
170 lines
6.3 KiB
C++
170 lines
6.3 KiB
C++
|
|
|
|
//===- Interfaces.cpp - C Interface for MLIR Interfaces -------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir-c/Interfaces.h"
|
|
|
|
#include "mlir/CAPI/IR.h"
|
|
#include "mlir/CAPI/Interfaces.h"
|
|
#include "mlir/CAPI/Support.h"
|
|
#include "mlir/CAPI/Wrap.h"
|
|
#include "mlir/IR/ValueRange.h"
|
|
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
|
#include "llvm/ADT/ScopeExit.h"
|
|
#include <optional>
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
|
|
std::optional<RegisteredOperationName>
|
|
getRegisteredOperationName(MlirContext context, MlirStringRef opName) {
|
|
StringRef name(opName.data, opName.length);
|
|
std::optional<RegisteredOperationName> info =
|
|
RegisteredOperationName::lookup(name, unwrap(context));
|
|
return info;
|
|
}
|
|
|
|
std::optional<Location> maybeGetLocation(MlirLocation location) {
|
|
std::optional<Location> maybeLocation;
|
|
if (!mlirLocationIsNull(location))
|
|
maybeLocation = unwrap(location);
|
|
return maybeLocation;
|
|
}
|
|
|
|
SmallVector<Value> unwrapOperands(intptr_t nOperands, MlirValue *operands) {
|
|
SmallVector<Value> unwrappedOperands;
|
|
(void)unwrapList(nOperands, operands, unwrappedOperands);
|
|
return unwrappedOperands;
|
|
}
|
|
|
|
DictionaryAttr unwrapAttributes(MlirAttribute attributes) {
|
|
DictionaryAttr attributeDict;
|
|
if (!mlirAttributeIsNull(attributes))
|
|
attributeDict = llvm::cast<DictionaryAttr>(unwrap(attributes));
|
|
return attributeDict;
|
|
}
|
|
|
|
SmallVector<std::unique_ptr<Region>> unwrapRegions(intptr_t nRegions,
|
|
MlirRegion *regions) {
|
|
// Create a vector of unique pointers to regions and make sure they are not
|
|
// deleted when exiting the scope. This is a hack caused by C++ API expecting
|
|
// an list of unique pointers to regions (without ownership transfer
|
|
// semantics) and C API making ownership transfer explicit.
|
|
SmallVector<std::unique_ptr<Region>> unwrappedRegions;
|
|
unwrappedRegions.reserve(nRegions);
|
|
for (intptr_t i = 0; i < nRegions; ++i)
|
|
unwrappedRegions.emplace_back(unwrap(*(regions + i)));
|
|
auto cleaner = llvm::make_scope_exit([&]() {
|
|
for (auto ®ion : unwrappedRegions)
|
|
region.release();
|
|
});
|
|
return unwrappedRegions;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
bool mlirOperationImplementsInterface(MlirOperation operation,
|
|
MlirTypeID interfaceTypeID) {
|
|
std::optional<RegisteredOperationName> info =
|
|
unwrap(operation)->getRegisteredInfo();
|
|
return info && info->hasInterface(unwrap(interfaceTypeID));
|
|
}
|
|
|
|
bool mlirOperationImplementsInterfaceStatic(MlirStringRef operationName,
|
|
MlirContext context,
|
|
MlirTypeID interfaceTypeID) {
|
|
std::optional<RegisteredOperationName> info = RegisteredOperationName::lookup(
|
|
StringRef(operationName.data, operationName.length), unwrap(context));
|
|
return info && info->hasInterface(unwrap(interfaceTypeID));
|
|
}
|
|
|
|
MlirTypeID mlirInferTypeOpInterfaceTypeID() {
|
|
return wrap(InferTypeOpInterface::getInterfaceID());
|
|
}
|
|
|
|
MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(
|
|
MlirStringRef opName, MlirContext context, MlirLocation location,
|
|
intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
|
|
void *properties, intptr_t nRegions, MlirRegion *regions,
|
|
MlirTypesCallback callback, void *userData) {
|
|
StringRef name(opName.data, opName.length);
|
|
std::optional<RegisteredOperationName> info =
|
|
getRegisteredOperationName(context, opName);
|
|
if (!info)
|
|
return mlirLogicalResultFailure();
|
|
|
|
std::optional<Location> maybeLocation = maybeGetLocation(location);
|
|
SmallVector<Value> unwrappedOperands = unwrapOperands(nOperands, operands);
|
|
DictionaryAttr attributeDict = unwrapAttributes(attributes);
|
|
SmallVector<std::unique_ptr<Region>> unwrappedRegions =
|
|
unwrapRegions(nRegions, regions);
|
|
|
|
SmallVector<Type> inferredTypes;
|
|
if (failed(info->getInterface<InferTypeOpInterface>()->inferReturnTypes(
|
|
unwrap(context), maybeLocation, unwrappedOperands, attributeDict,
|
|
properties, unwrappedRegions, inferredTypes)))
|
|
return mlirLogicalResultFailure();
|
|
|
|
SmallVector<MlirType> wrappedInferredTypes;
|
|
wrappedInferredTypes.reserve(inferredTypes.size());
|
|
for (Type t : inferredTypes)
|
|
wrappedInferredTypes.push_back(wrap(t));
|
|
callback(wrappedInferredTypes.size(), wrappedInferredTypes.data(), userData);
|
|
return mlirLogicalResultSuccess();
|
|
}
|
|
|
|
MlirTypeID mlirInferShapedTypeOpInterfaceTypeID() {
|
|
return wrap(InferShapedTypeOpInterface::getInterfaceID());
|
|
}
|
|
|
|
MlirLogicalResult mlirInferShapedTypeOpInterfaceInferReturnTypes(
|
|
MlirStringRef opName, MlirContext context, MlirLocation location,
|
|
intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
|
|
void *properties, intptr_t nRegions, MlirRegion *regions,
|
|
MlirShapedTypeComponentsCallback callback, void *userData) {
|
|
std::optional<RegisteredOperationName> info =
|
|
getRegisteredOperationName(context, opName);
|
|
if (!info)
|
|
return mlirLogicalResultFailure();
|
|
|
|
std::optional<Location> maybeLocation = maybeGetLocation(location);
|
|
SmallVector<Value> unwrappedOperands = unwrapOperands(nOperands, operands);
|
|
DictionaryAttr attributeDict = unwrapAttributes(attributes);
|
|
SmallVector<std::unique_ptr<Region>> unwrappedRegions =
|
|
unwrapRegions(nRegions, regions);
|
|
|
|
SmallVector<ShapedTypeComponents> inferredTypeComponents;
|
|
if (failed(info->getInterface<InferShapedTypeOpInterface>()
|
|
->inferReturnTypeComponents(
|
|
unwrap(context), maybeLocation,
|
|
mlir::ValueRange(llvm::ArrayRef(unwrappedOperands)),
|
|
attributeDict, properties, unwrappedRegions,
|
|
inferredTypeComponents)))
|
|
return mlirLogicalResultFailure();
|
|
|
|
bool hasRank;
|
|
intptr_t rank;
|
|
const int64_t *shapeData;
|
|
for (const ShapedTypeComponents &t : inferredTypeComponents) {
|
|
if (t.hasRank()) {
|
|
hasRank = true;
|
|
rank = t.getDims().size();
|
|
shapeData = t.getDims().data();
|
|
} else {
|
|
hasRank = false;
|
|
rank = 0;
|
|
shapeData = nullptr;
|
|
}
|
|
callback(hasRank, rank, shapeData, wrap(t.getElementType()),
|
|
wrap(t.getAttribute()), userData);
|
|
}
|
|
return mlirLogicalResultSuccess();
|
|
}
|