[mlir] Use llvm::filter_to_vector. NFC. (#117655)

This got recently added to SmallVectorExtras:
https://github.com/llvm/llvm-project/pull/117460.
This commit is contained in:
Jakub Kuderski 2024-11-26 09:11:36 -05:00 committed by GitHub
parent 619e4b7154
commit f4d7586343
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 32 additions and 37 deletions

View File

@ -701,9 +701,9 @@ computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool,
// Filter operands with dynamic dimension
auto operandsWithDynamicDim =
llvm::to_vector(llvm::make_filter_range(operands, [&](Value operand) {
llvm::filter_to_vector(operands, [&](Value operand) {
return cast<RankedTensorType>(operand.getType()).isDynamicDim(dim);
}));
});
// If no operand has a dynamic dimension, it means all sizes were 1
if (operandsWithDynamicDim.empty())

View File

@ -99,10 +99,9 @@ SmallVector<Value> mlir::LLVM::MemsetInlineOp::getAccessedOperands() {
}
SmallVector<Value> mlir::LLVM::CallOp::getAccessedOperands() {
return llvm::to_vector(
llvm::make_filter_range(getArgOperands(), [](Value arg) {
return isa<LLVMPointerType>(arg.getType());
}));
return llvm::filter_to_vector(getArgOperands(), [](Value arg) {
return isa<LLVMPointerType>(arg.getType());
});
}
#include "mlir/Dialect/LLVMIR/LLVMInterfaces.cpp.inc"

View File

@ -375,10 +375,8 @@ static void calculateTileOffsetsAndSizes(
b.setInsertionPointToStart(forallOp.getBody(0));
SmallVector<Value> threadIds = forallOp.getInductionVars();
SmallVector<OpFoldResult> nonZeroNumThreads =
llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) {
return !isConstantIntValue(ofr, 0);
}));
SmallVector<OpFoldResult> nonZeroNumThreads = llvm::filter_to_vector(
numThreads, [](OpFoldResult ofr) { return !isConstantIntValue(ofr, 0); });
int64_t nLoops = loopRanges.size();
tiledOffsets.reserve(nLoops);
tiledSizes.reserve(nLoops);
@ -656,10 +654,8 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
Operation *tiledOp = nullptr;
SmallVector<OpFoldResult> nonZeroNumThreads =
llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) {
return !isConstantIntValue(ofr, 0);
}));
SmallVector<OpFoldResult> nonZeroNumThreads = llvm::filter_to_vector(
numThreads, [](OpFoldResult ofr) { return !isConstantIntValue(ofr, 0); });
SmallVector<Value> materializedNonZeroNumThreads =
getValueOrCreateConstantIndexOp(b, loc, nonZeroNumThreads);

View File

@ -1090,8 +1090,8 @@ getPackUnpackNormalizedPerm(int rank, ArrayRef<int64_t> perm) {
SmallVector<int64_t> vec(rank, kNonTiledMarker);
for (auto [index, value] : llvm::enumerate(perm))
vec[value] = index;
SmallVector<int64_t> normalizedPerm = llvm::to_vector(llvm::make_filter_range(
vec, [&](int64_t v) { return v != kNonTiledMarker; }));
SmallVector<int64_t> normalizedPerm = llvm::filter_to_vector(
vec, [&](int64_t v) { return v != kNonTiledMarker; });
// This inverts the permutation in addition to normalizing so invert back.
return invertPermutationVector(normalizedPerm);
}

View File

