[mlir] Reapply 141423 mlir-query combinators plus fix (#146156)

An uninitialized variable that caused a crash
(https://lab.llvm.org/buildbot/#/builders/164/builds/11004) was
identified using the memory analyzer, leading to the reversion of
https://github.com/llvm/llvm-project/pull/141423. This pull request
reapplies the previously reverted changes and includes the fix, which
has been tested locally following the steps at
https://github.com/google/sanitizers/wiki/SanitizerBotReproduceBuild.

Note: the fix is included as part of the second commit
This commit is contained in:
Denzel-Brian Budii 2025-07-01 16:03:17 +03:00 committed by GitHub
parent 771ee8e387
commit 3702d64801
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 473 additions and 19 deletions

View File

@ -108,6 +108,9 @@ public:
const llvm::ArrayRef<ParserValue> args,
Diagnostics *error) const = 0;
// If the matcher is variadic, it can take any number of arguments.
virtual bool isVariadic() const = 0;
// Returns the number of arguments accepted by the matcher.
virtual unsigned getNumArgs() const = 0;
@ -140,6 +143,8 @@ public:
return marshaller(matcherFunc, matcherName, nameRange, args, error);
}
bool isVariadic() const override { return false; }
unsigned getNumArgs() const override { return argKinds.size(); }
void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
@ -153,6 +158,54 @@ private:
const std::vector<ArgKind> argKinds;
};
class VariadicOperatorMatcherDescriptor : public MatcherDescriptor {
public:
using VarOp = DynMatcher::VariadicOperator;
VariadicOperatorMatcherDescriptor(unsigned minCount, unsigned maxCount,
VarOp varOp, StringRef matcherName)
: minCount(minCount), maxCount(maxCount), varOp(varOp),
matcherName(matcherName) {}
VariantMatcher create(SourceRange nameRange, ArrayRef<ParserValue> args,
Diagnostics *error) const override {
if (args.size() < minCount || maxCount < args.size()) {
addError(error, nameRange, ErrorType::RegistryWrongArgCount,
{llvm::Twine("requires between "), llvm::Twine(minCount),
llvm::Twine(" and "), llvm::Twine(maxCount),
llvm::Twine(" args, got "), llvm::Twine(args.size())});
return VariantMatcher();
}
std::vector<VariantMatcher> innerArgs;
for (int64_t i = 0, e = args.size(); i != e; ++i) {
const ParserValue &arg = args[i];
const VariantValue &value = arg.value;
if (!value.isMatcher()) {
addError(error, arg.range, ErrorType::RegistryWrongArgType,
{llvm::Twine(i + 1), llvm::Twine("matcher: "),
llvm::Twine(value.getTypeAsString())});
return VariantMatcher();
}
innerArgs.push_back(value.getMatcher());
}
return VariantMatcher::VariadicOperatorMatcher(varOp, std::move(innerArgs));
}
bool isVariadic() const override { return true; }
unsigned getNumArgs() const override { return 0; }
void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
kinds.push_back(ArgKind(ArgKind::Matcher));
}
private:
const unsigned minCount;
const unsigned maxCount;
const VarOp varOp;
const StringRef matcherName;
};
// Helper function to check if argument count matches expected count
inline bool checkArgCount(SourceRange nameRange, size_t expectedArgCount,
llvm::ArrayRef<ParserValue> args,
@ -224,6 +277,14 @@ makeMatcherAutoMarshall(ReturnType (*matcherFunc)(ArgTypes...),
reinterpret_cast<void (*)()>(matcherFunc), matcherName, argKinds);
}
// Variadic operator overload.
template <unsigned MinCount, unsigned MaxCount>
std::unique_ptr<MatcherDescriptor>
makeMatcherAutoMarshall(VariadicOperatorMatcherFunc<MinCount, MaxCount> func,
StringRef matcherName) {
return std::make_unique<VariadicOperatorMatcherDescriptor>(
MinCount, MaxCount, func.varOp, matcherName);
}
} // namespace mlir::query::matcher::internal
#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H

View File

