[mlir][transform] Implement FlattenElementwiseLinalgOp
transform op (#81431)
A `transform.structured.flatten_elementwise` op is implemented for flattening the iteration space and (applicable) operands/results to a single dimension.
This commit is contained in:
parent
95e036956f
commit
b6f4dd9ee8
@ -2295,6 +2295,49 @@ def ConvertConv2DToImg2ColOp : Op<Transform_Dialect,
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// FlattenElementwiseLinalgOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def FlattenElementwiseLinalgOp : Op<Transform_Dialect,
|
||||||
|
"structured.flatten_elementwise",
|
||||||
|
[FunctionalStyleTransformOpTrait,
|
||||||
|
MemoryEffectsOpInterface,
|
||||||
|
TransformOpInterface,
|
||||||
|
TransformEachOpTrait,
|
||||||
|
ReportTrackingListenerFailuresOpTrait]> {
|
||||||
|
let description = [{
|
||||||
|
Flattens the iteration space and (applicable) operands of elementwise
|
||||||
|
linalg ops to a single dimension.
|
||||||
|
|
||||||
|
Returns one handle:
|
||||||
|
- Flattened linalg operation.
|
||||||
|
|
||||||
|
#### Return modes:
|
||||||
|
|
||||||
|
Returns a definite failure if target is not isolated from above.
|
||||||
|
Returns a silenceable failure if the pattern application failed.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins TransformHandleTypeInterface:$target);
|
||||||
|
let results = (outs TransformHandleTypeInterface:$transformed);
|
||||||
|
|
||||||
|
let assemblyFormat =
|
||||||
|
"$target attr-dict `:` functional-type($target, results)";
|
||||||
|
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<(ins "Value":$target)>
|
||||||
|
];
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
::mlir::DiagnosedSilenceableFailure applyToOne(
|
||||||
|
::mlir::transform::TransformRewriter &rewriter,
|
||||||
|
::mlir::linalg::LinalgOp target,
|
||||||
|
::mlir::transform::ApplyToEachResultList &results,
|
||||||
|
::mlir::transform::TransformState &state);
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Transpose Conv2D
|
// Transpose Conv2D
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -1074,6 +1074,11 @@ bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence);
|
|||||||
bool areDimSequencesPreserved(ArrayRef<AffineMap> maps,
|
bool areDimSequencesPreserved(ArrayRef<AffineMap> maps,
|
||||||
ArrayRef<ReassociationIndices> dimSequences);
|
ArrayRef<ReassociationIndices> dimSequences);
|
||||||
|
|
||||||
|
struct CollapseResult {
|
||||||
|
SmallVector<Value> results;
|
||||||
|
LinalgOp collapsedOp;
|
||||||
|
};
|
||||||
|
|
||||||
/// Collapses dimensions of linalg.generic/linalg.copy operation. A precondition
|
/// Collapses dimensions of linalg.generic/linalg.copy operation. A precondition
|
||||||
/// to calling this method is that for each list in `foldedIterationDim`, the
|
/// to calling this method is that for each list in `foldedIterationDim`, the
|
||||||
/// sequence of dimensions is contiguous in domains of all `indexing_maps` of
|
/// sequence of dimensions is contiguous in domains of all `indexing_maps` of
|
||||||
@ -1081,9 +1086,8 @@ bool areDimSequencesPreserved(ArrayRef<AffineMap> maps,
|
|||||||
/// When valid, the method also collapses the operands of the op. Returns
|
/// When valid, the method also collapses the operands of the op. Returns
|
||||||
/// replacement values of the results of the original `linalgOp` by inserting
|
/// replacement values of the results of the original `linalgOp` by inserting
|
||||||
/// reshapes to get back values of compatible types.
|
/// reshapes to get back values of compatible types.
|
||||||
template <typename LinalgType>
|
FailureOr<CollapseResult>
|
||||||
FailureOr<SmallVector<Value>>
|
collapseOpIterationDims(LinalgOp op,
|
||||||
collapseOpIterationDims(LinalgType op,
|
|
||||||
ArrayRef<ReassociationIndices> foldedIterationDims,
|
ArrayRef<ReassociationIndices> foldedIterationDims,
|
||||||
RewriterBase &rewriter);
|
RewriterBase &rewriter);
|
||||||
|
|
||||||
|
@ -3244,6 +3244,31 @@ DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
|
|||||||
return DiagnosedSilenceableFailure::success();
|
return DiagnosedSilenceableFailure::success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// FlattenElementwiseLinalgOp.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
|
||||||
|
transform::TransformRewriter &rewriter, linalg::LinalgOp target,
|
||||||
|
transform::ApplyToEachResultList &results,
|
||||||
|
transform::TransformState &state) {
|
||||||
|
rewriter.setInsertionPoint(target);
|
||||||
|
if (target.getNumLoops() <= 1)
|
||||||
|
return DiagnosedSilenceableFailure::success();
|
||||||
|
ReassociationIndices reassociation(target.getNumLoops());
|
||||||
|
std::iota(reassociation.begin(), reassociation.end(), 0);
|
||||||
|
auto maybeFlattened =
|
||||||
|
(isElementwise(target))
|
||||||
|
? collapseOpIterationDims(target, reassociation, rewriter)
|
||||||
|
: FailureOr<CollapseResult>(rewriter.notifyMatchFailure(
|
||||||
|
target, "only elementwise flattening is supported"));
|
||||||
|
if (failed(maybeFlattened))
|
||||||
|
return emitDefaultSilenceableFailure(target);
|
||||||
|
results.push_back(maybeFlattened->collapsedOp);
|
||||||
|
rewriter.replaceOp(target, maybeFlattened->results);
|
||||||
|
return DiagnosedSilenceableFailure::success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TransposeConv2DOp
|
// TransposeConv2DOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -1446,24 +1446,20 @@ void generateCollapsedIndexingRegion(Location loc, Block *block,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename LinalgType>
|
void collapseOperandsAndResults(LinalgOp op,
|
||||||
Operation *createCollapsedOp(LinalgType op,
|
const CollapsingInfo &collapsingInfo,
|
||||||
const CollapsingInfo &collapsingInfo,
|
RewriterBase &rewriter,
|
||||||
RewriterBase &rewriter) {
|
SmallVectorImpl<Value> &inputOperands,
|
||||||
static_assert(llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value,
|
SmallVectorImpl<Value> &outputOperands,
|
||||||
"unsupported linalg op type to create");
|
SmallVectorImpl<Type> &resultTypes) {
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
|
inputOperands =
|
||||||
// Get the input operands.
|
|
||||||
SmallVector<Value> inputOperands =
|
|
||||||
llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
|
llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
|
||||||
return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
|
return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo,
|
||||||
rewriter);
|
rewriter);
|
||||||
});
|
});
|
||||||
|
|
||||||
// Get the output operands and result types.
|
// Get the output operands and result types.
|
||||||
SmallVector<Type> resultTypes;
|
|
||||||
SmallVector<Value> outputOperands;
|
|
||||||
resultTypes.reserve(op.getNumDpsInits());
|
resultTypes.reserve(op.getNumDpsInits());
|
||||||
outputOperands.reserve(op.getNumDpsInits());
|
outputOperands.reserve(op.getNumDpsInits());
|
||||||
for (OpOperand &output : op.getDpsInitsMutable()) {
|
for (OpOperand &output : op.getDpsInitsMutable()) {
|
||||||
@ -1475,41 +1471,69 @@ Operation *createCollapsedOp(LinalgType op,
|
|||||||
if (!op.hasPureBufferSemantics())
|
if (!op.hasPureBufferSemantics())
|
||||||
resultTypes.push_back(newOutput.getType());
|
resultTypes.push_back(newOutput.getType());
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (isa<linalg::CopyOp>(op)) {
|
/// Clone a `LinalgOp` to a collapsed version of same name
|
||||||
return rewriter.create<linalg::CopyOp>(loc, inputOperands[0],
|
template <typename OpTy>
|
||||||
outputOperands[0]);
|
OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp,
|
||||||
}
|
const CollapsingInfo &collapsingInfo) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
// Get the iterator types for the operand.
|
/// Collapse any `LinalgOp` that does not require any specialization such as
|
||||||
SmallVector<utils::IteratorType> iteratorTypes =
|
/// indexing_maps, iterator_types, etc.
|
||||||
getCollapsedOpIteratorTypes(op.getIteratorTypesArray(), collapsingInfo);
|
template <>
|
||||||
|
LinalgOp cloneToCollapsedOp<LinalgOp>(RewriterBase &rewriter, LinalgOp origOp,
|
||||||
|
const CollapsingInfo &collapsingInfo) {
|
||||||
|
SmallVector<Value> inputOperands, outputOperands;
|
||||||
|
SmallVector<Type> resultTypes;
|
||||||
|
collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
|
||||||
|
outputOperands, resultTypes);
|
||||||
|
return cast<LinalgOp>(clone(
|
||||||
|
rewriter, origOp, resultTypes,
|
||||||
|
llvm::to_vector(llvm::concat<Value>(inputOperands, outputOperands))));
|
||||||
|
}
|
||||||
|
|
||||||
// Get the indexing maps.
|
/// Collapse a `GenericOp`
|
||||||
auto indexingMaps =
|
template <>
|
||||||
llvm::map_to_vector(op.getIndexingMapsArray(), [&](AffineMap map) {
|
GenericOp cloneToCollapsedOp<GenericOp>(RewriterBase &rewriter,
|
||||||
|
GenericOp origOp,
|
||||||
|
const CollapsingInfo &collapsingInfo) {
|
||||||
|
SmallVector<Value> inputOperands, outputOperands;
|
||||||
|
SmallVector<Type> resultTypes;
|
||||||
|
collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands,
|
||||||
|
outputOperands, resultTypes);
|
||||||
|
SmallVector<AffineMap> indexingMaps(
|
||||||
|
llvm::map_range(origOp.getIndexingMapsArray(), [&](AffineMap map) {
|
||||||
return getCollapsedOpIndexingMap(map, collapsingInfo);
|
return getCollapsedOpIndexingMap(map, collapsingInfo);
|
||||||
});
|
}));
|
||||||
|
|
||||||
Operation *collapsedOp = rewriter.create<linalg::GenericOp>(
|
SmallVector<utils::IteratorType> iteratorTypes(getCollapsedOpIteratorTypes(
|
||||||
loc, resultTypes, inputOperands, outputOperands, indexingMaps,
|
origOp.getIteratorTypesArray(), collapsingInfo));
|
||||||
|
|
||||||
|
GenericOp collapsedOp = rewriter.create<linalg::GenericOp>(
|
||||||
|
origOp.getLoc(), resultTypes, inputOperands, outputOperands, indexingMaps,
|
||||||
iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
|
iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
|
||||||
Block *origOpBlock = &op->getRegion(0).front();
|
Block *origOpBlock = &origOp->getRegion(0).front();
|
||||||
Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
|
Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
|
||||||
rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
|
rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
|
||||||
collapsedOpBlock->getArguments());
|
collapsedOpBlock->getArguments());
|
||||||
|
|
||||||
return collapsedOp;
|
return collapsedOp;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Implementation of fusion with reshape operation by collapsing dimensions.
|
LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo,
|
||||||
template <typename LinalgType>
|
RewriterBase &rewriter) {
|
||||||
FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
|
if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) {
|
||||||
LinalgType op, ArrayRef<ReassociationIndices> foldedIterationDims,
|
return cloneToCollapsedOp(rewriter, genericOp, collapsingInfo);
|
||||||
RewriterBase &rewriter) {
|
} else {
|
||||||
static_assert(llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value,
|
return cloneToCollapsedOp(rewriter, op, collapsingInfo);
|
||||||
"unsupported linalg op type to collapse");
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Implementation of fusion with reshape operation by collapsing dimensions.
|
||||||
|
FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
|
||||||
|
LinalgOp op, ArrayRef<ReassociationIndices> foldedIterationDims,
|
||||||
|
RewriterBase &rewriter) {
|
||||||
// Bail on trivial no-op cases.
|
// Bail on trivial no-op cases.
|
||||||
if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
|
if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
|
||||||
llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
|
llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
|
||||||
@ -1538,8 +1562,7 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Bail on non-canonical ranges.
|
// Bail on non-canonical ranges.
|
||||||
SmallVector<Range> loopRanges =
|
SmallVector<Range> loopRanges = op.createLoopRanges(rewriter, op.getLoc());
|
||||||
cast<LinalgOp>(op.getOperation()).createLoopRanges(rewriter, op.getLoc());
|
|
||||||
auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
|
auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
|
||||||
if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
|
if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
|
||||||
return cast<IntegerAttr>(attr).getInt() == value;
|
return cast<IntegerAttr>(attr).getInt() == value;
|
||||||
@ -1555,8 +1578,7 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
|
|||||||
op, "expected all loop ranges to have zero start and unit stride");
|
op, "expected all loop ranges to have zero start and unit stride");
|
||||||
}
|
}
|
||||||
|
|
||||||
LinalgType collapsedOp = cast<LinalgType>(
|
LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter);
|
||||||
createCollapsedOp<LinalgType>(op, collapsingInfo, rewriter));
|
|
||||||
|
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
if (collapsedOp.hasIndexSemantics()) {
|
if (collapsedOp.hasIndexSemantics()) {
|
||||||
@ -1597,7 +1619,7 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
|
|||||||
results.push_back(collapsedOpResult);
|
results.push_back(collapsedOpResult);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return results;
|
return CollapseResult{results, collapsedOp};
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -1629,15 +1651,14 @@ public:
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<SmallVector<Value>> replacements =
|
std::optional<CollapseResult> collapseResult = collapseOpIterationDims(
|
||||||
collapseOpIterationDims<linalg::GenericOp>(
|
genericOp, collapsableIterationDims, rewriter);
|
||||||
genericOp, collapsableIterationDims, rewriter);
|
if (!collapseResult) {
|
||||||
if (!replacements) {
|
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
genericOp, "failed to do the fusion by collapsing transformation");
|
genericOp, "failed to do the fusion by collapsing transformation");
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.replaceOp(genericOp, *replacements);
|
rewriter.replaceOp(genericOp, collapseResult->results);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
return failure();
|
return failure();
|
||||||
@ -1671,13 +1692,12 @@ public:
|
|||||||
op, "specified dimensions cannot be collapsed");
|
op, "specified dimensions cannot be collapsed");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<SmallVector<Value>> replacements =
|
std::optional<CollapseResult> collapseResult =
|
||||||
collapseOpIterationDims<LinalgType>(op, collapsableIterationDims,
|
collapseOpIterationDims(op, collapsableIterationDims, rewriter);
|
||||||
rewriter);
|
if (!collapseResult) {
|
||||||
if (!replacements) {
|
|
||||||
return rewriter.notifyMatchFailure(op, "failed to collapse dimensions");
|
return rewriter.notifyMatchFailure(op, "failed to collapse dimensions");
|
||||||
}
|
}
|
||||||
rewriter.replaceOp(op, *replacements);
|
rewriter.replaceOp(op, collapseResult->results);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
99
mlir/test/Dialect/Linalg/flatten-elementwise.mlir
Normal file
99
mlir/test/Dialect/Linalg/flatten-elementwise.mlir
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
// RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @fill(
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: f32,
|
||||||
|
// CHECK-SAME: %[[ARG1:.*]]: memref<32x7xf32>
|
||||||
|
// CHECK-NEXT: %[[FLATTENED:.*]] = memref.collapse_shape %[[ARG1]] {{\[}}[0, 1]]
|
||||||
|
// CHECK-NEXT: linalg.fill ins(%[[ARG0]] : f32) outs(%[[FLATTENED]] : memref<224xf32>)
|
||||||
|
func.func @fill(%cst: f32, %arg: memref<32x7xf32>) {
|
||||||
|
linalg.fill ins(%cst: f32) outs(%arg: memref<32x7xf32>)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
module attributes {transform.with_named_sequence} {
|
||||||
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
|
%0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
|
%flattened = transform.structured.flatten_elementwise %0
|
||||||
|
: (!transform.any_op) -> !transform.any_op
|
||||||
|
transform.yield
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @fill_tensor(
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: f32,
|
||||||
|
// CHECK-SAME: %[[ARG1:.*]]: tensor<32x7xf32>
|
||||||
|
// CHECK-NEXT: %[[FLATTENED:.*]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0, 1]]
|
||||||
|
// CHECK-NEXT: %[[FLATTENED_RESULT:.*]] = linalg.fill ins(%[[ARG0]] : f32) outs(%[[FLATTENED]] : tensor<224xf32>)
|
||||||
|
// CHECK-NEXT: %[[RESULT:.*]] = tensor.expand_shape %[[FLATTENED_RESULT]] {{\[}}[0, 1]]
|
||||||
|
func.func @fill_tensor(%cst: f32, %arg: tensor<32x7xf32>) -> tensor<32x7xf32> {
|
||||||
|
%0 = linalg.fill ins(%cst: f32) outs(%arg: tensor<32x7xf32>) -> tensor<32x7xf32>
|
||||||
|
return %0 : tensor<32x7xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
module attributes {transform.with_named_sequence} {
|
||||||
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
|
%0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
|
%flattened = transform.structured.flatten_elementwise %0
|
||||||
|
: (!transform.any_op) -> !transform.any_op
|
||||||
|
transform.yield
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @map(
|
||||||
|
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<32x7xf32>
|
||||||
|
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<32x7xf32>
|
||||||
|
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<32x7xf32>
|
||||||
|
// CHECK-NEXT: %[[FLATTENED_0:.*]] = memref.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
|
||||||
|
// CHECK-NEXT: %[[FLATTENED_1:.*]] = memref.collapse_shape %[[ARG1]] {{\[}}[0, 1]]
|
||||||
|
// CHECK-NEXT: %[[FLATTENED_2:.*]] = memref.collapse_shape %[[ARG2]] {{\[}}[0, 1]]
|
||||||
|
// CHECK-NEXT: linalg.map { arith.addf } ins(%[[FLATTENED_0]], %[[FLATTENED_1]] : memref<224xf32>, memref<224xf32>) outs(%[[FLATTENED_2]] : memref<224xf32>)
|
||||||
|
func.func @map(%arg0: memref<32x7xf32>, %arg1: memref<32x7xf32>, %arg2: memref<32x7xf32>) {
|
||||||
|
linalg.map {arith.addf} ins(%arg0, %arg1: memref<32x7xf32>, memref<32x7xf32>) outs(%arg2: memref<32x7xf32>)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
module attributes {transform.with_named_sequence} {
|
||||||
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
|
%0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
|
%flattened = transform.structured.flatten_elementwise %0
|
||||||
|
: (!transform.any_op) -> !transform.any_op
|
||||||
|
transform.yield
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
|
||||||
|
// CHECK-LABEL: func.func @generic
|
||||||
|
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<32x7xf32>
|
||||||
|
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<32x7xf32>
|
||||||
|
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<32x7xf32>
|
||||||
|
// CHECK-NEXT: %[[FLATTENED_0:.*]] = memref.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
|
||||||
|
// CHECK-NEXT: %[[FLATTENED_1:.*]] = memref.collapse_shape %[[ARG1]] {{\[}}[0, 1]]
|
||||||
|
// CHECK-NEXT: %[[FLATTENED_2:.*]] = memref.collapse_shape %[[ARG2]] {{\[}}[0, 1]]
|
||||||
|
// CHECK-NEXT: linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[FLATTENED_0]], %[[FLATTENED_1]] : memref<224xf32>, memref<224xf32>) outs(%[[FLATTENED_2]] : memref<224xf32>)
|
||||||
|
// CHECK-NEXT: ^bb0(%[[A:.*]]: f32, %[[B:.*]]: f32, %[[C:.*]]: f32)
|
||||||
|
// CHECK-NEXT: %[[SUM:.*]] = arith.addf %[[A]], %[[B]]
|
||||||
|
// CHECK-NEXT: linalg.yield %[[SUM]]
|
||||||
|
#map = affine_map<(d0, d1) -> (d0, d1)>
|
||||||
|
func.func @generic( %arg0: memref<32x7xf32>, %arg1: memref<32x7xf32>, %arg2: memref<32x7xf32>) {
|
||||||
|
linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1: memref<32x7xf32>, memref<32x7xf32>) outs(%arg2: memref<32x7xf32>) {
|
||||||
|
^bb0(%a: f32, %b: f32, %c: f32):
|
||||||
|
%0 = arith.addf %a, %b : f32
|
||||||
|
linalg.yield %0 : f32
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
module attributes {transform.with_named_sequence} {
|
||||||
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||||
|
%0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
|
%flattened = transform.structured.flatten_elementwise %0
|
||||||
|
: (!transform.any_op) -> !transform.any_op
|
||||||
|
transform.yield
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user