@ -695,8 +695,8 @@ struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
}
return true;
};
auto newOperands = llvm::to_vector<8>(
llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape));
auto newOperands = llvm::filter_to_vector<8>(op->getOperands(),
isPotentiallyNonEmptyShape);
// Reduce op to equivalent without empty shape operands.
if (newOperands.size() < op.getNumOperands()) {

View File

@ -1309,10 +1309,10 @@ LogicalResult OpTrait::impl::verifyNoRegionArguments(Operation *op) {
LogicalResult OpTrait::impl::verifyElementwise(Operation *op) {
auto isMappableType = llvm::IsaPred<VectorType, TensorType>;
auto resultMappableTypes = llvm::to_vector<1>(
llvm::make_filter_range(op->getResultTypes(), isMappableType));
auto operandMappableTypes = llvm::to_vector<2>(
llvm::make_filter_range(op->getOperandTypes(), isMappableType));
auto resultMappableTypes =
llvm::filter_to_vector<1>(op->getResultTypes(), isMappableType);
auto operandMappableTypes =
llvm::filter_to_vector<2>(op->getOperandTypes(), isMappableType);
// If the op only has scalar operand/result types, then we have nothing to
// check.

View File

@ -141,8 +141,8 @@ LogicalResult mlir::verifyCompatibleShapes(TypeRange types) {
}
// Remove all unranked shapes
auto shapes = llvm::to_vector<8>(llvm::make_filter_range(
shapedTypes, [](auto shapedType) { return shapedType.hasRank(); }));
auto shapes = llvm::filter_to_vector<8>(
shapedTypes, [](auto shapedType) { return shapedType.hasRank(); });
if (shapes.empty())
return success();

View File

@ -304,11 +304,11 @@ mlir::detail::getDevicePropertyValue(DataLayoutEntryInterface entry) {
DataLayoutEntryList
mlir::detail::filterEntriesForType(DataLayoutEntryListRef entries,
TypeID typeID) {
return llvm::to_vector<4>(llvm::make_filter_range(
return llvm::filter_to_vector<4>(
entries, [typeID](DataLayoutEntryInterface entry) {
auto type = llvm::dyn_cast_if_present<Type>(entry.getKey());
return type && type.getTypeID() == typeID;
}));
});
}
DataLayoutEntryInterface
@ -393,9 +393,9 @@ static DataLayoutSpecInterface getCombinedDataLayout(Operation *leaf) {
// Create the list of non-null specs (null/missing specs can be safely
// ignored) from the outermost to the innermost.
auto nonNullSpecs = llvm::to_vector<2>(llvm::make_filter_range(
auto nonNullSpecs = llvm::filter_to_vector<2>(
llvm::reverse(specs),
[](DataLayoutSpecInterface iface) { return iface != nullptr; }));
[](DataLayoutSpecInterface iface) { return iface != nullptr; });
// Combine the specs using the innermost as anchor.
if (DataLayoutSpecInterface current = getSpec(leaf))

View File

@ -10,6 +10,7 @@
#include "mlir/TableGen/GenInfo.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Error.h"
@ -161,13 +162,12 @@ void printParseConditional(mlir::raw_indented_ostream &ios,
return formatv("read{0}", capitalize(name));
};
auto parsedArgs =
llvm::to_vector(make_filter_range(args, [](const Init *const attr) {
const Record *def = cast<DefInit>(attr)->getDef();
if (def->isSubClassOf("Array"))
return true;
return !def->getValueAsString("cParser").empty();
}));
auto parsedArgs = llvm::filter_to_vector(args, [](const Init *const attr) {
const Record *def = cast<DefInit>(attr)->getDef();
if (def->isSubClassOf("Array"))
return true;
return !def->getValueAsString("cParser").empty();
});
interleave(
zip(parsedArgs, argNames),
@ -277,8 +277,8 @@ void Generator::emitParseHelper(StringRef kind, StringRef returnType,
printParseConditional(ios, args, argNames);
// Compute args to pass to create method.
auto passedArgs = llvm::to_vector(make_filter_range(
argNames, [](StringRef str) { return !str.starts_with("_"); }));
auto passedArgs = llvm::filter_to_vector(
argNames, [](StringRef str) { return !str.starts_with("_"); });
std::string argStr;
raw_string_ostream argStream(argStr);
interleaveComma(passedArgs, argStream,