[mlir][EmitC]Expand the MemRefToEmitC pass - Lowering AllocOp (#148257)

This aims to lower `memref.alloc` to `emitc.call_opaque “malloc” ` or
`emitc.call_opaque “aligned_alloc” `
From:
```
module{
  func.func @allocating() {
  %alloc_5 = memref.alloc() : memref<999xi32>
  return
  }
}
```

To:
```
module {
  emitc.include <"stdlib.h">
  func.func @allocating() {
    %0 = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
    %1 = "emitc.constant"() <{value = 999 : index}> : () -> index
    %2 = emitc.mul %0, %1 : (!emitc.size_t, index) -> !emitc.size_t
    %3 = emitc.call_opaque "malloc"(%2) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
    %4 = emitc.cast %3 : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32>
    return
  }
}
```
Which is then translated as:
```
#include <stdlib.h>
void allocating() {
  size_t v1 = sizeof(int32_t);
  size_t v2 = 999;
  size_t v3 = v1 * v2;
  void* v4 = malloc(v3);
  int32_t* v5 = (int32_t*) v4;
  return;
}
```
This commit is contained in:
Jaden Angella 2025-07-28 18:48:26 -07:00 committed by GitHub
parent 0d05e55f69
commit 5949f4596e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 192 additions and 7 deletions

View File

@ -18,6 +18,8 @@ The following convention is followed:
GCC or Clang. GCC or Clang.
* If `emitc.array` with a dimension of size zero is used, then the code * If `emitc.array` with a dimension of size zero is used, then the code
requires [a GCC extension](https://gcc.gnu.org/onlinedocs/gcc/Zero-Length.html). requires [a GCC extension](https://gcc.gnu.org/onlinedocs/gcc/Zero-Length.html).
* If `aligned_alloc` is passed to an `emitc.call_opaque` operation, then C++17
or C11 is required.
* Else the generated code is compatible with C99. * Else the generated code is compatible with C99.
These restrictions are neither inherent to the EmitC dialect itself nor to the These restrictions are neither inherent to the EmitC dialect itself nor to the

View File

@ -8,6 +8,11 @@
#ifndef MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H #ifndef MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
#define MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H #define MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
constexpr const char *alignedAllocFunctionName = "aligned_alloc";
constexpr const char *mallocFunctionName = "malloc";
constexpr const char *cppStandardLibraryHeader = "cstdlib";
constexpr const char *cStandardLibraryHeader = "stdlib.h";
namespace mlir { namespace mlir {
class DialectRegistry; class DialectRegistry;
class RewritePatternSet; class RewritePatternSet;

View File

@ -841,9 +841,13 @@ def ConvertMathToFuncs : Pass<"convert-math-to-funcs", "ModuleOp"> {
// MemRefToEmitC // MemRefToEmitC
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def ConvertMemRefToEmitC : Pass<"convert-memref-to-emitc"> { def ConvertMemRefToEmitC : Pass<"convert-memref-to-emitc", "ModuleOp"> {
let summary = "Convert MemRef dialect to EmitC dialect"; let summary = "Convert MemRef dialect to EmitC dialect";
let dependentDialects = ["emitc::EmitCDialect"]; let dependentDialects = ["emitc::EmitCDialect"];
let options = [Option<
"lowerToCpp", "lower-to-cpp", "bool",
/*default=*/"false",
/*description=*/"Target C++ (true) instead of C (false)">];
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -19,10 +19,18 @@
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeRange.h" #include "mlir/IR/TypeRange.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include <cstdint>
using namespace mlir; using namespace mlir;
static bool isMemRefTypeLegalForEmitC(MemRefType memRefType) {
return memRefType.hasStaticShape() && memRefType.getLayout().isIdentity() &&
memRefType.getRank() != 0 &&
!llvm::is_contained(memRefType.getShape(), 0);
}
namespace { namespace {
/// Implement the interface to convert MemRef to EmitC. /// Implement the interface to convert MemRef to EmitC.
struct MemRefToEmitCDialectInterface : public ConvertToEmitCPatternInterface { struct MemRefToEmitCDialectInterface : public ConvertToEmitCPatternInterface {
@ -89,6 +97,68 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) {
return resultTy; return resultTy;
} }
struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(memref::AllocOp allocOp, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {
Location loc = allocOp.getLoc();
MemRefType memrefType = allocOp.getType();
if (!isMemRefTypeLegalForEmitC(memrefType)) {
return rewriter.notifyMatchFailure(
loc, "incompatible memref type for EmitC conversion");
}
Type sizeTType = emitc::SizeTType::get(rewriter.getContext());
Type elementType = memrefType.getElementType();
IndexType indexType = rewriter.getIndexType();
emitc::CallOpaqueOp sizeofElementOp = rewriter.create<emitc::CallOpaqueOp>(
loc, sizeTType, rewriter.getStringAttr("sizeof"), ValueRange{},
ArrayAttr::get(rewriter.getContext(), {TypeAttr::get(elementType)}));
int64_t numElements = 1;
for (int64_t dimSize : memrefType.getShape()) {
numElements *= dimSize;
}
Value numElementsValue = rewriter.create<emitc::ConstantOp>(
loc, indexType, rewriter.getIndexAttr(numElements));
Value totalSizeBytes = rewriter.create<emitc::MulOp>(
loc, sizeTType, sizeofElementOp.getResult(0), numElementsValue);
emitc::CallOpaqueOp allocCall;
StringAttr allocFunctionName;
Value alignmentValue;
SmallVector<Value, 2> argsVec;
if (allocOp.getAlignment()) {
allocFunctionName = rewriter.getStringAttr(alignedAllocFunctionName);
alignmentValue = rewriter.create<emitc::ConstantOp>(
loc, sizeTType,
rewriter.getIntegerAttr(indexType,
allocOp.getAlignment().value_or(0)));
argsVec.push_back(alignmentValue);
} else {
allocFunctionName = rewriter.getStringAttr(mallocFunctionName);
}
argsVec.push_back(totalSizeBytes);
ValueRange args(argsVec);
allocCall = rewriter.create<emitc::CallOpaqueOp>(
loc,
emitc::PointerType::get(
emitc::OpaqueType::get(rewriter.getContext(), "void")),
allocFunctionName, args);
emitc::PointerType targetPointerType = emitc::PointerType::get(elementType);
emitc::CastOp castOp = rewriter.create<emitc::CastOp>(
loc, targetPointerType, allocCall.getResult(0));
rewriter.replaceOp(allocOp, castOp);
return success();
}
};
struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> { struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
using OpConversionPattern::OpConversionPattern; using OpConversionPattern::OpConversionPattern;
@ -223,9 +293,7 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
typeConverter.addConversion( typeConverter.addConversion(
[&](MemRefType memRefType) -> std::optional<Type> { [&](MemRefType memRefType) -> std::optional<Type> {
if (!memRefType.hasStaticShape() || if (!isMemRefTypeLegalForEmitC(memRefType)) {
!memRefType.getLayout().isIdentity() || memRefType.getRank() == 0 ||
llvm::is_contained(memRefType.getShape(), 0)) {
return {}; return {};
} }
Type convertedElementType = Type convertedElementType =
@ -252,6 +320,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
void mlir::populateMemRefToEmitCConversionPatterns( void mlir::populateMemRefToEmitCConversionPatterns(
RewritePatternSet &patterns, const TypeConverter &converter) { RewritePatternSet &patterns, const TypeConverter &converter) {
patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad, patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal,
ConvertStore>(converter, patterns.getContext()); ConvertLoad, ConvertStore>(converter, patterns.getContext());
} }

View File

@ -15,6 +15,7 @@
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h" #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Attributes.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
@ -28,9 +29,11 @@ using namespace mlir;
namespace { namespace {
struct ConvertMemRefToEmitCPass struct ConvertMemRefToEmitCPass
: public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> { : public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
using Base::Base;
void runOnOperation() override { void runOnOperation() override {
TypeConverter converter; TypeConverter converter;
ConvertMemRefToEmitCOptions options;
options.lowerToCpp = this->lowerToCpp;
// Fallback for other types. // Fallback for other types.
converter.addConversion([](Type type) -> std::optional<Type> { converter.addConversion([](Type type) -> std::optional<Type> {
if (!emitc::isSupportedEmitCType(type)) if (!emitc::isSupportedEmitCType(type))
@ -50,6 +53,37 @@ struct ConvertMemRefToEmitCPass
if (failed(applyPartialConversion(getOperation(), target, if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) std::move(patterns))))
return signalPassFailure(); return signalPassFailure();
mlir::ModuleOp module = getOperation();
module.walk([&](mlir::emitc::CallOpaqueOp callOp) {
if (callOp.getCallee() != alignedAllocFunctionName &&
callOp.getCallee() != mallocFunctionName) {
return mlir::WalkResult::advance();
}
for (auto &op : *module.getBody()) {
emitc::IncludeOp includeOp = llvm::dyn_cast<mlir::emitc::IncludeOp>(op);
if (!includeOp) {
continue;
}
if (includeOp.getIsStandardInclude() &&
((options.lowerToCpp &&
includeOp.getInclude() == cppStandardLibraryHeader) ||
(!options.lowerToCpp &&
includeOp.getInclude() == cStandardLibraryHeader))) {
return mlir::WalkResult::interrupt();
}
}
mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
StringAttr includeAttr =
builder.getStringAttr(options.lowerToCpp ? cppStandardLibraryHeader
: cStandardLibraryHeader);
builder.create<mlir::emitc::IncludeOp>(
module.getLoc(), includeAttr,
/*is_standard_include=*/builder.getUnitAttr());
return mlir::WalkResult::interrupt();
});
} }
}; };
} // namespace } // namespace

View File

@ -0,0 +1,72 @@
// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=true" %s -split-input-file | FileCheck %s --check-prefix=CPP
// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=false" %s -split-input-file | FileCheck %s --check-prefix=NOCPP
func.func @alloc() {
%alloc = memref.alloc() : memref<999xi32>
return
}
// CPP: module {
// CPP-NEXT: emitc.include <"cstdlib">
// CPP-LABEL: alloc()
// CPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
// CPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index
// CPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t
// CPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "malloc"(%[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
// CPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32>
// CPP-NEXT: return
// NOCPP: module {
// NOCPP-NEXT: emitc.include <"stdlib.h">
// NOCPP-LABEL: alloc()
// NOCPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
// NOCPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index
// NOCPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t
// NOCPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "malloc"(%[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
// NOCPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32>
// NOCPP-NEXT: return
func.func @alloc_aligned() {
%alloc = memref.alloc() {alignment = 64 : i64} : memref<999xf32>
return
}
// CPP-LABEL: alloc_aligned
// CPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
// CPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index
// CPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t
// CPP-NEXT: %[[ALIGNMENT:.*]] = "emitc.constant"() <{value = 64 : index}> : () -> !emitc.size_t
// CPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "aligned_alloc"(%[[ALIGNMENT]], %[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t, !emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
// CPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<f32>
// CPP-NEXT: return
// NOCPP-LABEL: alloc_aligned
// NOCPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
// NOCPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index
// NOCPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t
// NOCPP-NEXT: %[[ALIGNMENT:.*]] = "emitc.constant"() <{value = 64 : index}> : () -> !emitc.size_t
// NOCPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "aligned_alloc"(%[[ALIGNMENT]], %[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t, !emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
// NOCPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<f32>
// NOCPP-NEXT: return
func.func @allocating_multi() {
%alloc_5 = memref.alloc() : memref<7x999xi32>
return
}
// CPP-LABEL: allocating_multi
// CPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
// CPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 6993 : index}> : () -> index
// CPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t
// CPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "malloc"(%[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">
// CPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32>
// CPP-NEXT: return
// NOCPP-LABEL: allocating_multi
// NOCPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
// NOCPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 6993 : index}> : () -> index
// NOCPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t
// NOCPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "malloc"(%[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
// NOCPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32>
// NOCPP-NEXT: return