Ivan Butygin 32389d0c2e [mlir][spirv] Add OpenCL fma op and lowering
Also, it seems Khronos has changed html spec format so small adjustment to script was needed.
Base op parsing is also probably broken.

Differential Revision: https://reviews.llvm.org/D119678
2022-02-15 11:28:20 +03:00

123 lines
5.4 KiB
C++

//===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements patterns to convert Math dialect to SPIR-V dialect.
//
//===----------------------------------------------------------------------===//
#include "../SPIRVCommon/Pattern.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "math-to-spirv-pattern"
using namespace mlir;
//===----------------------------------------------------------------------===//
// Operation conversion
//===----------------------------------------------------------------------===//
// Note that DRR cannot be used for the patterns in this file: we may need to
// convert type along the way, which requires ConversionPattern. DRR generates
// normal RewritePattern.
namespace {
/// Converts math.expm1 to SPIR-V ops.
///
/// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to
/// these operations.
template <typename ExpOp>
class ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> {
public:
using OpConversionPattern<math::ExpM1Op>::OpConversionPattern;
LogicalResult
matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
assert(adaptor.getOperands().size() == 1);
Location loc = operation.getLoc();
auto type = this->getTypeConverter()->convertType(operation.getType());
auto exp = rewriter.create<ExpOp>(loc, type, adaptor.getOperand());
auto one = spirv::ConstantOp::getOne(type, loc, rewriter);
rewriter.replaceOpWithNewOp<spirv::FSubOp>(operation, exp, one);
return success();
}
};
/// Converts math.log1p to SPIR-V ops.
///
/// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to
/// these operations.
template <typename LogOp>
class Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
public:
using OpConversionPattern<math::Log1pOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
assert(adaptor.getOperands().size() == 1);
Location loc = operation.getLoc();
auto type = this->getTypeConverter()->convertType(operation.getType());
auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
auto onePlus =
rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperand());
rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus);
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Pattern population
//===----------------------------------------------------------------------===//
namespace mlir {
void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
// GLSL patterns
patterns
.add<Log1pOpPattern<spirv::GLSLLogOp>, ExpM1OpPattern<spirv::GLSLExpOp>,
spirv::ElementwiseOpPattern<math::AbsOp, spirv::GLSLFAbsOp>,
spirv::ElementwiseOpPattern<math::CeilOp, spirv::GLSLCeilOp>,
spirv::ElementwiseOpPattern<math::CosOp, spirv::GLSLCosOp>,
spirv::ElementwiseOpPattern<math::ExpOp, spirv::GLSLExpOp>,
spirv::ElementwiseOpPattern<math::FloorOp, spirv::GLSLFloorOp>,
spirv::ElementwiseOpPattern<math::FmaOp, spirv::GLSLFmaOp>,
spirv::ElementwiseOpPattern<math::LogOp, spirv::GLSLLogOp>,
spirv::ElementwiseOpPattern<math::PowFOp, spirv::GLSLPowOp>,
spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::GLSLInverseSqrtOp>,
spirv::ElementwiseOpPattern<math::SinOp, spirv::GLSLSinOp>,
spirv::ElementwiseOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>,
spirv::ElementwiseOpPattern<math::TanhOp, spirv::GLSLTanhOp>>(
typeConverter, patterns.getContext());
// OpenCL patterns
patterns.add<Log1pOpPattern<spirv::OCLLogOp>, ExpM1OpPattern<spirv::OCLExpOp>,
spirv::ElementwiseOpPattern<math::AbsOp, spirv::OCLFAbsOp>,
spirv::ElementwiseOpPattern<math::CeilOp, spirv::OCLCeilOp>,
spirv::ElementwiseOpPattern<math::CosOp, spirv::OCLCosOp>,
spirv::ElementwiseOpPattern<math::ErfOp, spirv::OCLErfOp>,
spirv::ElementwiseOpPattern<math::ExpOp, spirv::OCLExpOp>,
spirv::ElementwiseOpPattern<math::FloorOp, spirv::OCLFloorOp>,
spirv::ElementwiseOpPattern<math::FmaOp, spirv::OCLFmaOp>,
spirv::ElementwiseOpPattern<math::LogOp, spirv::OCLLogOp>,
spirv::ElementwiseOpPattern<math::PowFOp, spirv::OCLPowOp>,
spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::OCLRsqrtOp>,
spirv::ElementwiseOpPattern<math::SinOp, spirv::OCLSinOp>,
spirv::ElementwiseOpPattern<math::SqrtOp, spirv::OCLSqrtOp>,
spirv::ElementwiseOpPattern<math::TanhOp, spirv::OCLTanhOp>>(
typeConverter, patterns.getContext());
}
} // namespace mlir