[mlir][CAPI] Allow running pass manager on any operation
`mlirPassManagerRun` is currently restricted to running on `builtin.module` ops, but this restriction doesn't exist on the C++ side. This renames it to `mlirPassManagerRunOnOp` and updates it to take `MlirOperation` instead of `MlirModule`. Depends on D143352 Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D143354
This commit is contained in:
parent
37107e177e
commit
6f5590ca34
@ -70,9 +70,9 @@ static inline bool mlirPassManagerIsNull(MlirPassManager passManager) {
|
||||
MLIR_CAPI_EXPORTED MlirOpPassManager
|
||||
mlirPassManagerGetAsOpPassManager(MlirPassManager passManager);
|
||||
|
||||
/// Run the provided `passManager` on the given `module`.
|
||||
/// Run the provided `passManager` on the given `op`.
|
||||
MLIR_CAPI_EXPORTED MlirLogicalResult
|
||||
mlirPassManagerRun(MlirPassManager passManager, MlirModule module);
|
||||
mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op);
|
||||
|
||||
/// Enable mlir-print-ir-after-all.
|
||||
MLIR_CAPI_EXPORTED void
|
||||
|
@ -117,8 +117,8 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
|
||||
.def(
|
||||
"run",
|
||||
[](PyPassManager &passManager, PyModule &module) {
|
||||
MlirLogicalResult status =
|
||||
mlirPassManagerRun(passManager.get(), module.get());
|
||||
MlirLogicalResult status = mlirPassManagerRunOnOp(
|
||||
passManager.get(), mlirModuleGetOperation(module.get()));
|
||||
if (mlirLogicalResultIsFailure(status))
|
||||
throw SetPyError(PyExc_RuntimeError,
|
||||
"Failure while executing pass pipeline.");
|
||||
|
@ -39,9 +39,9 @@ mlirPassManagerGetAsOpPassManager(MlirPassManager passManager) {
|
||||
return wrap(static_cast<OpPassManager *>(unwrap(passManager)));
|
||||
}
|
||||
|
||||
MlirLogicalResult mlirPassManagerRun(MlirPassManager passManager,
|
||||
MlirModule module) {
|
||||
return wrap(unwrap(passManager)->run(unwrap(module)));
|
||||
MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager,
|
||||
MlirOperation op) {
|
||||
return wrap(unwrap(passManager)->run(unwrap(op)));
|
||||
}
|
||||
|
||||
void mlirPassManagerEnableIRPrinting(MlirPassManager passManager) {
|
||||
|
@ -37,7 +37,8 @@ void lowerModuleToLLVM(MlirContext ctx, MlirModule module) {
|
||||
mlirPassManagerAddOwnedPass(pm, mlirCreateConversionConvertFuncToLLVMPass());
|
||||
mlirOpPassManagerAddOwnedPass(
|
||||
opm, mlirCreateConversionArithToLLVMConversionPass());
|
||||
MlirLogicalResult status = mlirPassManagerRun(pm, module);
|
||||
MlirLogicalResult status =
|
||||
mlirPassManagerRunOnOp(pm, mlirModuleGetOperation(module));
|
||||
if (mlirLogicalResultIsFailure(status)) {
|
||||
fprintf(stderr, "Unexpected failure running pass pipeline\n");
|
||||
exit(2);
|
||||
|
@ -33,17 +33,16 @@ void testRunPassOnModule(void) {
|
||||
MlirContext ctx = mlirContextCreate();
|
||||
registerAllUpstreamDialects(ctx);
|
||||
|
||||
MlirModule module = mlirModuleCreateParse(
|
||||
ctx,
|
||||
// clang-format off
|
||||
mlirStringRefCreateFromCString(
|
||||
const char *funcAsm = //
|
||||
"func.func @foo(%arg0 : i32) -> i32 { \n"
|
||||
" %res = arith.addi %arg0, %arg0 : i32 \n"
|
||||
" return %res : i32 \n"
|
||||
"}"));
|
||||
// clang-format on
|
||||
if (mlirModuleIsNull(module)) {
|
||||
fprintf(stderr, "Unexpected failure parsing module.\n");
|
||||
"} \n";
|
||||
MlirOperation func =
|
||||
mlirOperationCreateParse(ctx, mlirStringRefCreateFromCString(funcAsm),
|
||||
mlirStringRefCreateFromCString("funcAsm"));
|
||||
if (mlirOperationIsNull(func)) {
|
||||
fprintf(stderr, "Unexpected failure parsing asm.\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
@ -56,14 +55,14 @@ void testRunPassOnModule(void) {
|
||||
MlirPassManager pm = mlirPassManagerCreate(ctx);
|
||||
MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
|
||||
mlirPassManagerAddOwnedPass(pm, printOpStatPass);
|
||||
MlirLogicalResult success = mlirPassManagerRun(pm, module);
|
||||
MlirLogicalResult success = mlirPassManagerRunOnOp(pm, func);
|
||||
if (mlirLogicalResultIsFailure(success)) {
|
||||
fprintf(stderr, "Unexpected failure running pass manager.\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
mlirPassManagerDestroy(pm);
|
||||
}
|
||||
mlirModuleDestroy(module);
|
||||
mlirOperationDestroy(func);
|
||||
mlirContextDestroy(ctx);
|
||||
}
|
||||
|
||||
@ -71,10 +70,8 @@ void testRunPassOnNestedModule(void) {
|
||||
MlirContext ctx = mlirContextCreate();
|
||||
registerAllUpstreamDialects(ctx);
|
||||
|
||||
MlirModule module = mlirModuleCreateParse(
|
||||
ctx,
|
||||
// clang-format off
|
||||
mlirStringRefCreateFromCString(
|
||||
const char *moduleAsm = //
|
||||
"module { \n"
|
||||
" func.func @foo(%arg0 : i32) -> i32 { \n"
|
||||
" %res = arith.addi %arg0, %arg0 : i32 \n"
|
||||
" return %res : i32 \n"
|
||||
@ -84,9 +81,12 @@ void testRunPassOnNestedModule(void) {
|
||||
" %res = arith.addf %arg0, %arg0 : f32 \n"
|
||||
" return %res : f32 \n"
|
||||
" } \n"
|
||||
"}"));
|
||||
// clang-format on
|
||||
if (mlirModuleIsNull(module))
|
||||
" } \n"
|
||||
"} \n";
|
||||
MlirOperation module =
|
||||
mlirOperationCreateParse(ctx, mlirStringRefCreateFromCString(moduleAsm),
|
||||
mlirStringRefCreateFromCString("moduleAsm"));
|
||||
if (mlirOperationIsNull(module))
|
||||
exit(1);
|
||||
|
||||
// Run the print-op-stats pass on functions under the top-level module:
|
||||
@ -100,7 +100,7 @@ void testRunPassOnNestedModule(void) {
|
||||
pm, mlirStringRefCreateFromCString("func.func"));
|
||||
MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
|
||||
mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
|
||||
MlirLogicalResult success = mlirPassManagerRun(pm, module);
|
||||
MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
|
||||
if (mlirLogicalResultIsFailure(success))
|
||||
exit(2);
|
||||
mlirPassManagerDestroy(pm);
|
||||
@ -118,13 +118,13 @@ void testRunPassOnNestedModule(void) {
|
||||
nestedModulePm, mlirStringRefCreateFromCString("func.func"));
|
||||
MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
|
||||
mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
|
||||
MlirLogicalResult success = mlirPassManagerRun(pm, module);
|
||||
MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
|
||||
if (mlirLogicalResultIsFailure(success))
|
||||
exit(2);
|
||||
mlirPassManagerDestroy(pm);
|
||||
}
|
||||
|
||||
mlirModuleDestroy(module);
|
||||
mlirOperationDestroy(module);
|
||||
mlirContextDestroy(ctx);
|
||||
}
|
||||
|
||||
@ -339,16 +339,17 @@ void testExternalPass(void) {
|
||||
MlirContext ctx = mlirContextCreate();
|
||||
registerAllUpstreamDialects(ctx);
|
||||
|
||||
MlirModule module = mlirModuleCreateParse(
|
||||
ctx,
|
||||
// clang-format off
|
||||
mlirStringRefCreateFromCString(
|
||||
const char *moduleAsm = //
|
||||
"module { \n"
|
||||
" func.func @foo(%arg0 : i32) -> i32 { \n"
|
||||
" %res = arith.addi %arg0, %arg0 : i32 \n"
|
||||
" return %res : i32 \n"
|
||||
"}"));
|
||||
// clang-format on
|
||||
if (mlirModuleIsNull(module)) {
|
||||
" } \n"
|
||||
"}";
|
||||
MlirOperation module =
|
||||
mlirOperationCreateParse(ctx, mlirStringRefCreateFromCString(moduleAsm),
|
||||
mlirStringRefCreateFromCString("moduleAsm"));
|
||||
if (mlirOperationIsNull(module)) {
|
||||
fprintf(stderr, "Unexpected failure parsing module.\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
@ -377,7 +378,7 @@ void testExternalPass(void) {
|
||||
|
||||
MlirPassManager pm = mlirPassManagerCreate(ctx);
|
||||
mlirPassManagerAddOwnedPass(pm, externalPass);
|
||||
MlirLogicalResult success = mlirPassManagerRun(pm, module);
|
||||
MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
|
||||
if (mlirLogicalResultIsFailure(success)) {
|
||||
fprintf(stderr, "Unexpected failure running external pass.\n");
|
||||
exit(EXIT_FAILURE);
|
||||
@ -421,7 +422,7 @@ void testExternalPass(void) {
|
||||
MlirOpPassManager nestedFuncPm =
|
||||
mlirPassManagerGetNestedUnder(pm, funcOpName);
|
||||
mlirOpPassManagerAddOwnedPass(nestedFuncPm, externalPass);
|
||||
MlirLogicalResult success = mlirPassManagerRun(pm, module);
|
||||
MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
|
||||
if (mlirLogicalResultIsFailure(success)) {
|
||||
fprintf(stderr, "Unexpected failure running external operation pass.\n");
|
||||
exit(EXIT_FAILURE);
|
||||
@ -469,7 +470,7 @@ void testExternalPass(void) {
|
||||
|
||||
MlirPassManager pm = mlirPassManagerCreate(ctx);
|
||||
mlirPassManagerAddOwnedPass(pm, externalPass);
|
||||
MlirLogicalResult success = mlirPassManagerRun(pm, module);
|
||||
MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
|
||||
if (mlirLogicalResultIsFailure(success)) {
|
||||
fprintf(stderr, "Unexpected failure running external pass.\n");
|
||||
exit(EXIT_FAILURE);
|
||||
@ -516,7 +517,7 @@ void testExternalPass(void) {
|
||||
|
||||
MlirPassManager pm = mlirPassManagerCreate(ctx);
|
||||
mlirPassManagerAddOwnedPass(pm, externalPass);
|
||||
MlirLogicalResult success = mlirPassManagerRun(pm, module);
|
||||
MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
|
||||
if (mlirLogicalResultIsSuccess(success)) {
|
||||
fprintf(
|
||||
stderr,
|
||||
@ -564,7 +565,7 @@ void testExternalPass(void) {
|
||||
|
||||
MlirPassManager pm = mlirPassManagerCreate(ctx);
|
||||
mlirPassManagerAddOwnedPass(pm, externalPass);
|
||||
MlirLogicalResult success = mlirPassManagerRun(pm, module);
|
||||
MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
|
||||
if (mlirLogicalResultIsSuccess(success)) {
|
||||
fprintf(
|
||||
stderr,
|
||||
@ -587,7 +588,7 @@ void testExternalPass(void) {
|
||||
}
|
||||
|
||||
mlirTypeIDAllocatorDestroy(typeIDAllocator);
|
||||
mlirModuleDestroy(module);
|
||||
mlirOperationDestroy(module);
|
||||
mlirContextDestroy(ctx);
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user