[mlir][EmitC]Expand the MemRefToEmitC pass - Lowering AllocOp (#148257)

This aims to lower `memref.alloc` to `emitc.call_opaque “malloc” ` or
`emitc.call_opaque “aligned_alloc” `
From:
```
module{
  func.func @allocating() {
  %alloc_5 = memref.alloc() : memref<999xi32>
  return
  }
}
```

To:
```
module {
  emitc.include <"stdlib.h">
  func.func @allocating() {
    %0 = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
    %1 = "emitc.constant"() <{value = 999 : index}> : () -> index
    %2 = emitc.mul %0, %1 : (!emitc.size_t, index) -> !emitc.size_t
    %3 = emitc.call_opaque "malloc"(%2) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
    %4 = emitc.cast %3 : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32>
    return
  }
}
```
Which is then translated as:
```
#include <stdlib.h>
void allocating() {
  size_t v1 = sizeof(int32_t);
  size_t v2 = 999;
  size_t v3 = v1 * v2;
  void* v4 = malloc(v3);
  int32_t* v5 = (int32_t*) v4;
  return;
}
```
This commit is contained in:
Jaden Angella 2025-07-28 18:48:26 -07:00 committed by GitHub
parent 0d05e55f69
commit 5949f4596e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 192 additions and 7 deletions

View File

@ -18,6 +18,8 @@ The following convention is followed:
GCC or Clang.
* If `emitc.array` with a dimension of size zero is used, then the code
requires [a GCC extension](https://gcc.gnu.org/onlinedocs/gcc/Zero-Length.html).
* If `aligned_alloc` is passed to an `emitc.call_opaque` operation, then C++17
or C11 is required.
* Else the generated code is compatible with C99.
These restrictions are neither inherent to the EmitC dialect itself nor to the

View File

@ -8,6 +8,11 @@
#ifndef MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
#define MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
constexpr const char *alignedAllocFunctionName = "aligned_alloc";
constexpr const char *mallocFunctionName = "malloc";
constexpr const char *cppStandardLibraryHeader = "cstdlib";
constexpr const char *cStandardLibraryHeader = "stdlib.h";
namespace mlir {
class DialectRegistry;
class RewritePatternSet;

View File

@ -841,9 +841,13 @@ def ConvertMathToFuncs : Pass<"convert-math-to-funcs", "ModuleOp"> {
// MemRefToEmitC
//===----------------------------------------------------------------------===//
def ConvertMemRefToEmitC : Pass<"convert-memref-to-emitc"> {
def ConvertMemRefToEmitC : Pass<"convert-memref-to-emitc", "ModuleOp"> {
let summary = "Convert MemRef dialect to EmitC dialect";
let dependentDialects = ["emitc::EmitCDialect"];
let options = [Option<
"lowerToCpp", "lower-to-cpp", "bool",
/*default=*/"false",
/*description=*/"Target C++ (true) instead of C (false)">];
}
//===----------------------------------------------------------------------===//

View File

@ -19,10 +19,18 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
#include <cstdint>
using namespace mlir;
static bool isMemRefTypeLegalForEmitC(MemRefType memRefType) {
return memRefType.hasStaticShape() && memRefType.getLayout().isIdentity() &&
memRefType.getRank() != 0 &&
!llvm::is_contained(memRefType.getShape(), 0);
}
namespace {
/// Implement the interface to convert MemRef to EmitC.
struct MemRefToEmitCDialectInterface : public ConvertToEmitCPatternInterface {
@ -89,6 +97,68 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) {
return resultTy;
}
struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(memref::AllocOp allocOp, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {
Location loc = allocOp.getLoc();
MemRefType memrefType = allocOp.getType();
if (!isMemRefTypeLegalForEmitC(memrefType)) {
return rewriter.notifyMatchFailure(
loc, "incompatible memref type for EmitC conversion");
}
Type sizeTType = emitc::SizeTType::get(rewriter.getContext());
Type elementType = memrefType.getElementType();
IndexType indexType = rewriter.getIndexType();
emitc::CallOpaqueOp sizeofElementOp = rewriter.create<emitc::CallOpaqueOp>(
loc, sizeTType, rewriter.getStringAttr("sizeof"), ValueRange{},
ArrayAttr::get(rewriter.getContext(), {TypeAttr::get(elementType)}));
int64_t numElements = 1;
for (int64_t dimSize : memrefType.getShape()) {
numElements *= dimSize;
}
Value numElementsValue = rewriter.create<emitc::ConstantOp>(
loc, indexType, rewriter.getIndexAttr(numElements));
Value totalSizeBytes = rewriter.create<emitc::MulOp>(
loc, sizeTType, sizeofElementOp.getResult(0), numElementsValue);
emitc::CallOpaqueOp allocCall;
StringAttr allocFunctionName;
Value alignmentValue;
SmallVector<Value, 2> argsVec;
if (allocOp.getAlignment()) {
allocFunctionName = rewriter.getStringAttr(alignedAllocFunctionName);
alignmentValue = rewriter.create<emitc::ConstantOp>(
loc, sizeTType,
rewriter.getIntegerAttr(indexType,
allocOp.getAlignment().value_or(0)));
argsVec.push_back(alignmentValue);
} else {
allocFunctionName = rewriter.getStringAttr(mallocFunctionName);
}
argsVec.push_back(totalSizeBytes);
ValueRange args(argsVec);
allocCall = rewriter.create<emitc::CallOpaqueOp>(
loc,
emitc::PointerType::get(
emitc::OpaqueType::get(rewriter.getContext(), "void")),
allocFunctionName, args);
emitc::PointerType targetPointerType = emitc::PointerType::get(elementType);
emitc::CastOp castOp = rewriter.create<emitc::CastOp>(
loc, targetPointerType, allocCall.getResult(0));
rewriter.replaceOp(allocOp, castOp);
return success();
}
};
struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
using OpConversionPattern::OpConversionPattern;
@ -223,9 +293,7 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
typeConverter.addConversion(
[&](MemRefType memRefType) -> std::optional<Type> {
if (!memRefType.hasStaticShape() ||
!memRefType.getLayout().isIdentity() || memRefType.getRank() == 0 ||
llvm::is_contained(memRefType.getShape(), 0)) {
if (!isMemRefTypeLegalForEmitC(memRefType)) {
return {};
}
Type convertedElementType =
@ -252,6 +320,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
void mlir::populateMemRefToEmitCConversionPatterns(
RewritePatternSet &patterns, const TypeConverter &converter) {
patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad,
ConvertStore>(converter, patterns.getContext());
patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal,
ConvertLoad, ConvertStore>(converter, patterns.getContext());
}

View File

@ -15,6 +15,7 @@
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Attributes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
@ -28,9 +29,11 @@ using namespace mlir;
namespace {
struct ConvertMemRefToEmitCPass
: public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
using Base::Base;
void runOnOperation() override {
TypeConverter converter;
ConvertMemRefToEmitCOptions options;
options.lowerToCpp = this->lowerToCpp;
// Fallback for other types.
converter.addConversion([](Type type) -> std::optional<Type> {
if (!emitc::isSupportedEmitCType(type))
@ -50,6 +53,37 @@ struct ConvertMemRefToEmitCPass
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
mlir::ModuleOp module = getOperation();
module.walk([&](mlir::emitc::CallOpaqueOp callOp) {
if (callOp.getCallee() != alignedAllocFunctionName &&
callOp.getCallee() != mallocFunctionName) {
return mlir::WalkResult::advance();
}
for (auto &op : *module.getBody()) {
emitc::IncludeOp includeOp = llvm::dyn_cast<mlir::emitc::IncludeOp>(op);
if (!includeOp) {
continue;
}
if (includeOp.getIsStandardInclude() &&
((options.lowerToCpp &&
includeOp.getInclude() == cppStandardLibraryHeader) ||
(!options.lowerToCpp &&
includeOp.getInclude() == cStandardLibraryHeader))) {
return mlir::WalkResult::interrupt();
}
}
mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
StringAttr includeAttr =
builder.getStringAttr(options.lowerToCpp ? cppStandardLibraryHeader
: cStandardLibraryHeader);
builder.create<mlir::emitc::IncludeOp>(
module.getLoc(), includeAttr,
/*is_standard_include=*/builder.getUnitAttr());
return mlir::WalkResult::interrupt();
});
}
};
} // namespace

View File

@ -0,0 +1,72 @@
// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=true" %s -split-input-file | FileCheck %s --check-prefix=CPP
// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=false" %s -split-input-file | FileCheck %s --check-prefix=NOCPP
func.func @alloc() {
%alloc = memref.alloc() : memref<999xi32>
return
}
// CPP: module {
// CPP-NEXT: emitc.include <"cstdlib">
// CPP-LABEL: alloc()
// CPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
// CPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index
// CPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t
// CPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "malloc"(%[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
// CPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32>
// CPP-NEXT: return
// NOCPP: module {
// NOCPP-NEXT: emitc.include <"stdlib.h">
// NOCPP-LABEL: alloc()
// NOCPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
// NOCPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index
// NOCPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t
// NOCPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "malloc"(%[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
// NOCPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32>
// NOCPP-NEXT: return
func.func @alloc_aligned() {
%alloc = memref.alloc() {alignment = 64 : i64} : memref<999xf32>
return
}
// CPP-LABEL: alloc_aligned
// CPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
// CPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index
// CPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t
// CPP-NEXT: %[[ALIGNMENT:.*]] = "emitc.constant"() <{value = 64 : index}> : () -> !emitc.size_t
// CPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "aligned_alloc"(%[[ALIGNMENT]], %[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t, !emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
// CPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<f32>
// CPP-NEXT: return
// NOCPP-LABEL: alloc_aligned
// NOCPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
// NOCPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index
// NOCPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t
// NOCPP-NEXT: %[[ALIGNMENT:.*]] = "emitc.constant"() <{value = 64 : index}> : () -> !emitc.size_t
// NOCPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "aligned_alloc"(%[[ALIGNMENT]], %[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t, !emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
// NOCPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<f32>
// NOCPP-NEXT: return
func.func @allocating_multi() {
%alloc_5 = memref.alloc() : memref<7x999xi32>
return
}
// CPP-LABEL: allocating_multi
// CPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
// CPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 6993 : index}> : () -> index
// CPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t
// CPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "malloc"(%[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">
// CPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32>
// CPP-NEXT: return
// NOCPP-LABEL: allocating_multi
// NOCPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
// NOCPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 6993 : index}> : () -> index
// NOCPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t
// NOCPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "malloc"(%[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
// NOCPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32>
// NOCPP-NEXT: return