Andrzej Warzyński 2ee5586ac7
[mlir][vector] Make the in_bounds attribute mandatory (#97049)
At the moment, the in_bounds attribute has two confusing/contradicting
properties:
  1. It is both optional _and_ has an effective default-value.
  2. The default value is "out-of-bounds" for non-broadcast dims, and
     "in-bounds" for broadcast dims.

(see the `isDimInBounds` vector interface method for an example of this
"default" behaviour [1]).

This PR aims to clarify the logic surrounding the `in_bounds` attribute
by:
  * making the attribute mandatory (i.e. it is always present),
  * always setting the default value to "out of bounds" (that's
    consistent with the current behaviour for the most common cases).

#### Broadcast dimensions in tests

As per [2], the broadcast dimensions requires the corresponding
`in_bounds` attribute to be `true`:
```
  vector.transfer_read op requires broadcast dimensions to be in-bounds
```

The changes in this PR mean that we can no longer rely on the
default value in cases like the following (dim 0 is a broadcast dim):
```mlir
  %read = vector.transfer_read %A[%base1, %base2], %f, %mask
      {permutation_map = affine_map<(d0, d1) -> (0, d1)>} :
    memref<?x?xf32>, vector<4x9xf32>
```

Instead, the broadcast dimension has to explicitly be marked as "in
bounds:

```mlir
  %read = vector.transfer_read %A[%base1, %base2], %f, %mask
      {in_bounds = [true, false], permutation_map = affine_map<(d0, d1) -> (0, d1)>} :
    memref<?x?xf32>, vector<4x9xf32>
```

All tests with broadcast dims are updated accordingly.

#### Changes in "SuperVectorize.cpp" and "Vectorization.cpp"

The following patterns in "Vectorization.cpp" are updated to explicitly
set the `in_bounds` attribute to `false`:
* `LinalgCopyVTRForwardingPattern` and `LinalgCopyVTWForwardingPattern`

Also, `vectorizeAffineLoad` (from "SuperVectorize.cpp") and
`vectorizeAsLinalgGeneric` (from "Vectorization.cpp") are updated to
make sure that xfer Ops created by these hooks set the dimension
corresponding to broadcast dims as "in bounds". Otherwise, the Op
verifier would complain

Note that there is no mechanism to verify whether the corresponding
memory access are indeed in bounds. Still, this is consistent with the
current behaviour where the broadcast dim would be implicitly assumed
to be "in bounds".

[1]
4145ad2bac/mlir/include/mlir/Interfaces/VectorInterfaces.td (L243-L246)
[2]
https://mlir.llvm.org/docs/Dialects/Vector/#vectortransfer_read-vectortransferreadop
2024-07-16 16:49:52 +01:00

312 lines
12 KiB
C++

//===- LowerVectorMask.cpp - Lower 'vector.mask' operation ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements target-independent rewrites and utilities to lower the
// 'vector.mask' operation.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/Passes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#define DEBUG_TYPE "lower-vector-mask"
namespace mlir {
namespace vector {
#define GEN_PASS_DEF_LOWERVECTORMASKPASS
#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
} // namespace vector
} // namespace mlir
using namespace mlir;
using namespace mlir::vector;
//===----------------------------------------------------------------------===//
// populateVectorMaskOpLoweringPatterns
//===----------------------------------------------------------------------===//
namespace {
/// Progressive lowering of CreateMaskOp.
/// One:
/// %x = vector.create_mask %a, ... : vector<dx...>
/// is replaced by:
/// %l = vector.create_mask ... : vector<...> ; one lower rank
/// %0 = arith.cmpi "slt", %ci, %a |
/// %1 = select %0, %l, %zeroes |
/// %r = vector.insert %1, %pr [i] | d-times
/// %x = ....
/// until a one-dimensional vector is reached.
class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::CreateMaskOp op,
PatternRewriter &rewriter) const override {
auto dstType = cast<VectorType>(op.getResult().getType());
int64_t rank = dstType.getRank();
if (rank <= 1)
return rewriter.notifyMatchFailure(
op, "0-D and 1-D vectors are handled separately");
if (dstType.getScalableDims().front())
return rewriter.notifyMatchFailure(
op, "Cannot unroll leading scalable dim in dstType");
auto loc = op.getLoc();
int64_t dim = dstType.getDimSize(0);
Value idx = op.getOperand(0);
VectorType lowType = VectorType::Builder(dstType).dropDim(0);
Value trueVal = rewriter.create<vector::CreateMaskOp>(
loc, lowType, op.getOperands().drop_front());
Value falseVal = rewriter.create<arith::ConstantOp>(
loc, lowType, rewriter.getZeroAttr(lowType));
Value result = rewriter.create<arith::ConstantOp>(
loc, dstType, rewriter.getZeroAttr(dstType));
for (int64_t d = 0; d < dim; d++) {
Value bnd =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(d));
Value val = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
bnd, idx);
Value sel = rewriter.create<arith::SelectOp>(loc, val, trueVal, falseVal);
result = rewriter.create<vector::InsertOp>(loc, sel, result, d);
}
rewriter.replaceOp(op, result);
return success();
}
};
/// Progressive lowering of ConstantMaskOp.
/// One:
/// %x = vector.constant_mask [a,b]
/// is replaced by:
/// %z = zero-result
/// %l = vector.constant_mask [b]
/// %4 = vector.insert %l, %z[0]
/// ..
/// %x = vector.insert %l, %..[a-1]
/// until a one-dimensional vector is reached. All these operations
/// will be folded at LLVM IR level.
class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto dstType = op.getType();
auto dimSizes = op.getMaskDimSizes();
int64_t rank = dstType.getRank();
if (rank == 0) {
assert(dimSizes.size() == 1 &&
"Expected exactly one dim size for a 0-D vector");
bool value = cast<IntegerAttr>(dimSizes[0]).getInt() == 1;
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, dstType,
DenseIntElementsAttr::get(VectorType::get({}, rewriter.getI1Type()),
value));
return success();
}
int64_t trueDimSize = cast<IntegerAttr>(dimSizes[0]).getInt();
if (rank == 1) {
if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(0)) {
// Use constant splat for 'all set' or 'none set' dims.
// This produces correct code for scalable dimensions (it will lower to
// a constant splat).
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, DenseElementsAttr::get(dstType, trueDimSize != 0));
} else {
// Express constant 1-D case in explicit vector form:
// [T,..,T,F,..,F].
// Note: The verifier would reject this case for scalable vectors.
SmallVector<bool> values(dstType.getDimSize(0), false);
for (int64_t d = 0; d < trueDimSize; d++)
values[d] = true;
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, dstType, rewriter.getBoolVectorAttr(values));
}
return success();
}
if (dstType.getScalableDims().front())
return rewriter.notifyMatchFailure(
op, "Cannot unroll leading scalable dim in dstType");
VectorType lowType = VectorType::Builder(dstType).dropDim(0);
Value trueVal = rewriter.create<vector::ConstantMaskOp>(
loc, lowType, rewriter.getArrayAttr(dimSizes.getValue().drop_front()));
Value result = rewriter.create<arith::ConstantOp>(
loc, dstType, rewriter.getZeroAttr(dstType));
for (int64_t d = 0; d < trueDimSize; d++)
result = rewriter.create<vector::InsertOp>(loc, trueVal, result, d);
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace
void mlir::vector::populateVectorMaskOpLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>(
patterns.getContext(), benefit);
}
//===----------------------------------------------------------------------===//
// populateVectorMaskLoweringPatternsForSideEffectingOps
//===----------------------------------------------------------------------===//
namespace {
/// The `MaskOpRewritePattern` implements a pattern that follows a two-fold
/// matching:
/// 1. It matches a `vector.mask` operation.
/// 2. It invokes `matchAndRewriteMaskableOp` on `MaskableOpInterface` nested
/// in the matched `vector.mask` operation.
///
/// It is required that the replacement op in the pattern replaces the
/// `vector.mask` operation and not the nested `MaskableOpInterface`. This
/// approach allows having patterns that "stop" at every `vector.mask` operation
/// and actually match the traits of its the nested `MaskableOpInterface`.
template <class SourceOp>
struct MaskOpRewritePattern : OpRewritePattern<MaskOp> {
using OpRewritePattern<MaskOp>::OpRewritePattern;
private:
LogicalResult matchAndRewrite(MaskOp maskOp,
PatternRewriter &rewriter) const final {
auto maskableOp = cast_or_null<MaskableOpInterface>(maskOp.getMaskableOp());
if (!maskableOp)
return failure();
SourceOp sourceOp = dyn_cast<SourceOp>(maskableOp.getOperation());
if (!sourceOp)
return failure();
return matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter);
}
protected:
virtual LogicalResult
matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp,
PatternRewriter &rewriter) const = 0;
};
/// Lowers a masked `vector.transfer_read` operation.
struct MaskedTransferReadOpPattern
: public MaskOpRewritePattern<TransferReadOp> {
public:
using MaskOpRewritePattern<TransferReadOp>::MaskOpRewritePattern;
LogicalResult
matchAndRewriteMaskableOp(TransferReadOp readOp, MaskingOpInterface maskingOp,
PatternRewriter &rewriter) const override {
// TODO: The 'vector.mask' passthru is a vector and 'vector.transfer_read'
// expects a scalar. We could only lower one to the other for cases where
// the passthru is a broadcast of a scalar.
if (maskingOp.hasPassthru())
return rewriter.notifyMatchFailure(
maskingOp, "Can't lower passthru to vector.transfer_read");
// Replace the `vector.mask` operation.
rewriter.replaceOpWithNewOp<TransferReadOp>(
maskingOp.getOperation(), readOp.getVectorType(), readOp.getSource(),
readOp.getIndices(), readOp.getPermutationMap(), readOp.getPadding(),
maskingOp.getMask(), readOp.getInBounds());
return success();
}
};
/// Lowers a masked `vector.transfer_write` operation.
struct MaskedTransferWriteOpPattern
: public MaskOpRewritePattern<TransferWriteOp> {
public:
using MaskOpRewritePattern<TransferWriteOp>::MaskOpRewritePattern;
LogicalResult
matchAndRewriteMaskableOp(TransferWriteOp writeOp,
MaskingOpInterface maskingOp,
PatternRewriter &rewriter) const override {
Type resultType =
writeOp.getResult() ? writeOp.getResult().getType() : Type();
// Replace the `vector.mask` operation.
rewriter.replaceOpWithNewOp<TransferWriteOp>(
maskingOp.getOperation(), resultType, writeOp.getVector(),
writeOp.getSource(), writeOp.getIndices(), writeOp.getPermutationMap(),
maskingOp.getMask(), writeOp.getInBounds());
return success();
}
};
/// Lowers a masked `vector.gather` operation.
struct MaskedGatherOpPattern : public MaskOpRewritePattern<GatherOp> {
public:
using MaskOpRewritePattern<GatherOp>::MaskOpRewritePattern;
LogicalResult
matchAndRewriteMaskableOp(GatherOp gatherOp, MaskingOpInterface maskingOp,
PatternRewriter &rewriter) const override {
Value passthru = maskingOp.hasPassthru()
? maskingOp.getPassthru()
: rewriter.create<arith::ConstantOp>(
gatherOp.getLoc(),
rewriter.getZeroAttr(gatherOp.getVectorType()));
// Replace the `vector.mask` operation.
rewriter.replaceOpWithNewOp<GatherOp>(
maskingOp.getOperation(), gatherOp.getVectorType(), gatherOp.getBase(),
gatherOp.getIndices(), gatherOp.getIndexVec(), maskingOp.getMask(),
passthru);
return success();
}
};
struct LowerVectorMaskPass
: public vector::impl::LowerVectorMaskPassBase<LowerVectorMaskPass> {
using Base::Base;
void runOnOperation() override {
Operation *op = getOperation();
MLIRContext *context = op->getContext();
RewritePatternSet loweringPatterns(context);
populateVectorMaskLoweringPatternsForSideEffectingOps(loweringPatterns);
MaskOp::getCanonicalizationPatterns(loweringPatterns, context);
if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns))))
signalPassFailure();
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<vector::VectorDialect>();
}
};
} // namespace
/// Populates instances of `MaskOpRewritePattern` to lower masked operations
/// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and
/// not its nested `MaskableOpInterface`.
void vector::populateVectorMaskLoweringPatternsForSideEffectingOps(
RewritePatternSet &patterns) {
patterns.add<MaskedTransferReadOpPattern, MaskedTransferWriteOpPattern,
MaskedGatherOpPattern>(patterns.getContext());
}
std::unique_ptr<Pass> mlir::vector::createLowerVectorMaskPass() {
return std::make_unique<LowerVectorMaskPass>();
}