[mlir][emitc] Update the WrapFuncInClassPass pass (#179184)
Update the `WrapFuncInClassPass` pass so that, by default, the generated
method is named `operator()()` rather than `execute()`. This makes the
pass more generic, instead of catering to specific users expecting an
`execute()` method.
To preserve the original behaviour, add a new pass option to override
the method name: `func-name`. For example:
```bash
mlir-opt file.mlir -wrap-emitc-func-in-class=func-name=execute
```
Additionally, make a couple of small editorial changes:
* Rename `populateFuncPatterns` to `populateWrapFuncInClass` to make it
clear that the corresponding pattern is specific to the
`WrapFuncInClass` pass.
* Remove `// CHECK: module {` to reduce test noise.
For context, this change was proposed on Discourse:
* https://discourse.llvm.org/t/rfc-emitc-support-for-mlgo
This commit is contained in:
parent
36dadddd74
commit
10d80708d9
@ -25,7 +25,8 @@ def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class"> {
|
||||
let description = [{
|
||||
This pass transforms `emitc.func` operations into `emitc.class` operations.
|
||||
Function arguments become fields of the class, and the function body is moved
|
||||
to a new `execute` method within the class.
|
||||
to a new member method within the class. By default, this is `operator()()`.
|
||||
|
||||
If the corresponding function argument has attributes (accessed via `argAttrs`),
|
||||
these attributes are attached to the field operation.
|
||||
Otherwise, the field is created without additional attributes.
|
||||
@ -41,7 +42,7 @@ def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class"> {
|
||||
// becomes
|
||||
emitc.class @modelClass {
|
||||
emitc.field @input_tensor : !emitc.array<1xf32> {emitc.opaque = "input_tensor"}
|
||||
emitc.func @execute() {
|
||||
emitc.func @operator() {
|
||||
%0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
|
||||
%1 = get_field @input_tensor : !emitc.array<1xf32>
|
||||
%2 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
|
||||
@ -51,6 +52,12 @@ def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class"> {
|
||||
```
|
||||
}];
|
||||
let dependentDialects = ["emitc::EmitCDialect"];
|
||||
let options = [
|
||||
Option<"funcName", "func-name", "std::string",
|
||||
/*default=*/[{"operator()"}],
|
||||
"The name of the newly generated member function with body "
|
||||
"matching the input function.">
|
||||
];
|
||||
}
|
||||
|
||||
#endif // MLIR_DIALECT_EMITC_TRANSFORMS_PASSES
|
||||
|
||||
@ -28,8 +28,11 @@ ExpressionOp createExpression(Operation *op, OpBuilder &builder);
|
||||
/// Populates `patterns` with expression-related patterns.
|
||||
void populateExpressionPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// Populates 'patterns' with func-related patterns.
|
||||
void populateFuncPatterns(RewritePatternSet &patterns);
|
||||
//===----------------------------------------------------------------------===//
|
||||
// The WrapFuncInClass pass.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void populateWrapFuncInClass(RewritePatternSet &patterns, StringRef fName);
|
||||
|
||||
} // namespace emitc
|
||||
} // namespace mlir
|
||||
|
||||
@ -31,7 +31,7 @@ struct WrapFuncInClassPass
|
||||
Operation *rootOp = getOperation();
|
||||
|
||||
RewritePatternSet patterns(&getContext());
|
||||
populateFuncPatterns(patterns);
|
||||
populateWrapFuncInClass(patterns, funcName);
|
||||
|
||||
walkAndApplyPatterns(rootOp, std::move(patterns));
|
||||
}
|
||||
@ -43,8 +43,8 @@ struct WrapFuncInClassPass
|
||||
|
||||
class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
|
||||
public:
|
||||
WrapFuncInClass(MLIRContext *context)
|
||||
: OpRewritePattern<emitc::FuncOp>(context) {}
|
||||
WrapFuncInClass(MLIRContext *context, StringRef funcName)
|
||||
: OpRewritePattern<emitc::FuncOp>(context), funcName(funcName) {}
|
||||
|
||||
LogicalResult matchAndRewrite(emitc::FuncOp funcOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
@ -76,7 +76,7 @@ public:
|
||||
FunctionType funcType = funcOp.getFunctionType();
|
||||
Location loc = funcOp.getLoc();
|
||||
FuncOp newFuncOp =
|
||||
emitc::FuncOp::create(rewriter, loc, ("execute"), funcType);
|
||||
emitc::FuncOp::create(rewriter, loc, (funcName), funcType);
|
||||
|
||||
rewriter.createBlock(&newFuncOp.getBody());
|
||||
newFuncOp.getBody().takeBody(funcOp.getBody());
|
||||
@ -102,8 +102,14 @@ public:
|
||||
rewriter.replaceOp(funcOp, newClassOp);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
/// Name of the newly generated member function with body matching the input
|
||||
/// function.
|
||||
std::string funcName;
|
||||
};
|
||||
|
||||
void mlir::emitc::populateFuncPatterns(RewritePatternSet &patterns) {
|
||||
patterns.add<WrapFuncInClass>(patterns.getContext());
|
||||
void mlir::emitc::populateWrapFuncInClass(RewritePatternSet &patterns,
|
||||
StringRef funcName) {
|
||||
patterns.add<WrapFuncInClass>(patterns.getContext(), funcName);
|
||||
}
|
||||
|
||||
@ -1,20 +1,22 @@
|
||||
// RUN: mlir-opt %s -wrap-emitc-func-in-class -split-input-file | FileCheck %s
|
||||
// RUN: mlir-opt %s -wrap-emitc-func-in-class=func-name=execute -split-input-file | FileCheck %s --check-prefixes=EXECUTE
|
||||
|
||||
emitc.func @foo(%arg0 : !emitc.array<1xf32>) {
|
||||
emitc.call_opaque "bar" (%arg0) : (!emitc.array<1xf32>) -> ()
|
||||
emitc.return
|
||||
}
|
||||
|
||||
// CHECK: module {
|
||||
// CHECK: emitc.class @fooClass {
|
||||
// CHECK: emitc.field @fieldName0 : !emitc.array<1xf32>
|
||||
// CHECK: emitc.func @execute() {
|
||||
// CHECK: emitc.func @"operator()"() {
|
||||
// CHECK: %0 = get_field @fieldName0 : !emitc.array<1xf32>
|
||||
// CHECK: call_opaque "bar"(%0) : (!emitc.array<1xf32>) -> ()
|
||||
// CHECK: return
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
|
||||
// EXECUTE-NOT: operator
|
||||
// EXECUTE: execute()
|
||||
|
||||
// -----
|
||||
|
||||
@ -34,12 +36,11 @@ module attributes { } {
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK: module {
|
||||
// CHECK: emitc.class @modelClass {
|
||||
// CHECK: emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.name_hint = "another_feature"}
|
||||
// CHECK: emitc.field @fieldName1 : !emitc.array<1xf32> {emitc.name_hint = "some_feature"}
|
||||
// CHECK: emitc.field @fieldName2 : !emitc.array<1xf32> {emitc.name_hint = "output_0"}
|
||||
// CHECK: emitc.func @execute() {
|
||||
// CHECK: emitc.func @"operator()"() {
|
||||
// CHECK: get_field @fieldName0 : !emitc.array<1xf32>
|
||||
// CHECK: get_field @fieldName1 : !emitc.array<1xf32>
|
||||
// CHECK: get_field @fieldName2 : !emitc.array<1xf32>
|
||||
@ -54,4 +55,6 @@ module attributes { } {
|
||||
// CHECK: return
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
|
||||
// EXECUTE-NOT: operator
|
||||
// EXECUTE: execute()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user