@ -21,7 +21,9 @@
namespace mlir::query::matcher {
/// A class that provides utilities to find operations in the IR.
/// Finds and collects matches from the IR. After construction
/// `collectMatches` can be used to traverse the IR and apply
/// matchers.
class MatchFinder {
public:

View File

@ -8,11 +8,11 @@
//
// Implements the base layer of the matcher framework.
//
// Matchers are methods that return a Matcher which provides a method one of the
// following methods: match(Operation *op), match(Operation *op,
// SetVector<Operation *> &matchedOps)
// Matchers are methods that return a Matcher which provides a
// `match(...)` method whose parameters define the context of the match.
// Support includes simple (unary) matchers as well as matcher combinators
// (anyOf, allOf, etc.)
//
// The matcher functions are defined in include/mlir/IR/Matchers.h.
// This file contains the wrapper classes needed to construct matchers for
// mlir-query.
//
@ -25,6 +25,15 @@
#include "llvm/ADT/IntrusiveRefCntPtr.h"
namespace mlir::query::matcher {
class DynMatcher;
namespace internal {
bool allOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
ArrayRef<DynMatcher> innerMatchers);
bool anyOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
ArrayRef<DynMatcher> innerMatchers);
} // namespace internal
// Defaults to false if T has no match() method with the signature:
// match(Operation* op).
@ -84,6 +93,27 @@ private:
MatcherFn matcherFn;
};
// VariadicMatcher takes a vector of Matchers and returns true if any Matchers
// match the given operation.
using VariadicOperatorFunction = bool (*)(Operation *op,
SetVector<Operation *> *matchedOps,
ArrayRef<DynMatcher> innerMatchers);
template <VariadicOperatorFunction Func>
class VariadicMatcher : public MatcherInterface {
public:
VariadicMatcher(std::vector<DynMatcher> matchers)
: matchers(std::move(matchers)) {}
bool match(Operation *op) override { return Func(op, nullptr, matchers); }
bool match(Operation *op, SetVector<Operation *> &matchedOps) override {
return Func(op, &matchedOps, matchers);
}
private:
std::vector<DynMatcher> matchers;
};
// Matcher wraps a MatcherInterface implementation and provides match()
// methods that redirect calls to the underlying implementation.
class DynMatcher {
@ -92,6 +122,31 @@ public:
DynMatcher(MatcherInterface *implementation)
: implementation(implementation) {}
// Construct from a variadic function.
enum VariadicOperator {
// Matches operations for which all provided matchers match.
AllOf,
// Matches operations for which at least one of the provided matchers
// matches.
AnyOf
};
static std::unique_ptr<DynMatcher>
constructVariadic(VariadicOperator Op,
std::vector<DynMatcher> innerMatchers) {
switch (Op) {
case AllOf:
return std::make_unique<DynMatcher>(
new VariadicMatcher<internal::allOfVariadicOperator>(
std::move(innerMatchers)));
case AnyOf:
return std::make_unique<DynMatcher>(
new VariadicMatcher<internal::anyOfVariadicOperator>(
std::move(innerMatchers)));
}
llvm_unreachable("Invalid Op value.");
}
template <typename MatcherFn>
static std::unique_ptr<DynMatcher>
constructDynMatcherFromMatcherFn(MatcherFn &matcherFn) {
@ -113,6 +168,59 @@ private:
std::string functionName;
};
// VariadicOperatorMatcher related types.
template <typename... Ps>
class VariadicOperatorMatcher {
public:
VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp, Ps &&...params)
: varOp(varOp), params(std::forward<Ps>(params)...) {}
operator std::unique_ptr<DynMatcher>() const & {
return DynMatcher::constructVariadic(
varOp, getMatchers(std::index_sequence_for<Ps...>()));
}
operator std::unique_ptr<DynMatcher>() && {
return DynMatcher::constructVariadic(
varOp, std::move(*this).getMatchers(std::index_sequence_for<Ps...>()));
}
private:
// Helper method to unpack the tuple into a vector.
template <std::size_t... Is>
std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) const & {
return {DynMatcher(std::get<Is>(params))...};
}
template <std::size_t... Is>
std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) && {
return {DynMatcher(std::get<Is>(std::move(params)))...};
}
const DynMatcher::VariadicOperator varOp;
std::tuple<Ps...> params;
};
// Overloaded function object to generate VariadicOperatorMatcher objects from
// arbitrary matchers.
template <unsigned MinCount, unsigned MaxCount>
struct VariadicOperatorMatcherFunc {
DynMatcher::VariadicOperator varOp;
template <typename... Ms>
VariadicOperatorMatcher<Ms...> operator()(Ms &&...Ps) const {
static_assert(MinCount <= sizeof...(Ms) && sizeof...(Ms) <= MaxCount,
"invalid number of parameters for variadic matcher");
return VariadicOperatorMatcher<Ms...>(varOp, std::forward<Ms>(Ps)...);
}
};
namespace internal {
const VariadicOperatorMatcherFunc<1, std::numeric_limits<unsigned>::max()>
anyOf = {DynMatcher::AnyOf};
const VariadicOperatorMatcherFunc<1, std::numeric_limits<unsigned>::max()>
allOf = {DynMatcher::AllOf};
} // namespace internal
} // namespace mlir::query::matcher
#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H

