//===- MemRefToEmitC.cpp - MemRef to EmitC conversion ---------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements patterns to convert memref ops into emitc ops. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h" #include "mlir/Conversion/ConvertToEmitC/ToEmitCInterface.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeRange.h" #include "mlir/IR/Value.h" #include "mlir/Transforms/DialectConversion.h" #include #include using namespace mlir; static bool isMemRefTypeLegalForEmitC(MemRefType memRefType) { return memRefType.hasStaticShape() && memRefType.getLayout().isIdentity() && memRefType.getRank() != 0 && !llvm::is_contained(memRefType.getShape(), 0); } namespace { /// Implement the interface to convert MemRef to EmitC. struct MemRefToEmitCDialectInterface : public ConvertToEmitCPatternInterface { using ConvertToEmitCPatternInterface::ConvertToEmitCPatternInterface; /// Hook for derived dialect interface to provide conversion patterns /// and mark dialect legal for the conversion target. void populateConvertToEmitCConversionPatterns( ConversionTarget &target, TypeConverter &typeConverter, RewritePatternSet &patterns) const final { populateMemRefToEmitCTypeConversion(typeConverter); populateMemRefToEmitCConversionPatterns(patterns, typeConverter); } }; } // namespace void mlir::registerConvertMemRefToEmitCInterface(DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) { dialect->addInterfaces(); }); } //===----------------------------------------------------------------------===// // Conversion Patterns //===----------------------------------------------------------------------===// namespace { struct ConvertAlloca final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(memref::AllocaOp op, OpAdaptor operands, ConversionPatternRewriter &rewriter) const override { if (!op.getType().hasStaticShape()) { return rewriter.notifyMatchFailure( op.getLoc(), "cannot transform alloca with dynamic shape"); } if (op.getAlignment().value_or(1) > 1) { // TODO: Allow alignment if it is not more than the natural alignment // of the C array. return rewriter.notifyMatchFailure( op.getLoc(), "cannot transform alloca with alignment requirement"); } auto resultTy = getTypeConverter()->convertType(op.getType()); if (!resultTy) { return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type"); } auto noInit = emitc::OpaqueAttr::get(getContext(), ""); rewriter.replaceOpWithNewOp(op, resultTy, noInit); return success(); } }; Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) { Type resultTy; if (opTy.getRank() == 0) { resultTy = typeConverter->convertType(mlir::getElementTypeOrSelf(opTy)); } else { resultTy = typeConverter->convertType(opTy); } return resultTy; } static Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType, OpBuilder &builder) { assert(isMemRefTypeLegalForEmitC(memrefType) && "incompatible memref type for EmitC conversion"); emitc::CallOpaqueOp elementSize = emitc::CallOpaqueOp::create( builder, loc, emitc::SizeTType::get(builder.getContext()), builder.getStringAttr("sizeof"), ValueRange{}, ArrayAttr::get(builder.getContext(), {TypeAttr::get(memrefType.getElementType())})); IndexType indexType = builder.getIndexType(); int64_t numElements = std::accumulate(memrefType.getShape().begin(), memrefType.getShape().end(), int64_t{1}, std::multiplies()); emitc::ConstantOp numElementsValue = emitc::ConstantOp::create( builder, loc, indexType, builder.getIndexAttr(numElements)); Type sizeTType = emitc::SizeTType::get(builder.getContext()); emitc::MulOp totalSizeBytes = emitc::MulOp::create( builder, loc, sizeTType, elementSize.getResult(0), numElementsValue); return totalSizeBytes.getResult(); } static emitc::ApplyOp createPointerFromEmitcArray(Location loc, OpBuilder &builder, TypedValue arrayValue) { emitc::ConstantOp zeroIndex = emitc::ConstantOp::create( builder, loc, builder.getIndexType(), builder.getIndexAttr(0)); emitc::ArrayType arrayType = arrayValue.getType(); llvm::SmallVector indices(arrayType.getRank(), zeroIndex); emitc::SubscriptOp subPtr = emitc::SubscriptOp::create(builder, loc, arrayValue, ValueRange(indices)); emitc::ApplyOp ptr = emitc::ApplyOp::create( builder, loc, emitc::PointerType::get(arrayType.getElementType()), builder.getStringAttr("&"), subPtr); return ptr; } struct ConvertAlloc final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(memref::AllocOp allocOp, OpAdaptor operands, ConversionPatternRewriter &rewriter) const override { Location loc = allocOp.getLoc(); MemRefType memrefType = allocOp.getType(); if (!isMemRefTypeLegalForEmitC(memrefType)) { return rewriter.notifyMatchFailure( loc, "incompatible memref type for EmitC conversion"); } Type sizeTType = emitc::SizeTType::get(rewriter.getContext()); Type elementType = memrefType.getElementType(); IndexType indexType = rewriter.getIndexType(); emitc::CallOpaqueOp sizeofElementOp = emitc::CallOpaqueOp::create( rewriter, loc, sizeTType, rewriter.getStringAttr("sizeof"), ValueRange{}, ArrayAttr::get(rewriter.getContext(), {TypeAttr::get(elementType)})); int64_t numElements = 1; for (int64_t dimSize : memrefType.getShape()) { numElements *= dimSize; } Value numElementsValue = emitc::ConstantOp::create( rewriter, loc, indexType, rewriter.getIndexAttr(numElements)); Value totalSizeBytes = emitc::MulOp::create(rewriter, loc, sizeTType, sizeofElementOp.getResult(0), numElementsValue); emitc::CallOpaqueOp allocCall; StringAttr allocFunctionName; Value alignmentValue; SmallVector argsVec; if (allocOp.getAlignment()) { allocFunctionName = rewriter.getStringAttr(alignedAllocFunctionName); alignmentValue = emitc::ConstantOp::create( rewriter, loc, sizeTType, rewriter.getIntegerAttr(indexType, allocOp.getAlignment().value_or(0))); argsVec.push_back(alignmentValue); } else { allocFunctionName = rewriter.getStringAttr(mallocFunctionName); } argsVec.push_back(totalSizeBytes); ValueRange args(argsVec); allocCall = emitc::CallOpaqueOp::create( rewriter, loc, emitc::PointerType::get( emitc::OpaqueType::get(rewriter.getContext(), "void")), allocFunctionName, args); emitc::PointerType targetPointerType = emitc::PointerType::get(elementType); emitc::CastOp castOp = emitc::CastOp::create( rewriter, loc, targetPointerType, allocCall.getResult(0)); rewriter.replaceOp(allocOp, castOp); return success(); } }; struct ConvertCopy final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(memref::CopyOp copyOp, OpAdaptor operands, ConversionPatternRewriter &rewriter) const override { Location loc = copyOp.getLoc(); MemRefType srcMemrefType = cast(copyOp.getSource().getType()); MemRefType targetMemrefType = cast(copyOp.getTarget().getType()); if (!isMemRefTypeLegalForEmitC(srcMemrefType)) return rewriter.notifyMatchFailure( loc, "incompatible source memref type for EmitC conversion"); if (!isMemRefTypeLegalForEmitC(targetMemrefType)) return rewriter.notifyMatchFailure( loc, "incompatible target memref type for EmitC conversion"); auto srcArrayValue = cast>(operands.getSource()); emitc::ApplyOp srcPtr = createPointerFromEmitcArray(loc, rewriter, srcArrayValue); auto targetArrayValue = cast>(operands.getTarget()); emitc::ApplyOp targetPtr = createPointerFromEmitcArray(loc, rewriter, targetArrayValue); emitc::CallOpaqueOp memCpyCall = emitc::CallOpaqueOp::create( rewriter, loc, TypeRange{}, "memcpy", ValueRange{ targetPtr.getResult(), srcPtr.getResult(), calculateMemrefTotalSizeBytes(loc, srcMemrefType, rewriter)}); rewriter.replaceOp(copyOp, memCpyCall.getResults()); return success(); } }; struct ConvertGlobal final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(memref::GlobalOp op, OpAdaptor operands, ConversionPatternRewriter &rewriter) const override { MemRefType opTy = op.getType(); if (!op.getType().hasStaticShape()) { return rewriter.notifyMatchFailure( op.getLoc(), "cannot transform global with dynamic shape"); } if (op.getAlignment().value_or(1) > 1) { // TODO: Extend GlobalOp to specify alignment via the `alignas` specifier. return rewriter.notifyMatchFailure( op.getLoc(), "global variable with alignment requirement is " "currently not supported"); } Type resultTy = convertMemRefType(opTy, getTypeConverter()); if (!resultTy) { return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert result type"); } SymbolTable::Visibility visibility = SymbolTable::getSymbolVisibility(op); if (visibility != SymbolTable::Visibility::Public && visibility != SymbolTable::Visibility::Private) { return rewriter.notifyMatchFailure( op.getLoc(), "only public and private visibility is currently supported"); } // We are explicit in specifing the linkage because the default linkage // for constants is different in C and C++. bool staticSpecifier = visibility == SymbolTable::Visibility::Private; bool externSpecifier = !staticSpecifier; Attribute initialValue = operands.getInitialValueAttr(); if (opTy.getRank() == 0) { auto elementsAttr = llvm::cast(*op.getInitialValue()); initialValue = elementsAttr.getSplatValue(); } if (isa_and_present(initialValue)) initialValue = {}; rewriter.replaceOpWithNewOp( op, operands.getSymName(), resultTy, initialValue, externSpecifier, staticSpecifier, operands.getConstant()); return success(); } }; struct ConvertGetGlobal final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(memref::GetGlobalOp op, OpAdaptor operands, ConversionPatternRewriter &rewriter) const override { MemRefType opTy = op.getType(); Type resultTy = convertMemRefType(opTy, getTypeConverter()); if (!resultTy) { return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert result type"); } if (opTy.getRank() == 0) { emitc::LValueType lvalueType = emitc::LValueType::get(resultTy); emitc::GetGlobalOp globalLValue = emitc::GetGlobalOp::create( rewriter, op.getLoc(), lvalueType, operands.getNameAttr()); emitc::PointerType pointerType = emitc::PointerType::get(resultTy); rewriter.replaceOpWithNewOp( op, pointerType, rewriter.getStringAttr("&"), globalLValue); return success(); } rewriter.replaceOpWithNewOp(op, resultTy, operands.getNameAttr()); return success(); } }; struct ConvertLoad final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(memref::LoadOp op, OpAdaptor operands, ConversionPatternRewriter &rewriter) const override { auto resultTy = getTypeConverter()->convertType(op.getType()); if (!resultTy) { return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type"); } auto arrayValue = dyn_cast>(operands.getMemref()); if (!arrayValue) { return rewriter.notifyMatchFailure(op.getLoc(), "expected array type"); } auto subscript = emitc::SubscriptOp::create( rewriter, op.getLoc(), arrayValue, operands.getIndices()); rewriter.replaceOpWithNewOp(op, resultTy, subscript); return success(); } }; struct ConvertStore final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(memref::StoreOp op, OpAdaptor operands, ConversionPatternRewriter &rewriter) const override { auto arrayValue = dyn_cast>(operands.getMemref()); if (!arrayValue) { return rewriter.notifyMatchFailure(op.getLoc(), "expected array type"); } auto subscript = emitc::SubscriptOp::create( rewriter, op.getLoc(), arrayValue, operands.getIndices()); rewriter.replaceOpWithNewOp(op, subscript, operands.getValue()); return success(); } }; } // namespace void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { typeConverter.addConversion( [&](MemRefType memRefType) -> std::optional { if (!isMemRefTypeLegalForEmitC(memRefType)) { return {}; } Type convertedElementType = typeConverter.convertType(memRefType.getElementType()); if (!convertedElementType) return {}; return emitc::ArrayType::get(memRefType.getShape(), convertedElementType); }); auto materializeAsUnrealizedCast = [](OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) -> Value { if (inputs.size() != 1) return Value(); return UnrealizedConversionCastOp::create(builder, loc, resultType, inputs) .getResult(0); }; typeConverter.addSourceMaterialization(materializeAsUnrealizedCast); typeConverter.addTargetMaterialization(materializeAsUnrealizedCast); } void mlir::populateMemRefToEmitCConversionPatterns( RewritePatternSet &patterns, const TypeConverter &converter) { patterns.add( converter, patterns.getContext()); }