[mlir][x86vector] Shuffle FMAs (#172823)
This patch Shuffles FMAs with x86vector operations as operands such that FMAs are grouped with respect to odd/even packed index. Continuation to PR: https://github.com/llvm/llvm-project/pull/170267 to manage register allocation efficiently.
This commit is contained in:
parent
aca2783840
commit
64ecd762e9
@ -60,6 +60,17 @@ def ApplySinkVectorProducerOpsPatternsOp : Op<Transform_Dialect,
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
def ApplyShuffleVectorFMAOpsPatternsOp : Op<Transform_Dialect,
|
||||
"apply_patterns.x86vector.shuffle_vector_fma_ops",
|
||||
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
|
||||
let description = [{
|
||||
Collect patterns to shuffle FMAs with x86vector operations as operands
|
||||
such that FMAs are grouped with respect to odd/even packed index.
|
||||
}];
|
||||
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
|
||||
#endif // X86VECTOR_TRANSFORM_OPS
|
||||
|
||||
|
||||
@ -100,6 +100,10 @@ void populateVectorContractBF16ToFMAPatterns(RewritePatternSet &patterns);
|
||||
// range by placing them at their earliest legal use site.
|
||||
void populateSinkVectorProducerOpsPatterns(RewritePatternSet &patterns);
|
||||
|
||||
// Shuffles FMAs with x86vector operations as operands such that FMAs are
|
||||
// grouped with respect to odd/even packed index.
|
||||
void populateShuffleVectorFMAOpsPatterns(RewritePatternSet &patterns);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Helpers extracted from:
|
||||
/// - clang/lib/Headers/avxintrin.h
|
||||
|
||||
@ -42,6 +42,11 @@ void mlir::transform::ApplySinkVectorProducerOpsPatternsOp::populatePatterns(
|
||||
x86vector::populateSinkVectorProducerOpsPatterns(patterns);
|
||||
}
|
||||
|
||||
void mlir::transform::ApplyShuffleVectorFMAOpsPatternsOp::populatePatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
x86vector::populateShuffleVectorFMAOpsPatterns(patterns);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Transform op registration
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@ -5,6 +5,7 @@ add_mlir_dialect_library(MLIRX86VectorTransforms
|
||||
VectorContractToPackedTypeDotProduct.cpp
|
||||
VectorContractBF16ToFMA.cpp
|
||||
SinkVectorProducerOps.cpp
|
||||
ShuffleVectorFMAOps.cpp
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRArithDialect
|
||||
|
||||
186
mlir/lib/Dialect/X86Vector/Transforms/ShuffleVectorFMAOps.cpp
Normal file
186
mlir/lib/Dialect/X86Vector/Transforms/ShuffleVectorFMAOps.cpp
Normal file
@ -0,0 +1,186 @@
|
||||
//===- ShuffleVectorFMAOps.cpp --------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/Dialect/X86Vector/Transforms.h"
|
||||
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
|
||||
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::vector;
|
||||
using namespace mlir::x86vector;
|
||||
|
||||
namespace {
|
||||
|
||||
// Validates whether the given operation is an x86vector operation and has only
|
||||
// one consumer.
|
||||
static bool validateFMAOperands(Value op) {
|
||||
if (auto cvt = op.getDefiningOp<x86vector::CvtPackedEvenIndexedToF32Op>())
|
||||
return cvt.getResult().hasOneUse();
|
||||
|
||||
if (auto bcst = op.getDefiningOp<x86vector::BcstToPackedF32Op>())
|
||||
return bcst.getResult().hasOneUse();
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// Validates the vector.fma operation on the following conditions:
|
||||
// (i) one of the lhs or rhs defining operation should be
|
||||
// CvtPackedEvenIndexedToF32Op, (ii) the lhs or rhs defining operation should be
|
||||
// an x86vector operation and has only one consumer, (iii) all operations
|
||||
// are in the same block, and (iv) ths FMA has only one user.
|
||||
static bool validateVectorFMAOp(vector::FMAOp fmaOp) {
|
||||
Value lhs = fmaOp.getLhs();
|
||||
Value rhs = fmaOp.getRhs();
|
||||
|
||||
if (!isa<x86vector::CvtPackedEvenIndexedToF32Op>(lhs.getDefiningOp()) &&
|
||||
!isa<x86vector::CvtPackedEvenIndexedToF32Op>(rhs.getDefiningOp()))
|
||||
return false;
|
||||
|
||||
if (!validateFMAOperands(lhs) || !validateFMAOperands(rhs))
|
||||
return false;
|
||||
|
||||
if (lhs.getDefiningOp()->getBlock() != rhs.getDefiningOp()->getBlock())
|
||||
return false;
|
||||
|
||||
if (lhs.getDefiningOp()->getBlock() != fmaOp->getBlock())
|
||||
return false;
|
||||
|
||||
if (!fmaOp.getResult().hasOneUse())
|
||||
return false;
|
||||
|
||||
Operation *consumer = *fmaOp.getResult().getUsers().begin();
|
||||
if (consumer->getBlock() != fmaOp->getBlock())
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Moves vector.fma along with the lhs and rhs defining operation before its
|
||||
// consumer. If the consumer is vector.ShapeCastOp and has only one user then
|
||||
// move before the consumer of vector.ShapeCastOp.
|
||||
// TODO: Move before first consumer, if there are multiple.
|
||||
static void moveFMA(PatternRewriter &rewriter, vector::FMAOp fmaOp) {
|
||||
Operation *consumer = *fmaOp.getResult().getUsers().begin();
|
||||
|
||||
if (auto shapeCastOp = dyn_cast<vector::ShapeCastOp>(consumer)) {
|
||||
if (shapeCastOp.getResult().hasOneUse()) {
|
||||
Operation *nxtConsumer = *shapeCastOp.getResult().getUsers().begin();
|
||||
if (nxtConsumer->getBlock() == fmaOp->getBlock()) {
|
||||
consumer = *shapeCastOp.getResult().getUsers().begin();
|
||||
rewriter.moveOpBefore(fmaOp.getLhs().getDefiningOp(), consumer);
|
||||
rewriter.moveOpBefore(fmaOp.getRhs().getDefiningOp(), consumer);
|
||||
rewriter.moveOpBefore(fmaOp.getOperation(), consumer);
|
||||
rewriter.moveOpBefore(shapeCastOp.getOperation(), consumer);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.moveOpBefore(fmaOp.getLhs().getDefiningOp(), consumer);
|
||||
rewriter.moveOpBefore(fmaOp.getRhs().getDefiningOp(), consumer);
|
||||
rewriter.moveOpBefore(fmaOp.getOperation(), consumer);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// Shuffle FMAs with x86vector operations as operands such that
|
||||
// FMAs are grouped with respect to odd/even packed index.
|
||||
//
|
||||
// For example:
|
||||
// ```
|
||||
// %1 = x86vector.avx.bcst_to_f32.packed
|
||||
// %2 = x86vector.avx.cvt.packed.odd.indexed_to_f32
|
||||
// %3 = vector.fma %1, %2, %arg1
|
||||
// %4 = x86vector.avx.bcst_to_f32.packed
|
||||
// %5 = x86vector.avx.cvt.packed.even.indexed_to_f32
|
||||
// %6 = vector.fma %4, %5, %3
|
||||
// %7 = x86vector.avx.bcst_to_f32.packed
|
||||
// %8 = x86vector.avx.cvt.packed.odd.indexed_to_f32
|
||||
// %9 = vector.fma %7, %8, %arg2
|
||||
// %10 = x86vector.avx.bcst_to_f32.packed
|
||||
// %11 = x86vector.avx.cvt.packed.even.indexed_to_f32
|
||||
// %12 = vector.fma %10, %11, %9
|
||||
// yield %6, %12
|
||||
// ```
|
||||
// to
|
||||
// ```
|
||||
// %1 = x86vector.avx.bcst_to_f32.packed
|
||||
// %2 = x86vector.avx.cvt.packed.odd.indexed_to_f32
|
||||
// %3 = vector.fma %1, %2, %arg1
|
||||
// %7 = x86vector.avx.bcst_to_f32.packed
|
||||
// %8 = x86vector.avx.cvt.packed.odd.indexed_to_f32
|
||||
// %9 = vector.fma %7, %8, %arg2
|
||||
// %4 = x86vector.avx.bcst_to_f32.packed
|
||||
// %5 = x86vector.avx.cvt.packed.even.indexed_to_f32
|
||||
// %6 = vector.fma %4, %5, %3
|
||||
// %10 = x86vector.avx.bcst_to_f32.packed
|
||||
// %11 = x86vector.avx.cvt.packed.even.indexed_to_f32
|
||||
// %12 = vector.fma %10, %11, %9
|
||||
// yield %9, %12
|
||||
// ```
|
||||
// TODO: Shuffling supported only if the FMA, lhs/rhs defining operations
|
||||
// have only one consumer. Have to extend this pass for multiple consumers.
|
||||
struct ShuffleVectorFMAOps : public OpRewritePattern<vector::FMAOp> {
|
||||
using OpRewritePattern<vector::FMAOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(vector::FMAOp fmaOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
||||
if (!validateVectorFMAOp(fmaOp))
|
||||
return failure();
|
||||
|
||||
llvm::SmallVector<vector::FMAOp> fmaOps;
|
||||
Operation *nextOp = fmaOp;
|
||||
bool stopAtNextDependentFMA = true;
|
||||
|
||||
// Break the loop and return failure if the immediate next FMA op
|
||||
// have CvtPackedEvenIndexedToF32Op in it's lhs/rhs defining ops.
|
||||
while ((nextOp = nextOp->getNextNode())) {
|
||||
auto fma = dyn_cast<vector::FMAOp>(nextOp);
|
||||
if (!fma)
|
||||
continue;
|
||||
|
||||
bool hasX86CvtOperand = isa<x86vector::CvtPackedEvenIndexedToF32Op>(
|
||||
fma.getLhs().getDefiningOp()) ||
|
||||
isa<x86vector::CvtPackedEvenIndexedToF32Op>(
|
||||
fma.getRhs().getDefiningOp());
|
||||
|
||||
if (hasX86CvtOperand && stopAtNextDependentFMA)
|
||||
break;
|
||||
|
||||
if (validateVectorFMAOp(fma))
|
||||
fmaOps.push_back(fma);
|
||||
|
||||
stopAtNextDependentFMA = false;
|
||||
}
|
||||
|
||||
if (fmaOps.empty())
|
||||
return rewriter.notifyMatchFailure(
|
||||
fmaOp, "No eligible FMA operations were found: the operation may "
|
||||
"already be shuffled, there may be no following FMAs, or the "
|
||||
"following FMAs do not satisfy the shuffle conditions.");
|
||||
|
||||
fmaOps.push_back(fmaOp);
|
||||
for (auto fmaOp : fmaOps)
|
||||
moveFMA(rewriter, fmaOp);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void x86vector::populateShuffleVectorFMAOpsPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<ShuffleVectorFMAOps>(patterns.getContext());
|
||||
}
|
||||
312
mlir/test/Dialect/X86Vector/shuffle-vector-fmas.mlir
Normal file
312
mlir/test/Dialect/X86Vector/shuffle-vector-fmas.mlir
Normal file
@ -0,0 +1,312 @@
|
||||
// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
|
||||
|
||||
!vec = vector<8xf32>
|
||||
!memrefA = memref<1x1x1xbf16>
|
||||
!memrefB = memref<1x8x2xbf16>
|
||||
|
||||
func.func @shuffle_fma_with_rhs_as_even.index_to_f32(
|
||||
%arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
|
||||
%arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vec
|
||||
{
|
||||
%0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
|
||||
%1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
|
||||
%2 = vector.fma %0, %1, %arg6 : !vec
|
||||
%3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
|
||||
%4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
|
||||
%5 = vector.fma %3, %4, %2 : !vec
|
||||
%6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
|
||||
%7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
|
||||
%8 = vector.fma %6, %7, %arg6 : !vec
|
||||
%9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
|
||||
%10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
|
||||
%11 = vector.fma %9, %10, %8 : !vec
|
||||
%12 = vector.fma %5, %11, %arg6 : !vec
|
||||
return %12 : !vec
|
||||
}
|
||||
|
||||
// Groups FMAs with respect to even/odd indexed input operands.
|
||||
// The vector.fma at %5 is moved along with its operands after %8.
|
||||
// CHECK-LABEL: @shuffle_fma_with_rhs_as_even.index_to_f32
|
||||
// Odd-Indexed FMAs
|
||||
// CHECK: %[[BCST0:.*]] = x86vector.avx.bcst_to_f32.packed %arg0
|
||||
// CHECK: %[[ODD0:.*]] = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2
|
||||
// CHECK: %[[FMA_ODD0:.*]] = vector.fma %[[BCST0]], %[[ODD0]], %arg6
|
||||
// CHECK: %[[BCST1:.*]] = x86vector.avx.bcst_to_f32.packed %arg3
|
||||
// CHECK: %[[ODD1:.*]] = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5
|
||||
// CHECK: %[[FMA_ODD1:.*]] = vector.fma %[[BCST1]], %[[ODD1]], %arg6
|
||||
// Even-Indexed FMAs
|
||||
// CHECK: %[[BCST2:.*]] = x86vector.avx.bcst_to_f32.packed %arg4
|
||||
// CHECK: %[[EVEN0:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5
|
||||
// CHECK: %[[FMA_EVEN0:.*]] = vector.fma %[[BCST2]], %[[EVEN0]], %[[FMA_ODD1]]
|
||||
// CHECK: %[[BCST3:.*]] = x86vector.avx.bcst_to_f32.packed %arg1
|
||||
// CHECK: %[[EVEN1:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2
|
||||
// CHECK: %[[FMA_EVEN1:.*]] = vector.fma %[[BCST3]], %[[EVEN1]], %[[FMA_ODD0]]
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||
transform.apply_patterns to %0 {
|
||||
transform.apply_patterns.x86vector.shuffle_vector_fma_ops
|
||||
} : !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
!vec = vector<8xf32>
|
||||
!memrefA = memref<1x1x1xbf16>
|
||||
!memrefB = memref<1x8x2xbf16>
|
||||
|
||||
func.func @shuffle_fma_with_lhs_as_even.index_to_f32(
|
||||
%arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
|
||||
%arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vec
|
||||
{
|
||||
%0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
|
||||
%1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
|
||||
%2 = vector.fma %0, %1, %arg6 : !vec
|
||||
%3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
|
||||
%4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
|
||||
%5 = vector.fma %4, %3, %2 : !vec
|
||||
%6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
|
||||
%7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
|
||||
%8 = vector.fma %6, %7, %arg6 : !vec
|
||||
%9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
|
||||
%10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
|
||||
%11 = vector.fma %9, %10, %8 : !vec
|
||||
%12 = vector.fma %5, %11, %arg6 : !vec
|
||||
return %12 : !vec
|
||||
}
|
||||
|
||||
// The vector.fma at %5 is moved along with its operands after %8.
|
||||
// CHECK-LABEL: @shuffle_fma_with_lhs_as_even.index_to_f32
|
||||
// Odd-Indexed FMAs
|
||||
// CHECK: %[[BCST0:.*]] = x86vector.avx.bcst_to_f32.packed %arg0
|
||||
// CHECK: %[[ODD0:.*]] = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2
|
||||
// CHECK: %[[FMA_ODD0:.*]] = vector.fma %[[BCST0]], %[[ODD0]], %arg6
|
||||
// CHECK: %[[BCST1:.*]] = x86vector.avx.bcst_to_f32.packed %arg3
|
||||
// CHECK: %[[ODD1:.*]] = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5
|
||||
// CHECK: %[[FMA_ODD1:.*]] = vector.fma %[[BCST1]], %[[ODD1]], %arg6
|
||||
// Even-Indexed FMAs
|
||||
// CHECK: %[[BCST2:.*]] = x86vector.avx.bcst_to_f32.packed %arg4
|
||||
// CHECK: %[[EVEN0:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5
|
||||
// CHECK: %[[FMA_EVEN0:.*]] = vector.fma %[[BCST2]], %[[EVEN0]], %[[FMA_ODD1]]
|
||||
// CHECK: %[[EVEN1:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2
|
||||
// CHECK: %[[BCST3:.*]] = x86vector.avx.bcst_to_f32.packed %arg1
|
||||
// CHECK: %[[FMA_EVEN1:.*]] = vector.fma %[[EVEN1]], %[[BCST3]], %[[FMA_ODD0]]
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||
transform.apply_patterns to %0 {
|
||||
transform.apply_patterns.x86vector.shuffle_vector_fma_ops
|
||||
} : !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
!vec = vector<8xf32>
|
||||
!vecOut = vector<1x8xf32>
|
||||
!memrefA = memref<1x1x1xbf16>
|
||||
!memrefB = memref<1x8x2xbf16>
|
||||
|
||||
func.func @shuffle_fma_with_shape_cast(
|
||||
%arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
|
||||
%arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vecOut
|
||||
{
|
||||
%0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
|
||||
%1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
|
||||
%2 = vector.fma %0, %1, %arg6 : !vec
|
||||
%3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
|
||||
%4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
|
||||
%5 = vector.fma %3, %4, %2 : !vec
|
||||
%res1 = vector.shape_cast %5 : !vec to !vecOut
|
||||
%6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
|
||||
%7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
|
||||
%8 = vector.fma %6, %7, %arg6 : !vec
|
||||
%9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
|
||||
%10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
|
||||
%11 = vector.fma %9, %10, %8 : !vec
|
||||
%res2 = vector.shape_cast %11 : !vec to !vecOut
|
||||
%12 = arith.addf %res1, %res2 : !vecOut
|
||||
return %12 : !vecOut
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @shuffle_fma_with_shape_cast
|
||||
// Odd-Indexed FMAs
|
||||
// CHECK: %[[BCST0:.*]] = x86vector.avx.bcst_to_f32.packed %arg0
|
||||
// CHECK: %[[ODD0:.*]] = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2
|
||||
// CHECK: %[[FMA_ODD0:.*]] = vector.fma %[[BCST0]], %[[ODD0]], %arg6
|
||||
// CHECK: %[[BCST1:.*]] = x86vector.avx.bcst_to_f32.packed %arg3
|
||||
// CHECK: %[[ODD1:.*]] = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5
|
||||
// CHECK: %[[FMA_ODD1:.*]] = vector.fma %[[BCST1]], %[[ODD1]], %arg6
|
||||
// Even-Indexed FMAs
|
||||
// CHECK: %[[BCST3:.*]] = x86vector.avx.bcst_to_f32.packed %arg4
|
||||
// CHECK: %[[EVEN1:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5
|
||||
// CHECK: %[[FMA_EVEN1:.*]] = vector.fma %[[BCST3]], %[[EVEN1]], %[[FMA_ODD1]]
|
||||
// CHECK: vector.shape_cast
|
||||
// CHECK: %[[BCST2:.*]] = x86vector.avx.bcst_to_f32.packed %arg1
|
||||
// CHECK: %[[EVEN0:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2
|
||||
// CHECK: %[[FMA_EVEN0:.*]] = vector.fma %[[BCST2]], %[[EVEN0]], %[[FMA_ODD0]]
|
||||
// CHECK: vector.shape_cast
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||
transform.apply_patterns to %0 {
|
||||
transform.apply_patterns.x86vector.shuffle_vector_fma_ops
|
||||
} : !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
!vec = vector<8xf32>
|
||||
!memrefA = memref<1x1x1xbf16>
|
||||
!memrefB = memref<1x8x2xbf16>
|
||||
|
||||
func.func @negative_fma_operand_has_multiple_consumer(
|
||||
%arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB,
|
||||
%arg3: !memrefA, %arg4: !memrefB, %arg5: !vec) -> !vec
|
||||
{
|
||||
%0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
|
||||
%1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
|
||||
%2 = vector.fma %0, %1, %arg5 : !vec
|
||||
%3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
|
||||
%4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
|
||||
%5 = vector.fma %3, %4, %2 : !vec
|
||||
%7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg4 : !memrefB -> !vec
|
||||
%8 = vector.fma %3, %7, %arg5 : !vec
|
||||
%9 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
|
||||
%10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg4 : !memrefB -> !vec
|
||||
%11 = vector.fma %9, %10, %8 : !vec
|
||||
%12 = vector.fma %5, %11, %arg5 : !vec
|
||||
return %12 : !vec
|
||||
}
|
||||
|
||||
// The vector.fma at %5 uses %3 as its LHS operand, which has two consumers; therefore,
|
||||
// the rewrite is not applied.
|
||||
// CHECK-LABEL: @negative_fma_operand_has_multiple_consumer
|
||||
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
|
||||
// CHECK: vector.fma
|
||||
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
|
||||
// CHECK: vector.fma
|
||||
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
|
||||
// CHECK: vector.fma
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||
transform.apply_patterns to %0 {
|
||||
transform.apply_patterns.x86vector.shuffle_vector_fma_ops
|
||||
} : !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
!vec = vector<8xf32>
|
||||
!memrefA = memref<1x1x1xbf16>
|
||||
!memrefB = memref<1x8x2xbf16>
|
||||
|
||||
func.func @negative_fma_has_multiple_consumer(
|
||||
%arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
|
||||
%arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vec
|
||||
{
|
||||
%0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
|
||||
%1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
|
||||
%2 = vector.fma %0, %1, %arg6 : !vec
|
||||
%3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
|
||||
%4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
|
||||
%5 = vector.fma %3, %4, %2 : !vec
|
||||
%6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
|
||||
%7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
|
||||
%8 = vector.fma %6, %7, %5 : !vec
|
||||
%9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
|
||||
%10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
|
||||
%11 = vector.fma %9, %10, %8 : !vec
|
||||
%12 = vector.fma %5, %11, %arg6 : !vec
|
||||
return %12 : !vec
|
||||
}
|
||||
|
||||
// vector.fma at %5 has two uses; therefore no re-write applied.
|
||||
// CHECK-LABEL: @negative_fma_has_multiple_consumer
|
||||
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
|
||||
// CHECK: vector.fma
|
||||
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
|
||||
// CHECK: vector.fma
|
||||
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||
transform.apply_patterns to %0 {
|
||||
transform.apply_patterns.x86vector.shuffle_vector_fma_ops
|
||||
} : !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
!vec = vector<8xf32>
|
||||
!memrefA = memref<1x1x1xbf16>
|
||||
!memrefB = memref<1x8x2xbf16>
|
||||
|
||||
func.func @negative_no_shuffle_outside_block(
|
||||
%arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA,
|
||||
%arg4: !memrefA, %arg5: !memrefB, %arg6: !vec, %arg7: i1) -> !vec
|
||||
{
|
||||
%0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec
|
||||
%1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec
|
||||
%2 = vector.fma %0, %1, %arg6 : !vec
|
||||
%3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec
|
||||
%4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec
|
||||
%5 = vector.fma %3, %4, %2 : !vec
|
||||
|
||||
%loop = scf.if %arg7 -> (vector<8xf32>) {
|
||||
%6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
|
||||
%7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
|
||||
%8 = vector.fma %6, %7, %arg6 : !vec
|
||||
%9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
|
||||
%10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
|
||||
%11 = vector.fma %9, %10, %8 : !vec
|
||||
%12 = vector.fma %5, %11, %arg6 : !vec
|
||||
scf.yield %12 : vector<8xf32>
|
||||
} else {
|
||||
%6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec
|
||||
%7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec
|
||||
%8 = vector.fma %6, %7, %arg6 : !vec
|
||||
%9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec
|
||||
%10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec
|
||||
%11 = vector.fma %9, %10, %8 : !vec
|
||||
%12 = vector.fma %5, %11, %arg6 : !vec
|
||||
scf.yield %12 : vector<8xf32>
|
||||
}
|
||||
|
||||
return %loop : !vec
|
||||
}
|
||||
|
||||
// vector.fma at %5 has its consumer in an another block (%12); therefore rewrite is not
|
||||
// applied.
|
||||
// CHECK-LABEL: @negative_no_shuffle_outside_block
|
||||
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
|
||||
// CHECK: vector.fma
|
||||
// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32
|
||||
// CHECK: vector.fma
|
||||
// CHECK: scf.if
|
||||
// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32
|
||||
// CHECK: vector.fma
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||
transform.apply_patterns to %0 {
|
||||
transform.apply_patterns.x86vector.shuffle_vector_fma_ops
|
||||
} : !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user