View File

@ -6,7 +6,8 @@
//
//===----------------------------------------------------------------------===//
//
// This file provides matchers for MLIRQuery that peform slicing analysis
// This file defines slicing-analysis matchers that extend and abstract the
// core implementations from `SliceAnalysis.h`.
//
//===----------------------------------------------------------------------===//
@ -16,9 +17,9 @@
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/IR/Operation.h"
/// A matcher encapsulating `getBackwardSlice` method from SliceAnalysis.h.
/// Additionally, it limits the slice computation to a certain depth level using
/// a custom filter.
/// Computes the backward-slice of all transitive defs reachable from `rootOp`,
/// if `innerMatcher` matches. The traversal stops once the desired depth level
/// is reached.
///
/// Example: starting from node 9, assuming the matcher
/// computes the slice for the first two depth levels:
@ -119,6 +120,77 @@ bool BackwardSliceMatcher<Matcher>::matches(
: backwardSlice.size() >= 1;
}
/// Computes the backward-slice of all transitive defs reachable from `rootOp`,
/// if `innerMatcher` matches. Traversal stops where `filterMatcher` matches.
template <typename BaseMatcher, typename Filter>
class PredicateBackwardSliceMatcher {
public:
PredicateBackwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher,
bool inclusive, bool omitBlockArguments,
bool omitUsesFromAbove)
: innerMatcher(std::move(innerMatcher)),
filterMatcher(std::move(filterMatcher)), inclusive(inclusive),
omitBlockArguments(omitBlockArguments),
omitUsesFromAbove(omitUsesFromAbove) {}
bool match(Operation *rootOp, SetVector<Operation *> &backwardSlice) {
backwardSlice.clear();
BackwardSliceOptions options;
options.inclusive = inclusive;
options.omitUsesFromAbove = omitUsesFromAbove;
options.omitBlockArguments = omitBlockArguments;
if (innerMatcher.match(rootOp)) {
options.filter = [&](Operation *subOp) {
return !filterMatcher.match(subOp);
};
LogicalResult result = getBackwardSlice(rootOp, &backwardSlice, options);
assert(result.succeeded() && "expected backward slice to succeed");
(void)result;
return options.inclusive ? backwardSlice.size() > 1
: backwardSlice.size() >= 1;
}
return false;
}
private:
BaseMatcher innerMatcher;
Filter filterMatcher;
bool inclusive;
bool omitBlockArguments;
bool omitUsesFromAbove;
};
/// Computes the forward-slice of all users reachable from `rootOp`,
/// if `innerMatcher` matches. Traversal stops where `filterMatcher` matches.
template <typename BaseMatcher, typename Filter>
class PredicateForwardSliceMatcher {
public:
PredicateForwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher,
bool inclusive)
: innerMatcher(std::move(innerMatcher)),
filterMatcher(std::move(filterMatcher)), inclusive(inclusive) {}
bool match(Operation *rootOp, SetVector<Operation *> &forwardSlice) {
forwardSlice.clear();
ForwardSliceOptions options;
options.inclusive = inclusive;
if (innerMatcher.match(rootOp)) {
options.filter = [&](Operation *subOp) {
return !filterMatcher.match(subOp);
};
getForwardSlice(rootOp, &forwardSlice, options);
return options.inclusive ? forwardSlice.size() > 1
: forwardSlice.size() >= 1;
}
return false;
}
private:
BaseMatcher innerMatcher;
Filter filterMatcher;
bool inclusive;
};
/// Matches transitive defs of a top-level operation up to N levels.
template <typename Matcher>
inline BackwardSliceMatcher<Matcher>
@ -130,7 +202,7 @@ m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
omitUsesFromAbove);
}
/// Matches all transitive defs of a top-level operation up to N levels
/// Matches all transitive defs of a top-level operation up to N levels.
template <typename Matcher>
inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher,
int64_t maxDepth) {
@ -139,6 +211,28 @@ inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher,
false, false);
}
/// Matches all transitive defs of a top-level operation and stops where
/// `filterMatcher` rejects.
template <typename BaseMatcher, typename Filter>
inline PredicateBackwardSliceMatcher<BaseMatcher, Filter>
m_GetDefinitionsByPredicate(BaseMatcher innerMatcher, Filter filterMatcher,
bool inclusive, bool omitBlockArguments,
bool omitUsesFromAbove) {
return PredicateBackwardSliceMatcher<BaseMatcher, Filter>(
std::move(innerMatcher), std::move(filterMatcher), inclusive,
omitBlockArguments, omitUsesFromAbove);
}
/// Matches all users of a top-level operation and stops where
/// `filterMatcher` rejects.
template <typename BaseMatcher, typename Filter>
inline PredicateForwardSliceMatcher<BaseMatcher, Filter>
m_GetUsersByPredicate(BaseMatcher innerMatcher, Filter filterMatcher,
bool inclusive) {
return PredicateForwardSliceMatcher<BaseMatcher, Filter>(
std::move(innerMatcher), std::move(filterMatcher), inclusive);
}
} // namespace mlir::query::matcher
#endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_SLICEMATCHERS_H

