Compare commits

...

2 Commits

Author SHA1 Message Date
William S. Moses
c75561b2d5 fmt 2025-08-18 05:19:22 -05:00
William S. Moses
b5a6e88845 [MLIR][Math] Fix mathtolibm to use conversion patterns 2025-08-18 04:58:26 -05:00

View File

@ -29,32 +29,38 @@ namespace {
// Pattern to convert vector operations to scalar operations. This is needed as
// libm calls require scalars.
template <typename Op>
struct VecOpToScalarOp : public OpRewritePattern<Op> {
struct VecOpToScalarOp : public OpConversionPattern<Op> {
public:
using OpRewritePattern<Op>::OpRewritePattern;
using OpConversionPattern<Op>::OpConversionPattern;
LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
LogicalResult
matchAndRewrite(Op op, typename OpConversionPattern<Op>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final;
};
// Pattern to promote an op of a smaller floating point type to F32.
template <typename Op>
struct PromoteOpToF32 : public OpRewritePattern<Op> {
struct PromoteOpToF32 : public OpConversionPattern<Op> {
public:
using OpRewritePattern<Op>::OpRewritePattern;
using OpConversionPattern<Op>::OpConversionPattern;
LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
LogicalResult
matchAndRewrite(Op op, typename OpConversionPattern<Op>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final;
};
// Pattern to convert scalar math operations to calls to libm functions.
// Additionally the libm function signatures are declared.
template <typename Op>
struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
struct ScalarOpToLibmCall : public OpConversionPattern<Op> {
public:
using OpRewritePattern<Op>::OpRewritePattern;
using OpConversionPattern<Op>::OpConversionPattern;
ScalarOpToLibmCall(MLIRContext *context, PatternBenefit benefit,
StringRef floatFunc, StringRef doubleFunc)
: OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
doubleFunc(doubleFunc) {};
: OpConversionPattern<Op>(context, benefit), floatFunc(floatFunc),
doubleFunc(doubleFunc){};
LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
LogicalResult
matchAndRewrite(Op op, typename OpConversionPattern<Op>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final;
private:
std::string floatFunc, doubleFunc;
@ -71,8 +77,9 @@ void populatePatternsForOp(RewritePatternSet &patterns, PatternBenefit benefit,
} // namespace
template <typename Op>
LogicalResult
VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
LogicalResult VecOpToScalarOp<Op>::matchAndRewrite(
Op op, typename OpConversionPattern<Op>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto opType = op.getType();
auto loc = op.getLoc();
auto vecType = dyn_cast<VectorType>(opType);
@ -92,7 +99,7 @@ VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
for (auto linearIndex = 0; linearIndex < numElements; ++linearIndex) {
SmallVector<int64_t> positions = delinearize(linearIndex, strides);
SmallVector<Value> operands;
for (auto input : op->getOperands())
for (auto input : adaptor.getOperands())
operands.push_back(
vector::ExtractOp::create(rewriter, loc, input, positions));
Value scalarOp =
@ -105,8 +112,9 @@ VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
}
template <typename Op>
LogicalResult
PromoteOpToF32<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
LogicalResult PromoteOpToF32<Op>::matchAndRewrite(
Op op, typename OpConversionPattern<Op>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto opType = op.getType();
if (!isa<Float16Type, BFloat16Type>(opType))
return failure();
@ -114,7 +122,7 @@ PromoteOpToF32<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
auto loc = op.getLoc();
auto f32 = rewriter.getF32Type();
auto extendedOperands = llvm::to_vector(
llvm::map_range(op->getOperands(), [&](Value operand) -> Value {
llvm::map_range(adaptor.getOperands(), [&](Value operand) -> Value {
return arith::ExtFOp::create(rewriter, loc, f32, operand);
}));
auto newOp = Op::create(rewriter, loc, f32, extendedOperands);
@ -123,9 +131,9 @@ PromoteOpToF32<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
}
template <typename Op>
LogicalResult
ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
PatternRewriter &rewriter) const {
LogicalResult ScalarOpToLibmCall<Op>::matchAndRewrite(
Op op, typename OpConversionPattern<Op>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto module = SymbolTable::getNearestSymbolTable(op);
auto type = op.getType();
if (!isa<Float32Type, Float64Type>(type))
@ -155,7 +163,7 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
assert(isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn(module, name)));
rewriter.replaceOpWithNewOp<func::CallOp>(op, name, op.getType(),
op->getOperands());
adaptor.getOperands());
return success();
}