[mlir][CAPI] Allow running pass manager on any operation

`mlirPassManagerRun` is currently restricted to running on
`builtin.module` ops, but this restriction doesn't exist on the C++
side. This renames it to `mlirPassManagerRunOnOp` and updates it to take
`MlirOperation` instead of `MlirModule`.

Depends on D143352

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D143354
This commit is contained in:
rkayaith 2022-11-08 22:39:18 -05:00 committed by Rahul Kayaith
parent 37107e177e
commit 6f5590ca34
5 changed files with 58 additions and 56 deletions

View File

@ -70,9 +70,9 @@ static inline bool mlirPassManagerIsNull(MlirPassManager passManager) {
MLIR_CAPI_EXPORTED MlirOpPassManager
mlirPassManagerGetAsOpPassManager(MlirPassManager passManager);
/// Run the provided `passManager` on the given `module`.
/// Run the provided `passManager` on the given `op`.
MLIR_CAPI_EXPORTED MlirLogicalResult
mlirPassManagerRun(MlirPassManager passManager, MlirModule module);
mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op);
/// Enable mlir-print-ir-after-all.
MLIR_CAPI_EXPORTED void

View File

@ -117,8 +117,8 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
.def(
"run",
[](PyPassManager &passManager, PyModule &module) {
MlirLogicalResult status =
mlirPassManagerRun(passManager.get(), module.get());
MlirLogicalResult status = mlirPassManagerRunOnOp(
passManager.get(), mlirModuleGetOperation(module.get()));
if (mlirLogicalResultIsFailure(status))
throw SetPyError(PyExc_RuntimeError,
"Failure while executing pass pipeline.");

View File

@ -39,9 +39,9 @@ mlirPassManagerGetAsOpPassManager(MlirPassManager passManager) {
return wrap(static_cast<OpPassManager *>(unwrap(passManager)));
}
MlirLogicalResult mlirPassManagerRun(MlirPassManager passManager,
MlirModule module) {
return wrap(unwrap(passManager)->run(unwrap(module)));
MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager,
MlirOperation op) {
return wrap(unwrap(passManager)->run(unwrap(op)));
}
void mlirPassManagerEnableIRPrinting(MlirPassManager passManager) {

View File

@ -37,7 +37,8 @@ void lowerModuleToLLVM(MlirContext ctx, MlirModule module) {
mlirPassManagerAddOwnedPass(pm, mlirCreateConversionConvertFuncToLLVMPass());
mlirOpPassManagerAddOwnedPass(
opm, mlirCreateConversionArithToLLVMConversionPass());
MlirLogicalResult status = mlirPassManagerRun(pm, module);
MlirLogicalResult status =
mlirPassManagerRunOnOp(pm, mlirModuleGetOperation(module));
if (mlirLogicalResultIsFailure(status)) {
fprintf(stderr, "Unexpected failure running pass pipeline\n");
exit(2);

View File

@ -33,17 +33,16 @@ void testRunPassOnModule(void) {
MlirContext ctx = mlirContextCreate();
registerAllUpstreamDialects(ctx);
MlirModule module = mlirModuleCreateParse(
ctx,
// clang-format off
mlirStringRefCreateFromCString(
const char *funcAsm = //
"func.func @foo(%arg0 : i32) -> i32 { \n"
" %res = arith.addi %arg0, %arg0 : i32 \n"
" return %res : i32 \n"
"}"));
// clang-format on
if (mlirModuleIsNull(module)) {
fprintf(stderr, "Unexpected failure parsing module.\n");
"} \n";
MlirOperation func =
mlirOperationCreateParse(ctx, mlirStringRefCreateFromCString(funcAsm),
mlirStringRefCreateFromCString("funcAsm"));
if (mlirOperationIsNull(func)) {
fprintf(stderr, "Unexpected failure parsing asm.\n");
exit(EXIT_FAILURE);
}
@ -56,14 +55,14 @@ void testRunPassOnModule(void) {
MlirPassManager pm = mlirPassManagerCreate(ctx);
MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
mlirPassManagerAddOwnedPass(pm, printOpStatPass);
MlirLogicalResult success = mlirPassManagerRun(pm, module);
MlirLogicalResult success = mlirPassManagerRunOnOp(pm, func);
if (mlirLogicalResultIsFailure(success)) {
fprintf(stderr, "Unexpected failure running pass manager.\n");
exit(EXIT_FAILURE);
}
mlirPassManagerDestroy(pm);
}
mlirModuleDestroy(module);
mlirOperationDestroy(func);
mlirContextDestroy(ctx);
}
@ -71,10 +70,8 @@ void testRunPassOnNestedModule(void) {
MlirContext ctx = mlirContextCreate();
registerAllUpstreamDialects(ctx);
MlirModule module = mlirModuleCreateParse(
ctx,
// clang-format off
mlirStringRefCreateFromCString(
const char *moduleAsm = //
"module { \n"
" func.func @foo(%arg0 : i32) -> i32 { \n"
" %res = arith.addi %arg0, %arg0 : i32 \n"
" return %res : i32 \n"
@ -84,9 +81,12 @@ void testRunPassOnNestedModule(void) {
" %res = arith.addf %arg0, %arg0 : f32 \n"
" return %res : f32 \n"
" } \n"
"}"));
// clang-format on
if (mlirModuleIsNull(module))
" } \n"
"} \n";
MlirOperation module =
mlirOperationCreateParse(ctx, mlirStringRefCreateFromCString(moduleAsm),
mlirStringRefCreateFromCString("moduleAsm"));
if (mlirOperationIsNull(module))
exit(1);
// Run the print-op-stats pass on functions under the top-level module:
@ -100,7 +100,7 @@ void testRunPassOnNestedModule(void) {
pm, mlirStringRefCreateFromCString("func.func"));
MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
MlirLogicalResult success = mlirPassManagerRun(pm, module);
MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
if (mlirLogicalResultIsFailure(success))
exit(2);
mlirPassManagerDestroy(pm);
@ -118,13 +118,13 @@ void testRunPassOnNestedModule(void) {
nestedModulePm, mlirStringRefCreateFromCString("func.func"));
MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
MlirLogicalResult success = mlirPassManagerRun(pm, module);
MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
if (mlirLogicalResultIsFailure(success))
exit(2);
mlirPassManagerDestroy(pm);
}
mlirModuleDestroy(module);
mlirOperationDestroy(module);
mlirContextDestroy(ctx);
}
@ -339,16 +339,17 @@ void testExternalPass(void) {
MlirContext ctx = mlirContextCreate();
registerAllUpstreamDialects(ctx);
MlirModule module = mlirModuleCreateParse(
ctx,
// clang-format off
mlirStringRefCreateFromCString(
const char *moduleAsm = //
"module { \n"
" func.func @foo(%arg0 : i32) -> i32 { \n"
" %res = arith.addi %arg0, %arg0 : i32 \n"
" return %res : i32 \n"
"}"));
// clang-format on
if (mlirModuleIsNull(module)) {
" } \n"
"}";
MlirOperation module =
mlirOperationCreateParse(ctx, mlirStringRefCreateFromCString(moduleAsm),
mlirStringRefCreateFromCString("moduleAsm"));
if (mlirOperationIsNull(module)) {
fprintf(stderr, "Unexpected failure parsing module.\n");
exit(EXIT_FAILURE);
}
@ -377,7 +378,7 @@ void testExternalPass(void) {
MlirPassManager pm = mlirPassManagerCreate(ctx);
mlirPassManagerAddOwnedPass(pm, externalPass);
MlirLogicalResult success = mlirPassManagerRun(pm, module);
MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
if (mlirLogicalResultIsFailure(success)) {
fprintf(stderr, "Unexpected failure running external pass.\n");
exit(EXIT_FAILURE);
@ -421,7 +422,7 @@ void testExternalPass(void) {
MlirOpPassManager nestedFuncPm =
mlirPassManagerGetNestedUnder(pm, funcOpName);
mlirOpPassManagerAddOwnedPass(nestedFuncPm, externalPass);
MlirLogicalResult success = mlirPassManagerRun(pm, module);
MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
if (mlirLogicalResultIsFailure(success)) {
fprintf(stderr, "Unexpected failure running external operation pass.\n");
exit(EXIT_FAILURE);
@ -469,7 +470,7 @@ void testExternalPass(void) {
MlirPassManager pm = mlirPassManagerCreate(ctx);
mlirPassManagerAddOwnedPass(pm, externalPass);
MlirLogicalResult success = mlirPassManagerRun(pm, module);
MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
if (mlirLogicalResultIsFailure(success)) {
fprintf(stderr, "Unexpected failure running external pass.\n");
exit(EXIT_FAILURE);
@ -516,7 +517,7 @@ void testExternalPass(void) {
MlirPassManager pm = mlirPassManagerCreate(ctx);
mlirPassManagerAddOwnedPass(pm, externalPass);
MlirLogicalResult success = mlirPassManagerRun(pm, module);
MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
if (mlirLogicalResultIsSuccess(success)) {
fprintf(
stderr,
@ -564,7 +565,7 @@ void testExternalPass(void) {
MlirPassManager pm = mlirPassManagerCreate(ctx);
mlirPassManagerAddOwnedPass(pm, externalPass);
MlirLogicalResult success = mlirPassManagerRun(pm, module);
MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
if (mlirLogicalResultIsSuccess(success)) {
fprintf(
stderr,
@ -587,7 +588,7 @@ void testExternalPass(void) {
}
mlirTypeIDAllocatorDestroy(typeIDAllocator);
mlirModuleDestroy(module);
mlirOperationDestroy(module);
mlirContextDestroy(ctx);
}