[mlir][linalg] Migrate Detensorize pass to new dialect conversion driver

This commit is contained in:
Matthias Springer 2025-08-10 11:41:51 +00:00
parent 0d8aa9d9ec
commit fdd7e25a18
2 changed files with 37 additions and 6 deletions

View File

@ -458,6 +458,22 @@ struct LinalgDetensorize
}
};
/// A listener that forwards notifyBlockErased and notifyOperationErased to
/// the given callbacks.
struct CallbackListener : public RewriterBase::Listener {
CallbackListener(std::function<void(Operation *op)> onOperationErased,
std::function<void(Block *block)> onBlockErased)
: onOperationErased(onOperationErased), onBlockErased(onBlockErased) {}
void notifyBlockErased(Block *block) override { onBlockErased(block); }
void notifyOperationErased(Operation *op) override {
onOperationErased(op);
}
std::function<void(Operation *op)> onOperationErased;
std::function<void(Block *block)> onBlockErased;
};
void runOnOperation() override {
MLIRContext *context = &getContext();
DetensorizeTypeConverter typeConverter;
@ -551,8 +567,22 @@ struct LinalgDetensorize
populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter,
shouldConvertBranchOperand);
if (failed(
applyFullConversion(getOperation(), target, std::move(patterns))))
ConversionConfig config;
auto onOperationErased = [&](Operation *op) {
opsToDetensor.erase(op);
detensorableBranchOps.erase(op);
};
auto onBlockErased = [&](Block *block) {
for (BlockArgument arg : block->getArguments()) {
blockArgsToDetensor.erase(arg);
}
};
CallbackListener listener(onOperationErased, onBlockErased);
config.listener = &listener;
config.allowPatternRollback = false;
if (failed(applyFullConversion(getOperation(), target, std::move(patterns),
config)))
signalPassFailure();
RewritePatternSet canonPatterns(context);

View File

@ -53,10 +53,11 @@ func.func @detensor_op_sequence(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tenso
}
// CHECK-LABEL: func @detensor_op_sequence
// CHECK-SAME: (%[[arg1:.*]]: tensor<f32>, %[[arg2:.*]]: tensor<f32>)
// CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]]
// CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]]
// CHECK: %[[detensored_res:.*]] = arith.addf %[[arg1_val]], %[[arg2_val]]
// CHECK: %[[detensored_res2:.*]] = arith.mulf %[[arg1_val]], %[[detensored_res]]
// CHECK: %[[arg1_val_1:.*]] = tensor.extract %[[arg1]]
// CHECK: %[[arg2_val:.*]] = tensor.extract %[[arg2]]
// CHECK: %[[arg1_val_2:.*]] = tensor.extract %[[arg1]]
// CHECK: %[[detensored_res:.*]] = arith.addf %[[arg1_val_2]], %[[arg2_val]]
// CHECK: %[[detensored_res2:.*]] = arith.mulf %[[arg1_val_1]], %[[detensored_res]]
// CHECK: %[[detensored_res3:.*]] = arith.divf %[[detensored_res]], %[[detensored_res2]]
// CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res3]]
// CHECK: return %[[new_tensor_res]]