View File

@ -26,7 +26,12 @@ enum class ArgKind { Boolean, Matcher, Signed, String };
// A variant matcher object to abstract simple and complex matchers into a
// single object type.
class VariantMatcher {
class MatcherOps;
class MatcherOps {
public:
std::optional<DynMatcher>
constructVariadicOperator(DynMatcher::VariadicOperator varOp,
ArrayRef<VariantMatcher> innerMatchers) const;
};
// Payload interface to be specialized by each matcher type. It follows a
// similar interface as VariantMatcher itself.
@ -43,6 +48,9 @@ public:
// Clones the provided matcher.
static VariantMatcher SingleMatcher(DynMatcher matcher);
static VariantMatcher
VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp,
ArrayRef<VariantMatcher> args);
// Makes the matcher the "null" matcher.
void reset();
@ -61,6 +69,7 @@ private:
: value(std::move(value)) {}
class SinglePayload;
class VariadicOpPayload;
std::shared_ptr<const Payload> value;
};

View File

@ -1,5 +1,6 @@
add_mlir_library(MLIRQueryMatcher
MatchFinder.cpp
MatchersInternal.cpp
Parser.cpp
RegistryManager.cpp
VariantValue.cpp

View File

@ -0,0 +1,33 @@
//===--- MatchersInternal.cpp----------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Query/Matcher/MatchersInternal.h"
#include "llvm/ADT/SetVector.h"
namespace mlir::query::matcher {
namespace internal {
bool allOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
ArrayRef<DynMatcher> innerMatchers) {
return llvm::all_of(innerMatchers, [&](const DynMatcher &matcher) {
if (matchedOps)
return matcher.match(op, *matchedOps);
return matcher.match(op);
});
}
bool anyOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
ArrayRef<DynMatcher> innerMatchers) {
return llvm::any_of(innerMatchers, [&](const DynMatcher &matcher) {
if (matchedOps)
return matcher.match(op, *matchedOps);
return matcher.match(op);
});
}
} // namespace internal
} // namespace mlir::query::matcher

