Arun Thangamani 8390909842
[mlir][x86vector] Lower BF16 vector.contract to FMA using AVX2 BF16 packed ops. (#170267)
A `transform` pass to lower `BF16` type `vector.contract` to
`vector.fma` using `AVX2` BF16 packed operations:

- `vbcstnebf162ps` - Broadcasts BF16 into packed F32.
- `vcvtneebf162ps` - Convert packed BF16 even-indexed elements into
packed F32.
- `vcvtneobf162ps` - Convert packed BF16 odd-indexed elements into
packed F32 Data.
2025-12-17 14:41:58 +01:00

109 lines
4.1 KiB
C++

//===- X86VectorUtils.cpp - MLIR Utilities for X86VectorOps -------------===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/X86Vector/Utils/X86VectorUtils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"
namespace mlir {
namespace x86vector {
static FailureOr<SmallVector<mlir::utils::IteratorType>>
inferIteratorsFromOutMap(AffineMap map) {
if (!map.isProjectedPermutation())
return failure();
SmallVector<mlir::utils::IteratorType> iterators(
map.getNumDims(), mlir::utils::IteratorType::reduction);
for (auto expr : map.getResults())
if (auto dim = dyn_cast<AffineDimExpr>(expr))
iterators[dim.getPosition()] = mlir::utils::IteratorType::parallel;
return iterators;
}
// Returns true if the operation is in VNNI layout.
// Optionally, the check can be constrained to a specific VNNI blocking factor.
bool isInVnniLayout(Operation *op, ArrayRef<AffineMap> indexingMaps,
std::optional<unsigned> blockingFactor) {
// Narrow down type operations - VNNI only applies to contractions.
FailureOr<linalg::ContractionDimensions> dims =
linalg::inferContractionDims(indexingMaps);
if (failed(dims))
return false;
auto matA = op->getOperand(0);
auto matB = op->getOperand(1);
auto typeA = dyn_cast<ShapedType>(matA.getType());
auto typeB = dyn_cast<ShapedType>(matB.getType());
unsigned rankA = typeA.getRank();
unsigned rankB = typeB.getRank();
// VNNI format requires at least 1 parallel and 2 reduction dimensions.
if (rankA < 3 || rankB < 3)
return false;
// At least two reduction dimensions are expected:
// one for the VNNI factor and one for the K dimension
if (dims->k.size() < 2)
return false;
// Validate affine maps - VNNI computation should be defined by the two
// innermost reduction iterators.
// The input matrix dimensions layout must match the following:
// - matrix A - [...][K/vnniFactor][vnniFactor]
// - matrix B - [...][K/vnniFactor][N][vnniFactor]
auto maybeIters = inferIteratorsFromOutMap(indexingMaps[2] /* outs */);
if (failed(maybeIters))
return false;
SmallVector<mlir::utils::IteratorType> iteratorTypes = *maybeIters;
AffineMap mapA = indexingMaps[0];
AffineMap mapB = indexingMaps[1];
auto vnniDimA = dyn_cast<AffineDimExpr>(mapA.getResult(rankA - 1));
auto vnniDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 1));
if (!vnniDimA || !vnniDimB || vnniDimA != vnniDimB ||
iteratorTypes[vnniDimA.getPosition()] !=
mlir::utils::IteratorType::reduction)
return false;
auto redDimA = dyn_cast<AffineDimExpr>(mapA.getResult(rankA - 2));
auto redDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 3));
if (!redDimA || !redDimB || redDimA != redDimB ||
iteratorTypes[redDimA.getPosition()] !=
mlir::utils::IteratorType::reduction)
return false;
auto parallelDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 2));
if (!parallelDimB || iteratorTypes[parallelDimB.getPosition()] !=
mlir::utils::IteratorType::parallel)
return false;
// VNNI factor must be:
// - the innermost inputs' dimension
// - statically known
// - multiple of 2 or equal to the specified factor
auto vnniDimSize = typeB.getShape().back();
if (vnniDimSize == ShapedType::kDynamic || vnniDimSize == 0 ||
vnniDimSize % 2 != 0)
return false;
if (typeA.getShape().back() != vnniDimSize)
return false;
if (blockingFactor && vnniDimSize != *blockingFactor)
return false;
// The split reduction dimension size should also match.
if (typeA.getShape().end()[-2] != typeB.getShape().end()[-3])
return false;
return true;
}
} // namespace x86vector
} // namespace mlir