llvm-project/mlir/lib/Query/Matcher/MatchFinder.cpp
Denzel-Brian Budii 9b63bdd154
[mlir] Improve mlir-query tool by implementing getBackwardSlice and getForwardSlice matchers (#115670)
Improve mlir-query tool by implementing `getBackwardSlice` and
`getForwardSlice` matchers. As an addition `SetQuery` also needed to be
added to enable custom configuration for each query. e.g: `inclusive`,
`omitUsesFromAbove`, `omitBlockArguments`.

Note: backwardSlice and forwardSlice algoritms are the same as the ones
in `mlir/lib/Analysis/SliceAnalysis.cpp`
Example of current matcher. The query was made to the file:
`mlir/test/mlir-query/complex-test.mlir`

```mlir
./mlir-query /home/dbudii/personal/llvm-project/mlir/test/mlir-query/complex-test.mlir -c "match getDefinitions(hasOpName(\"arith.add
f\"),2)"

Match #1:

/home/dbudii/personal/llvm-project/mlir/test/mlir-query/complex-test.mlir:5:8:
  %0 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
       ^
/home/dbudii/personal/llvm-project/mlir/test/mlir-query/complex-test.mlir:7:10: note: "root" binds here
    %2 = arith.addf %in, %in : f32
         ^
Match #2:

/home/dbudii/personal/llvm-project/mlir/test/mlir-query/complex-test.mlir:10:16:
  %collapsed = tensor.collapse_shape %0 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32>
               ^
/home/dbudii/personal/llvm-project/mlir/test/mlir-query/complex-test.mlir:13:11:
    %c2 = arith.constant 2 : index
          ^
/home/dbudii/personal/llvm-project/mlir/test/mlir-query/complex-test.mlir:14:18:
    %extracted = tensor.extract %collapsed[%c2] : tensor<25xf32>
                 ^
/home/dbudii/personal/llvm-project/mlir/test/mlir-query/complex-test.mlir:15:10: note: "root" binds here
    %2 = arith.addf %extracted, %extracted : f32
         ^
2 matches.
```
2025-05-13 13:18:14 +02:00

69 lines
2.7 KiB
C++

//===- MatchFinder.cpp - --------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains the method definitions for the `MatchFinder` class
//
//===----------------------------------------------------------------------===//
#include "mlir/Query/Matcher/MatchFinder.h"
namespace mlir::query::matcher {
MatchFinder::MatchResult::MatchResult(Operation *rootOp,
std::vector<Operation *> matchedOps)
: rootOp(rootOp), matchedOps(std::move(matchedOps)) {}
std::vector<MatchFinder::MatchResult>
MatchFinder::collectMatches(Operation *root, DynMatcher matcher) const {
std::vector<MatchResult> results;
llvm::SetVector<Operation *> tempStorage;
root->walk([&](Operation *subOp) {
if (matcher.match(subOp)) {
MatchResult match;
match.rootOp = subOp;
match.matchedOps.push_back(subOp);
results.push_back(std::move(match));
} else if (matcher.match(subOp, tempStorage)) {
results.emplace_back(subOp, std::vector<Operation *>(tempStorage.begin(),
tempStorage.end()));
}
tempStorage.clear();
});
return results;
}
void MatchFinder::printMatch(llvm::raw_ostream &os, QuerySession &qs,
Operation *op) const {
auto fileLoc = cast<FileLineColLoc>(op->getLoc());
SMLoc smloc = qs.getSourceManager().FindLocForLineAndColumn(
qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
llvm::SMDiagnostic diag =
qs.getSourceManager().GetMessage(smloc, llvm::SourceMgr::DK_Note, "");
diag.print("", os, true, false, true);
}
void MatchFinder::printMatch(llvm::raw_ostream &os, QuerySession &qs,
Operation *op, const std::string &binding) const {
auto fileLoc = cast<FileLineColLoc>(op->getLoc());
auto smloc = qs.getSourceManager().FindLocForLineAndColumn(
qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn());
qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note,
"\"" + binding + "\" binds here");
}
std::vector<Operation *>
MatchFinder::flattenMatchedOps(std::vector<MatchResult> &matches) const {
std::vector<Operation *> newVector;
for (auto &result : matches) {
newVector.insert(newVector.end(), result.matchedOps.begin(),
result.matchedOps.end());
}
return newVector;
}
} // namespace mlir::query::matcher