View File

@ -64,7 +64,7 @@ std::vector<ArgKind> RegistryManager::getAcceptedCompletionTypes(
unsigned argNumber = ctxEntry.second;
std::vector<ArgKind> nextTypeSet;
if (argNumber < ctor->getNumArgs())
if (ctor->isVariadic() || argNumber < ctor->getNumArgs())
ctor->getArgKinds(argNumber, nextTypeSet);
typeSet.insert(nextTypeSet.begin(), nextTypeSet.end());
@ -83,7 +83,7 @@ RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes,
const internal::MatcherDescriptor &matcher = *m.getValue();
llvm::StringRef name = m.getKey();
unsigned numArgs = matcher.getNumArgs();
unsigned numArgs = matcher.isVariadic() ? 1 : matcher.getNumArgs();
std::vector<std::vector<ArgKind>> argKinds(numArgs);
for (const ArgKind &kind : acceptedTypes) {
@ -115,6 +115,9 @@ RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes,
}
}
if (matcher.isVariadic())
os << ",...";
os << ")";
typedText += "(";

View File

@ -27,12 +27,64 @@ private:
DynMatcher matcher;
};
class VariantMatcher::VariadicOpPayload : public VariantMatcher::Payload {
public:
VariadicOpPayload(DynMatcher::VariadicOperator varOp,
std::vector<VariantMatcher> args)
: varOp(varOp), args(std::move(args)) {}
std::optional<DynMatcher> getDynMatcher() const override {
std::vector<DynMatcher> dynMatchers;
for (auto variantMatcher : args) {
std::optional<DynMatcher> dynMatcher = variantMatcher.getDynMatcher();
if (dynMatcher)
dynMatchers.push_back(dynMatcher.value());
}
auto result = DynMatcher::constructVariadic(varOp, dynMatchers);
return *result;
}
std::string getTypeAsString() const override {
std::string inner;
llvm::interleave(
args, [&](auto const &arg) { inner += arg.getTypeAsString(); },
[&] { inner += " & "; });
return inner;
}
private:
const DynMatcher::VariadicOperator varOp;
const std::vector<VariantMatcher> args;
};
VariantMatcher::VariantMatcher() = default;
VariantMatcher VariantMatcher::SingleMatcher(DynMatcher matcher) {
return VariantMatcher(std::make_shared<SinglePayload>(std::move(matcher)));
}
VariantMatcher
VariantMatcher::VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp,
ArrayRef<VariantMatcher> args) {
return VariantMatcher(
std::make_shared<VariadicOpPayload>(varOp, std::move(args)));
}
std::optional<DynMatcher> VariantMatcher::MatcherOps::constructVariadicOperator(
DynMatcher::VariadicOperator varOp,
ArrayRef<VariantMatcher> innerMatchers) const {
std::vector<DynMatcher> dynMatchers;
for (const auto &innerMatcher : innerMatchers) {
if (!innerMatcher.value)
return std::nullopt;
std::optional<DynMatcher> inner = innerMatcher.value->getDynMatcher();
if (!inner)
return std::nullopt;
dynMatchers.push_back(*inner);
}
return *DynMatcher::constructVariadic(varOp, dynMatchers);
}
std::optional<DynMatcher> VariantMatcher::getDynMatcher() const {
return value ? value->getDynMatcher() : std::nullopt;
}
@ -120,11 +172,11 @@ void VariantValue::setSigned(int64_t newValue) {
// Boolean
bool VariantValue::isBoolean() const { return type == ValueType::Boolean; }
bool VariantValue::getBoolean() const { return value.Signed; }
bool VariantValue::getBoolean() const { return value.Boolean; }
void VariantValue::setBoolean(bool newValue) {
type = ValueType::Boolean;
value.Signed = newValue;
value.Boolean = newValue;
}
bool VariantValue::isString() const { return type == ValueType::String; }

View File

@ -10,6 +10,7 @@
#include "QueryParser.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Query/Matcher/MatchFinder.h"
#include "mlir/Query/QuerySession.h"
#include "llvm/ADT/SetVector.h"
@ -68,6 +69,8 @@ static Operation *extractFunction(std::vector<Operation *> &ops,
// Clone operations and build function body
std::vector<Operation *> clonedOps;
std::vector<Value> clonedVals;
// TODO: Handle extraction of operations with compute payloads defined via
// regions.
for (Operation *slicedOp : slice) {
Operation *clonedOp =
clonedOps.emplace_back(builder.clone(*slicedOp, mapper));
@ -129,6 +132,8 @@ LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
finder.flattenMatchedOps(matches);
Operation *function =
extractFunction(flattenedMatches, rootOp->getContext(), functionName);
if (failed(verify(function)))
return mlir::failure();
os << "\n" << *function << "\n\n";
function->erase();
return mlir::success();

View File

@ -1,4 +1,4 @@
// RUN: mlir-query %s -c "m getAllDefinitions(hasOpName(\"arith.addf\"),2)" | FileCheck %s
// RUN: mlir-query %s -c "m anyOf(getAllDefinitions(hasOpName(\"arith.addf\"),2),getAllDefinitions(hasOpName(\"tensor.extract\"),1))" | FileCheck %s
#map = affine_map<(d0, d1) -> (d0, d1)>
func.func @slice_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) {
@ -19,14 +19,23 @@ func.func @slice_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>)
}
// CHECK: Match #1:
// CHECK: %[[LINALG:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]}
// CHECK-SAME: ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>)
// CHECK: {{.*}}.mlir:7:10: note: "root" binds here
// CHECK: %[[ADDF1:.*]] = arith.addf %in, %in : f32
// CHECK: Match #2:
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[LINALG]] {{\[\[.*\]\]}} : tensor<5x5xf32> into tensor<25xf32>
// CHECK: %[[C2:.*]] = arith.constant {{.*}} : index
// CHECK: {{.*}}.mlir:14:18: note: "root" binds here
// CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[COLLAPSED]][%[[C2]]] : tensor<25xf32>
// CHECK: Match #3:
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[LINALG]] {{\[\[.*\]\]}} : tensor<5x5xf32> into tensor<25xf32>
// CHECK: %[[C2:.*]] = arith.constant {{.*}} : index
// CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[COLLAPSED]][%[[C2]]] : tensor<25xf32>
// CHECK: {{.*}}.mlir:15:10: note: "root" binds here
// CHECK: %[[ADDF2:.*]] = arith.addf %[[EXTRACTED]], %[[EXTRACTED]] : f32

