diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td index 1893c101e735..40ecef33448d 100644 --- a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td @@ -25,7 +25,8 @@ def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class"> { let description = [{ This pass transforms `emitc.func` operations into `emitc.class` operations. Function arguments become fields of the class, and the function body is moved - to a new `execute` method within the class. + to a new member method within the class. By default, this is `operator()()`. + If the corresponding function argument has attributes (accessed via `argAttrs`), these attributes are attached to the field operation. Otherwise, the field is created without additional attributes. @@ -41,7 +42,7 @@ def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class"> { // becomes emitc.class @modelClass { emitc.field @input_tensor : !emitc.array<1xf32> {emitc.opaque = "input_tensor"} - emitc.func @execute() { + emitc.func @operator() { %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t %1 = get_field @input_tensor : !emitc.array<1xf32> %2 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue @@ -51,6 +52,12 @@ def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class"> { ``` }]; let dependentDialects = ["emitc::EmitCDialect"]; + let options = [ + Option<"funcName", "func-name", "std::string", + /*default=*/[{"operator()"}], + "The name of the newly generated member function with body " + "matching the input function."> + ]; } #endif // MLIR_DIALECT_EMITC_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h index bdf6d0985e6d..962bdb3c032b 100644 --- a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h @@ -28,8 +28,11 @@ ExpressionOp createExpression(Operation *op, OpBuilder &builder); /// Populates `patterns` with expression-related patterns. void populateExpressionPatterns(RewritePatternSet &patterns); -/// Populates 'patterns' with func-related patterns. -void populateFuncPatterns(RewritePatternSet &patterns); +//===----------------------------------------------------------------------===// +// The WrapFuncInClass pass. +//===----------------------------------------------------------------------===// + +void populateWrapFuncInClass(RewritePatternSet &patterns, StringRef fName); } // namespace emitc } // namespace mlir diff --git a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp index 06d7e07005f8..fc8acd616ba7 100644 --- a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp +++ b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp @@ -31,7 +31,7 @@ struct WrapFuncInClassPass Operation *rootOp = getOperation(); RewritePatternSet patterns(&getContext()); - populateFuncPatterns(patterns); + populateWrapFuncInClass(patterns, funcName); walkAndApplyPatterns(rootOp, std::move(patterns)); } @@ -43,8 +43,8 @@ struct WrapFuncInClassPass class WrapFuncInClass : public OpRewritePattern { public: - WrapFuncInClass(MLIRContext *context) - : OpRewritePattern(context) {} + WrapFuncInClass(MLIRContext *context, StringRef funcName) + : OpRewritePattern(context), funcName(funcName) {} LogicalResult matchAndRewrite(emitc::FuncOp funcOp, PatternRewriter &rewriter) const override { @@ -76,7 +76,7 @@ public: FunctionType funcType = funcOp.getFunctionType(); Location loc = funcOp.getLoc(); FuncOp newFuncOp = - emitc::FuncOp::create(rewriter, loc, ("execute"), funcType); + emitc::FuncOp::create(rewriter, loc, (funcName), funcType); rewriter.createBlock(&newFuncOp.getBody()); newFuncOp.getBody().takeBody(funcOp.getBody()); @@ -102,8 +102,14 @@ public: rewriter.replaceOp(funcOp, newClassOp); return success(); } + +private: + /// Name of the newly generated member function with body matching the input + /// function. + std::string funcName; }; -void mlir::emitc::populateFuncPatterns(RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); +void mlir::emitc::populateWrapFuncInClass(RewritePatternSet &patterns, + StringRef funcName) { + patterns.add(patterns.getContext(), funcName); } diff --git a/mlir/test/Dialect/EmitC/wrap-func-in-class.mlir b/mlir/test/Dialect/EmitC/wrap-func-in-class.mlir index 809febd0267b..cb5f99d31e9d 100644 --- a/mlir/test/Dialect/EmitC/wrap-func-in-class.mlir +++ b/mlir/test/Dialect/EmitC/wrap-func-in-class.mlir @@ -1,20 +1,22 @@ // RUN: mlir-opt %s -wrap-emitc-func-in-class -split-input-file | FileCheck %s +// RUN: mlir-opt %s -wrap-emitc-func-in-class=func-name=execute -split-input-file | FileCheck %s --check-prefixes=EXECUTE emitc.func @foo(%arg0 : !emitc.array<1xf32>) { emitc.call_opaque "bar" (%arg0) : (!emitc.array<1xf32>) -> () emitc.return } -// CHECK: module { // CHECK: emitc.class @fooClass { // CHECK: emitc.field @fieldName0 : !emitc.array<1xf32> -// CHECK: emitc.func @execute() { +// CHECK: emitc.func @"operator()"() { // CHECK: %0 = get_field @fieldName0 : !emitc.array<1xf32> // CHECK: call_opaque "bar"(%0) : (!emitc.array<1xf32>) -> () // CHECK: return // CHECK: } // CHECK: } -// CHECK: } + +// EXECUTE-NOT: operator +// EXECUTE: execute() // ----- @@ -34,12 +36,11 @@ module attributes { } { } } -// CHECK: module { // CHECK: emitc.class @modelClass { // CHECK: emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.name_hint = "another_feature"} // CHECK: emitc.field @fieldName1 : !emitc.array<1xf32> {emitc.name_hint = "some_feature"} // CHECK: emitc.field @fieldName2 : !emitc.array<1xf32> {emitc.name_hint = "output_0"} -// CHECK: emitc.func @execute() { +// CHECK: emitc.func @"operator()"() { // CHECK: get_field @fieldName0 : !emitc.array<1xf32> // CHECK: get_field @fieldName1 : !emitc.array<1xf32> // CHECK: get_field @fieldName2 : !emitc.array<1xf32> @@ -54,4 +55,6 @@ module attributes { } { // CHECK: return // CHECK: } // CHECK: } -// CHECK: } + +// EXECUTE-NOT: operator +// EXECUTE: execute()