diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td index 3c73eadf8216..891829fca017 100644 --- a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td @@ -60,6 +60,17 @@ def ApplySinkVectorProducerOpsPatternsOp : Op]> { + let description = [{ + Collect patterns to shuffle FMAs with x86vector operations as operands + such that FMAs are grouped with respect to odd/even packed index. + }]; + + let assemblyFormat = "attr-dict"; +} + #endif // X86VECTOR_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h index c25cdaf2d942..aadca9270890 100644 --- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h +++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h @@ -100,6 +100,10 @@ void populateVectorContractBF16ToFMAPatterns(RewritePatternSet &patterns); // range by placing them at their earliest legal use site. void populateSinkVectorProducerOpsPatterns(RewritePatternSet &patterns); +// Shuffles FMAs with x86vector operations as operands such that FMAs are +// grouped with respect to odd/even packed index. +void populateShuffleVectorFMAOpsPatterns(RewritePatternSet &patterns); + //===----------------------------------------------------------------------===// /// Helpers extracted from: /// - clang/lib/Headers/avxintrin.h diff --git a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp index e77d30c9c5ff..c6be69305da5 100644 --- a/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp +++ b/mlir/lib/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp @@ -42,6 +42,11 @@ void mlir::transform::ApplySinkVectorProducerOpsPatternsOp::populatePatterns( x86vector::populateSinkVectorProducerOpsPatterns(patterns); } +void mlir::transform::ApplyShuffleVectorFMAOpsPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + x86vector::populateShuffleVectorFMAOpsPatterns(patterns); +} + //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt index bbd9be880eb0..01d2ec4810e2 100644 --- a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt @@ -5,6 +5,7 @@ add_mlir_dialect_library(MLIRX86VectorTransforms VectorContractToPackedTypeDotProduct.cpp VectorContractBF16ToFMA.cpp SinkVectorProducerOps.cpp + ShuffleVectorFMAOps.cpp LINK_LIBS PUBLIC MLIRArithDialect diff --git a/mlir/lib/Dialect/X86Vector/Transforms/ShuffleVectorFMAOps.cpp b/mlir/lib/Dialect/X86Vector/Transforms/ShuffleVectorFMAOps.cpp new file mode 100644 index 000000000000..a66546a5d1e4 --- /dev/null +++ b/mlir/lib/Dialect/X86Vector/Transforms/ShuffleVectorFMAOps.cpp @@ -0,0 +1,186 @@ +//===- ShuffleVectorFMAOps.cpp --------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/X86Vector/Transforms.h" +#include "mlir/Dialect/X86Vector/X86VectorDialect.h" + +#include "mlir/IR/PatternMatch.h" + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::vector; +using namespace mlir::x86vector; + +namespace { + +// Validates whether the given operation is an x86vector operation and has only +// one consumer. +static bool validateFMAOperands(Value op) { + if (auto cvt = op.getDefiningOp()) + return cvt.getResult().hasOneUse(); + + if (auto bcst = op.getDefiningOp()) + return bcst.getResult().hasOneUse(); + + return false; +} + +// Validates the vector.fma operation on the following conditions: +// (i) one of the lhs or rhs defining operation should be +// CvtPackedEvenIndexedToF32Op, (ii) the lhs or rhs defining operation should be +// an x86vector operation and has only one consumer, (iii) all operations +// are in the same block, and (iv) ths FMA has only one user. +static bool validateVectorFMAOp(vector::FMAOp fmaOp) { + Value lhs = fmaOp.getLhs(); + Value rhs = fmaOp.getRhs(); + + if (!isa(lhs.getDefiningOp()) && + !isa(rhs.getDefiningOp())) + return false; + + if (!validateFMAOperands(lhs) || !validateFMAOperands(rhs)) + return false; + + if (lhs.getDefiningOp()->getBlock() != rhs.getDefiningOp()->getBlock()) + return false; + + if (lhs.getDefiningOp()->getBlock() != fmaOp->getBlock()) + return false; + + if (!fmaOp.getResult().hasOneUse()) + return false; + + Operation *consumer = *fmaOp.getResult().getUsers().begin(); + if (consumer->getBlock() != fmaOp->getBlock()) + return false; + + return true; +} + +// Moves vector.fma along with the lhs and rhs defining operation before its +// consumer. If the consumer is vector.ShapeCastOp and has only one user then +// move before the consumer of vector.ShapeCastOp. +// TODO: Move before first consumer, if there are multiple. +static void moveFMA(PatternRewriter &rewriter, vector::FMAOp fmaOp) { + Operation *consumer = *fmaOp.getResult().getUsers().begin(); + + if (auto shapeCastOp = dyn_cast(consumer)) { + if (shapeCastOp.getResult().hasOneUse()) { + Operation *nxtConsumer = *shapeCastOp.getResult().getUsers().begin(); + if (nxtConsumer->getBlock() == fmaOp->getBlock()) { + consumer = *shapeCastOp.getResult().getUsers().begin(); + rewriter.moveOpBefore(fmaOp.getLhs().getDefiningOp(), consumer); + rewriter.moveOpBefore(fmaOp.getRhs().getDefiningOp(), consumer); + rewriter.moveOpBefore(fmaOp.getOperation(), consumer); + rewriter.moveOpBefore(shapeCastOp.getOperation(), consumer); + return; + } + } + } + + rewriter.moveOpBefore(fmaOp.getLhs().getDefiningOp(), consumer); + rewriter.moveOpBefore(fmaOp.getRhs().getDefiningOp(), consumer); + rewriter.moveOpBefore(fmaOp.getOperation(), consumer); + + return; +} + +// Shuffle FMAs with x86vector operations as operands such that +// FMAs are grouped with respect to odd/even packed index. +// +// For example: +// ``` +// %1 = x86vector.avx.bcst_to_f32.packed +// %2 = x86vector.avx.cvt.packed.odd.indexed_to_f32 +// %3 = vector.fma %1, %2, %arg1 +// %4 = x86vector.avx.bcst_to_f32.packed +// %5 = x86vector.avx.cvt.packed.even.indexed_to_f32 +// %6 = vector.fma %4, %5, %3 +// %7 = x86vector.avx.bcst_to_f32.packed +// %8 = x86vector.avx.cvt.packed.odd.indexed_to_f32 +// %9 = vector.fma %7, %8, %arg2 +// %10 = x86vector.avx.bcst_to_f32.packed +// %11 = x86vector.avx.cvt.packed.even.indexed_to_f32 +// %12 = vector.fma %10, %11, %9 +// yield %6, %12 +// ``` +// to +// ``` +// %1 = x86vector.avx.bcst_to_f32.packed +// %2 = x86vector.avx.cvt.packed.odd.indexed_to_f32 +// %3 = vector.fma %1, %2, %arg1 +// %7 = x86vector.avx.bcst_to_f32.packed +// %8 = x86vector.avx.cvt.packed.odd.indexed_to_f32 +// %9 = vector.fma %7, %8, %arg2 +// %4 = x86vector.avx.bcst_to_f32.packed +// %5 = x86vector.avx.cvt.packed.even.indexed_to_f32 +// %6 = vector.fma %4, %5, %3 +// %10 = x86vector.avx.bcst_to_f32.packed +// %11 = x86vector.avx.cvt.packed.even.indexed_to_f32 +// %12 = vector.fma %10, %11, %9 +// yield %9, %12 +// ``` +// TODO: Shuffling supported only if the FMA, lhs/rhs defining operations +// have only one consumer. Have to extend this pass for multiple consumers. +struct ShuffleVectorFMAOps : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::FMAOp fmaOp, + PatternRewriter &rewriter) const override { + + if (!validateVectorFMAOp(fmaOp)) + return failure(); + + llvm::SmallVector fmaOps; + Operation *nextOp = fmaOp; + bool stopAtNextDependentFMA = true; + + // Break the loop and return failure if the immediate next FMA op + // have CvtPackedEvenIndexedToF32Op in it's lhs/rhs defining ops. + while ((nextOp = nextOp->getNextNode())) { + auto fma = dyn_cast(nextOp); + if (!fma) + continue; + + bool hasX86CvtOperand = isa( + fma.getLhs().getDefiningOp()) || + isa( + fma.getRhs().getDefiningOp()); + + if (hasX86CvtOperand && stopAtNextDependentFMA) + break; + + if (validateVectorFMAOp(fma)) + fmaOps.push_back(fma); + + stopAtNextDependentFMA = false; + } + + if (fmaOps.empty()) + return rewriter.notifyMatchFailure( + fmaOp, "No eligible FMA operations were found: the operation may " + "already be shuffled, there may be no following FMAs, or the " + "following FMAs do not satisfy the shuffle conditions."); + + fmaOps.push_back(fmaOp); + for (auto fmaOp : fmaOps) + moveFMA(rewriter, fmaOp); + + return success(); + } +}; + +} // namespace + +void x86vector::populateShuffleVectorFMAOpsPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} diff --git a/mlir/test/Dialect/X86Vector/shuffle-vector-fmas.mlir b/mlir/test/Dialect/X86Vector/shuffle-vector-fmas.mlir new file mode 100644 index 000000000000..4bf930b51c0c --- /dev/null +++ b/mlir/test/Dialect/X86Vector/shuffle-vector-fmas.mlir @@ -0,0 +1,312 @@ +// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s + +!vec = vector<8xf32> +!memrefA = memref<1x1x1xbf16> +!memrefB = memref<1x8x2xbf16> + +func.func @shuffle_fma_with_rhs_as_even.index_to_f32( + %arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA, + %arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vec +{ + %0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec + %1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec + %2 = vector.fma %0, %1, %arg6 : !vec + %3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec + %4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec + %5 = vector.fma %3, %4, %2 : !vec + %6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec + %7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec + %8 = vector.fma %6, %7, %arg6 : !vec + %9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec + %10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec + %11 = vector.fma %9, %10, %8 : !vec + %12 = vector.fma %5, %11, %arg6 : !vec + return %12 : !vec +} + +// Groups FMAs with respect to even/odd indexed input operands. +// The vector.fma at %5 is moved along with its operands after %8. +// CHECK-LABEL: @shuffle_fma_with_rhs_as_even.index_to_f32 +// Odd-Indexed FMAs +// CHECK: %[[BCST0:.*]] = x86vector.avx.bcst_to_f32.packed %arg0 +// CHECK: %[[ODD0:.*]] = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 +// CHECK: %[[FMA_ODD0:.*]] = vector.fma %[[BCST0]], %[[ODD0]], %arg6 +// CHECK: %[[BCST1:.*]] = x86vector.avx.bcst_to_f32.packed %arg3 +// CHECK: %[[ODD1:.*]] = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 +// CHECK: %[[FMA_ODD1:.*]] = vector.fma %[[BCST1]], %[[ODD1]], %arg6 +// Even-Indexed FMAs +// CHECK: %[[BCST2:.*]] = x86vector.avx.bcst_to_f32.packed %arg4 +// CHECK: %[[EVEN0:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 +// CHECK: %[[FMA_EVEN0:.*]] = vector.fma %[[BCST2]], %[[EVEN0]], %[[FMA_ODD1]] +// CHECK: %[[BCST3:.*]] = x86vector.avx.bcst_to_f32.packed %arg1 +// CHECK: %[[EVEN1:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 +// CHECK: %[[FMA_EVEN1:.*]] = vector.fma %[[BCST3]], %[[EVEN1]], %[[FMA_ODD0]] + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %0 { + transform.apply_patterns.x86vector.shuffle_vector_fma_ops + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vec = vector<8xf32> +!memrefA = memref<1x1x1xbf16> +!memrefB = memref<1x8x2xbf16> + +func.func @shuffle_fma_with_lhs_as_even.index_to_f32( + %arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA, + %arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vec +{ + %0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec + %1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec + %2 = vector.fma %0, %1, %arg6 : !vec + %3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec + %4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec + %5 = vector.fma %4, %3, %2 : !vec + %6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec + %7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec + %8 = vector.fma %6, %7, %arg6 : !vec + %9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec + %10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec + %11 = vector.fma %9, %10, %8 : !vec + %12 = vector.fma %5, %11, %arg6 : !vec + return %12 : !vec +} + +// The vector.fma at %5 is moved along with its operands after %8. +// CHECK-LABEL: @shuffle_fma_with_lhs_as_even.index_to_f32 +// Odd-Indexed FMAs +// CHECK: %[[BCST0:.*]] = x86vector.avx.bcst_to_f32.packed %arg0 +// CHECK: %[[ODD0:.*]] = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 +// CHECK: %[[FMA_ODD0:.*]] = vector.fma %[[BCST0]], %[[ODD0]], %arg6 +// CHECK: %[[BCST1:.*]] = x86vector.avx.bcst_to_f32.packed %arg3 +// CHECK: %[[ODD1:.*]] = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 +// CHECK: %[[FMA_ODD1:.*]] = vector.fma %[[BCST1]], %[[ODD1]], %arg6 +// Even-Indexed FMAs +// CHECK: %[[BCST2:.*]] = x86vector.avx.bcst_to_f32.packed %arg4 +// CHECK: %[[EVEN0:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 +// CHECK: %[[FMA_EVEN0:.*]] = vector.fma %[[BCST2]], %[[EVEN0]], %[[FMA_ODD1]] +// CHECK: %[[EVEN1:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 +// CHECK: %[[BCST3:.*]] = x86vector.avx.bcst_to_f32.packed %arg1 +// CHECK: %[[FMA_EVEN1:.*]] = vector.fma %[[EVEN1]], %[[BCST3]], %[[FMA_ODD0]] + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %0 { + transform.apply_patterns.x86vector.shuffle_vector_fma_ops + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vec = vector<8xf32> +!vecOut = vector<1x8xf32> +!memrefA = memref<1x1x1xbf16> +!memrefB = memref<1x8x2xbf16> + +func.func @shuffle_fma_with_shape_cast( + %arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA, + %arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vecOut +{ + %0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec + %1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec + %2 = vector.fma %0, %1, %arg6 : !vec + %3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec + %4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec + %5 = vector.fma %3, %4, %2 : !vec + %res1 = vector.shape_cast %5 : !vec to !vecOut + %6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec + %7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec + %8 = vector.fma %6, %7, %arg6 : !vec + %9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec + %10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec + %11 = vector.fma %9, %10, %8 : !vec + %res2 = vector.shape_cast %11 : !vec to !vecOut + %12 = arith.addf %res1, %res2 : !vecOut + return %12 : !vecOut +} + +// CHECK-LABEL: @shuffle_fma_with_shape_cast +// Odd-Indexed FMAs +// CHECK: %[[BCST0:.*]] = x86vector.avx.bcst_to_f32.packed %arg0 +// CHECK: %[[ODD0:.*]] = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 +// CHECK: %[[FMA_ODD0:.*]] = vector.fma %[[BCST0]], %[[ODD0]], %arg6 +// CHECK: %[[BCST1:.*]] = x86vector.avx.bcst_to_f32.packed %arg3 +// CHECK: %[[ODD1:.*]] = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 +// CHECK: %[[FMA_ODD1:.*]] = vector.fma %[[BCST1]], %[[ODD1]], %arg6 +// Even-Indexed FMAs +// CHECK: %[[BCST3:.*]] = x86vector.avx.bcst_to_f32.packed %arg4 +// CHECK: %[[EVEN1:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 +// CHECK: %[[FMA_EVEN1:.*]] = vector.fma %[[BCST3]], %[[EVEN1]], %[[FMA_ODD1]] +// CHECK: vector.shape_cast +// CHECK: %[[BCST2:.*]] = x86vector.avx.bcst_to_f32.packed %arg1 +// CHECK: %[[EVEN0:.*]] = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 +// CHECK: %[[FMA_EVEN0:.*]] = vector.fma %[[BCST2]], %[[EVEN0]], %[[FMA_ODD0]] +// CHECK: vector.shape_cast + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %0 { + transform.apply_patterns.x86vector.shuffle_vector_fma_ops + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vec = vector<8xf32> +!memrefA = memref<1x1x1xbf16> +!memrefB = memref<1x8x2xbf16> + +func.func @negative_fma_operand_has_multiple_consumer( + %arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, + %arg3: !memrefA, %arg4: !memrefB, %arg5: !vec) -> !vec +{ + %0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec + %1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec + %2 = vector.fma %0, %1, %arg5 : !vec + %3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec + %4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec + %5 = vector.fma %3, %4, %2 : !vec + %7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg4 : !memrefB -> !vec + %8 = vector.fma %3, %7, %arg5 : !vec + %9 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec + %10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg4 : !memrefB -> !vec + %11 = vector.fma %9, %10, %8 : !vec + %12 = vector.fma %5, %11, %arg5 : !vec + return %12 : !vec +} + +// The vector.fma at %5 uses %3 as its LHS operand, which has two consumers; therefore, +// the rewrite is not applied. +// CHECK-LABEL: @negative_fma_operand_has_multiple_consumer +// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 +// CHECK: vector.fma +// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 +// CHECK: vector.fma +// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 +// CHECK: vector.fma + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %0 { + transform.apply_patterns.x86vector.shuffle_vector_fma_ops + } : !transform.any_op + transform.yield + } +} + +// ----- + +!vec = vector<8xf32> +!memrefA = memref<1x1x1xbf16> +!memrefB = memref<1x8x2xbf16> + +func.func @negative_fma_has_multiple_consumer( + %arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA, + %arg4: !memrefA, %arg5: !memrefB, %arg6: !vec) -> !vec +{ + %0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec + %1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec + %2 = vector.fma %0, %1, %arg6 : !vec + %3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec + %4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec + %5 = vector.fma %3, %4, %2 : !vec + %6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec + %7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec + %8 = vector.fma %6, %7, %5 : !vec + %9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec + %10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec + %11 = vector.fma %9, %10, %8 : !vec + %12 = vector.fma %5, %11, %arg6 : !vec + return %12 : !vec +} + +// vector.fma at %5 has two uses; therefore no re-write applied. +// CHECK-LABEL: @negative_fma_has_multiple_consumer +// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 +// CHECK: vector.fma +// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 +// CHECK: vector.fma +// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %0 { + transform.apply_patterns.x86vector.shuffle_vector_fma_ops + } : !transform.any_op + transform.yield + } +} + +// ----- +!vec = vector<8xf32> +!memrefA = memref<1x1x1xbf16> +!memrefB = memref<1x8x2xbf16> + +func.func @negative_no_shuffle_outside_block( + %arg0: !memrefA, %arg1: !memrefA, %arg2: !memrefB, %arg3: !memrefA, + %arg4: !memrefA, %arg5: !memrefB, %arg6: !vec, %arg7: i1) -> !vec +{ + %0 = x86vector.avx.bcst_to_f32.packed %arg0 : !memrefA -> !vec + %1 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg2 : !memrefB -> !vec + %2 = vector.fma %0, %1, %arg6 : !vec + %3 = x86vector.avx.bcst_to_f32.packed %arg1 : !memrefA -> !vec + %4 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg2 : !memrefB -> !vec + %5 = vector.fma %3, %4, %2 : !vec + + %loop = scf.if %arg7 -> (vector<8xf32>) { + %6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec + %7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec + %8 = vector.fma %6, %7, %arg6 : !vec + %9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec + %10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec + %11 = vector.fma %9, %10, %8 : !vec + %12 = vector.fma %5, %11, %arg6 : !vec + scf.yield %12 : vector<8xf32> + } else { + %6 = x86vector.avx.bcst_to_f32.packed %arg3 : !memrefA -> !vec + %7 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %arg5 : !memrefB -> !vec + %8 = vector.fma %6, %7, %arg6 : !vec + %9 = x86vector.avx.bcst_to_f32.packed %arg4 : !memrefA -> !vec + %10 = x86vector.avx.cvt.packed.even.indexed_to_f32 %arg5 : !memrefB -> !vec + %11 = vector.fma %9, %10, %8 : !vec + %12 = vector.fma %5, %11, %arg6 : !vec + scf.yield %12 : vector<8xf32> + } + + return %loop : !vec +} + +// vector.fma at %5 has its consumer in an another block (%12); therefore rewrite is not +// applied. +// CHECK-LABEL: @negative_no_shuffle_outside_block +// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 +// CHECK: vector.fma +// CHECK: x86vector.avx.cvt.packed.even.indexed_to_f32 +// CHECK: vector.fma +// CHECK: scf.if +// CHECK: x86vector.avx.cvt.packed.odd.indexed_to_f32 +// CHECK: vector.fma + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %0 { + transform.apply_patterns.x86vector.shuffle_vector_fma_ops + } : !transform.any_op + transform.yield + } +}