llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
Andrzej Warzyński c45cc3e420
[mlir][vector] Standardize base Naming Across Vector Ops (NFC) (#137859)
[mlir][vector] Standardize base Naming Across Vector Ops (NFC)

This change standardizes the naming convention for the argument
representing the value to read from or write to in Vector ops that
interface with Tensors or MemRefs. Specifically, it ensures that all
such ops use the name `base` (i.e., the base address or location to
which offsets are applied).

Updated operations:

* `vector.transfer_read`,
* `vector.transfer_write`.

For reference, these ops already use `base`:

* `vector.load`, `vector.store`, `vector.scatter`, `vector.gather`,
  `vector.expandload`, `vector.compressstore`, `vector.maskedstore`,
  `vector.maskedload`.

This is a non-functional change (NFC) and does not alter the semantics of these
operations. However, it does require users of the XFer ops to switch from
`op.getSource()` to `op.getBase()`.

To ease the transition, this PR temporarily adds a `getSource()` interface
method for compatibility. This is intended for downstream use only and should
not be relied on upstream. The method will be removed prior to the LLVM 21
release.

Implements #131602
2025-05-12 09:44:50 +01:00

1817 lines
77 KiB
C++

//===- VectorDistribute.cpp - patterns to do vector distribution ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Utils/DistributionUtils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/FormatVariadic.h"
#include <utility>
using namespace mlir;
using namespace mlir::vector;
using namespace mlir::gpu;
/// Currently the distribution map is implicit based on the vector shape. In the
/// future it will be part of the op.
/// Example:
/// ```
/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) {
/// ...
/// gpu.yield %3 : vector<32x16x64xf32>
/// }
/// ```
/// Would have an implicit map of:
/// `(d0, d1, d2) -> (d0, d2)`
static AffineMap calculateImplicitMap(VectorType sequentialType,
VectorType distributedType) {
SmallVector<AffineExpr> perm;
perm.reserve(1);
// Check which dimensions of the sequential type are different than the
// dimensions of the distributed type to know the distributed dimensions. Then
// associate each distributed dimension to an ID in order.
for (unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
perm.push_back(getAffineDimExpr(i, distributedType.getContext()));
}
auto map = AffineMap::get(sequentialType.getRank(), 0, perm,
distributedType.getContext());
return map;
}
namespace {
/// Helper struct to create the load / store operations that permit transit
/// through the parallel / sequential and the sequential / parallel boundaries
/// when performing `rewriteWarpOpToScfFor`.
///
/// The vector distribution dimension is inferred from the vector types.
struct DistributedLoadStoreHelper {
DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal,
Value laneId, Value zero)
: sequentialVal(sequentialVal), distributedVal(distributedVal),
laneId(laneId), zero(zero) {
sequentialVectorType = dyn_cast<VectorType>(sequentialVal.getType());
distributedVectorType = dyn_cast<VectorType>(distributedVal.getType());
if (sequentialVectorType && distributedVectorType)
distributionMap =
calculateImplicitMap(sequentialVectorType, distributedVectorType);
}
Value buildDistributedOffset(RewriterBase &b, Location loc, int64_t index) {
int64_t distributedSize = distributedVectorType.getDimSize(index);
AffineExpr tid = getAffineSymbolExpr(0, b.getContext());
return b.createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize,
ArrayRef<Value>{laneId});
}
/// Create a store during the process of distributing the
/// `vector.warp_execute_on_thread_0` op.
/// Vector distribution assumes the following convention regarding the
/// temporary buffers that are created to transition values. This **must**
/// be properly specified in the `options.warpAllocationFn`:
/// 1. scalars of type T transit through a memref<1xT>.
/// 2. vectors of type V<shapexT> transit through a memref<shapexT>
Operation *buildStore(RewriterBase &b, Location loc, Value val,
Value buffer) {
assert((val == distributedVal || val == sequentialVal) &&
"Must store either the preregistered distributed or the "
"preregistered sequential value.");
// Scalar case can directly use memref.store.
if (!isa<VectorType>(val.getType()))
return b.create<memref::StoreOp>(loc, val, buffer, zero);
// Vector case must use vector::TransferWriteOp which will later lower to
// vector.store of memref.store depending on further lowerings.
int64_t rank = sequentialVectorType.getRank();
SmallVector<Value> indices(rank, zero);
if (val == distributedVal) {
for (auto dimExpr : distributionMap.getResults()) {
int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
indices[index] = buildDistributedOffset(b, loc, index);
}
}
SmallVector<bool> inBounds(indices.size(), true);
return b.create<vector::TransferWriteOp>(
loc, val, buffer, indices,
ArrayRef<bool>(inBounds.begin(), inBounds.end()));
}
/// Create a load during the process of distributing the
/// `vector.warp_execute_on_thread_0` op.
/// Vector distribution assumes the following convention regarding the
/// temporary buffers that are created to transition values. This **must**
/// be properly specified in the `options.warpAllocationFn`:
/// 1. scalars of type T transit through a memref<1xT>.
/// 2. vectors of type V<shapexT> transit through a memref<shapexT>
///
/// When broadcastMode is true, the load is not distributed to account for
/// the broadcast semantics of the `gpu.warp_execute_on_lane_0` op.
///
/// Example:
///
/// ```
/// %r = gpu.warp_execute_on_lane_0(...) -> (f32) {
/// gpu.yield %cst : f32
/// }
/// // Both types are f32. The constant %cst is broadcasted to all lanes.
/// ```
/// This behavior described in more detail in the documentation of the op.
Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer) {
// Scalar case can directly use memref.store.
if (!isa<VectorType>(type))
return b.create<memref::LoadOp>(loc, buffer, zero);
// Other cases must be vector atm.
// Vector case must use vector::TransferReadOp which will later lower to
// vector.read of memref.read depending on further lowerings.
assert((type == distributedVectorType || type == sequentialVectorType) &&
"Must store either the preregistered distributed or the "
"preregistered sequential type.");
SmallVector<Value> indices(sequentialVectorType.getRank(), zero);
if (type == distributedVectorType) {
for (auto dimExpr : distributionMap.getResults()) {
int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
indices[index] = buildDistributedOffset(b, loc, index);
}
}
SmallVector<bool> inBounds(indices.size(), true);
return b.create<vector::TransferReadOp>(
loc, cast<VectorType>(type), buffer, indices,
ArrayRef<bool>(inBounds.begin(), inBounds.end()));
}
Value sequentialVal, distributedVal, laneId, zero;
VectorType sequentialVectorType, distributedVectorType;
AffineMap distributionMap;
};
} // namespace
// Clones `op` into a new operation that takes `operands` and returns
// `resultTypes`.
static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter,
Location loc, Operation *op,
ArrayRef<Value> operands,
ArrayRef<Type> resultTypes) {
OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
op->getAttrs());
return rewriter.create(res);
}
namespace {
/// Rewrite a WarpExecuteOnLane0Op into a predicated scf.if op where the single
/// thread `laneId` executes the entirety of the computation.
///
/// After the transformation:
/// - the IR within the scf.if op can be thought of as executing sequentially
/// (from the point of view of threads along `laneId`).
/// - the IR outside of the scf.if op can be thought of as executing in
/// parallel (from the point of view of threads along `laneId`).
///
/// Values that need to transit through the parallel / sequential and the
/// sequential / parallel boundaries do so via reads and writes to a temporary
/// memory location.
///
/// The transformation proceeds in multiple steps:
/// 1. Create the scf.if op.
/// 2. Insert appropriate (alloc, write)-pairs before the scf.if and reads
/// within the scf.if to transit the values captured from above.
/// 3. Synchronize before the scf.if to ensure all writes inserted in 2. are
/// consistent within the scf.if.
/// 4. Move the body of the WarpExecuteOnLane0Op inside the scf.if.
/// 5. Insert appropriate writes within scf.if and reads after the scf.if to
/// transit the values returned by the op.
/// 6. Synchronize after the scf.if to ensure all writes inserted in 5. are
/// consistent after the scf.if.
/// 7. Perform late cleanups.
///
/// All this assumes the vector distribution occurs along the most minor
/// distributed vector dimension.
struct WarpOpToScfIfPattern : public WarpDistributionPattern {
WarpOpToScfIfPattern(MLIRContext *context,
const WarpExecuteOnLane0LoweringOptions &options,
PatternBenefit benefit = 1)
: WarpDistributionPattern(context, benefit), options(options) {}
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
assert(warpOp.getBodyRegion().hasOneBlock() &&
"expected WarpOp with single block");
Block *warpOpBody = &warpOp.getBodyRegion().front();
Location loc = warpOp.getLoc();
// Passed all checks. Start rewriting.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(warpOp);
// Step 1: Create scf.if op.
Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value isLane0 = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0,
/*withElseRegion=*/false);
rewriter.eraseOp(ifOp.thenBlock()->getTerminator());
// Step 2: insert appropriate (alloc, write)-pairs before the scf.if and
// reads within the scf.if to transit the values captured from above.
SmallVector<Value> bbArgReplacements;
for (const auto &it : llvm::enumerate(warpOp.getArgs())) {
Value sequentialVal = warpOpBody->getArgument(it.index());
Value distributedVal = it.value();
DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
warpOp.getLaneid(), c0);
// Create buffer before the ifOp.
rewriter.setInsertionPoint(ifOp);
Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
sequentialVal.getType());
// Store distributed vector into buffer, before the ifOp.
helper.buildStore(rewriter, loc, distributedVal, buffer);
// Load sequential vector from buffer, inside the ifOp.
rewriter.setInsertionPointToStart(ifOp.thenBlock());
bbArgReplacements.push_back(
helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer));
}
// Step 3. Insert sync after all the stores and before all the loads.
if (!warpOp.getArgs().empty()) {
rewriter.setInsertionPoint(ifOp);
options.warpSyncronizationFn(loc, rewriter, warpOp);
}
// Step 4. Move body of warpOp to ifOp.
rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
// Step 5. Insert appropriate writes within scf.if and reads after the
// scf.if to transit the values returned by the op.
// TODO: at this point, we can reuse the shared memory from previous
// buffers.
SmallVector<Value> replacements;
auto yieldOp = cast<gpu::YieldOp>(ifOp.thenBlock()->getTerminator());
Location yieldLoc = yieldOp.getLoc();
for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
Value sequentialVal = it.value();
Value distributedVal = warpOp->getResult(it.index());
DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
warpOp.getLaneid(), c0);
// Create buffer before the ifOp.
rewriter.setInsertionPoint(ifOp);
Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
sequentialVal.getType());
// Store yielded value into buffer, inside the ifOp, before the
// terminator.
rewriter.setInsertionPoint(yieldOp);
helper.buildStore(rewriter, loc, sequentialVal, buffer);
// Load distributed value from buffer, after the warpOp.
rewriter.setInsertionPointAfter(ifOp);
// Result type and yielded value type are the same. This is a broadcast.
// E.g.:
// %r = gpu.warp_execute_on_lane_0(...) -> (f32) {
// gpu.yield %cst : f32
// }
// Both types are f32. The constant %cst is broadcasted to all lanes.
// This is described in more detail in the documentation of the op.
replacements.push_back(
helper.buildLoad(rewriter, loc, distributedVal.getType(), buffer));
}
// Step 6. Insert sync after all the stores and before all the loads.
if (!yieldOp.getOperands().empty()) {
rewriter.setInsertionPointAfter(ifOp);
options.warpSyncronizationFn(loc, rewriter, warpOp);
}
// Step 7. Delete terminator and add empty scf.yield.
rewriter.eraseOp(yieldOp);
rewriter.setInsertionPointToEnd(ifOp.thenBlock());
rewriter.create<scf::YieldOp>(yieldLoc);
// Compute replacements for WarpOp results.
rewriter.replaceOp(warpOp, replacements);
return success();
}
private:
const WarpExecuteOnLane0LoweringOptions &options;
};
/// Return the distributed vector type based on the original type and the
/// distribution map. The map is expected to have a dimension equal to the
/// original type rank and should be a projection where the results are the
/// distributed dimensions. The number of results should be equal to the number
/// of warp sizes which is currently limited to 1.
/// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1)
/// and a warp size of 16 would distribute the second dimension (associated to
/// d1) and return vector<16x2x64>
static VectorType getDistributedType(VectorType originalType, AffineMap map,
int64_t warpSize) {
SmallVector<int64_t> targetShape(originalType.getShape());
for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
unsigned position = map.getDimPosition(i);
if (targetShape[position] % warpSize != 0) {
if (warpSize % targetShape[position] != 0) {
return VectorType();
}
warpSize /= targetShape[position];
targetShape[position] = 1;
continue;
}
targetShape[position] = targetShape[position] / warpSize;
warpSize = 1;
break;
}
if (warpSize != 1) {
return VectorType();
}
VectorType targetType =
VectorType::get(targetShape, originalType.getElementType());
return targetType;
}
/// Distribute transfer_write ops based on the affine map returned by
/// `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
/// will not be distributed (it should be less than the warp size).
///
/// Example:
/// ```
/// %0 = gpu.warp_execute_on_lane_0(%id){
/// ...
/// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
/// gpu.yield
/// }
/// ```
/// To
/// ```
/// %r:3 = gpu.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
/// ...
/// gpu.yield %v : vector<32xf32>
/// }
/// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
struct WarpOpTransferWrite : public WarpDistributionPattern {
WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
unsigned maxNumElementsToExtract, PatternBenefit b = 1)
: WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)),
maxNumElementsToExtract(maxNumElementsToExtract) {}
/// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
/// are multiples of the distribution ratio are supported at the moment.
LogicalResult tryDistributeOp(RewriterBase &rewriter,
vector::TransferWriteOp writeOp,
WarpExecuteOnLane0Op warpOp) const {
VectorType writtenVectorType = writeOp.getVectorType();
// 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op
// to separate it from the rest.
if (writtenVectorType.getRank() == 0)
return failure();
// 2. Compute the distributed type.
AffineMap map = distributionMapFn(writeOp.getVector());
VectorType targetType =
getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
if (!targetType)
return failure();
// 2.5 Compute the distributed type for the new mask;
VectorType maskType;
if (writeOp.getMask()) {
// TODO: Distribution of masked writes with non-trivial permutation maps
// requires the distribution of the mask to elementwise match the
// distribution of the permuted written vector. Currently the details
// of which lane is responsible for which element is captured strictly
// by shape information on the warp op, and thus requires materializing
// the permutation in IR.
if (!writeOp.getPermutationMap().isMinorIdentity())
return failure();
maskType =
getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize());
}
// 3. clone the write into a new WarpExecuteOnLane0Op to separate it from
// the rest.
vector::TransferWriteOp newWriteOp =
cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);
// 4. Reindex the write using the distribution map.
auto newWarpOp =
newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
// Delinearize the lane id based on the way threads are divided across the
// vector. To get the number of threads per vector dimension, divide the
// sequential size by the distributed size along each dim.
rewriter.setInsertionPoint(newWriteOp);
SmallVector<OpFoldResult> delinearizedIdSizes;
for (auto [seqSize, distSize] :
llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) {
assert(seqSize % distSize == 0 && "Invalid distributed vector shape");
delinearizedIdSizes.push_back(rewriter.getIndexAttr(seqSize / distSize));
}
SmallVector<Value> delinearized;
if (map.getNumResults() > 1) {
delinearized = rewriter
.create<mlir::affine::AffineDelinearizeIndexOp>(
newWarpOp.getLoc(), newWarpOp.getLaneid(),
delinearizedIdSizes)
.getResults();
} else {
// If there is only one map result, we can elide the delinearization
// op and use the lane id directly.
delinearized.append(targetType.getRank(), newWarpOp.getLaneid());
}
AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
Location loc = newWriteOp.getLoc();
SmallVector<Value> indices(newWriteOp.getIndices().begin(),
newWriteOp.getIndices().end());
for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
AffineExpr d0, d1;
bindDims(newWarpOp.getContext(), d0, d1);
auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
if (!indexExpr)
continue;
unsigned indexPos = indexExpr.getPosition();
unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
Value laneId = delinearized[vectorPos];
auto scale =
rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos));
indices[indexPos] = affine::makeComposedAffineApply(
rewriter, loc, d0 + scale * d1, {indices[indexPos], laneId});
}
newWriteOp.getIndicesMutable().assign(indices);
return success();
}
/// Extract TransferWriteOps of vector<1x> into a separate warp op.
LogicalResult tryExtractOp(RewriterBase &rewriter,
vector::TransferWriteOp writeOp,
WarpExecuteOnLane0Op warpOp) const {
Location loc = writeOp.getLoc();
VectorType vecType = writeOp.getVectorType();
if (vecType.getNumElements() > maxNumElementsToExtract) {
return rewriter.notifyMatchFailure(
warpOp,
llvm::formatv(
"writes more elements ({0}) than allowed to extract ({1})",
vecType.getNumElements(), maxNumElementsToExtract));
}
// Do not process warp ops that contain only TransferWriteOps.
if (llvm::all_of(warpOp.getOps(),
llvm::IsaPred<vector::TransferWriteOp, gpu::YieldOp>))
return failure();
SmallVector<Value> yieldValues = {writeOp.getVector()};
SmallVector<Type> retTypes = {vecType};
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, yieldValues, retTypes, newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
// Create a second warp op that contains only writeOp.
auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
Block &body = secondWarpOp.getBodyRegion().front();
rewriter.setInsertionPointToStart(&body);
auto newWriteOp =
cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
newWriteOp.getValueToStoreMutable().assign(
newWarpOp.getResult(newRetIndices[0]));
rewriter.eraseOp(writeOp);
rewriter.create<gpu::YieldOp>(newWarpOp.getLoc());
return success();
}
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
auto yield = cast<gpu::YieldOp>(
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
Operation *lastNode = yield->getPrevNode();
auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
if (!writeOp)
return failure();
Value maybeMask = writeOp.getMask();
if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
return writeOp.getVector() == value ||
(maybeMask && maybeMask == value) ||
warpOp.isDefinedOutsideOfRegion(value);
}))
return failure();
if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
return success();
// Masked writes not supported for extraction.
if (writeOp.getMask())
return failure();
if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
return success();
return failure();
}
private:
/// Clone `writeOp` assumed to be nested under `warpOp` into a new warp
/// execute op with the proper return type. The new write op is updated to
/// write the result of the new warp execute op. The old `writeOp` is deleted.
vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
WarpExecuteOnLane0Op warpOp,
vector::TransferWriteOp writeOp,
VectorType targetType,
VectorType maybeMaskType) const {
assert(writeOp->getParentOp() == warpOp &&
"write must be nested immediately under warp");
OpBuilder::InsertionGuard g(rewriter);
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp;
if (maybeMaskType) {
newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, ValueRange{writeOp.getVector(), writeOp.getMask()},
TypeRange{targetType, maybeMaskType}, newRetIndices);
} else {
newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, ValueRange{{writeOp.getVector()}},
TypeRange{targetType}, newRetIndices);
}
rewriter.setInsertionPointAfter(newWarpOp);
auto newWriteOp =
cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
rewriter.eraseOp(writeOp);
newWriteOp.getValueToStoreMutable().assign(
newWarpOp.getResult(newRetIndices[0]));
if (maybeMaskType)
newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
return newWriteOp;
}
DistributionMapFn distributionMapFn;
unsigned maxNumElementsToExtract = 1;
};
/// Sink out elementwise op feeding into a warp op yield.
/// ```
/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
/// ...
/// %3 = arith.addf %1, %2 : vector<32xf32>
/// gpu.yield %3 : vector<32xf32>
/// }
/// ```
/// To
/// ```
/// %r:3 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
/// vector<1xf32>, vector<1xf32>) {
/// ...
/// %4 = arith.addf %2, %3 : vector<32xf32>
/// gpu.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>,
/// vector<32xf32>
/// }
/// %0 = arith.addf %r#1, %r#2 : vector<1xf32>
struct WarpOpElementwise : public WarpDistributionPattern {
using Base::Base;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
return OpTrait::hasElementwiseMappableTraits(op);
});
if (!yieldOperand)
return failure();
Operation *elementWise = yieldOperand->get().getDefiningOp();
unsigned operandIndex = yieldOperand->getOperandNumber();
Value distributedVal = warpOp.getResult(operandIndex);
SmallVector<Value> yieldValues;
SmallVector<Type> retTypes;
Location loc = warpOp.getLoc();
for (OpOperand &operand : elementWise->getOpOperands()) {
Type targetType;
if (auto vecType = dyn_cast<VectorType>(distributedVal.getType())) {
// If the result type is a vector, the operands must also be vectors.
auto operandType = cast<VectorType>(operand.get().getType());
targetType =
VectorType::get(vecType.getShape(), operandType.getElementType());
} else {
auto operandType = operand.get().getType();
assert(!isa<VectorType>(operandType) &&
"unexpected yield of vector from op with scalar result type");
targetType = operandType;
}
retTypes.push_back(targetType);
yieldValues.push_back(operand.get());
}
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, yieldValues, retTypes, newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
SmallVector<Value> newOperands(elementWise->getOperands().begin(),
elementWise->getOperands().end());
for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) {
newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
}
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(newWarpOp);
Operation *newOp = cloneOpWithOperandsAndTypes(
rewriter, loc, elementWise, newOperands,
{newWarpOp.getResult(operandIndex).getType()});
rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex),
newOp->getResult(0));
return success();
}
};
/// Sink out splat constant op feeding into a warp op yield.
/// ```
/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
/// ...
/// %cst = arith.constant dense<2.0> : vector<32xf32>
/// gpu.yield %cst : vector<32xf32>
/// }
/// ```
/// To
/// ```
/// gpu.warp_execute_on_lane_0(%arg0 {
/// ...
/// }
/// %0 = arith.constant dense<2.0> : vector<1xf32>
struct WarpOpConstant : public WarpDistributionPattern {
using Base::Base;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *yieldOperand =
getWarpResult(warpOp, llvm::IsaPred<arith::ConstantOp>);
if (!yieldOperand)
return failure();
auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>();
auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
if (!dense)
return failure();
// Notify the rewriter that the warp op is changing (see the comment on
// the WarpOpTransferRead pattern).
rewriter.startOpModification(warpOp);
unsigned operandIndex = yieldOperand->getOperandNumber();
Attribute scalarAttr = dense.getSplatValue<Attribute>();
auto newAttr = DenseElementsAttr::get(
cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr);
Location loc = warpOp.getLoc();
rewriter.setInsertionPointAfter(warpOp);
Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr);
rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant);
rewriter.finalizeOpModification(warpOp);
return success();
}
};
/// Sink out transfer_read op feeding into a warp op yield.
/// ```
/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
/// ...
// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
// vector<32xf32>
/// gpu.yield %2 : vector<32xf32>
/// }
/// ```
/// To
/// ```
/// %dead = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
/// vector<1xf32>, vector<1xf32>) {
/// ...
/// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
/// vector<32xf32> gpu.yield %2 : vector<32xf32>
/// }
/// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32>
struct WarpOpTransferRead : public WarpDistributionPattern {
using Base::Base;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
// Try to find a distributable yielded read. Note that this pattern can
// still fail at the end after distribution, in which case this might have
// missed another distributable read.
OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
// Don't duplicate transfer_read ops when distributing.
return isa<vector::TransferReadOp>(op) && op->hasOneUse();
});
if (!operand)
return rewriter.notifyMatchFailure(
warpOp, "warp result is not a vector.transfer_read op");
auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
// Source must be defined outside of the region.
if (!warpOp.isDefinedOutsideOfRegion(read.getBase()))
return rewriter.notifyMatchFailure(
read, "source must be defined outside of the region");
unsigned operandIndex = operand->getOperandNumber();
Value distributedVal = warpOp.getResult(operandIndex);
SmallVector<Value, 4> indices(read.getIndices().begin(),
read.getIndices().end());
auto sequentialType = cast<VectorType>(read.getResult().getType());
auto distributedType = cast<VectorType>(distributedVal.getType());
AffineMap map = calculateImplicitMap(sequentialType, distributedType);
AffineMap indexMap = map.compose(read.getPermutationMap());
// Try to delinearize the lane ID to match the rank expected for
// distribution.
SmallVector<Value> delinearizedIds;
if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
distributedType.getShape(), warpOp.getWarpSize(),
warpOp.getLaneid(), delinearizedIds)) {
return rewriter.notifyMatchFailure(
read, "cannot delinearize lane ID for distribution");
}
assert(!delinearizedIds.empty() || map.getNumResults() == 0);
// Distribute indices and the mask (if present).
OpBuilder::InsertionGuard g(rewriter);
SmallVector<Value> additionalResults(indices.begin(), indices.end());
SmallVector<Type> additionalResultTypes(indices.size(),
rewriter.getIndexType());
additionalResults.push_back(read.getPadding());
additionalResultTypes.push_back(read.getPadding().getType());
bool hasMask = false;
if (read.getMask()) {
hasMask = true;
// TODO: Distribution of masked reads with non-trivial permutation maps
// requires the distribution of the mask to elementwise match the
// distribution of the permuted written vector. Currently the details
// of which lane is responsible for which element is captured strictly
// by shape information on the warp op, and thus requires materializing
// the permutation in IR.
if (!mlir::compressUnusedDims(read.getPermutationMap()).isIdentity())
return rewriter.notifyMatchFailure(
read, "non-trivial permutation maps not supported");
VectorType maskType =
getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
additionalResults.push_back(read.getMask());
additionalResultTypes.push_back(maskType);
}
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, additionalResults, additionalResultTypes,
newRetIndices);
distributedVal = newWarpOp.getResult(operandIndex);
// Distributed indices were appended first.
SmallVector<Value> newIndices;
for (int64_t i = 0, e = indices.size(); i < e; ++i)
newIndices.push_back(newWarpOp.getResult(newRetIndices[i]));
rewriter.setInsertionPointAfter(newWarpOp);
for (auto it : llvm::zip_equal(indexMap.getResults(), map.getResults())) {
AffineExpr d0, d1;
bindDims(read.getContext(), d0, d1);
auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
if (!indexExpr)
continue;
unsigned indexPos = indexExpr.getPosition();
unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
int64_t scale = distributedType.getDimSize(vectorPos);
newIndices[indexPos] = affine::makeComposedAffineApply(
rewriter, read.getLoc(), d0 + scale * d1,
{newIndices[indexPos], delinearizedIds[vectorPos]});
}
// Distributed padding value was appended right after the indices.
Value newPadding = newWarpOp.getResult(newRetIndices[indices.size()]);
// Distributed mask value was added at the end (if the op has a mask).
Value newMask =
hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])
: Value();
auto newRead = rewriter.create<vector::TransferReadOp>(
read.getLoc(), distributedVal.getType(), read.getBase(), newIndices,
read.getPermutationMapAttr(), newPadding, newMask,
read.getInBoundsAttr());
rewriter.replaceAllUsesWith(distributedVal, newRead);
return success();
}
};
/// Remove any result that has no use along with the matching yieldOp operand.
// TODO: Move this in WarpExecuteOnLane0Op canonicalization.
struct WarpOpDeadResult : public WarpDistributionPattern {
using Base::Base;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
SmallVector<Type> newResultTypes;
newResultTypes.reserve(warpOp->getNumResults());
SmallVector<Value> newYieldValues;
newYieldValues.reserve(warpOp->getNumResults());
DenseMap<Value, int64_t> dedupYieldOperandPositionMap;
DenseMap<OpResult, int64_t> dedupResultPositionMap;
auto yield = cast<gpu::YieldOp>(
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
// Some values may be yielded multiple times and correspond to multiple
// results. Deduplicating occurs by taking each result with its matching
// yielded value, and:
// 1. recording the unique first position at which the value is yielded.
// 2. recording for the result, the first position at which the dedup'ed
// value is yielded.
// 3. skipping from the new result types / new yielded values any result
// that has no use or whose yielded value has already been seen.
for (OpResult result : warpOp.getResults()) {
Value yieldOperand = yield.getOperand(result.getResultNumber());
auto it = dedupYieldOperandPositionMap.insert(
std::make_pair(yieldOperand, newResultTypes.size()));
dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
if (result.use_empty() || !it.second)
continue;
newResultTypes.push_back(result.getType());
newYieldValues.push_back(yieldOperand);
}
// No modification, exit early.
if (yield.getNumOperands() == newYieldValues.size())
return failure();
// Move the body of the old warpOp to a new warpOp.
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
rewriter, warpOp, newYieldValues, newResultTypes);
// Simplify the new warp op after dropping dead results.
newWarpOp.getBody()->walk([&](Operation *op) {
if (isOpTriviallyDead(op))
rewriter.eraseOp(op);
});
// Replace results of the old warpOp by the new, deduplicated results.
SmallVector<Value> newValues;
newValues.reserve(warpOp->getNumResults());
for (OpResult result : warpOp.getResults()) {
if (result.use_empty())
newValues.push_back(Value());
else
newValues.push_back(
newWarpOp.getResult(dedupResultPositionMap.lookup(result)));
}
rewriter.replaceOp(warpOp, newValues);
return success();
}
};
// If an operand is directly yielded out of the region we can forward it
// directly and it doesn't need to go through the region.
struct WarpOpForwardOperand : public WarpDistributionPattern {
using Base::Base;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
auto yield = cast<gpu::YieldOp>(
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
Value valForwarded;
unsigned resultIndex;
for (OpOperand &operand : yield->getOpOperands()) {
Value result = warpOp.getResult(operand.getOperandNumber());
if (result.use_empty())
continue;
// Assume all the values coming from above are uniform.
if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) {
if (result.getType() != operand.get().getType())
continue;
valForwarded = operand.get();
resultIndex = operand.getOperandNumber();
break;
}
auto arg = dyn_cast<BlockArgument>(operand.get());
if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
continue;
Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
if (result.getType() != warpOperand.getType())
continue;
valForwarded = warpOperand;
resultIndex = operand.getOperandNumber();
break;
}
if (!valForwarded)
return failure();
// Notify the rewriter that the warp op is changing (see the comment on
// the WarpOpTransferRead pattern).
rewriter.startOpModification(warpOp);
rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded);
rewriter.finalizeOpModification(warpOp);
return success();
}
};
struct WarpOpBroadcast : public WarpDistributionPattern {
using Base::Base;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand =
getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>);
if (!operand)
return failure();
unsigned int operandNumber = operand->getOperandNumber();
auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>();
Location loc = broadcastOp.getLoc();
auto destVecType =
cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
Value broadcastSrc = broadcastOp.getSource();
Type broadcastSrcType = broadcastSrc.getType();
// Check that the broadcast actually spans a set of values uniformly across
// all threads. In other words, check that each thread can reconstruct
// their own broadcast.
// For that we simply check that the broadcast we want to build makes sense.
if (vector::isBroadcastableTo(broadcastSrcType, destVecType) !=
vector::BroadcastableToResult::Success)
return failure();
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
Value broadcasted = rewriter.create<vector::BroadcastOp>(
loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
broadcasted);
return success();
}
};
/// Pattern to move shape cast out of the warp op. shape cast is basically a
/// no-op for warp distribution; we need to handle the shape though.
struct WarpOpShapeCast : public WarpDistributionPattern {
using Base::Base;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand =
getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
if (!operand)
return failure();
auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();
unsigned int operandNumber = operand->getOperandNumber();
auto castDistributedType =
cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
VectorType castOriginalType = oldCastOp.getSourceVectorType();
VectorType castResultType = castDistributedType;
// We expect the distributed type to have a smaller rank than the original
// type. Prepend with size-one dimensions to make them the same.
unsigned castDistributedRank = castDistributedType.getRank();
unsigned castOriginalRank = castOriginalType.getRank();
if (castDistributedRank < castOriginalRank) {
SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1);
llvm::append_range(shape, castDistributedType.getShape());
castDistributedType =
VectorType::get(shape, castDistributedType.getElementType());
}
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
Value newCast = rewriter.create<vector::ShapeCastOp>(
oldCastOp.getLoc(), castResultType,
newWarpOp->getResult(newRetIndices[0]));
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast);
return success();
}
};
/// Sink out vector.create_mask op feeding into a warp op yield.
/// ```
/// %0 = ...
/// %1 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
/// ...
/// %mask = vector.create_mask %0 : vector<32xi1>
/// gpu.yield %mask : vector<32xi1>
/// }
/// ```
/// To
/// ```
/// %0 = ...
/// gpu.warp_execute_on_lane_0(%arg0) {
/// ...
/// }
/// %cmp = arith.cmpi ult, %laneid, %0
/// %ub = arith.select %cmp, %c0, %c1
/// %1 = vector.create_mask %ub : vector<1xi1>
struct WarpOpCreateMask : public WarpDistributionPattern {
using Base::Base;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *yieldOperand =
getWarpResult(warpOp, llvm::IsaPred<vector::CreateMaskOp>);
if (!yieldOperand)
return failure();
auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
// Early exit if any values needed for calculating the new mask indices
// are defined inside the warp op.
if (!llvm::all_of(mask->getOperands(), [&](Value value) {
return warpOp.isDefinedOutsideOfRegion(value);
}))
return failure();
Location loc = mask.getLoc();
unsigned operandIndex = yieldOperand->getOperandNumber();
auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
VectorType seqType = mask.getVectorType();
ArrayRef<int64_t> seqShape = seqType.getShape();
ArrayRef<int64_t> distShape = distType.getShape();
rewriter.setInsertionPointAfter(warpOp);
// Delinearize the lane ID for constructing the distributed mask sizes.
SmallVector<Value> delinearizedIds;
if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
warpOp.getWarpSize(), warpOp.getLaneid(),
delinearizedIds))
return rewriter.notifyMatchFailure(
mask, "cannot delinearize lane ID for distribution");
assert(!delinearizedIds.empty());
// Notify the rewriter that the warp op is changing (see the comment on
// the WarpOpTransferRead pattern).
rewriter.startOpModification(warpOp);
AffineExpr s0, s1;
bindSymbols(rewriter.getContext(), s0, s1);
SmallVector<Value> newOperands;
for (int i = 0, e = distShape.size(); i < e; ++i) {
// Get `mask_dim_range_upper_limit[i] - lane_id[i] * dist_sizes[i]` to
// find the distance from the largest mask index owned by this lane to the
// original mask size. `vector.create_mask` implicitly clamps mask
// operands to the range [0, mask_vector_size[i]], or in other words, the
// mask sizes are always in the range [0, mask_vector_size[i]).
Value maskDimIdx = affine::makeComposedAffineApply(
rewriter, loc, s1 - s0 * distShape[i],
{delinearizedIds[i], mask.getOperand(i)});
newOperands.push_back(maskDimIdx);
}
auto newMask =
rewriter.create<vector::CreateMaskOp>(loc, distType, newOperands);
rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask);
rewriter.finalizeOpModification(warpOp);
return success();
}
};
/// Pattern to move out vector.extract of single element vector. Those don't
/// need to be distributed and can just be propagated outside of the region.
struct WarpOpExtract : public WarpDistributionPattern {
using Base::Base;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand =
getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
if (!operand)
return failure();
unsigned int operandNumber = operand->getOperandNumber();
auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
VectorType extractSrcType = extractOp.getSourceVectorType();
Location loc = extractOp.getLoc();
// For 1-d or 0-d source cases, we rely on WarpOpExtractScalar pattern.
if (extractSrcType.getRank() <= 1) {
return failure();
}
// All following cases are 2d or higher dimensional source vectors.
if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
// There is no distribution, this is a broadcast. Simply move the extract
// out of the warp op.
// TODO: This could be optimized. E.g., in case of a scalar result, let
// one lane extract and shuffle the result to all other lanes (same as
// the 1d case).
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, {extractOp.getVector()},
{extractOp.getSourceVectorType()}, newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
// Extract from distributed vector.
Value newExtract = rewriter.create<vector::ExtractOp>(
loc, distributedVec, extractOp.getMixedPosition());
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
newExtract);
return success();
}
// Find the distributed dimension. There should be exactly one.
auto distributedType =
cast<VectorType>(warpOp.getResult(operandNumber).getType());
auto yieldedType = cast<VectorType>(operand->get().getType());
int64_t distributedDim = -1;
for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) {
// Keep this assert here in case WarpExecuteOnLane0Op gets extended to
// support distributing multiple dimensions in the future.
assert(distributedDim == -1 && "found multiple distributed dims");
distributedDim = i;
}
}
assert(distributedDim != -1 && "could not find distributed dimension");
(void)distributedDim;
// Yield source vector from warp op.
SmallVector<int64_t> newDistributedShape(extractSrcType.getShape());
for (int i = 0; i < distributedType.getRank(); ++i)
newDistributedShape[i + extractOp.getNumIndices()] =
distributedType.getDimSize(i);
auto newDistributedType =
VectorType::get(newDistributedShape, distributedType.getElementType());
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
// Extract from distributed vector.
Value newExtract = rewriter.create<vector::ExtractOp>(
loc, distributedVec, extractOp.getMixedPosition());
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
newExtract);
return success();
}
};
/// Pattern to move out vector.extract with a scalar result.
/// Only supports 1-D and 0-D sources for now.
struct WarpOpExtractScalar : public WarpDistributionPattern {
WarpOpExtractScalar(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
PatternBenefit b = 1)
: WarpDistributionPattern(ctx, b), warpShuffleFromIdxFn(std::move(fn)) {}
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand =
getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
if (!operand)
return failure();
unsigned int operandNumber = operand->getOperandNumber();
auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
VectorType extractSrcType = extractOp.getSourceVectorType();
// Only supports 1-D or 0-D sources for now.
if (extractSrcType.getRank() > 1) {
return rewriter.notifyMatchFailure(
extractOp, "only 0-D or 1-D source supported for now");
}
// TODO: Supported shuffle types should be parameterizable, similar to
// `WarpShuffleFromIdxFn`.
if (!extractSrcType.getElementType().isF32() &&
!extractSrcType.getElementType().isInteger(32))
return rewriter.notifyMatchFailure(
extractOp, "only f32/i32 element types are supported");
bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
Type elType = extractSrcType.getElementType();
VectorType distributedVecType;
if (!is0dOrVec1Extract) {
assert(extractSrcType.getRank() == 1 &&
"expected that extract src rank is 0 or 1");
if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
return failure();
int64_t elementsPerLane =
extractSrcType.getShape()[0] / warpOp.getWarpSize();
distributedVecType = VectorType::get({elementsPerLane}, elType);
} else {
distributedVecType = extractSrcType;
}
// Yield source vector and position (if present) from warp op.
SmallVector<Value> additionalResults{extractOp.getVector()};
SmallVector<Type> additionalResultTypes{distributedVecType};
additionalResults.append(
SmallVector<Value>(extractOp.getDynamicPosition()));
additionalResultTypes.append(
SmallVector<Type>(extractOp.getDynamicPosition().getTypes()));
Location loc = extractOp.getLoc();
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, additionalResults, additionalResultTypes,
newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
// 0d extract: The new warp op broadcasts the source vector to all lanes.
// All lanes extract the scalar.
if (is0dOrVec1Extract) {
Value newExtract;
SmallVector<int64_t> indices(extractSrcType.getRank(), 0);
newExtract =
rewriter.create<vector::ExtractOp>(loc, distributedVec, indices);
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
newExtract);
return success();
}
int64_t staticPos = extractOp.getStaticPosition()[0];
OpFoldResult pos = ShapedType::isDynamic(staticPos)
? (newWarpOp->getResult(newRetIndices[1]))
: OpFoldResult(rewriter.getIndexAttr(staticPos));
// 1d extract: Distribute the source vector. One lane extracts and shuffles
// the value to all other lanes.
int64_t elementsPerLane = distributedVecType.getShape()[0];
AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
// tid of extracting thread: pos / elementsPerLane
Value broadcastFromTid = affine::makeComposedAffineApply(
rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
// Extract at position: pos % elementsPerLane
Value newPos =
elementsPerLane == 1
? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
: affine::makeComposedAffineApply(rewriter, loc,
sym0 % elementsPerLane, pos);
Value extracted =
rewriter.create<vector::ExtractOp>(loc, distributedVec, newPos);
// Shuffle the extracted value to all lanes.
Value shuffled = warpShuffleFromIdxFn(
loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), shuffled);
return success();
}
private:
WarpShuffleFromIdxFn warpShuffleFromIdxFn;
};
/// Pattern to convert vector.extractelement to vector.extract.
struct WarpOpExtractElement : public WarpDistributionPattern {
using Base::Base;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand =
getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>);
if (!operand)
return failure();
auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
SmallVector<OpFoldResult> indices;
if (auto pos = extractOp.getPosition()) {
indices.push_back(pos);
}
rewriter.setInsertionPoint(extractOp);
rewriter.replaceOpWithNewOp<vector::ExtractOp>(
extractOp, extractOp.getVector(), indices);
return success();
}
};
/// Pattern to move out vector.insert with a scalar input.
/// Only supports 1-D and 0-D destinations for now.
struct WarpOpInsertScalar : public WarpDistributionPattern {
using Base::Base;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
if (!operand)
return failure();
unsigned int operandNumber = operand->getOperandNumber();
auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
VectorType vecType = insertOp.getDestVectorType();
VectorType distrType =
cast<VectorType>(warpOp.getResult(operandNumber).getType());
// Only supports 1-D or 0-D destinations for now.
if (vecType.getRank() > 1) {
return rewriter.notifyMatchFailure(
insertOp, "only 0-D or 1-D source supported for now");
}
// Yield destination vector, source scalar and position from warp op.
SmallVector<Value> additionalResults{insertOp.getDest(),
insertOp.getValueToStore()};
SmallVector<Type> additionalResultTypes{
distrType, insertOp.getValueToStore().getType()};
additionalResults.append(SmallVector<Value>(insertOp.getDynamicPosition()));
additionalResultTypes.append(
SmallVector<Type>(insertOp.getDynamicPosition().getTypes()));
Location loc = insertOp.getLoc();
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, additionalResults, additionalResultTypes,
newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
Value newSource = newWarpOp->getResult(newRetIndices[1]);
rewriter.setInsertionPointAfter(newWarpOp);
OpFoldResult pos;
if (vecType.getRank() != 0) {
int64_t staticPos = insertOp.getStaticPosition()[0];
pos = ShapedType::isDynamic(staticPos)
? (newWarpOp->getResult(newRetIndices[2]))
: OpFoldResult(rewriter.getIndexAttr(staticPos));
}
// This condition is always true for 0-d vectors.
if (vecType == distrType) {
Value newInsert;
SmallVector<OpFoldResult> indices;
if (pos) {
indices.push_back(pos);
}
newInsert = rewriter.create<vector::InsertOp>(loc, newSource,
distributedVec, indices);
// Broadcast: Simply move the vector.insert op out.
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
newInsert);
return success();
}
// This is a distribution. Only one lane should insert.
int64_t elementsPerLane = distrType.getShape()[0];
AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
// tid of extracting thread: pos / elementsPerLane
Value insertingLane = affine::makeComposedAffineApply(
rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
// Insert position: pos % elementsPerLane
OpFoldResult newPos = affine::makeComposedFoldedAffineApply(
rewriter, loc, sym0 % elementsPerLane, pos);
Value isInsertingLane = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
Value newResult =
rewriter
.create<scf::IfOp>(
loc, isInsertingLane,
/*thenBuilder=*/
[&](OpBuilder &builder, Location loc) {
Value newInsert = builder.create<vector::InsertOp>(
loc, newSource, distributedVec, newPos);
builder.create<scf::YieldOp>(loc, newInsert);
},
/*elseBuilder=*/
[&](OpBuilder &builder, Location loc) {
builder.create<scf::YieldOp>(loc, distributedVec);
})
.getResult(0);
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
return success();
}
};
struct WarpOpInsert : public WarpDistributionPattern {
using Base::Base;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
if (!operand)
return failure();
unsigned int operandNumber = operand->getOperandNumber();
auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
Location loc = insertOp.getLoc();
// For 1-d or 0-d destination cases, we rely on WarpOpInsertScalar pattern.
if (insertOp.getDestVectorType().getRank() <= 1) {
return failure();
}
// All following cases are 2d or higher dimensional source vectors.
if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
// There is no distribution, this is a broadcast. Simply move the insert
// out of the warp op.
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
{insertOp.getValueToStoreType(), insertOp.getDestVectorType()},
newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
Value newResult = rewriter.create<vector::InsertOp>(
loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
newResult);
return success();
}
// Find the distributed dimension. There should be exactly one.
auto distrDestType =
cast<VectorType>(warpOp.getResult(operandNumber).getType());
auto yieldedType = cast<VectorType>(operand->get().getType());
int64_t distrDestDim = -1;
for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
// Keep this assert here in case WarpExecuteOnLane0Op gets extended to
// support distributing multiple dimensions in the future.
assert(distrDestDim == -1 && "found multiple distributed dims");
distrDestDim = i;
}
}
assert(distrDestDim != -1 && "could not find distributed dimension");
// Compute the distributed source vector type.
VectorType srcVecType = cast<VectorType>(insertOp.getValueToStoreType());
SmallVector<int64_t> distrSrcShape(srcVecType.getShape());
// E.g.: vector.insert %s, %d [2] : vector<96xf32> into vector<128x96xf32>
// Case 1: distrDestDim = 1 (dim of size 96). In that case, each lane will
// insert a smaller vector<3xf32>.
// Case 2: distrDestDim = 0 (dim of size 128) => distrSrcDim = -1. In that
// case, one lane will insert the source vector<96xf32>. The other
// lanes will not do anything.
int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices();
if (distrSrcDim >= 0)
distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
auto distrSrcType =
VectorType::get(distrSrcShape, distrDestType.getElementType());
// Yield source and dest vectors from warp op.
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
{distrSrcType, distrDestType}, newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
// Insert into the distributed vector.
Value newResult;
if (distrSrcDim >= 0) {
// Every lane inserts a small piece.
newResult = rewriter.create<vector::InsertOp>(
loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
} else {
// One lane inserts the entire source vector.
int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
SmallVector<OpFoldResult> pos = insertOp.getMixedPosition();
SmallVector<int64_t> newPos = getAsIntegers(pos);
// tid of inserting lane: pos / elementsPerLane
Value insertingLane = rewriter.create<arith::ConstantIndexOp>(
loc, newPos[distrDestDim] / elementsPerLane);
Value isInsertingLane = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
// Insert position: pos % elementsPerLane
newPos[distrDestDim] %= elementsPerLane;
auto insertingBuilder = [&](OpBuilder &builder, Location loc) {
Value newInsert = builder.create<vector::InsertOp>(
loc, distributedSrc, distributedDest, newPos);
builder.create<scf::YieldOp>(loc, newInsert);
};
auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) {
builder.create<scf::YieldOp>(loc, distributedDest);
};
newResult = rewriter
.create<scf::IfOp>(loc, isInsertingLane,
/*thenBuilder=*/insertingBuilder,
/*elseBuilder=*/nonInsertingBuilder)
.getResult(0);
}
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
return success();
}
};
struct WarpOpInsertElement : public WarpDistributionPattern {
using Base::Base;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *operand =
getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>);
if (!operand)
return failure();
auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>();
SmallVector<OpFoldResult> indices;
if (auto pos = insertOp.getPosition()) {
indices.push_back(pos);
}
rewriter.setInsertionPoint(insertOp);
rewriter.replaceOpWithNewOp<vector::InsertOp>(
insertOp, insertOp.getSource(), insertOp.getDest(), indices);
return success();
}
};
/// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
/// the scf.ForOp is the last operation in the region so that it doesn't
/// change the order of execution. This creates a new scf.for region after the
/// WarpExecuteOnLane0Op. The new scf.for region will contain a new
/// WarpExecuteOnLane0Op region. Example:
/// ```
/// %w = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) {
/// ...
/// %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v)
/// -> (vector<128xf32>) {
/// ...
/// scf.yield %r : vector<128xf32>
/// }
/// gpu.yield %v1 : vector<128xf32>
/// }
/// ```
/// To:
/// %w0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) {
/// ...
/// gpu.yield %v : vector<128xf32>
/// }
/// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0)
/// -> (vector<4xf32>) {
/// %iw = gpu.warp_execute_on_lane_0(%laneid)
/// args(%varg : vector<4xf32>) -> (vector<4xf32>) {
/// ^bb0(%arg: vector<128xf32>):
/// ...
/// gpu.yield %ir : vector<128xf32>
/// }
/// scf.yield %iw : vector<4xf32>
/// }
/// ```
struct WarpOpScfForOp : public WarpDistributionPattern {
WarpOpScfForOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
: WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
auto yield = cast<gpu::YieldOp>(
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
// Only pick up forOp if it is the last op in the region.
Operation *lastNode = yield->getPrevNode();
auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
if (!forOp)
return failure();
// Collect Values that come from the warp op but are outside the forOp.
// Those Value needs to be returned by the original warpOp and passed to
// the new op.
llvm::SmallSetVector<Value, 32> escapingValues;
SmallVector<Type> inputTypes;
SmallVector<Type> distTypes;
mlir::visitUsedValuesDefinedAbove(
forOp.getBodyRegion(), [&](OpOperand *operand) {
Operation *parent = operand->get().getParentRegion()->getParentOp();
if (warpOp->isAncestor(parent)) {
if (!escapingValues.insert(operand->get()))
return;
Type distType = operand->get().getType();
if (auto vecType = dyn_cast<VectorType>(distType)) {
AffineMap map = distributionMapFn(operand->get());
distType = getDistributedType(vecType, map, warpOp.getWarpSize());
}
inputTypes.push_back(operand->get().getType());
distTypes.push_back(distType);
}
});
if (llvm::is_contained(distTypes, Type{}))
return failure();
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
newRetIndices);
yield = cast<gpu::YieldOp>(
newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
SmallVector<Value> newOperands;
SmallVector<unsigned> resultIdx;
// Collect all the outputs coming from the forOp.
for (OpOperand &yieldOperand : yield->getOpOperands()) {
if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
continue;
auto forResult = cast<OpResult>(yieldOperand.get());
newOperands.push_back(
newWarpOp.getResult(yieldOperand.getOperandNumber()));
yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
resultIdx.push_back(yieldOperand.getOperandNumber());
}
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(newWarpOp);
// Create a new for op outside the region with a WarpExecuteOnLane0Op
// region inside.
auto newForOp = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), newOperands);
rewriter.setInsertionPointToStart(newForOp.getBody());
SmallVector<Value> warpInput(newForOp.getRegionIterArgs().begin(),
newForOp.getRegionIterArgs().end());
SmallVector<Type> warpInputType(forOp.getResultTypes().begin(),
forOp.getResultTypes().end());
llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) {
warpInput.push_back(newWarpOp.getResult(retIdx));
argIndexMapping[escapingValues[i]] = warpInputType.size();
warpInputType.push_back(inputTypes[i]);
}
auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
newWarpOp.getWarpSize(), warpInput, warpInputType);
SmallVector<Value> argMapping;
argMapping.push_back(newForOp.getInductionVar());
for (Value args : innerWarp.getBody()->getArguments()) {
argMapping.push_back(args);
}
argMapping.resize(forOp.getBody()->getNumArguments());
SmallVector<Value> yieldOperands;
for (Value operand : forOp.getBody()->getTerminator()->getOperands())
yieldOperands.push_back(operand);
rewriter.eraseOp(forOp.getBody()->getTerminator());
rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
rewriter.setInsertionPointToEnd(innerWarp.getBody());
rewriter.create<gpu::YieldOp>(innerWarp.getLoc(), yieldOperands);
rewriter.setInsertionPointAfter(innerWarp);
if (!innerWarp.getResults().empty())
rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
rewriter.eraseOp(forOp);
// Replace the warpOp result coming from the original ForOp.
for (const auto &res : llvm::enumerate(resultIdx)) {
rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()),
newForOp.getResult(res.index()));
newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value()));
}
newForOp.walk([&](Operation *op) {
for (OpOperand &operand : op->getOpOperands()) {
auto it = argIndexMapping.find(operand.get());
if (it == argIndexMapping.end())
continue;
operand.set(innerWarp.getBodyRegion().getArgument(it->second));
}
});
// Finally, hoist out any now uniform code from the inner warp op.
mlir::vector::moveScalarUniformCode(innerWarp);
return success();
}
private:
DistributionMapFn distributionMapFn;
};
/// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
/// The vector is reduced in parallel. Currently limited to vector size
/// matching the warpOp size. E.g.:
/// ```
/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
/// %0 = "some_def"() : () -> (vector<32xf32>)
/// %1 = vector.reduction "add", %0 : vector<32xf32> into f32
/// gpu.yield %1 : f32
/// }
/// ```
/// is lowered to:
/// ```
/// %0 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
/// %1 = "some_def"() : () -> (vector<32xf32>)
/// gpu.yield %1 : vector<32xf32>
/// }
/// %a = vector.extract %0[0] : f32 from vector<1xf32>
/// %r = ("warp.reduction %a")
/// ```
struct WarpOpReduction : public WarpDistributionPattern {
WarpOpReduction(MLIRContext *context,
DistributedReductionFn distributedReductionFn,
PatternBenefit benefit = 1)
: WarpDistributionPattern(context, benefit),
distributedReductionFn(std::move(distributedReductionFn)) {}
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
OpOperand *yieldOperand =
getWarpResult(warpOp, llvm::IsaPred<vector::ReductionOp>);
if (!yieldOperand)
return failure();
auto reductionOp =
cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp());
auto vectorType = cast<VectorType>(reductionOp.getVector().getType());
// Only rank 1 vectors supported.
if (vectorType.getRank() != 1)
return rewriter.notifyMatchFailure(
warpOp, "Only rank 1 reductions can be distributed.");
// Only warp_size-sized vectors supported.
if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
return rewriter.notifyMatchFailure(
warpOp, "Reduction vector dimension must match was size.");
if (!reductionOp.getType().isIntOrFloat())
return rewriter.notifyMatchFailure(
warpOp, "Reduction distribution currently only supports floats and "
"integer types.");
int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
// Return vector that will be reduced from the WarpExecuteOnLane0Op.
unsigned operandIndex = yieldOperand->getOperandNumber();
SmallVector<Value> yieldValues = {reductionOp.getVector()};
SmallVector<Type> retTypes = {
VectorType::get({numElements}, reductionOp.getType())};
if (reductionOp.getAcc()) {
yieldValues.push_back(reductionOp.getAcc());
retTypes.push_back(reductionOp.getAcc().getType());
}
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, yieldValues, retTypes, newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
// Obtain data to reduce for a single lane.
Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
// Distribute and reduce across threads.
Value fullReduce =
distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
reductionOp.getKind(), newWarpOp.getWarpSize());
if (reductionOp.getAcc()) {
fullReduce = vector::makeArithReduction(
rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
newWarpOp.getResult(newRetIndices[1]));
}
rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex), fullReduce);
return success();
}
private:
DistributedReductionFn distributedReductionFn;
};
} // namespace
void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern(
RewritePatternSet &patterns,
const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit) {
patterns.add<WarpOpToScfIfPattern>(patterns.getContext(), options, benefit);
}
void mlir::vector::populateDistributeTransferWriteOpPatterns(
RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
unsigned maxNumElementsToExtract, PatternBenefit benefit) {
patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn,
maxNumElementsToExtract, benefit);
}
void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
PatternBenefit readBenefit) {
patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
patterns.getContext(), benefit);
patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
benefit);
patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
benefit);
}
void mlir::vector::populateDistributeReduction(
RewritePatternSet &patterns,
const DistributedReductionFn &distributedReductionFn,
PatternBenefit benefit) {
patterns.add<WarpOpReduction>(patterns.getContext(), distributedReductionFn,
benefit);
}
/// Helper to know if an op can be hoisted out of the region.
static bool canBeHoisted(Operation *op,
function_ref<bool(Value)> definedOutside) {
return llvm::all_of(op->getOperands(), definedOutside) &&
isMemoryEffectFree(op) && op->getNumRegions() == 0;
}
void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
Block *body = warpOp.getBody();
// Keep track of the ops we want to hoist.
llvm::SmallSetVector<Operation *, 8> opsToMove;
// Helper to check if a value is or will be defined outside of the region.
auto isDefinedOutsideOfBody = [&](Value value) {
auto *definingOp = value.getDefiningOp();
return (definingOp && opsToMove.count(definingOp)) ||
warpOp.isDefinedOutsideOfRegion(value);
};
// Do not use walk here, as we do not want to go into nested regions and hoist
// operations from there.
for (auto &op : body->without_terminator()) {
bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {
return isa<VectorType>(result.getType());
});
if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))
opsToMove.insert(&op);
}
// Move all the ops marked as uniform outside of the region.
for (Operation *op : opsToMove)
op->moveBefore(warpOp);
}