
The MLIR classes Type/Attribute/Operation/Op/Value support cast/dyn_cast/isa/dyn_cast_or_null functionality through llvm's doCast functionality in addition to defining methods with the same name. This change begins the migration of uses of the method to the corresponding function call as has been decided as more consistent. Note that there still exist classes that only define methods directly, such as AffineExpr, and this does not include work currently to support a functional cast/isa call. Context: - https://mlir.llvm.org/deprecation/ at "Use the free function variants for dyn_cast/cast/isa/…" - Original discussion at https://discourse.llvm.org/t/preferred-casting-style-going-forward/68443 Implementation: This patch updates all remaining uses of the deprecated functionality in mlir/. This was done with clang-tidy as described below and further modifications to GPUBase.td and OpenMPOpsInterfaces.td. Steps are described per line, as comments are removed by git: 0. Retrieve the change from the following to build clang-tidy with an additional check: main...tpopp:llvm-project:tidy-cast-check 1. Build clang-tidy 2. Run clang-tidy over your entire codebase while disabling all checks and enabling the one relevant one. Run on all header files also. 3. Delete .inc files that were also modified, so the next build rebuilds them to a pure state. ``` ninja -C $BUILD_DIR clang-tidy run-clang-tidy -clang-tidy-binary=$BUILD_DIR/bin/clang-tidy -checks='-*,misc-cast-functions'\ -header-filter=mlir/ mlir/* -fix rm -rf $BUILD_DIR/tools/mlir/**/*.inc ``` Differential Revision: https://reviews.llvm.org/D151542
171 lines
6.9 KiB
C++
171 lines
6.9 KiB
C++
//===- Utils.cpp - Utilities to support the Linalg dialect ----------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements utilities for the Linalg dialect.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Arith/Utils/Utils.h"
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "llvm/ADT/SmallBitVector.h"
|
|
|
|
using namespace mlir;
|
|
|
|
/// Matches a ConstantIndexOp.
|
|
/// TODO: This should probably just be a general matcher that uses matchConstant
|
|
/// and checks the operation for an index type.
|
|
detail::op_matcher<arith::ConstantIndexOp> mlir::matchConstantIndex() {
|
|
return detail::op_matcher<arith::ConstantIndexOp>();
|
|
}
|
|
|
|
// Returns `success` when any of the elements in `ofrs` was produced by
|
|
// arith::ConstantIndexOp. In that case the constant attribute replaces the
|
|
// Value. Returns `failure` when no folding happened.
|
|
LogicalResult mlir::foldDynamicIndexList(Builder &b,
|
|
SmallVectorImpl<OpFoldResult> &ofrs) {
|
|
bool valuesChanged = false;
|
|
for (OpFoldResult &ofr : ofrs) {
|
|
if (ofr.is<Attribute>())
|
|
continue;
|
|
// Newly static, move from Value to constant.
|
|
if (auto cstOp = llvm::dyn_cast_if_present<Value>(ofr)
|
|
.getDefiningOp<arith::ConstantIndexOp>()) {
|
|
ofr = b.getIndexAttr(cstOp.value());
|
|
valuesChanged = true;
|
|
}
|
|
}
|
|
return success(valuesChanged);
|
|
}
|
|
|
|
llvm::SmallBitVector mlir::getPositionsOfShapeOne(unsigned rank,
|
|
ArrayRef<int64_t> shape) {
|
|
llvm::SmallBitVector dimsToProject(shape.size());
|
|
for (unsigned pos = 0, e = shape.size(); pos < e && rank > 0; ++pos) {
|
|
if (shape[pos] == 1) {
|
|
dimsToProject.set(pos);
|
|
--rank;
|
|
}
|
|
}
|
|
return dimsToProject;
|
|
}
|
|
|
|
Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
|
|
OpFoldResult ofr) {
|
|
if (auto value = llvm::dyn_cast_if_present<Value>(ofr))
|
|
return value;
|
|
auto attr = dyn_cast<IntegerAttr>(llvm::dyn_cast_if_present<Attribute>(ofr));
|
|
assert(attr && "expect the op fold result casts to an integer attribute");
|
|
return b.create<arith::ConstantIndexOp>(loc, attr.getValue().getSExtValue());
|
|
}
|
|
|
|
Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc,
|
|
Type targetType, Value value) {
|
|
if (targetType == value.getType())
|
|
return value;
|
|
|
|
bool targetIsIndex = targetType.isIndex();
|
|
bool valueIsIndex = value.getType().isIndex();
|
|
if (targetIsIndex ^ valueIsIndex)
|
|
return b.create<arith::IndexCastOp>(loc, targetType, value);
|
|
|
|
auto targetIntegerType = dyn_cast<IntegerType>(targetType);
|
|
auto valueIntegerType = dyn_cast<IntegerType>(value.getType());
|
|
assert(targetIntegerType && valueIntegerType &&
|
|
"unexpected cast between types other than integers and index");
|
|
assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
|
|
|
|
if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
|
|
return b.create<arith::ExtSIOp>(loc, targetIntegerType, value);
|
|
return b.create<arith::TruncIOp>(loc, targetIntegerType, value);
|
|
}
|
|
|
|
Value mlir::convertScalarToDtype(OpBuilder &b, Location loc, Value operand,
|
|
Type toType, bool isUnsignedCast) {
|
|
if (operand.getType() == toType)
|
|
return operand;
|
|
if (auto toIntType = dyn_cast<IntegerType>(toType)) {
|
|
// If operand is floating point, cast directly to the int type.
|
|
if (isa<FloatType>(operand.getType())) {
|
|
if (isUnsignedCast)
|
|
return b.create<arith::FPToUIOp>(loc, toType, operand);
|
|
return b.create<arith::FPToSIOp>(loc, toType, operand);
|
|
}
|
|
// Cast index operands directly to the int type.
|
|
if (operand.getType().isIndex())
|
|
return b.create<arith::IndexCastOp>(loc, toType, operand);
|
|
if (auto fromIntType = dyn_cast<IntegerType>(operand.getType())) {
|
|
// Either extend or truncate.
|
|
if (toIntType.getWidth() > fromIntType.getWidth()) {
|
|
if (isUnsignedCast)
|
|
return b.create<arith::ExtUIOp>(loc, toType, operand);
|
|
return b.create<arith::ExtSIOp>(loc, toType, operand);
|
|
}
|
|
if (toIntType.getWidth() < fromIntType.getWidth())
|
|
return b.create<arith::TruncIOp>(loc, toType, operand);
|
|
}
|
|
} else if (auto toFloatType = dyn_cast<FloatType>(toType)) {
|
|
// If operand is integer, cast directly to the float type.
|
|
// Note that it is unclear how to cast from BF16<->FP16.
|
|
if (isa<IntegerType>(operand.getType())) {
|
|
if (isUnsignedCast)
|
|
return b.create<arith::UIToFPOp>(loc, toFloatType, operand);
|
|
return b.create<arith::SIToFPOp>(loc, toFloatType, operand);
|
|
}
|
|
if (auto fromFloatType = dyn_cast<FloatType>(operand.getType())) {
|
|
if (toFloatType.getWidth() > fromFloatType.getWidth())
|
|
return b.create<arith::ExtFOp>(loc, toFloatType, operand);
|
|
if (toFloatType.getWidth() < fromFloatType.getWidth())
|
|
return b.create<arith::TruncFOp>(loc, toFloatType, operand);
|
|
}
|
|
}
|
|
emitWarning(loc) << "could not cast operand of type " << operand.getType()
|
|
<< " to " << toType;
|
|
return operand;
|
|
}
|
|
|
|
SmallVector<Value>
|
|
mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
|
|
ArrayRef<OpFoldResult> valueOrAttrVec) {
|
|
return llvm::to_vector<4>(
|
|
llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
|
|
return getValueOrCreateConstantIndexOp(b, loc, value);
|
|
}));
|
|
}
|
|
|
|
Value ArithBuilder::_and(Value lhs, Value rhs) {
|
|
return b.create<arith::AndIOp>(loc, lhs, rhs);
|
|
}
|
|
Value ArithBuilder::add(Value lhs, Value rhs) {
|
|
if (isa<FloatType>(lhs.getType()))
|
|
return b.create<arith::AddFOp>(loc, lhs, rhs);
|
|
return b.create<arith::AddIOp>(loc, lhs, rhs);
|
|
}
|
|
Value ArithBuilder::sub(Value lhs, Value rhs) {
|
|
if (isa<FloatType>(lhs.getType()))
|
|
return b.create<arith::SubFOp>(loc, lhs, rhs);
|
|
return b.create<arith::SubIOp>(loc, lhs, rhs);
|
|
}
|
|
Value ArithBuilder::mul(Value lhs, Value rhs) {
|
|
if (isa<FloatType>(lhs.getType()))
|
|
return b.create<arith::MulFOp>(loc, lhs, rhs);
|
|
return b.create<arith::MulIOp>(loc, lhs, rhs);
|
|
}
|
|
Value ArithBuilder::sgt(Value lhs, Value rhs) {
|
|
if (isa<FloatType>(lhs.getType()))
|
|
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT, lhs, rhs);
|
|
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, lhs, rhs);
|
|
}
|
|
Value ArithBuilder::slt(Value lhs, Value rhs) {
|
|
if (isa<FloatType>(lhs.getType()))
|
|
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT, lhs, rhs);
|
|
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, lhs, rhs);
|
|
}
|
|
Value ArithBuilder::select(Value cmp, Value lhs, Value rhs) {
|
|
return b.create<arith::SelectOp>(loc, cmp, lhs, rhs);
|
|
}
|