
Extend the "storage class" <-> "memory space" map for the Vulkan SPIR-V environment to include the Image class. 12 is chosen as the next available value in the MemRef memory space list. Signed-off-by: Jack Frankland <jack.frankland@arm.com>
330 lines
12 KiB
C++
330 lines
12 KiB
C++
//===- MapMemRefStorageCLassPass.cpp --------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements a pass to map numeric MemRef memory spaces to
|
|
// symbolic ones defined in the SPIR-V specification.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
|
|
|
|
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
|
|
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
|
|
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
|
|
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
|
|
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
|
|
#include "mlir/IR/Attributes.h"
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/Operation.h"
|
|
#include "mlir/IR/Visitors.h"
|
|
#include "mlir/Interfaces/FunctionInterfaces.h"
|
|
#include "llvm/ADT/SmallVectorExtras.h"
|
|
#include "llvm/ADT/StringExtras.h"
|
|
#include "llvm/Support/Debug.h"
|
|
#include <optional>
|
|
|
|
namespace mlir {
|
|
#define GEN_PASS_DEF_MAPMEMREFSTORAGECLASS
|
|
#include "mlir/Conversion/Passes.h.inc"
|
|
} // namespace mlir
|
|
|
|
#define DEBUG_TYPE "mlir-map-memref-storage-class"
|
|
|
|
using namespace mlir;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Mappings
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Mapping between SPIR-V storage classes to memref memory spaces.
|
|
///
|
|
/// Note: memref does not have a defined semantics for each memory space; it
|
|
/// depends on the context where it is used. There are no particular reasons
|
|
/// behind the number assignments; we try to follow NVVM conventions and largely
|
|
/// give common storage classes a smaller number.
|
|
#define VULKAN_STORAGE_SPACE_MAP_LIST(MAP_FN) \
|
|
MAP_FN(spirv::StorageClass::StorageBuffer, 0) \
|
|
MAP_FN(spirv::StorageClass::Generic, 1) \
|
|
MAP_FN(spirv::StorageClass::Workgroup, 3) \
|
|
MAP_FN(spirv::StorageClass::Uniform, 4) \
|
|
MAP_FN(spirv::StorageClass::Private, 5) \
|
|
MAP_FN(spirv::StorageClass::Function, 6) \
|
|
MAP_FN(spirv::StorageClass::PushConstant, 7) \
|
|
MAP_FN(spirv::StorageClass::UniformConstant, 8) \
|
|
MAP_FN(spirv::StorageClass::Input, 9) \
|
|
MAP_FN(spirv::StorageClass::Output, 10) \
|
|
MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 11) \
|
|
MAP_FN(spirv::StorageClass::Image, 12)
|
|
|
|
std::optional<spirv::StorageClass>
|
|
spirv::mapMemorySpaceToVulkanStorageClass(Attribute memorySpaceAttr) {
|
|
// Handle null memory space attribute specially.
|
|
if (!memorySpaceAttr)
|
|
return spirv::StorageClass::StorageBuffer;
|
|
|
|
// Unknown dialect custom attributes are not supported by default.
|
|
// Downstream callers should plug in more specialized ones.
|
|
auto intAttr = dyn_cast<IntegerAttr>(memorySpaceAttr);
|
|
if (!intAttr)
|
|
return std::nullopt;
|
|
unsigned memorySpace = intAttr.getInt();
|
|
|
|
#define STORAGE_SPACE_MAP_FN(storage, space) \
|
|
case space: \
|
|
return storage;
|
|
|
|
switch (memorySpace) {
|
|
VULKAN_STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
|
|
default:
|
|
break;
|
|
}
|
|
return std::nullopt;
|
|
|
|
#undef STORAGE_SPACE_MAP_FN
|
|
}
|
|
|
|
std::optional<unsigned>
|
|
spirv::mapVulkanStorageClassToMemorySpace(spirv::StorageClass storageClass) {
|
|
#define STORAGE_SPACE_MAP_FN(storage, space) \
|
|
case storage: \
|
|
return space;
|
|
|
|
switch (storageClass) {
|
|
VULKAN_STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
|
|
default:
|
|
break;
|
|
}
|
|
return std::nullopt;
|
|
|
|
#undef STORAGE_SPACE_MAP_FN
|
|
}
|
|
|
|
#undef VULKAN_STORAGE_SPACE_MAP_LIST
|
|
|
|
#define OPENCL_STORAGE_SPACE_MAP_LIST(MAP_FN) \
|
|
MAP_FN(spirv::StorageClass::CrossWorkgroup, 0) \
|
|
MAP_FN(spirv::StorageClass::Generic, 1) \
|
|
MAP_FN(spirv::StorageClass::Workgroup, 3) \
|
|
MAP_FN(spirv::StorageClass::UniformConstant, 4) \
|
|
MAP_FN(spirv::StorageClass::Private, 5) \
|
|
MAP_FN(spirv::StorageClass::Function, 6) \
|
|
MAP_FN(spirv::StorageClass::Image, 7)
|
|
|
|
std::optional<spirv::StorageClass>
|
|
spirv::mapMemorySpaceToOpenCLStorageClass(Attribute memorySpaceAttr) {
|
|
// Handle null memory space attribute specially.
|
|
if (!memorySpaceAttr)
|
|
return spirv::StorageClass::CrossWorkgroup;
|
|
|
|
// Unknown dialect custom attributes are not supported by default.
|
|
// Downstream callers should plug in more specialized ones.
|
|
auto intAttr = dyn_cast<IntegerAttr>(memorySpaceAttr);
|
|
if (!intAttr)
|
|
return std::nullopt;
|
|
unsigned memorySpace = intAttr.getInt();
|
|
|
|
#define STORAGE_SPACE_MAP_FN(storage, space) \
|
|
case space: \
|
|
return storage;
|
|
|
|
switch (memorySpace) {
|
|
OPENCL_STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
|
|
default:
|
|
break;
|
|
}
|
|
return std::nullopt;
|
|
|
|
#undef STORAGE_SPACE_MAP_FN
|
|
}
|
|
|
|
std::optional<unsigned>
|
|
spirv::mapOpenCLStorageClassToMemorySpace(spirv::StorageClass storageClass) {
|
|
#define STORAGE_SPACE_MAP_FN(storage, space) \
|
|
case storage: \
|
|
return space;
|
|
|
|
switch (storageClass) {
|
|
OPENCL_STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
|
|
default:
|
|
break;
|
|
}
|
|
return std::nullopt;
|
|
|
|
#undef STORAGE_SPACE_MAP_FN
|
|
}
|
|
|
|
#undef OPENCL_STORAGE_SPACE_MAP_LIST
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Type Converter
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
spirv::MemorySpaceToStorageClassConverter::MemorySpaceToStorageClassConverter(
|
|
const spirv::MemorySpaceToStorageClassMap &memorySpaceMap)
|
|
: memorySpaceMap(memorySpaceMap) {
|
|
// Pass through for all other types.
|
|
addConversion([](Type type) { return type; });
|
|
|
|
addConversion([this](BaseMemRefType memRefType) -> std::optional<Type> {
|
|
std::optional<spirv::StorageClass> storage =
|
|
this->memorySpaceMap(memRefType.getMemorySpace());
|
|
if (!storage) {
|
|
LLVM_DEBUG(llvm::dbgs()
|
|
<< "cannot convert " << memRefType
|
|
<< " due to being unable to find memory space in map\n");
|
|
return std::nullopt;
|
|
}
|
|
|
|
auto storageAttr =
|
|
spirv::StorageClassAttr::get(memRefType.getContext(), *storage);
|
|
if (auto rankedType = dyn_cast<MemRefType>(memRefType)) {
|
|
return MemRefType::get(memRefType.getShape(), memRefType.getElementType(),
|
|
rankedType.getLayout(), storageAttr);
|
|
}
|
|
return UnrankedMemRefType::get(memRefType.getElementType(), storageAttr);
|
|
});
|
|
|
|
addConversion([this](FunctionType type) {
|
|
auto inputs = llvm::map_to_vector(
|
|
type.getInputs(), [this](Type ty) { return convertType(ty); });
|
|
auto results = llvm::map_to_vector(
|
|
type.getResults(), [this](Type ty) { return convertType(ty); });
|
|
return FunctionType::get(type.getContext(), inputs, results);
|
|
});
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Conversion Target
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Returns true if the given `type` is considered as legal for SPIR-V
|
|
/// conversion.
|
|
static bool isLegalType(Type type) {
|
|
if (auto memRefType = dyn_cast<BaseMemRefType>(type)) {
|
|
Attribute spaceAttr = memRefType.getMemorySpace();
|
|
return isa_and_nonnull<spirv::StorageClassAttr>(spaceAttr);
|
|
}
|
|
return true;
|
|
}
|
|
|
|
/// Returns true if the given `attr` is considered as legal for SPIR-V
|
|
/// conversion.
|
|
static bool isLegalAttr(Attribute attr) {
|
|
if (auto typeAttr = dyn_cast<TypeAttr>(attr))
|
|
return isLegalType(typeAttr.getValue());
|
|
return true;
|
|
}
|
|
|
|
/// Returns true if the given `op` is considered as legal for SPIR-V conversion.
|
|
static bool isLegalOp(Operation *op) {
|
|
if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
|
|
return llvm::all_of(funcOp.getArgumentTypes(), isLegalType) &&
|
|
llvm::all_of(funcOp.getResultTypes(), isLegalType) &&
|
|
llvm::all_of(funcOp.getFunctionBody().getArgumentTypes(),
|
|
isLegalType);
|
|
}
|
|
|
|
auto attrs = llvm::map_range(op->getAttrs(), [](const NamedAttribute &attr) {
|
|
return attr.getValue();
|
|
});
|
|
|
|
return llvm::all_of(op->getOperandTypes(), isLegalType) &&
|
|
llvm::all_of(op->getResultTypes(), isLegalType) &&
|
|
llvm::all_of(attrs, isLegalAttr);
|
|
}
|
|
|
|
std::unique_ptr<ConversionTarget>
|
|
spirv::getMemorySpaceToStorageClassTarget(MLIRContext &context) {
|
|
auto target = std::make_unique<ConversionTarget>(context);
|
|
target->markUnknownOpDynamicallyLegal(isLegalOp);
|
|
return target;
|
|
}
|
|
|
|
void spirv::convertMemRefTypesAndAttrs(
|
|
Operation *op, MemorySpaceToStorageClassConverter &typeConverter) {
|
|
AttrTypeReplacer replacer;
|
|
replacer.addReplacement([&typeConverter](BaseMemRefType origType)
|
|
-> std::optional<BaseMemRefType> {
|
|
return typeConverter.convertType<BaseMemRefType>(origType);
|
|
});
|
|
|
|
replacer.recursivelyReplaceElementsIn(op, /*replaceAttrs=*/true,
|
|
/*replaceLocs=*/false,
|
|
/*replaceTypes=*/true);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Conversion Pass
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
class MapMemRefStorageClassPass final
|
|
: public impl::MapMemRefStorageClassBase<MapMemRefStorageClassPass> {
|
|
public:
|
|
MapMemRefStorageClassPass() = default;
|
|
|
|
explicit MapMemRefStorageClassPass(
|
|
const spirv::MemorySpaceToStorageClassMap &memorySpaceMap)
|
|
: memorySpaceMap(memorySpaceMap) {}
|
|
|
|
LogicalResult initializeOptions(
|
|
StringRef options,
|
|
function_ref<LogicalResult(const Twine &)> errorHandler) override {
|
|
if (failed(Pass::initializeOptions(options, errorHandler)))
|
|
return failure();
|
|
|
|
if (clientAPI == "opencl")
|
|
memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
|
|
else if (clientAPI != "vulkan")
|
|
return errorHandler(llvm::Twine("Invalid clienAPI: ") + clientAPI);
|
|
|
|
return success();
|
|
}
|
|
|
|
void runOnOperation() override {
|
|
MLIRContext *context = &getContext();
|
|
Operation *op = getOperation();
|
|
|
|
spirv::MemorySpaceToStorageClassMap spaceToStorage = memorySpaceMap;
|
|
if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op)) {
|
|
spirv::TargetEnv targetEnv(attr);
|
|
if (targetEnv.allows(spirv::Capability::Kernel)) {
|
|
spaceToStorage = spirv::mapMemorySpaceToOpenCLStorageClass;
|
|
} else if (targetEnv.allows(spirv::Capability::Shader)) {
|
|
spaceToStorage = spirv::mapMemorySpaceToVulkanStorageClass;
|
|
}
|
|
}
|
|
|
|
spirv::MemorySpaceToStorageClassConverter converter(spaceToStorage);
|
|
// Perform the replacement.
|
|
spirv::convertMemRefTypesAndAttrs(op, converter);
|
|
|
|
// Check if there are any illegal ops remaining.
|
|
std::unique_ptr<ConversionTarget> target =
|
|
spirv::getMemorySpaceToStorageClassTarget(*context);
|
|
op->walk([&target, this](Operation *childOp) {
|
|
if (target->isIllegal(childOp)) {
|
|
childOp->emitOpError("failed to legalize memory space");
|
|
signalPassFailure();
|
|
return WalkResult::interrupt();
|
|
}
|
|
return WalkResult::advance();
|
|
});
|
|
}
|
|
|
|
private:
|
|
spirv::MemorySpaceToStorageClassMap memorySpaceMap =
|
|
spirv::mapMemorySpaceToVulkanStorageClass;
|
|
};
|
|
} // namespace
|
|
|
|
std::unique_ptr<OperationPass<>> mlir::createMapMemRefStorageClassPass() {
|
|
return std::make_unique<MapMemRefStorageClassPass>();
|
|
}
|