View File

@ -0,0 +1,27 @@
// RUN: mlir-query %s -c "m getUsersByPredicate(anyOf(hasOpName(\"memref.alloc\"),isConstantOp()),anyOf(hasOpName(\"affine.load\"), hasOpName(\"memref.dealloc\")),true)" | FileCheck %s
func.func @slice_depth1_loop_nest_with_offsets() {
%0 = memref.alloc() : memref<100xf32>
%cst = arith.constant 7.000000e+00 : f32
affine.for %i0 = 0 to 16 {
%a0 = affine.apply affine_map<(d0) -> (d0 + 2)>(%i0)
affine.store %cst, %0[%a0] : memref<100xf32>
}
affine.for %i1 = 4 to 8 {
%a1 = affine.apply affine_map<(d0) -> (d0 - 1)>(%i1)
%1 = affine.load %0[%a1] : memref<100xf32>
}
return
}
// CHECK: Match #1:
// CHECK: {{.*}}.mlir:4:8: note: "root" binds here
// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<100xf32>
// CHECK: affine.store %cst, %0[%a0] : memref<100xf32>
// CHECK: Match #2:
// CHECK: {{.*}}.mlir:5:10: note: "root" binds here
// CHECK: %[[CST:.*]] = arith.constant 7.000000e+00 : f32
// CHECK: affine.store %[[CST]], %0[%a0] : memref<100xf32>

