[MLIR][Linalg] Generic to category specialization for unary elementwise ops (#187217)
Handle specialization of `linalg.generic` ops representing a unary elementwise computation to the `linalg.elementwise` category op. This implements a previously absent path in the linalg morphism.
This commit is contained in:
parent
81691d23cd
commit
018e048daf
@ -135,11 +135,16 @@ std::optional<SmallVector<int64_t>>
|
||||
isaTransposeOpInterface(GenericOp genericOp);
|
||||
|
||||
/// Checks whether a given `genericOp` is semantically equivalent to a single
|
||||
/// linalgelementwise unary op. e.g. linalg.exp.
|
||||
/// linalg elementwise unary op, e.g. `linalg.exp` or
|
||||
/// `linalg.elementwise kind=#linalg.elementwise_kind<exp>`.
|
||||
/// If `allowNonIdentityMaps` is true, operations with custom indexing maps are
|
||||
/// included in the check. Note that these operations can only be represented by
|
||||
/// the category op.
|
||||
/// A linalg.generic body could be a series of unary elementwise ops e.g.
|
||||
/// `exp(neg(x))`, such as formed by linalg op fusion. Here we restrict it to
|
||||
/// detecting cases where body is is a single computation op.
|
||||
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp);
|
||||
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp,
|
||||
bool allowNonIdentityMaps = false);
|
||||
|
||||
/// Checks whether `genericOp` is semantically equivalent to a single linalg
|
||||
/// elementwise binary op e.g. linalg.sub.
|
||||
|
||||
@ -227,16 +227,19 @@ linalg::isaTransposeOpInterface(GenericOp op) {
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Elementwise Single Unary/Binary-OpInterface implementation
|
||||
//===----------------------------------------------------------------------===//
|
||||
static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op,
|
||||
unsigned arity) {
|
||||
static bool
|
||||
isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op, unsigned arity,
|
||||
bool allowNonIdentityMaps) {
|
||||
// Check all loops are parallel.
|
||||
if (!op.isAllParallelLoops() || op.getNumLoops() < 1)
|
||||
return false;
|
||||
|
||||
// Check there are arity-inputs, 1-output and all are identity-maps.
|
||||
// Check there are arity-inputs, 1-output and all are identity-maps (unless
|
||||
// requested otherwise).
|
||||
if (op.getNumDpsInputs() != arity || op.getNumDpsInits() != 1 ||
|
||||
!llvm::all_of(op.getIndexingMapsArray(),
|
||||
[](AffineMap map) { return map.isIdentity(); }))
|
||||
(!allowNonIdentityMaps &&
|
||||
!llvm::all_of(op.getIndexingMapsArray(),
|
||||
[](AffineMap map) { return map.isIdentity(); })))
|
||||
return false;
|
||||
|
||||
// Init should not be referenced for elementwise operations.
|
||||
@ -264,19 +267,21 @@ static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op,
|
||||
yieldOp->getOperand(0).getDefiningOp() != oper);
|
||||
}
|
||||
|
||||
bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp op) {
|
||||
bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp op,
|
||||
bool allowNonIdentityMaps) {
|
||||
// All basic elemwise checks.
|
||||
if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, 1))
|
||||
if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, 1, allowNonIdentityMaps))
|
||||
return false;
|
||||
|
||||
// Check input is actully used.
|
||||
// Check input is actually used.
|
||||
if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp op) {
|
||||
if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, 2))
|
||||
if (!isaElemwiseSingleUnaryOrBinaryOpInterface(
|
||||
op, 2, /*allowNonIdentityMaps=*/false))
|
||||
return false;
|
||||
|
||||
// Check both inputs are used (elementwise).
|
||||
|
||||
@ -35,14 +35,13 @@ namespace mlir {
|
||||
genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 0 : 1]}, \
|
||||
ValueRange{genericOp.getDpsInits()[0]}))
|
||||
|
||||
#define REPLACE_UNARY_OP(NEWOP) \
|
||||
static_cast<LinalgOp>(rewriter.replaceOpWithNewOp<NEWOP>( \
|
||||
genericOp, ValueRange{genericOp.getDpsInputs()[0]}, \
|
||||
ValueRange{genericOp.getDpsInits()[0]}))
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::linalg;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Specialize linalg generic to elementwise ops.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Given a elementwise single binary linalg generic op, checks whether the
|
||||
// binary op accesses operands as swapped. e.g.
|
||||
// this differentiates between a linalg-generic body that contains:
|
||||
@ -67,6 +66,98 @@ static bool areBinOpsSwapped(GenericOp genericOp) {
|
||||
return swapped;
|
||||
}
|
||||
|
||||
// Attempt to specialize linalg.generic to named elementwise ops or
|
||||
// linalg.elementwise.
|
||||
//
|
||||
// Example:
|
||||
// %0 = linalg.generic {
|
||||
// indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
|
||||
// affine_map<(d0, d1) -> (d0, d1)>],
|
||||
// iterator_types = ["parallel", "parallel"]
|
||||
// } ins(%In : tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) {
|
||||
// ^bb0(%in: f32, %out: f32):
|
||||
// %1 = math.exp %in : f32
|
||||
// linalg.yield %1 : f32
|
||||
// } -> tensor<?x?xf32>
|
||||
//
|
||||
// is specialized to either
|
||||
// linalg.exp ins(...) outs(...) -> ...
|
||||
// or
|
||||
// linalg.elementwise kind=#linalg.elementwise_kind<exp> ...
|
||||
//
|
||||
// Only the category op can carry non-identity indexing maps; these are
|
||||
// transferred verbatim from the `genericOp`.
|
||||
static FailureOr<LinalgOp>
|
||||
specializeLinalgUnaryElementwise(RewriterBase &rewriter, GenericOp genericOp,
|
||||
bool emitCategoryOp) {
|
||||
bool hasNonIdentityMaps =
|
||||
!llvm::all_of(genericOp.getIndexingMapsArray(),
|
||||
[](AffineMap map) { return map.isIdentity(); });
|
||||
|
||||
// Early exit: Named ops cannot carry user-defined maps.
|
||||
if (hasNonIdentityMaps && !emitCategoryOp)
|
||||
return rewriter.notifyMatchFailure(
|
||||
genericOp,
|
||||
"non-identity indexing maps prevent specialization to named op");
|
||||
|
||||
// Helper to dispatch between named op and `linalg.elementwise`.
|
||||
// Lambdas with explicit template parameter list are a C++20 feature, hence
|
||||
// the dummy op object.
|
||||
auto replaceUnaryOp = [&](auto namedOp, ElementwiseKind kind) -> LinalgOp {
|
||||
LinalgOp newOp;
|
||||
if (!emitCategoryOp)
|
||||
newOp = decltype(namedOp)::create(
|
||||
rewriter, genericOp.getLoc(), genericOp.getDpsInputs(),
|
||||
genericOp.getDpsInits(), ArrayRef<NamedAttribute>{});
|
||||
else
|
||||
newOp = ElementwiseOp::create(
|
||||
rewriter, genericOp.getLoc(), genericOp.getDpsInputs(),
|
||||
genericOp.getDpsInits(),
|
||||
ElementwiseKindAttr::get(rewriter.getContext(), kind),
|
||||
genericOp.getIndexingMaps());
|
||||
|
||||
rewriter.replaceOp(genericOp, newOp);
|
||||
return newOp;
|
||||
};
|
||||
|
||||
// Inspect body operation to determine named op or elementwise kind.
|
||||
Operation *op = &genericOp.getBody()->front();
|
||||
|
||||
if (isa<math::ExpOp>(op))
|
||||
return replaceUnaryOp(ExpOp{}, ElementwiseKind::exp);
|
||||
if (isa<math::LogOp>(op))
|
||||
return replaceUnaryOp(LogOp{}, ElementwiseKind::log);
|
||||
if (isa<math::AbsFOp>(op))
|
||||
return replaceUnaryOp(AbsOp{}, ElementwiseKind::abs);
|
||||
if (isa<math::CeilOp>(op))
|
||||
return replaceUnaryOp(CeilOp{}, ElementwiseKind::ceil);
|
||||
if (isa<math::FloorOp>(op))
|
||||
return replaceUnaryOp(FloorOp{}, ElementwiseKind::floor);
|
||||
if (isa<arith::NegFOp>(op))
|
||||
return replaceUnaryOp(NegFOp{}, ElementwiseKind::negf);
|
||||
if (auto divOp = dyn_cast<arith::DivFOp>(op)) {
|
||||
if (auto constOp = dyn_cast_if_present<arith::ConstantOp>(
|
||||
divOp.getLhs().getDefiningOp()))
|
||||
if (cast<FloatAttr>(constOp.getValue()).getValue().isExactlyValue(1.0))
|
||||
return replaceUnaryOp(ReciprocalOp{}, ElementwiseKind::reciprocal);
|
||||
}
|
||||
if (isa<math::RoundOp>(op))
|
||||
return replaceUnaryOp(RoundOp{}, ElementwiseKind::round);
|
||||
if (isa<math::SqrtOp>(op))
|
||||
return replaceUnaryOp(SqrtOp{}, ElementwiseKind::sqrt);
|
||||
if (isa<math::RsqrtOp>(op))
|
||||
return replaceUnaryOp(RsqrtOp{}, ElementwiseKind::rsqrt);
|
||||
if (auto mulOp = dyn_cast<arith::MulFOp>(op);
|
||||
mulOp && mulOp.getLhs() == mulOp.getRhs())
|
||||
return replaceUnaryOp(SquareOp{}, ElementwiseKind::square);
|
||||
if (isa<math::TanhOp>(op))
|
||||
return replaceUnaryOp(TanhOp{}, ElementwiseKind::tanh);
|
||||
if (isa<math::ErfOp>(op))
|
||||
return replaceUnaryOp(ErfOp{}, ElementwiseKind::erf);
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Specialize linalg generic to matmul variants.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -455,6 +546,12 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
|
||||
FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(
|
||||
RewriterBase &rewriter, GenericOp genericOp,
|
||||
const GenericOpSpecializationOptions &options) {
|
||||
// Unary elementwise - e.g. exp
|
||||
if (isaElemwiseSingleUnaryOpInterface(genericOp, options.emitCategoryOps)) {
|
||||
return specializeLinalgUnaryElementwise(rewriter, genericOp,
|
||||
options.emitCategoryOps);
|
||||
}
|
||||
|
||||
// Contraction - e.g. matmul
|
||||
if (isaContractionOpInterface(genericOp)) {
|
||||
return specializeLinalgContractions(rewriter, genericOp,
|
||||
@ -505,42 +602,6 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(
|
||||
return namedOp;
|
||||
}
|
||||
|
||||
// Elementwise Unary
|
||||
if (isaElemwiseSingleUnaryOpInterface(genericOp)) {
|
||||
Operation *op = &genericOp.getBody()->front();
|
||||
if (isa<math::ExpOp>(op))
|
||||
return REPLACE_UNARY_OP(ExpOp);
|
||||
if (isa<math::LogOp>(op))
|
||||
return REPLACE_UNARY_OP(LogOp);
|
||||
if (isa<math::AbsFOp>(op))
|
||||
return REPLACE_UNARY_OP(AbsOp);
|
||||
if (isa<math::CeilOp>(op))
|
||||
return REPLACE_UNARY_OP(CeilOp);
|
||||
if (isa<math::FloorOp>(op))
|
||||
return REPLACE_UNARY_OP(FloorOp);
|
||||
if (isa<arith::NegFOp>(op))
|
||||
return REPLACE_UNARY_OP(NegFOp);
|
||||
if (auto divOp = dyn_cast<arith::DivFOp>(op)) {
|
||||
if (auto constOp = dyn_cast_if_present<arith::ConstantOp>(
|
||||
divOp.getLhs().getDefiningOp()))
|
||||
if (cast<FloatAttr>(constOp.getValue()).getValue().isExactlyValue(1.0))
|
||||
return REPLACE_UNARY_OP(ReciprocalOp);
|
||||
}
|
||||
if (isa<math::RoundOp>(op))
|
||||
return REPLACE_UNARY_OP(RoundOp);
|
||||
if (isa<math::SqrtOp>(op))
|
||||
return REPLACE_UNARY_OP(SqrtOp);
|
||||
if (isa<math::RsqrtOp>(op))
|
||||
return REPLACE_UNARY_OP(RsqrtOp);
|
||||
if (auto mulOp = dyn_cast<arith::MulFOp>(op);
|
||||
mulOp && mulOp.getLhs() == mulOp.getRhs())
|
||||
return REPLACE_UNARY_OP(SquareOp);
|
||||
if (isa<math::TanhOp>(op))
|
||||
return REPLACE_UNARY_OP(TanhOp);
|
||||
if (isa<math::ErfOp>(op))
|
||||
return REPLACE_UNARY_OP(ErfOp);
|
||||
}
|
||||
|
||||
// Elementwise Binary
|
||||
if (isaElemwiseSingleBinaryOpInterface(genericOp)) {
|
||||
bool swap = areBinOpsSwapped(genericOp);
|
||||
|
||||
@ -5,6 +5,101 @@
|
||||
// RUN: | mlir-opt -split-input-file -linalg-morph-ops=generic-to-category \
|
||||
// RUN: | FileCheck %s
|
||||
|
||||
func.func @unary_ops(%A: memref<7x14x21xf32>, %Out: memref<7x14x21xf32>) {
|
||||
linalg.elementwise kind=#linalg.elementwise_kind<exp>
|
||||
ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
|
||||
linalg.elementwise kind=#linalg.elementwise_kind<log>
|
||||
ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
|
||||
linalg.elementwise kind=#linalg.elementwise_kind<abs>
|
||||
ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
|
||||
linalg.elementwise kind=#linalg.elementwise_kind<ceil>
|
||||
ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
|
||||
linalg.elementwise kind=#linalg.elementwise_kind<floor>
|
||||
ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
|
||||
linalg.elementwise kind=#linalg.elementwise_kind<negf>
|
||||
ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
|
||||
linalg.elementwise kind=#linalg.elementwise_kind<reciprocal>
|
||||
ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
|
||||
linalg.elementwise kind=#linalg.elementwise_kind<round>
|
||||
ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
|
||||
linalg.elementwise kind=#linalg.elementwise_kind<sqrt>
|
||||
ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
|
||||
linalg.elementwise kind=#linalg.elementwise_kind<rsqrt>
|
||||
ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
|
||||
linalg.elementwise kind=#linalg.elementwise_kind<square>
|
||||
ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
|
||||
linalg.elementwise kind=#linalg.elementwise_kind<tanh>
|
||||
ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
|
||||
linalg.elementwise kind=#linalg.elementwise_kind<erf>
|
||||
ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: unary_ops
|
||||
// CHECK-SAME: %[[A:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
|
||||
// CHECK-NOT: linalg.generic
|
||||
// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<exp>
|
||||
// CHECK-SAME: ins(%[[A]] : memref<7x14x21xf32>)
|
||||
// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
|
||||
// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<log>
|
||||
// CHECK-SAME: ins(%[[A]] : memref<7x14x21xf32>)
|
||||
// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
|
||||
// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<abs>
|
||||
// CHECK-SAME: ins(%[[A]] : memref<7x14x21xf32>)
|
||||
// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
|
||||
// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<ceil>
|
||||
// CHECK-SAME: ins(%[[A]] : memref<7x14x21xf32>)
|
||||
// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
|
||||
// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<floor>
|
||||
// CHECK-SAME: ins(%[[A]] : memref<7x14x21xf32>)
|
||||
// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
|
||||
// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<negf>
|
||||
// CHECK-SAME: ins(%[[A]] : memref<7x14x21xf32>)
|
||||
// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
|
||||
// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<reciprocal>
|
||||
// CHECK-SAME: ins(%[[A]] : memref<7x14x21xf32>)
|
||||
// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
|
||||
// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<round>
|
||||
// CHECK-SAME: ins(%[[A]] : memref<7x14x21xf32>)
|
||||
// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
|
||||
// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<sqrt>
|
||||
// CHECK-SAME: ins(%[[A]] : memref<7x14x21xf32>)
|
||||
// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
|
||||
// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<rsqrt>
|
||||
// CHECK-SAME: ins(%[[A]] : memref<7x14x21xf32>)
|
||||
// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
|
||||
// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<square>
|
||||
// CHECK-SAME: ins(%[[A]] : memref<7x14x21xf32>)
|
||||
// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
|
||||
// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<tanh>
|
||||
// CHECK-SAME: ins(%[[A]] : memref<7x14x21xf32>)
|
||||
// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
|
||||
// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<erf>
|
||||
// CHECK-SAME: ins(%[[A]] : memref<7x14x21xf32>)
|
||||
// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
|
||||
|
||||
// -----
|
||||
|
||||
func.func @unary_ops_non_identity(%A: tensor<?xf32>, %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
%0 = linalg.elementwise
|
||||
kind=#linalg.elementwise_kind<log>
|
||||
indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>]
|
||||
ins(%A : tensor<?xf32>) outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// CHECK-DAG: #[[MAP_BC:.+]] = affine_map<(d0, d1) -> (d1)>
|
||||
// CHECK-DAG: #[[MAP_TP:.+]] = affine_map<(d0, d1) -> (d1, d0)>
|
||||
// CHECK: unary_ops_non_identity
|
||||
// CHECK-SAME: %[[A:.+]]: tensor<?xf32>, %[[OUT:.+]]: tensor<?x?xf32>)
|
||||
// CHECK-NOT: linalg.generic
|
||||
// CHECK: linalg.elementwise kind=#linalg.elementwise_kind<log>
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP_BC]], #[[MAP_TP]]]
|
||||
// CHECK-SAME: ins(%[[A]] : tensor<?xf32>)
|
||||
// CHECK-SAME: outs(%[[OUT]] : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
|
||||
// -----
|
||||
|
||||
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
|
||||
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
|
||||
@ -147,8 +147,76 @@ func.func @unary_ops(%A: tensor<?x?x?xf32>, %Out: tensor<?x?x?xf32>) -> tensor<?
|
||||
// NAMED-SAME: ins(%[[RES11]] : tensor<?x?x?xf32>)
|
||||
// NAMED-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
|
||||
// Not supported yet.
|
||||
// CATEGORY: linalg.generic
|
||||
// CATEGORY: %[[RES0:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<exp>
|
||||
// CATEGORY-SAME: ins(%[[A]] : tensor<?x?x?xf32>)
|
||||
// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CATEGORY: %[[RES1:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<log>
|
||||
// CATEGORY-SAME: ins(%[[RES0]] : tensor<?x?x?xf32>)
|
||||
// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CATEGORY: %[[RES2:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<abs>
|
||||
// CATEGORY-SAME: ins(%[[RES1]] : tensor<?x?x?xf32>)
|
||||
// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CATEGORY: %[[RES3:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<ceil>
|
||||
// CATEGORY-SAME: ins(%[[RES2]] : tensor<?x?x?xf32>)
|
||||
// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CATEGORY: %[[RES4:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<floor>
|
||||
// CATEGORY-SAME: ins(%[[RES3]] : tensor<?x?x?xf32>)
|
||||
// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CATEGORY: %[[RES5:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<negf>
|
||||
// CATEGORY-SAME: ins(%[[RES4]] : tensor<?x?x?xf32>)
|
||||
// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CATEGORY: %[[RES6:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<reciprocal>
|
||||
// CATEGORY-SAME: ins(%[[RES5]] : tensor<?x?x?xf32>)
|
||||
// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CATEGORY: %[[RES7:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<round>
|
||||
// CATEGORY-SAME: ins(%[[RES6]] : tensor<?x?x?xf32>)
|
||||
// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CATEGORY: %[[RES8:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<sqrt>
|
||||
// CATEGORY-SAME: ins(%[[RES7]] : tensor<?x?x?xf32>)
|
||||
// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CATEGORY: %[[RES9:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<rsqrt>
|
||||
// CATEGORY-SAME: ins(%[[RES8]] : tensor<?x?x?xf32>)
|
||||
// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CATEGORY: %[[RES10:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<square>
|
||||
// CATEGORY-SAME: ins(%[[RES9]] : tensor<?x?x?xf32>)
|
||||
// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CATEGORY: %[[RES11:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<tanh>
|
||||
// CATEGORY-SAME: ins(%[[RES10]] : tensor<?x?x?xf32>)
|
||||
// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CATEGORY: %[[RES12:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<erf>
|
||||
// CATEGORY-SAME: ins(%[[RES11]] : tensor<?x?x?xf32>)
|
||||
// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
|
||||
// -----
|
||||
|
||||
func.func @unary_ops_non_identity(%A: tensor<?xf32>, %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
%0 = linalg.generic
|
||||
{indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1, d0)>],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
ins(%A : tensor<?xf32>)
|
||||
outs(%Out : tensor<?x?xf32>) {
|
||||
^bb0(%in: f32, %out: f32):
|
||||
%v = math.exp %in : f32
|
||||
linalg.yield %v : f32
|
||||
} -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
|
||||
// ALL-DAG: #[[MAP_BC:.+]] = affine_map<(d0, d1) -> (d1)>
|
||||
// ALL-DAG: #[[MAP_TP:.+]] = affine_map<(d0, d1) -> (d1, d0)>
|
||||
// ALL: unary_ops_non_identity
|
||||
// ALL-SAME: %[[A:.+]]: tensor<?xf32>, %[[OUT:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
|
||||
// Named ops cannot carry user-defined indexing maps -> expect no change.
|
||||
// NAMED-NOT: linalg.exp
|
||||
// NAMED: linalg.generic
|
||||
|
||||
// CATEGORY-NOT: linalg.generic
|
||||
// CATEGORY: linalg.elementwise kind=#linalg.elementwise_kind<exp>
|
||||
// CATEGORY-SAME: indexing_maps = [#[[MAP_BC]], #[[MAP_TP]]]
|
||||
// CATEGORY-SAME: ins(%[[A]] : tensor<?xf32>)
|
||||
// CATEGORY-SAME: outs(%[[OUT]] : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user