[mlir][linalg] Migrate Detensorize pass to new dialect conversion driver
This commit is contained in:
parent
0d8aa9d9ec
commit
fdd7e25a18
@ -458,6 +458,22 @@ struct LinalgDetensorize
|
||||
}
|
||||
};
|
||||
|
||||
/// A listener that forwards notifyBlockErased and notifyOperationErased to
|
||||
/// the given callbacks.
|
||||
struct CallbackListener : public RewriterBase::Listener {
|
||||
CallbackListener(std::function<void(Operation *op)> onOperationErased,
|
||||
std::function<void(Block *block)> onBlockErased)
|
||||
: onOperationErased(onOperationErased), onBlockErased(onBlockErased) {}
|
||||
|
||||
void notifyBlockErased(Block *block) override { onBlockErased(block); }
|
||||
void notifyOperationErased(Operation *op) override {
|
||||
onOperationErased(op);
|
||||
}
|
||||
|
||||
std::function<void(Operation *op)> onOperationErased;
|
||||
std::function<void(Block *block)> onBlockErased;
|
||||
};
|
||||
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
DetensorizeTypeConverter typeConverter;
|
||||
@ -551,8 +567,22 @@ struct LinalgDetensorize
|
||||
populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter,
|
||||
shouldConvertBranchOperand);
|
||||
|
||||
if (failed(
|
||||
applyFullConversion(getOperation(), target, std::move(patterns))))
|
||||
ConversionConfig config;
|
||||
auto onOperationErased = [&](Operation *op) {
|
||||
opsToDetensor.erase(op);
|
||||
detensorableBranchOps.erase(op);
|
||||
};
|
||||
auto onBlockErased = [&](Block *block) {
|
||||
for (BlockArgument arg : block->getArguments()) {
|
||||
blockArgsToDetensor.erase(arg);
|
||||
}
|
||||
};
|
||||
CallbackListener listener(onOperationErased, onBlockErased);
|
||||
|
||||
config.listener = &listener;
|
||||
config.allowPatternRollback = false;
|
||||
if (failed(applyFullConversion(getOperation(), target, std::move(patterns),
|
||||
config)))
|
||||
signalPassFailure();
|
||||
|
||||
RewritePatternSet canonPatterns(context);
|
||||
|
@ -53,10 +53,11 @@ func.func @detensor_op_sequence(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tenso
|
||||
}
|
||||
// CHECK-LABEL: func @detensor_op_sequence
|
||||
// CHECK-SAME: (%[[arg1:.*]]: tensor<f32>, %[[arg2:.*]]: tensor<f32>)
|
||||
// CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]]
|
||||
// CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]]
|
||||
// CHECK: %[[detensored_res:.*]] = arith.addf %[[arg1_val]], %[[arg2_val]]
|
||||
// CHECK: %[[detensored_res2:.*]] = arith.mulf %[[arg1_val]], %[[detensored_res]]
|
||||
// CHECK: %[[arg1_val_1:.*]] = tensor.extract %[[arg1]]
|
||||
// CHECK: %[[arg2_val:.*]] = tensor.extract %[[arg2]]
|
||||
// CHECK: %[[arg1_val_2:.*]] = tensor.extract %[[arg1]]
|
||||
// CHECK: %[[detensored_res:.*]] = arith.addf %[[arg1_val_2]], %[[arg2_val]]
|
||||
// CHECK: %[[detensored_res2:.*]] = arith.mulf %[[arg1_val_1]], %[[detensored_res]]
|
||||
// CHECK: %[[detensored_res3:.*]] = arith.divf %[[detensored_res]], %[[detensored_res2]]
|
||||
// CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res3]]
|
||||
// CHECK: return %[[new_tensor_res]]
|
||||
|
Loading…
x
Reference in New Issue
Block a user