llvm-project/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp
Jacques Pienaar 09dfc5713d
[mlir] Enable decoupling two kinds of greedy behavior. (#104649)
The greedy rewriter is used in many different flows and it has a lot of
convenience (work list management, debugging actions, tracing, etc). But
it combines two kinds of greedy behavior 1) how ops are matched, 2)
folding wherever it can.

These are independent forms of greedy and leads to inefficiency. E.g.,
cases where one need to create different phases in lowering and is
required to applying patterns in specific order split across different
passes. Using the driver one ends up needlessly retrying folding/having
multiple rounds of folding attempts, where one final run would have
sufficed.

Of course folks can locally avoid this behavior by just building their
own, but this is also a common requested feature that folks keep on
working around locally in suboptimal ways.

For downstream users, there should be no behavioral change. Updating
from the deprecated should just be a find and replace (e.g., `find ./
-type f -exec sed -i
's|applyPatternsAndFoldGreedily|applyPatternsGreedily|g' {} \;` variety)
as the API arguments hasn't changed between the two.
2024-12-20 08:15:48 -08:00

261 lines
11 KiB
C++

//===- TestMathToVCIXConversion.cpp - Test conversion to VCIX ops ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/VCIXDialect.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
namespace {
/// Return number of extracts required to make input VectorType \vt legal and
/// also return thatlegal vector type.
/// For fixed vectors nothing special is needed. Scalable vectors are legalizes
/// according to LLVM's encoding:
/// https://lists.llvm.org/pipermail/llvm-dev/2020-October/145850.html
static std::pair<unsigned, VectorType> legalizeVectorType(const Type &type) {
VectorType vt = cast<VectorType>(type);
// To simplify test pass, avoid multi-dimensional vectors.
if (!vt || vt.getRank() != 1)
return {0, nullptr};
if (!vt.isScalable())
return {1, vt};
Type eltTy = vt.getElementType();
unsigned sew = 0;
if (eltTy.isF32())
sew = 32;
else if (eltTy.isF64())
sew = 64;
else if (auto intTy = dyn_cast<IntegerType>(eltTy))
sew = intTy.getWidth();
else
return {0, nullptr};
unsigned eltCount = vt.getShape()[0];
const unsigned lmul = eltCount * sew / 64;
unsigned n = lmul > 8 ? llvm::Log2_32(lmul) - 2 : 1;
return {n, VectorType::get({eltCount >> (n - 1)}, eltTy, {true})};
}
/// Replace math.cos(v) operation with vcix.v.iv(v).
struct MathCosToVCIX final : OpRewritePattern<math::CosOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(math::CosOp op,
PatternRewriter &rewriter) const override {
const Type opType = op.getOperand().getType();
auto [n, legalType] = legalizeVectorType(opType);
if (!legalType)
return rewriter.notifyMatchFailure(op, "cannot legalize type for RVV");
Location loc = op.getLoc();
Value vec = op.getOperand();
Attribute immAttr = rewriter.getI32IntegerAttr(0);
Attribute opcodeAttr = rewriter.getI64IntegerAttr(0);
Value rvl = nullptr;
if (legalType.isScalable())
// Use arbitrary runtime vector length when vector type is scalable.
// Proper conversion pass should take it from the IR.
rvl = rewriter.create<arith::ConstantOp>(loc,
rewriter.getI64IntegerAttr(9));
Value res;
if (n == 1) {
res = rewriter.create<vcix::BinaryImmOp>(loc, legalType, opcodeAttr, vec,
immAttr, rvl);
} else {
const unsigned eltCount = legalType.getShape()[0];
Type eltTy = legalType.getElementType();
Value zero = rewriter.create<arith::ConstantOp>(
loc, eltTy, rewriter.getZeroAttr(eltTy));
res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/);
for (unsigned i = 0; i < n; ++i) {
Value extracted = rewriter.create<vector::ScalableExtractOp>(
loc, legalType, vec, i * eltCount);
Value v = rewriter.create<vcix::BinaryImmOp>(loc, legalType, opcodeAttr,
extracted, immAttr, rvl);
res = rewriter.create<vector::ScalableInsertOp>(loc, v, res,
i * eltCount);
}
}
rewriter.replaceOp(op, res);
return success();
}
};
// Replace math.sin(v) operation with vcix.v.sv(v, v).
struct MathSinToVCIX final : OpRewritePattern<math::SinOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(math::SinOp op,
PatternRewriter &rewriter) const override {
const Type opType = op.getOperand().getType();
auto [n, legalType] = legalizeVectorType(opType);
if (!legalType)
return rewriter.notifyMatchFailure(op, "cannot legalize type for RVV");
Location loc = op.getLoc();
Value vec = op.getOperand();
Attribute opcodeAttr = rewriter.getI64IntegerAttr(0);
Value rvl = nullptr;
if (legalType.isScalable())
// Use arbitrary runtime vector length when vector type is scalable.
// Proper conversion pass should take it from the IR.
rvl = rewriter.create<arith::ConstantOp>(loc,
rewriter.getI64IntegerAttr(9));
Value res;
if (n == 1) {
res = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, vec,
vec, rvl);
} else {
const unsigned eltCount = legalType.getShape()[0];
Type eltTy = legalType.getElementType();
Value zero = rewriter.create<arith::ConstantOp>(
loc, eltTy, rewriter.getZeroAttr(eltTy));
res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/);
for (unsigned i = 0; i < n; ++i) {
Value extracted = rewriter.create<vector::ScalableExtractOp>(
loc, legalType, vec, i * eltCount);
Value v = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr,
extracted, extracted, rvl);
res = rewriter.create<vector::ScalableInsertOp>(loc, v, res,
i * eltCount);
}
}
rewriter.replaceOp(op, res);
return success();
}
};
// Replace math.tan(v) operation with vcix.v.sv(v, 0.0f).
struct MathTanToVCIX final : OpRewritePattern<math::TanOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(math::TanOp op,
PatternRewriter &rewriter) const override {
const Type opType = op.getOperand().getType();
auto [n, legalType] = legalizeVectorType(opType);
Type eltTy = legalType.getElementType();
if (!legalType)
return rewriter.notifyMatchFailure(op, "cannot legalize type for RVV");
Location loc = op.getLoc();
Value vec = op.getOperand();
Attribute opcodeAttr = rewriter.getI64IntegerAttr(0);
Value zero = rewriter.create<arith::ConstantOp>(
loc, eltTy, rewriter.getZeroAttr(eltTy));
Value rvl = nullptr;
if (legalType.isScalable())
// Use arbitrary runtime vector length when vector type is scalable.
// Proper conversion pass should take it from the IR.
rvl = rewriter.create<arith::ConstantOp>(loc,
rewriter.getI64IntegerAttr(9));
Value res;
if (n == 1) {
res = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, vec,
zero, rvl);
} else {
const unsigned eltCount = legalType.getShape()[0];
res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/);
for (unsigned i = 0; i < n; ++i) {
Value extracted = rewriter.create<vector::ScalableExtractOp>(
loc, legalType, vec, i * eltCount);
Value v = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr,
extracted, zero, rvl);
res = rewriter.create<vector::ScalableInsertOp>(loc, v, res,
i * eltCount);
}
}
rewriter.replaceOp(op, res);
return success();
}
};
// Replace math.log(v) operation with vcix.v.sv(v, 0).
struct MathLogToVCIX final : OpRewritePattern<math::LogOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(math::LogOp op,
PatternRewriter &rewriter) const override {
const Type opType = op.getOperand().getType();
auto [n, legalType] = legalizeVectorType(opType);
if (!legalType)
return rewriter.notifyMatchFailure(op, "cannot legalize type for RVV");
Location loc = op.getLoc();
Value vec = op.getOperand();
Attribute opcodeAttr = rewriter.getI64IntegerAttr(0);
Value rvl = nullptr;
Value zeroInt = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
if (legalType.isScalable())
// Use arbitrary runtime vector length when vector type is scalable.
// Proper conversion pass should take it from the IR.
rvl = rewriter.create<arith::ConstantOp>(loc,
rewriter.getI64IntegerAttr(9));
Value res;
if (n == 1) {
res = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr, vec,
zeroInt, rvl);
} else {
const unsigned eltCount = legalType.getShape()[0];
Type eltTy = legalType.getElementType();
Value zero = rewriter.create<arith::ConstantOp>(
loc, eltTy, rewriter.getZeroAttr(eltTy));
res = rewriter.create<vector::BroadcastOp>(loc, opType, zero /*dummy*/);
for (unsigned i = 0; i < n; ++i) {
Value extracted = rewriter.create<vector::ScalableExtractOp>(
loc, legalType, vec, i * eltCount);
Value v = rewriter.create<vcix::BinaryOp>(loc, legalType, opcodeAttr,
extracted, zeroInt, rvl);
res = rewriter.create<vector::ScalableInsertOp>(loc, v, res,
i * eltCount);
}
}
rewriter.replaceOp(op, res);
return success();
}
};
struct TestMathToVCIX
: PassWrapper<TestMathToVCIX, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMathToVCIX)
StringRef getArgument() const final { return "test-math-to-vcix"; }
StringRef getDescription() const final {
return "Test lowering patterns that converts some vector operations to "
"VCIX. Since DLA can implement VCIX instructions in completely "
"different way, conversions of that test pass only lives here.";
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<arith::ArithDialect, func::FuncDialect, math::MathDialect,
vcix::VCIXDialect, vector::VectorDialect>();
}
void runOnOperation() override {
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
patterns.add<MathCosToVCIX, MathSinToVCIX, MathTanToVCIX, MathLogToVCIX>(
ctx);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};
} // namespace
namespace test {
void registerTestMathToVCIXPass() { PassRegistration<TestMathToVCIX>(); }
} // namespace test
} // namespace mlir