Renato Golin d15280894b
[MLIR][Linalg] Remove matmul_transpose variants (#147961)
Removes the `(batch_)matmul_transpose_{a|b}` variants from OpDSL and
replace it with `matmul affine_maps [...]` whenever appropriate. This is
in line with the
[plan](https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863),
and can be done since #104783 merged.

See:
https://discourse.llvm.org/t/deprecate-batch-matmul-transpose-a-b-linalg-operations/87245

Issues investigated:
* pad transform tests that could use `matmul` instead, so change to
that.
* ArmSME test using transpose actually needed it, so changed to `matmul`
+ affine maps.

Arm tests validated by @banach-space (thanks!!).
2025-08-08 22:20:27 +01:00

326 lines
12 KiB
C++

//===- BlockPackMatmul.cpp - Linalg matmul block packing ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SmallVector.h"
#include <optional>
namespace mlir {
#define GEN_PASS_DEF_LINALGBLOCKPACKMATMUL
#include "mlir/Dialect/Linalg/Passes.h.inc"
} // namespace mlir
using namespace mlir;
using namespace mlir::linalg;
/// Return constant range span or nullopt, otherwise.
static std::optional<int64_t> getConstantRange(const Range &range) {
std::optional<int64_t> stride = getConstantIntValue(range.stride);
if (!stride || *stride != 1)
return std::nullopt;
std::optional<int64_t> offset = getConstantIntValue(range.offset);
if (!offset)
return std::nullopt;
std::optional<int64_t> size = getConstantIntValue(range.size);
if (!size)
return std::nullopt;
return (*size - *offset);
}
/// Return true if all dimensions are fully divisible by the respective tiles.
static bool validateFullTilesOnDims(linalg::LinalgOp linalgOp,
ArrayRef<OpFoldResult> tiles,
ArrayRef<int64_t> dims) {
if (dims.size() != tiles.size() || tiles.empty())
return false;
FailureOr<ContractionDimensions> contractDims =
inferContractionDims(linalgOp);
if (failed(contractDims))
return false;
unsigned batchDimsOffset = contractDims->batch.size();
// Skip the batch dimension if present.
// Offset all dimensions accordingly.
SmallVector<int64_t, 3> offsetDims(dims);
for (size_t i = 0; i < offsetDims.size(); i++)
offsetDims[i] += batchDimsOffset;
auto tileOp = cast<TilingInterface>(linalgOp.getOperation());
OpBuilder builder(tileOp);
OpBuilder::InsertionGuard guard(builder);
SmallVector<Range> iterationDomain = tileOp.getIterationDomain(builder);
for (auto dim : llvm::enumerate(offsetDims)) {
if (dim.value() >= static_cast<int64_t>(iterationDomain.size()))
return false;
std::optional<int64_t> tileSize = getConstantIntValue(tiles[dim.index()]);
std::optional<int64_t> rangeOnDim =
getConstantRange(iterationDomain[dim.value()]);
// If the tile factor or the range are non-constant, the tile size is
// considered to be invalid.
if (!tileSize || !rangeOnDim)
return false;
// The dimension must be fully divisible by the tile.
if (*rangeOnDim % *tileSize != 0)
return false;
}
return true;
}
/// Return failure or packed matmul with one of its operands transposed.
static FailureOr<PackTransposeResult>
transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
linalg::PackOp packOp, AffineMap operandMap,
ArrayRef<unsigned> blocksStartDimPos,
bool transposeOuterBlocks, bool transposeInnerBlocks) {
assert(operandMap.getNumDims() >= 4 &&
"expected at least 4D prepacked matmul");
assert(blocksStartDimPos.size() >= 2 &&
"expected starting outer and inner block positions");
// Bias toward innermost dimensions.
unsigned outerBlockPos = operandMap.getNumResults() - 4;
unsigned innerBlockPos = operandMap.getNumResults() - 2;
// Transpose control options define the desired block and element layout.
// Block transposition (outer dimensions) or element transposition (inner
// dimensions) may not be necessary depending on the original matmul data
// layout.
bool isOuterTransposed =
operandMap.getDimPosition(outerBlockPos) != blocksStartDimPos.end()[-2];
bool isInnerTransposed =
operandMap.getDimPosition(innerBlockPos) != blocksStartDimPos.back();
// Transpose only the dimensions that need that to conform to the provided
// transpotion settings.
SmallVector<int64_t> innerPerm = {0, 1};
if (isInnerTransposed != transposeInnerBlocks)
innerPerm = {1, 0};
SmallVector<int64_t> outerPerm = {0, 1};
if (isOuterTransposed != transposeOuterBlocks)
outerPerm = {1, 0};
// Leave the outer dimensions, like batch, unchanged by offsetting all
// outer dimensions permutations.
SmallVector<int64_t> offsetPerms;
for (auto i : llvm::seq(0u, outerBlockPos))
offsetPerms.push_back(i);
for (auto perm : outerPerm)
offsetPerms.push_back(perm + outerBlockPos);
outerPerm = offsetPerms;
FailureOr<PackTransposeResult> packTransposedMatmul =
packTranspose(rewriter, packOp, linalgOp,
/*maybeUnPackOp=*/nullptr, outerPerm, innerPerm);
return packTransposedMatmul;
}
/// Pack a matmul operation into blocked 4D layout.
FailureOr<PackResult>
linalg::blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
const ControlBlockPackMatmulFn &controlPackMatmul) {
// Check to not let go the batch_matmul with extended semantic, through this
// transform.
if (auto *batchMatmulOp = dyn_cast<linalg::BatchMatmulOp>(&linalgOp)) {
if (batchMatmulOp->hasUserDefinedMaps()) {
return rewriter.notifyMatchFailure(
*batchMatmulOp,
"only batch_matmul ops with non-extended semantics are supported");
}
}
if (linalgOp.hasPureBufferSemantics())
return rewriter.notifyMatchFailure(linalgOp, "require tensor semantics");
std::optional<BlockPackMatmulOptions> options = controlPackMatmul(linalgOp);
if (!options)
return rewriter.notifyMatchFailure(linalgOp, "invalid packing options");
if (options->blockFactors.size() != 3)
return rewriter.notifyMatchFailure(linalgOp, "require 3 tile factors");
SmallVector<OpFoldResult> mnkTiles =
getAsOpFoldResult(rewriter.getI64ArrayAttr(options->blockFactors));
// If padding is disabled, make sure that dimensions can be packed cleanly.
if (!options->allowPadding &&
!validateFullTilesOnDims(linalgOp, mnkTiles, options->mnkOrder)) {
return rewriter.notifyMatchFailure(linalgOp,
"expect packing full tiles only");
}
OpBuilder::InsertionGuard guard(rewriter);
// The op is replaced, we need to set the insertion point after it.
rewriter.setInsertionPointAfter(linalgOp);
// Pack the matmul operation into blocked layout with two levels of
// subdivision:
// - major 2D blocks - outer dimensions, consist of minor blocks
// - minor 2D blocks - inner dimensions, consist of scalar elements
FailureOr<PackResult> packedMatmul = packMatmulGreedily(
rewriter, linalgOp, mnkTiles, options->mnkPaddedSizesNextMultipleOf,
options->mnkOrder);
if (failed(packedMatmul))
return failure();
assert(packedMatmul->packOps.size() == 3 &&
"invalid number of pack ops after matmul packing");
assert(packedMatmul->unPackOps.size() == 1 &&
"invalid number of unpack ops after matmul packing");
FailureOr<ContractionDimensions> contractDims =
inferContractionDims(packedMatmul->packedLinalgOp);
if (failed(contractDims))
return failure();
auto genericOp =
dyn_cast<linalg::GenericOp>(packedMatmul->packedLinalgOp.getOperation());
SmallVector<AffineMap> maps = genericOp.getIndexingMapsArray();
// Transpose LHS matrix according to the options.
FailureOr<PackTransposeResult> packedLhs = transposePackedMatmul(
rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[0], maps[0],
contractDims->m, options->lhsTransposeOuterBlocks,
options->lhsTransposeInnerBlocks);
if (failed(packedLhs))
return failure();
// Update results.
packedMatmul->packOps[0] = packedLhs->transposedPackOp;
packedMatmul->packedLinalgOp = packedLhs->transposedLinalgOp;
// Transpose RHS matrix according to the options.
FailureOr<PackTransposeResult> packedRhs = transposePackedMatmul(
rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[1], maps[1],
contractDims->k, options->rhsTransposeOuterBlocks,
options->rhsTransposeInnerBlocks);
if (failed(packedRhs))
return failure();
// Update results.
packedMatmul->packOps[1] = packedRhs->transposedPackOp;
packedMatmul->packedLinalgOp = packedRhs->transposedLinalgOp;
return packedMatmul;
}
namespace {
template <typename OpTy>
struct BlockPackMatmul : public OpRewritePattern<OpTy> {
BlockPackMatmul(MLIRContext *context, ControlBlockPackMatmulFn fun,
PatternBenefit benefit = 1)
: OpRewritePattern<OpTy>(context, benefit), controlFn(std::move(fun)) {}
LogicalResult matchAndRewrite(OpTy linalgOp,
PatternRewriter &rewriter) const override {
FailureOr<PackResult> packedMatmul =
blockPackMatmul(rewriter, linalgOp, controlFn);
if (failed(packedMatmul))
return failure();
return success();
}
private:
ControlBlockPackMatmulFn controlFn;
};
template <>
struct BlockPackMatmul<linalg::GenericOp>
: public OpRewritePattern<linalg::GenericOp> {
BlockPackMatmul(MLIRContext *context, ControlBlockPackMatmulFn fun,
PatternBenefit benefit = 1)
: OpRewritePattern<linalg::GenericOp>(context, benefit),
controlFn(std::move(fun)) {}
LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
PatternRewriter &rewriter) const override {
// Match suitable generics.
if (!linalg::isaContractionOpInterface(linalgOp)) {
return rewriter.notifyMatchFailure(linalgOp, "not a contraction");
}
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
auto infer = [&](MapList m) {
return AffineMap::inferFromExprList(m, linalgOp.getContext());
};
AffineExpr i, j, k;
bindDims(linalgOp->getContext(), i, j, k);
SmallVector<AffineMap> maps = linalgOp.getIndexingMapsArray();
// For now, only match simple matmuls.
if (!(maps == infer({{i, k}, {k, j}, {i, j}}) ||
maps == infer({{k, i}, {k, j}, {i, j}}) ||
maps == infer({{i, k}, {j, k}, {i, j}}))) {
return rewriter.notifyMatchFailure(linalgOp, "not a suitable matmul");
}
FailureOr<PackResult> packedMatmul =
blockPackMatmul(rewriter, linalgOp, controlFn);
if (failed(packedMatmul))
return failure();
return success();
}
private:
ControlBlockPackMatmulFn controlFn;
};
/// Convert linalg matmul ops to block layout and back.
struct LinalgBlockPackMatmul
: public impl::LinalgBlockPackMatmulBase<LinalgBlockPackMatmul> {
using LinalgBlockPackMatmulBase::LinalgBlockPackMatmulBase;
void runOnOperation() override {
Operation *op = getOperation();
RewritePatternSet patterns(&getContext());
ControlBlockPackMatmulFn controlFn =
[&](linalg::LinalgOp op) -> BlockPackMatmulOptions {
BlockPackMatmulOptions options;
options.blockFactors = SmallVector<int64_t>{*blockFactors};
options.allowPadding = allowPadding;
options.mnkPaddedSizesNextMultipleOf =
SmallVector<int64_t>{*mnkPaddedSizesNextMultipleOf};
if (!mnkOrder.empty())
options.mnkOrder = SmallVector<int64_t>{*mnkOrder};
options.lhsTransposeOuterBlocks = lhsTransposeOuterBlocks;
options.lhsTransposeInnerBlocks = lhsTransposeInnerBlocks;
options.rhsTransposeOuterBlocks = rhsTransposeOuterBlocks;
options.rhsTransposeInnerBlocks = rhsTransposeInnerBlocks;
return options;
};
linalg::populateBlockPackMatmulPatterns(patterns, controlFn);
if (failed(applyPatternsGreedily(op, std::move(patterns))))
return signalPassFailure();
}
};
} // namespace
void linalg::populateBlockPackMatmulPatterns(
RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn) {
patterns.add<BlockPackMatmul<linalg::GenericOp>,
BlockPackMatmul<linalg::MatmulOp>,
BlockPackMatmul<linalg::BatchMatmulOp>>(patterns.getContext(),
controlFn);
}