[mlir][EmitC]Expand the MemRefToEmitC pass - Lowering AllocOp
(#148257)
This aims to lower `memref.alloc` to `emitc.call_opaque “malloc” ` or `emitc.call_opaque “aligned_alloc” ` From: ``` module{ func.func @allocating() { %alloc_5 = memref.alloc() : memref<999xi32> return } } ``` To: ``` module { emitc.include <"stdlib.h"> func.func @allocating() { %0 = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t %1 = "emitc.constant"() <{value = 999 : index}> : () -> index %2 = emitc.mul %0, %1 : (!emitc.size_t, index) -> !emitc.size_t %3 = emitc.call_opaque "malloc"(%2) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">> %4 = emitc.cast %3 : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32> return } } ``` Which is then translated as: ``` #include <stdlib.h> void allocating() { size_t v1 = sizeof(int32_t); size_t v2 = 999; size_t v3 = v1 * v2; void* v4 = malloc(v3); int32_t* v5 = (int32_t*) v4; return; } ```
This commit is contained in:
parent
0d05e55f69
commit
5949f4596e
@ -18,6 +18,8 @@ The following convention is followed:
|
||||
GCC or Clang.
|
||||
* If `emitc.array` with a dimension of size zero is used, then the code
|
||||
requires [a GCC extension](https://gcc.gnu.org/onlinedocs/gcc/Zero-Length.html).
|
||||
* If `aligned_alloc` is passed to an `emitc.call_opaque` operation, then C++17
|
||||
or C11 is required.
|
||||
* Else the generated code is compatible with C99.
|
||||
|
||||
These restrictions are neither inherent to the EmitC dialect itself nor to the
|
||||
|
@ -8,6 +8,11 @@
|
||||
#ifndef MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
|
||||
#define MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
|
||||
|
||||
constexpr const char *alignedAllocFunctionName = "aligned_alloc";
|
||||
constexpr const char *mallocFunctionName = "malloc";
|
||||
constexpr const char *cppStandardLibraryHeader = "cstdlib";
|
||||
constexpr const char *cStandardLibraryHeader = "stdlib.h";
|
||||
|
||||
namespace mlir {
|
||||
class DialectRegistry;
|
||||
class RewritePatternSet;
|
||||
|
@ -841,9 +841,13 @@ def ConvertMathToFuncs : Pass<"convert-math-to-funcs", "ModuleOp"> {
|
||||
// MemRefToEmitC
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ConvertMemRefToEmitC : Pass<"convert-memref-to-emitc"> {
|
||||
def ConvertMemRefToEmitC : Pass<"convert-memref-to-emitc", "ModuleOp"> {
|
||||
let summary = "Convert MemRef dialect to EmitC dialect";
|
||||
let dependentDialects = ["emitc::EmitCDialect"];
|
||||
let options = [Option<
|
||||
"lowerToCpp", "lower-to-cpp", "bool",
|
||||
/*default=*/"false",
|
||||
/*description=*/"Target C++ (true) instead of C (false)">];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -19,10 +19,18 @@
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeRange.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include <cstdint>
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
static bool isMemRefTypeLegalForEmitC(MemRefType memRefType) {
|
||||
return memRefType.hasStaticShape() && memRefType.getLayout().isIdentity() &&
|
||||
memRefType.getRank() != 0 &&
|
||||
!llvm::is_contained(memRefType.getShape(), 0);
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Implement the interface to convert MemRef to EmitC.
|
||||
struct MemRefToEmitCDialectInterface : public ConvertToEmitCPatternInterface {
|
||||
@ -89,6 +97,68 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) {
|
||||
return resultTy;
|
||||
}
|
||||
|
||||
struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(memref::AllocOp allocOp, OpAdaptor operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = allocOp.getLoc();
|
||||
MemRefType memrefType = allocOp.getType();
|
||||
if (!isMemRefTypeLegalForEmitC(memrefType)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
loc, "incompatible memref type for EmitC conversion");
|
||||
}
|
||||
|
||||
Type sizeTType = emitc::SizeTType::get(rewriter.getContext());
|
||||
Type elementType = memrefType.getElementType();
|
||||
IndexType indexType = rewriter.getIndexType();
|
||||
emitc::CallOpaqueOp sizeofElementOp = rewriter.create<emitc::CallOpaqueOp>(
|
||||
loc, sizeTType, rewriter.getStringAttr("sizeof"), ValueRange{},
|
||||
ArrayAttr::get(rewriter.getContext(), {TypeAttr::get(elementType)}));
|
||||
|
||||
int64_t numElements = 1;
|
||||
for (int64_t dimSize : memrefType.getShape()) {
|
||||
numElements *= dimSize;
|
||||
}
|
||||
Value numElementsValue = rewriter.create<emitc::ConstantOp>(
|
||||
loc, indexType, rewriter.getIndexAttr(numElements));
|
||||
|
||||
Value totalSizeBytes = rewriter.create<emitc::MulOp>(
|
||||
loc, sizeTType, sizeofElementOp.getResult(0), numElementsValue);
|
||||
|
||||
emitc::CallOpaqueOp allocCall;
|
||||
StringAttr allocFunctionName;
|
||||
Value alignmentValue;
|
||||
SmallVector<Value, 2> argsVec;
|
||||
if (allocOp.getAlignment()) {
|
||||
allocFunctionName = rewriter.getStringAttr(alignedAllocFunctionName);
|
||||
alignmentValue = rewriter.create<emitc::ConstantOp>(
|
||||
loc, sizeTType,
|
||||
rewriter.getIntegerAttr(indexType,
|
||||
allocOp.getAlignment().value_or(0)));
|
||||
argsVec.push_back(alignmentValue);
|
||||
} else {
|
||||
allocFunctionName = rewriter.getStringAttr(mallocFunctionName);
|
||||
}
|
||||
|
||||
argsVec.push_back(totalSizeBytes);
|
||||
ValueRange args(argsVec);
|
||||
|
||||
allocCall = rewriter.create<emitc::CallOpaqueOp>(
|
||||
loc,
|
||||
emitc::PointerType::get(
|
||||
emitc::OpaqueType::get(rewriter.getContext(), "void")),
|
||||
allocFunctionName, args);
|
||||
|
||||
emitc::PointerType targetPointerType = emitc::PointerType::get(elementType);
|
||||
emitc::CastOp castOp = rewriter.create<emitc::CastOp>(
|
||||
loc, targetPointerType, allocCall.getResult(0));
|
||||
|
||||
rewriter.replaceOp(allocOp, castOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
@ -223,9 +293,7 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
|
||||
void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
|
||||
typeConverter.addConversion(
|
||||
[&](MemRefType memRefType) -> std::optional<Type> {
|
||||
if (!memRefType.hasStaticShape() ||
|
||||
!memRefType.getLayout().isIdentity() || memRefType.getRank() == 0 ||
|
||||
llvm::is_contained(memRefType.getShape(), 0)) {
|
||||
if (!isMemRefTypeLegalForEmitC(memRefType)) {
|
||||
return {};
|
||||
}
|
||||
Type convertedElementType =
|
||||
@ -252,6 +320,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
|
||||
|
||||
void mlir::populateMemRefToEmitCConversionPatterns(
|
||||
RewritePatternSet &patterns, const TypeConverter &converter) {
|
||||
patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad,
|
||||
ConvertStore>(converter, patterns.getContext());
|
||||
patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal,
|
||||
ConvertLoad, ConvertStore>(converter, patterns.getContext());
|
||||
}
|
||||
|
@ -15,6 +15,7 @@
|
||||
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
|
||||
#include "mlir/Dialect/EmitC/IR/EmitC.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
@ -28,9 +29,11 @@ using namespace mlir;
|
||||
namespace {
|
||||
struct ConvertMemRefToEmitCPass
|
||||
: public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
|
||||
using Base::Base;
|
||||
void runOnOperation() override {
|
||||
TypeConverter converter;
|
||||
|
||||
ConvertMemRefToEmitCOptions options;
|
||||
options.lowerToCpp = this->lowerToCpp;
|
||||
// Fallback for other types.
|
||||
converter.addConversion([](Type type) -> std::optional<Type> {
|
||||
if (!emitc::isSupportedEmitCType(type))
|
||||
@ -50,6 +53,37 @@ struct ConvertMemRefToEmitCPass
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
|
||||
mlir::ModuleOp module = getOperation();
|
||||
module.walk([&](mlir::emitc::CallOpaqueOp callOp) {
|
||||
if (callOp.getCallee() != alignedAllocFunctionName &&
|
||||
callOp.getCallee() != mallocFunctionName) {
|
||||
return mlir::WalkResult::advance();
|
||||
}
|
||||
|
||||
for (auto &op : *module.getBody()) {
|
||||
emitc::IncludeOp includeOp = llvm::dyn_cast<mlir::emitc::IncludeOp>(op);
|
||||
if (!includeOp) {
|
||||
continue;
|
||||
}
|
||||
if (includeOp.getIsStandardInclude() &&
|
||||
((options.lowerToCpp &&
|
||||
includeOp.getInclude() == cppStandardLibraryHeader) ||
|
||||
(!options.lowerToCpp &&
|
||||
includeOp.getInclude() == cStandardLibraryHeader))) {
|
||||
return mlir::WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
|
||||
mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
|
||||
StringAttr includeAttr =
|
||||
builder.getStringAttr(options.lowerToCpp ? cppStandardLibraryHeader
|
||||
: cStandardLibraryHeader);
|
||||
builder.create<mlir::emitc::IncludeOp>(
|
||||
module.getLoc(), includeAttr,
|
||||
/*is_standard_include=*/builder.getUnitAttr());
|
||||
return mlir::WalkResult::interrupt();
|
||||
});
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
@ -0,0 +1,72 @@
|
||||
// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=true" %s -split-input-file | FileCheck %s --check-prefix=CPP
|
||||
// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=false" %s -split-input-file | FileCheck %s --check-prefix=NOCPP
|
||||
|
||||
func.func @alloc() {
|
||||
%alloc = memref.alloc() : memref<999xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// CPP: module {
|
||||
// CPP-NEXT: emitc.include <"cstdlib">
|
||||
// CPP-LABEL: alloc()
|
||||
// CPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
|
||||
// CPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index
|
||||
// CPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t
|
||||
// CPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "malloc"(%[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
|
||||
// CPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32>
|
||||
// CPP-NEXT: return
|
||||
|
||||
// NOCPP: module {
|
||||
// NOCPP-NEXT: emitc.include <"stdlib.h">
|
||||
// NOCPP-LABEL: alloc()
|
||||
// NOCPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
|
||||
// NOCPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index
|
||||
// NOCPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t
|
||||
// NOCPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "malloc"(%[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
|
||||
// NOCPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32>
|
||||
// NOCPP-NEXT: return
|
||||
|
||||
func.func @alloc_aligned() {
|
||||
%alloc = memref.alloc() {alignment = 64 : i64} : memref<999xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CPP-LABEL: alloc_aligned
|
||||
// CPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
|
||||
// CPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index
|
||||
// CPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t
|
||||
// CPP-NEXT: %[[ALIGNMENT:.*]] = "emitc.constant"() <{value = 64 : index}> : () -> !emitc.size_t
|
||||
// CPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "aligned_alloc"(%[[ALIGNMENT]], %[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t, !emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
|
||||
// CPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<f32>
|
||||
// CPP-NEXT: return
|
||||
|
||||
// NOCPP-LABEL: alloc_aligned
|
||||
// NOCPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
|
||||
// NOCPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index
|
||||
// NOCPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t
|
||||
// NOCPP-NEXT: %[[ALIGNMENT:.*]] = "emitc.constant"() <{value = 64 : index}> : () -> !emitc.size_t
|
||||
// NOCPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "aligned_alloc"(%[[ALIGNMENT]], %[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t, !emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
|
||||
// NOCPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<f32>
|
||||
// NOCPP-NEXT: return
|
||||
|
||||
func.func @allocating_multi() {
|
||||
%alloc_5 = memref.alloc() : memref<7x999xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// CPP-LABEL: allocating_multi
|
||||
// CPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
|
||||
// CPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 6993 : index}> : () -> index
|
||||
// CPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t
|
||||
// CPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "malloc"(%[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">
|
||||
// CPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32>
|
||||
// CPP-NEXT: return
|
||||
|
||||
// NOCPP-LABEL: allocating_multi
|
||||
// NOCPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
|
||||
// NOCPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 6993 : index}> : () -> index
|
||||
// NOCPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t
|
||||
// NOCPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "malloc"(%[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
|
||||
// NOCPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32>
|
||||
// NOCPP-NEXT: return
|
||||
|
Loading…
x
Reference in New Issue
Block a user