View File

@ -0,0 +1,11 @@
// RUN: mlir-query %s -c "m allOf(hasOpName(\"memref.alloca\"), hasOpAttrName(\"alignment\"))" | FileCheck %s
func.func @dynamic_alloca(%arg0: index, %arg1: index) -> memref<?x?xf32> {
%0 = memref.alloca(%arg0, %arg1) : memref<?x?xf32>
memref.alloca(%arg0, %arg1) {alignment = 32} : memref<?x?xf32>
return %0 : memref<?x?xf32>
}
// CHECK: Match #1:
// CHECK: {{.*}}.mlir:5:3: note: "root" binds here
// CHECK: memref.alloca(%arg0, %arg1) {alignment = 32} : memref<?x?xf32>

View File

@ -0,0 +1,29 @@
// RUN: mlir-query %s -c "m getDefinitionsByPredicate(hasOpName(\"memref.store\"),hasOpName(\"memref.alloc\"),true,false,false).extract(\"backward_slice\")" | FileCheck %s
// CHECK: func.func @backward_slice(%{{.*}}: memref<10xf32>) -> (f32, index, index, f32, index, index, f32) {
// CHECK: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[I0:.*]] = affine.apply affine_map<()[s0] -> (s0)>()[%[[C0]]]
// CHECK-NEXT: memref.store %[[CST0]], %{{.*}}[%[[I0]]] : memref<10xf32>
// CHECK-NEXT: %[[CST2:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: %[[I1:.*]] = affine.apply affine_map<() -> (0)>()
// CHECK-NEXT: memref.store %[[CST2]], %{{.*}}[%[[I1]]] : memref<10xf32>
// CHECK-NEXT: %[[C1:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[LOAD:.*]] = memref.load %{{.*}}[%[[C1]]] : memref<10xf32>
// CHECK-NEXT: memref.store %[[LOAD]], %{{.*}}[%[[C1]]] : memref<10xf32>
// CHECK-NEXT: return %[[CST0]], %[[C0]], %[[I0]], %[[CST2]], %[[I1]], %[[C1]], %[[LOAD]] : f32, index, index, f32, index, index, f32
func.func @slicing_memref_store_trivial() {
%0 = memref.alloc() : memref<10xf32>
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
affine.for %i1 = 0 to 10 {
%1 = affine.apply affine_map<()[s0] -> (s0)>()[%c0]
memref.store %cst, %0[%1] : memref<10xf32>
%2 = memref.load %0[%c0] : memref<10xf32>
%3 = affine.apply affine_map<()[] -> (0)>()[]
memref.store %cst, %0[%3] : memref<10xf32>
memref.store %2, %0[%c0] : memref<10xf32>
}
return
}

View File

@ -40,12 +40,22 @@ int main(int argc, char **argv) {
query::matcher::Registry matcherRegistry;
// Matchers registered in alphabetical order for consistency:
matcherRegistry.registerMatcher("allOf", query::matcher::internal::allOf);
matcherRegistry.registerMatcher("anyOf", query::matcher::internal::anyOf);
matcherRegistry.registerMatcher(
"getAllDefinitions",
query::matcher::m_GetAllDefinitions<query::matcher::DynMatcher>);
matcherRegistry.registerMatcher(
"getDefinitions",
query::matcher::m_GetDefinitions<query::matcher::DynMatcher>);
matcherRegistry.registerMatcher(
"getAllDefinitions",
query::matcher::m_GetAllDefinitions<query::matcher::DynMatcher>);
"getDefinitionsByPredicate",
query::matcher::m_GetDefinitionsByPredicate<query::matcher::DynMatcher,
query::matcher::DynMatcher>);
matcherRegistry.registerMatcher(
"getUsersByPredicate",
query::matcher::m_GetUsersByPredicate<query::matcher::DynMatcher,
query::matcher::DynMatcher>);
matcherRegistry.registerMatcher("hasOpAttrName",
static_cast<HasOpAttrName *>(m_Attr));
matcherRegistry.registerMatcher("hasOpName", static_cast<HasOpName *>(m_Op));