A `transform` pass to lower `BF16` type `vector.contract` to `vector.fma` using `AVX2` BF16 packed operations: - `vbcstnebf162ps` - Broadcasts BF16 into packed F32. - `vcvtneebf162ps` - Convert packed BF16 even-indexed elements into packed F32. - `vcvtneobf162ps` - Convert packed BF16 odd-indexed elements into packed F32 Data.
109 lines
4.1 KiB
C++
109 lines
4.1 KiB
C++
//===- X86VectorUtils.cpp - MLIR Utilities for X86VectorOps -------------===//
|
|
//
|
|
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/X86Vector/Utils/X86VectorUtils.h"
|
|
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
|
|
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/TypeUtilities.h"
|
|
#include "mlir/IR/Types.h"
|
|
|
|
namespace mlir {
|
|
namespace x86vector {
|
|
|
|
static FailureOr<SmallVector<mlir::utils::IteratorType>>
|
|
inferIteratorsFromOutMap(AffineMap map) {
|
|
if (!map.isProjectedPermutation())
|
|
return failure();
|
|
SmallVector<mlir::utils::IteratorType> iterators(
|
|
map.getNumDims(), mlir::utils::IteratorType::reduction);
|
|
for (auto expr : map.getResults())
|
|
if (auto dim = dyn_cast<AffineDimExpr>(expr))
|
|
iterators[dim.getPosition()] = mlir::utils::IteratorType::parallel;
|
|
return iterators;
|
|
}
|
|
|
|
// Returns true if the operation is in VNNI layout.
|
|
// Optionally, the check can be constrained to a specific VNNI blocking factor.
|
|
bool isInVnniLayout(Operation *op, ArrayRef<AffineMap> indexingMaps,
|
|
std::optional<unsigned> blockingFactor) {
|
|
// Narrow down type operations - VNNI only applies to contractions.
|
|
FailureOr<linalg::ContractionDimensions> dims =
|
|
linalg::inferContractionDims(indexingMaps);
|
|
if (failed(dims))
|
|
return false;
|
|
|
|
auto matA = op->getOperand(0);
|
|
auto matB = op->getOperand(1);
|
|
auto typeA = dyn_cast<ShapedType>(matA.getType());
|
|
auto typeB = dyn_cast<ShapedType>(matB.getType());
|
|
unsigned rankA = typeA.getRank();
|
|
unsigned rankB = typeB.getRank();
|
|
// VNNI format requires at least 1 parallel and 2 reduction dimensions.
|
|
if (rankA < 3 || rankB < 3)
|
|
return false;
|
|
|
|
// At least two reduction dimensions are expected:
|
|
// one for the VNNI factor and one for the K dimension
|
|
if (dims->k.size() < 2)
|
|
return false;
|
|
|
|
// Validate affine maps - VNNI computation should be defined by the two
|
|
// innermost reduction iterators.
|
|
// The input matrix dimensions layout must match the following:
|
|
// - matrix A - [...][K/vnniFactor][vnniFactor]
|
|
// - matrix B - [...][K/vnniFactor][N][vnniFactor]
|
|
auto maybeIters = inferIteratorsFromOutMap(indexingMaps[2] /* outs */);
|
|
if (failed(maybeIters))
|
|
return false;
|
|
SmallVector<mlir::utils::IteratorType> iteratorTypes = *maybeIters;
|
|
AffineMap mapA = indexingMaps[0];
|
|
AffineMap mapB = indexingMaps[1];
|
|
|
|
auto vnniDimA = dyn_cast<AffineDimExpr>(mapA.getResult(rankA - 1));
|
|
auto vnniDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 1));
|
|
if (!vnniDimA || !vnniDimB || vnniDimA != vnniDimB ||
|
|
iteratorTypes[vnniDimA.getPosition()] !=
|
|
mlir::utils::IteratorType::reduction)
|
|
return false;
|
|
auto redDimA = dyn_cast<AffineDimExpr>(mapA.getResult(rankA - 2));
|
|
auto redDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 3));
|
|
if (!redDimA || !redDimB || redDimA != redDimB ||
|
|
iteratorTypes[redDimA.getPosition()] !=
|
|
mlir::utils::IteratorType::reduction)
|
|
return false;
|
|
auto parallelDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 2));
|
|
if (!parallelDimB || iteratorTypes[parallelDimB.getPosition()] !=
|
|
mlir::utils::IteratorType::parallel)
|
|
return false;
|
|
|
|
// VNNI factor must be:
|
|
// - the innermost inputs' dimension
|
|
// - statically known
|
|
// - multiple of 2 or equal to the specified factor
|
|
auto vnniDimSize = typeB.getShape().back();
|
|
if (vnniDimSize == ShapedType::kDynamic || vnniDimSize == 0 ||
|
|
vnniDimSize % 2 != 0)
|
|
return false;
|
|
if (typeA.getShape().back() != vnniDimSize)
|
|
return false;
|
|
if (blockingFactor && vnniDimSize != *blockingFactor)
|
|
return false;
|
|
|
|
// The split reduction dimension size should also match.
|
|
if (typeA.getShape().end()[-2] != typeB.getShape().end()[-3])
|
|
return false;
|
|
|
|
return true;
|
|
}
|
|
|
|
} // namespace x86vector
|
|
} // namespace mlir
|