Revert "[MLIR][Conversion] Add convert-xevm-to-llvm pass." (#148081)
Reverts llvm/llvm-project#147375
This commit is contained in:
parent
eb97422e00
commit
da608271ae
@ -80,7 +80,6 @@
|
||||
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
|
||||
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRVPass.h"
|
||||
#include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h"
|
||||
#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
|
@ -1495,13 +1495,4 @@ def ConvertVectorToXeGPU : Pass<"convert-vector-to-xegpu"> {
|
||||
];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XeVMToLLVM
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ConvertXeVMToLLVMPass : Pass<"convert-xevm-to-llvm"> {
|
||||
let summary = "Convert XeVM to LLVM dialect";
|
||||
let dependentDialects = ["LLVM::LLVMDialect"];
|
||||
}
|
||||
|
||||
#endif // MLIR_CONVERSION_PASSES
|
||||
|
@ -1,27 +0,0 @@
|
||||
//===-- XeVMToLLVM.h - Convert XeVM to LLVM dialect -------------*- C++ -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#ifndef MLIR_CONVERSION_XEVMTOLLVM_XEVMTOLLVMPASS_H_
|
||||
#define MLIR_CONVERSION_XEVMTOLLVM_XEVMTOLLVMPASS_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
class DialectRegistry;
|
||||
class LLVMTypeConverter;
|
||||
class RewritePatternSet;
|
||||
class Pass;
|
||||
|
||||
#define GEN_PASS_DECL_CONVERTXEVMTOLLVMPASS
|
||||
#include "mlir/Conversion/Passes.h.inc"
|
||||
|
||||
void populateXeVMToLLVMConversionPatterns(RewritePatternSet &patterns);
|
||||
|
||||
void registerConvertXeVMToLLVMInterface(DialectRegistry ®istry);
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_CONVERSION_XEVMTOLLVM_XEVMTOLLVMPASS_H_
|
@ -32,7 +32,6 @@
|
||||
#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
|
||||
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
|
||||
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
|
||||
#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
|
||||
#include "mlir/Dialect/AMX/Transforms.h"
|
||||
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
|
||||
#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h"
|
||||
@ -92,7 +91,6 @@ inline void registerAllExtensions(DialectRegistry ®istry) {
|
||||
gpu::registerConvertGpuToLLVMInterface(registry);
|
||||
NVVM::registerConvertGpuToNVVMInterface(registry);
|
||||
vector::registerConvertVectorToLLVMInterface(registry);
|
||||
registerConvertXeVMToLLVMInterface(registry);
|
||||
|
||||
// Register all transform dialect extensions.
|
||||
affine::registerTransformDialectExtension(registry);
|
||||
|
@ -73,4 +73,3 @@ add_subdirectory(VectorToLLVM)
|
||||
add_subdirectory(VectorToSCF)
|
||||
add_subdirectory(VectorToSPIRV)
|
||||
add_subdirectory(VectorToXeGPU)
|
||||
add_subdirectory(XeVMToLLVM)
|
||||
|
@ -1,21 +0,0 @@
|
||||
add_mlir_conversion_library(MLIRXeVMToLLVM
|
||||
XeVMToLLVM.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/XeVMToLLVM
|
||||
|
||||
DEPENDS
|
||||
MLIRConversionPassIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRFuncDialect
|
||||
MLIRGPUDialect
|
||||
MLIRLLVMCommonConversion
|
||||
MLIRLLVMDialect
|
||||
MLIRXeVMDialect
|
||||
MLIRPass
|
||||
MLIRTransforms
|
||||
)
|
@ -1,633 +0,0 @@
|
||||
//===-- XeVMToLLVM.cpp - XeVM to LLVM dialect conversion --------*- C++ -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
|
||||
|
||||
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
|
||||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
|
||||
namespace mlir {
|
||||
#define GEN_PASS_DEF_CONVERTXEVMTOLLVMPASS
|
||||
#include "mlir/Conversion/Passes.h.inc"
|
||||
} // namespace mlir
|
||||
|
||||
using namespace mlir;
|
||||
using namespace xevm;
|
||||
|
||||
namespace {
|
||||
|
||||
struct LLVMFuncAttributeOptions {
|
||||
bool isConvergent = false;
|
||||
bool isNoUnwind = false;
|
||||
bool isWillReturn = false;
|
||||
LLVM::MemoryEffectsAttr memEffectsAttr{};
|
||||
};
|
||||
static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
|
||||
false, true, false, {}};
|
||||
static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
|
||||
false, true, true, {}};
|
||||
static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = {
|
||||
true, true, true, {}};
|
||||
|
||||
std::string getTypeMangling(Type ty, bool isUnsigned = false) {
|
||||
return TypeSwitch<Type, std::string>(ty)
|
||||
.Case([isUnsigned](VectorType ty) -> std::string {
|
||||
return "Dv" + std::to_string(ty.getNumElements()) + "_" +
|
||||
getTypeMangling(ty.getElementType(), isUnsigned);
|
||||
})
|
||||
.Case([](Float16Type) -> std::string { return "Dh"; })
|
||||
.Case([](Float32Type) -> std::string { return "f"; })
|
||||
.Case([](Float64Type) -> std::string { return "d"; })
|
||||
.Case([isUnsigned](IntegerType ty) -> std::string {
|
||||
switch (ty.getWidth()) {
|
||||
case 8:
|
||||
return isUnsigned ? "h" : "c";
|
||||
case 16:
|
||||
return isUnsigned ? "t" : "s";
|
||||
case 32:
|
||||
return isUnsigned ? "j" : "i";
|
||||
case 64:
|
||||
return isUnsigned ? "m" : "l";
|
||||
default:
|
||||
llvm_unreachable("unhandled integer type");
|
||||
}
|
||||
})
|
||||
.Default([](Type) -> std::string {
|
||||
llvm_unreachable("unhandled type for mangling");
|
||||
});
|
||||
}
|
||||
|
||||
std::string mangle(StringRef baseName, ArrayRef<Type> types,
|
||||
ArrayRef<bool> isUnsigned = {}) {
|
||||
assert((isUnsigned.empty() || isUnsigned.size() == types.size()) &&
|
||||
"Signedness info doesn't match");
|
||||
std::string s;
|
||||
llvm::raw_string_ostream os(s);
|
||||
llvm::SmallDenseMap<Type, unsigned> substitutions;
|
||||
os << "_Z" << baseName.size() << baseName;
|
||||
for (auto [idx, type] : llvm::enumerate(types)) {
|
||||
auto it = substitutions.find(type);
|
||||
if (it != substitutions.end()) {
|
||||
os << "S";
|
||||
// First substitution is `S_`, second is `S0_`, and so on.
|
||||
if (unsigned firstIdx = it->getSecond(); firstIdx > 0)
|
||||
os << firstIdx - 1;
|
||||
os << "_";
|
||||
} else {
|
||||
if (!type.isIntOrFloat())
|
||||
substitutions[type] = substitutions.size();
|
||||
os << getTypeMangling(type, isUnsigned.empty() ? false : isUnsigned[idx]);
|
||||
}
|
||||
}
|
||||
return os.str();
|
||||
}
|
||||
|
||||
template <bool isLoad, typename OpType>
|
||||
int32_t getL1CacheControl(OpType op) {
|
||||
int32_t control = 0;
|
||||
if constexpr (isLoad) {
|
||||
switch (*op.getCacheControl()) {
|
||||
case LoadCacheControl::L1UC_L2UC_L3UC:
|
||||
case LoadCacheControl::L1UC_L2UC_L3C:
|
||||
case LoadCacheControl::L1UC_L2C_L3UC:
|
||||
case LoadCacheControl::L1UC_L2C_L3C:
|
||||
control = 1;
|
||||
break;
|
||||
case LoadCacheControl::L1C_L2UC_L3UC:
|
||||
case LoadCacheControl::L1C_L2UC_L3C:
|
||||
case LoadCacheControl::L1C_L2C_L3UC:
|
||||
case LoadCacheControl::L1C_L2C_L3C:
|
||||
control = 2;
|
||||
break;
|
||||
case LoadCacheControl::L1S_L2UC_L3UC:
|
||||
case LoadCacheControl::L1S_L2UC_L3C:
|
||||
case LoadCacheControl::L1S_L2C_L3UC:
|
||||
case LoadCacheControl::L1S_L2C_L3C:
|
||||
control = 3;
|
||||
break;
|
||||
case LoadCacheControl::INVALIDATE_READ:
|
||||
control = 4;
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
switch (*op.getCacheControl()) {
|
||||
case StoreCacheControl::L1UC_L2UC_L3UC:
|
||||
case StoreCacheControl::L1UC_L2UC_L3WB:
|
||||
case StoreCacheControl::L1UC_L2WB_L3UC:
|
||||
case StoreCacheControl::L1UC_L2WB_L3WB:
|
||||
control = 1;
|
||||
break;
|
||||
case StoreCacheControl::L1WT_L2UC_L3UC:
|
||||
case StoreCacheControl::L1WT_L2UC_L3WB:
|
||||
case StoreCacheControl::L1WT_L2WB_L3UC:
|
||||
case StoreCacheControl::L1WT_L2WB_L3WB:
|
||||
control = 2;
|
||||
break;
|
||||
case StoreCacheControl::L1S_L2UC_L3UC:
|
||||
case StoreCacheControl::L1S_L2UC_L3WB:
|
||||
case StoreCacheControl::L1S_L2WB_L3UC:
|
||||
case StoreCacheControl::L1S_L2WB_L3WB:
|
||||
control = 3;
|
||||
break;
|
||||
case StoreCacheControl::L1WB_L2UC_L3UC:
|
||||
case StoreCacheControl::L1WB_L2WB_L3UC:
|
||||
case StoreCacheControl::L1WB_L2UC_L3WB:
|
||||
control = 4;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return control;
|
||||
}
|
||||
|
||||
template <bool isLoad, typename OpType>
|
||||
int32_t getL3CacheControl(OpType op) {
|
||||
int32_t control = 0;
|
||||
if constexpr (isLoad) {
|
||||
switch (*op.getCacheControl()) {
|
||||
case LoadCacheControl::L1UC_L2UC_L3UC:
|
||||
case LoadCacheControl::L1UC_L2C_L3UC:
|
||||
case LoadCacheControl::L1C_L2UC_L3UC:
|
||||
case LoadCacheControl::L1C_L2C_L3UC:
|
||||
case LoadCacheControl::L1S_L2UC_L3UC:
|
||||
case LoadCacheControl::L1S_L2C_L3UC:
|
||||
control = 1;
|
||||
break;
|
||||
case LoadCacheControl::L1UC_L2UC_L3C:
|
||||
case LoadCacheControl::L1UC_L2C_L3C:
|
||||
case LoadCacheControl::L1C_L2UC_L3C:
|
||||
case LoadCacheControl::L1C_L2C_L3C:
|
||||
case LoadCacheControl::L1S_L2UC_L3C:
|
||||
case LoadCacheControl::L1S_L2C_L3C:
|
||||
control = 2;
|
||||
break;
|
||||
case LoadCacheControl::INVALIDATE_READ:
|
||||
control = 4;
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
switch (*op.getCacheControl()) {
|
||||
case StoreCacheControl::L1UC_L2UC_L3UC:
|
||||
case StoreCacheControl::L1UC_L2WB_L3UC:
|
||||
case StoreCacheControl::L1WT_L2UC_L3UC:
|
||||
case StoreCacheControl::L1WT_L2WB_L3UC:
|
||||
case StoreCacheControl::L1S_L2UC_L3UC:
|
||||
case StoreCacheControl::L1S_L2WB_L3UC:
|
||||
case StoreCacheControl::L1WB_L2UC_L3UC:
|
||||
case StoreCacheControl::L1WB_L2WB_L3UC:
|
||||
control = 1;
|
||||
break;
|
||||
case StoreCacheControl::L1UC_L2UC_L3WB:
|
||||
case StoreCacheControl::L1UC_L2WB_L3WB:
|
||||
case StoreCacheControl::L1WT_L2UC_L3WB:
|
||||
case StoreCacheControl::L1WT_L2WB_L3WB:
|
||||
case StoreCacheControl::L1S_L2UC_L3WB:
|
||||
case StoreCacheControl::L1S_L2WB_L3WB:
|
||||
case StoreCacheControl::L1WB_L2UC_L3WB:
|
||||
control = 2;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return control;
|
||||
}
|
||||
|
||||
template <bool isLoad, typename OpType>
|
||||
static std::optional<ArrayAttr>
|
||||
getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
|
||||
if (!op.getCacheControl())
|
||||
return {};
|
||||
constexpr int32_t decorationCacheControlArity{4};
|
||||
constexpr int32_t loadCacheControlKey{6442};
|
||||
constexpr int32_t storeCacheControlKey{6443};
|
||||
const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
|
||||
SmallVector<int32_t, decorationCacheControlArity> decorationsL1{
|
||||
controlKey, 0, getL1CacheControl<isLoad, OpType>(op), 0};
|
||||
SmallVector<int32_t, decorationCacheControlArity> decorationsL3{
|
||||
controlKey, 1, getL3CacheControl<isLoad, OpType>(op), 0};
|
||||
auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1);
|
||||
auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3);
|
||||
|
||||
SmallVector<Attribute, 2> combinedAttrs = {arrayAttrL1, arrayAttrL3};
|
||||
return rewriter.getArrayAttr(combinedAttrs);
|
||||
}
|
||||
|
||||
static LLVM::CallOp createDeviceFunctionCall(
|
||||
ConversionPatternRewriter &rewriter, StringRef funcName, Type retType,
|
||||
ArrayRef<Type> argTypes, ArrayRef<Value> args,
|
||||
mlir::ArrayRef<std::pair<unsigned, mlir::StringRef>> paramAttrs,
|
||||
LLVMFuncAttributeOptions funcAttributeOptions, Operation *op) {
|
||||
auto moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
|
||||
assert(moduleOp && "Expecting module");
|
||||
Location loc = op->getLoc();
|
||||
|
||||
auto funcOpRes =
|
||||
LLVM::lookupOrCreateFn(rewriter, moduleOp, funcName, argTypes, retType);
|
||||
assert(!failed(funcOpRes));
|
||||
LLVM::LLVMFuncOp funcOp = funcOpRes.value();
|
||||
funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
|
||||
funcOp.setConvergent(funcAttributeOptions.isConvergent);
|
||||
funcOp.setNoUnwind(funcAttributeOptions.isNoUnwind);
|
||||
funcOp.setWillReturn(funcAttributeOptions.isWillReturn);
|
||||
|
||||
if (funcAttributeOptions.memEffectsAttr)
|
||||
funcOp.setMemoryEffectsAttr(funcAttributeOptions.memEffectsAttr);
|
||||
|
||||
for (auto [idx, attrName] : paramAttrs)
|
||||
funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr());
|
||||
|
||||
auto callOp = rewriter.create<LLVM::CallOp>(loc, funcOp, args);
|
||||
callOp->setAttrs(funcOp->getAttrs());
|
||||
|
||||
return callOp;
|
||||
}
|
||||
|
||||
class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(xevm::MMAOp op, xevm::MMAOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (!op.getC()) {
|
||||
return rewriter.notifyMatchFailure(op, "OCL requires C operand");
|
||||
}
|
||||
auto precisionA = op.getTypes().getA();
|
||||
auto precisionB = op.getTypes().getB();
|
||||
auto precisionC = op.getTypes().getC();
|
||||
auto precisionD = op.getTypes().getD();
|
||||
if (precisionC != precisionD) {
|
||||
return rewriter.notifyMatchFailure(op, "type of C and D need to match");
|
||||
}
|
||||
if (precisionC != xevm::ElemType::S32 &&
|
||||
precisionC != xevm::ElemType::F32 &&
|
||||
precisionC != xevm::ElemType::F16 &&
|
||||
precisionC != xevm::ElemType::BF16) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "type of C and D must be S32, F32, F16 or BF16");
|
||||
}
|
||||
if (precisionA == xevm::ElemType::S32 ||
|
||||
precisionA == xevm::ElemType::F32) {
|
||||
return rewriter.notifyMatchFailure(op, "type of A cannot be S32 or F32");
|
||||
}
|
||||
if (precisionB == xevm::ElemType::S32 ||
|
||||
precisionB == xevm::ElemType::F32) {
|
||||
return rewriter.notifyMatchFailure(op, "type of B cannot be S32 or F32");
|
||||
}
|
||||
constexpr uint32_t bitWidthPackedA{16};
|
||||
constexpr uint32_t bitWidthPackedB{32};
|
||||
auto loc = op.getLoc();
|
||||
|
||||
auto castIfNeeded = [&](Value val, Type packedType) -> Value {
|
||||
VectorType origTy = cast<VectorType>(val.getType());
|
||||
const uint32_t vecBitSize =
|
||||
origTy.getNumElements() *
|
||||
origTy.getElementType().getIntOrFloatBitWidth();
|
||||
VectorType newTy = VectorType::get(
|
||||
vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
|
||||
if (origTy != newTy)
|
||||
val = rewriter.create<LLVM::BitcastOp>(loc, newTy, val);
|
||||
return val;
|
||||
};
|
||||
|
||||
Value a = op.getA();
|
||||
Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32)
|
||||
? cast<Type>(rewriter.getF32Type())
|
||||
: rewriter.getIntegerType(bitWidthPackedA);
|
||||
a = castIfNeeded(a, packedAType);
|
||||
|
||||
Value b = op.getB();
|
||||
Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32)
|
||||
? cast<Type>(rewriter.getF32Type())
|
||||
: rewriter.getIntegerType(bitWidthPackedB);
|
||||
b = castIfNeeded(b, packedBType);
|
||||
|
||||
Value c = op.getC();
|
||||
VectorType cOrigTy = cast<VectorType>(c.getType());
|
||||
VectorType resOrigTy = cast<VectorType>(op->getResultTypes()[0]);
|
||||
assert(cOrigTy == resOrigTy && "Accumulator and result type mismatch");
|
||||
// OCL builtins encode bfloat16 as int16
|
||||
VectorType cTy =
|
||||
cOrigTy.getElementType().isBF16()
|
||||
? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16))
|
||||
: cOrigTy;
|
||||
VectorType resTy = cTy;
|
||||
if (cOrigTy != cTy)
|
||||
c = rewriter.create<LLVM::BitcastOp>(loc, cTy, c);
|
||||
|
||||
constexpr int32_t systolicDepth{8};
|
||||
std::string fnName =
|
||||
llvm::formatv("intel_sub_group_{0}_{1}_matrix_mad_k{2}",
|
||||
stringifyElemType(op.getTypes().getA()).str(),
|
||||
stringifyElemType(op.getTypes().getB()).str(),
|
||||
systolicDepth *
|
||||
getNumOperandsPerDword(op.getTypes().getA()))
|
||||
.str();
|
||||
SmallVector<Type> argTypes{a.getType(), b.getType(), cTy};
|
||||
fnName = mangle(fnName, argTypes);
|
||||
SmallVector<Value> args{a, b, c};
|
||||
|
||||
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
|
||||
/*other=*/LLVM::ModRefInfo::NoModRef,
|
||||
/*argMem=*/LLVM::ModRefInfo::NoModRef,
|
||||
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
|
||||
auto funcAttrs = convergentNoUnwindWillReturnAttrs;
|
||||
funcAttrs.memEffectsAttr = memAttr;
|
||||
Value result =
|
||||
createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args, {},
|
||||
funcAttrs, op.getOperation())
|
||||
->getResult(0);
|
||||
|
||||
if (resOrigTy != resTy)
|
||||
result = rewriter.create<LLVM::BitcastOp>(loc, resOrigTy, result);
|
||||
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
|
||||
switch (pTy) {
|
||||
case xevm::ElemType::TF32:
|
||||
return 1;
|
||||
case xevm::ElemType::BF16:
|
||||
case xevm::ElemType::F16:
|
||||
return 2;
|
||||
case xevm::ElemType::U8:
|
||||
case xevm::ElemType::S8:
|
||||
return 4;
|
||||
default:
|
||||
llvm_unreachable("unsupported xevm::ElemType");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(PrefetchOp op, PrefetchOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
const std::string fnName{"_Z8prefetchPU3AS1Kcm"};
|
||||
Value one =
|
||||
rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(), 1);
|
||||
SmallVector<Value> args{op.getPtr(), one};
|
||||
SmallVector<Type> argTypes;
|
||||
for (auto arg : args)
|
||||
argTypes.push_back(arg.getType());
|
||||
auto funcAttr = noUnwindAttrs;
|
||||
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
|
||||
/*other=*/LLVM::ModRefInfo::NoModRef,
|
||||
/*argMem=*/LLVM::ModRefInfo::Ref,
|
||||
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
|
||||
funcAttr.memEffectsAttr = memAttr;
|
||||
|
||||
LLVM::CallOp call = createDeviceFunctionCall(
|
||||
rewriter, fnName, LLVM::LLVMVoidType::get(rewriter.getContext()),
|
||||
argTypes, args, {}, funcAttr, op.getOperation());
|
||||
if (std::optional<ArrayAttr> optCacheControls =
|
||||
getCacheControlMetadata<true>(rewriter, op))
|
||||
call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class MemfenceToOCLPattern : public OpConversionPattern<MemfenceOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(MemfenceOp op, MemfenceOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
const std::string fnName{"atomic_work_item_fence"};
|
||||
int memScope, addrSpace;
|
||||
switch (op.getAddrspace()) {
|
||||
case xevm::AddrSpace::SHARED:
|
||||
addrSpace = 1; // CLK_LOCAL_MEM_FENCE
|
||||
break;
|
||||
case xevm::AddrSpace::GLOBAL:
|
||||
addrSpace = 2; // CLK_GLOBAL_MEM_FENCE
|
||||
break;
|
||||
default:
|
||||
// GENERIC is not supported in OpenCL
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Fence only supports global and shared address spaces.");
|
||||
}
|
||||
switch (op.getScope()) {
|
||||
case xevm::MemScope::WORKGROUP:
|
||||
memScope = 1;
|
||||
break;
|
||||
case xevm::MemScope::DEVICE:
|
||||
memScope = 2;
|
||||
break;
|
||||
default:
|
||||
// CLUSTER and SYSTEM are not supported in OpenCL
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Fence only supports workgroup and device memory scopes.");
|
||||
}
|
||||
Type i32Type = rewriter.getI32Type();
|
||||
Value acqRel = rewriter.create<LLVM::ConstantOp>(loc, i32Type, 4);
|
||||
Value memScopeConst =
|
||||
rewriter.create<LLVM::ConstantOp>(loc, i32Type, memScope);
|
||||
Value addrSpaceConst =
|
||||
rewriter.create<LLVM::ConstantOp>(loc, i32Type, addrSpace);
|
||||
SmallVector<Value> args{addrSpaceConst, acqRel, memScopeConst};
|
||||
SmallVector<Type> argTypes{3, i32Type};
|
||||
createDeviceFunctionCall(rewriter, mangle(fnName, argTypes),
|
||||
LLVM::LLVMVoidType::get(rewriter.getContext()),
|
||||
argTypes, args, {}, noUnwindAttrs,
|
||||
op.getOperation());
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
template <typename OpType>
|
||||
class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
|
||||
using OpConversionPattern<OpType>::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp>;
|
||||
constexpr bool isPrefetch = std::is_same_v<OpType, BlockPrefetch2dOp>;
|
||||
|
||||
auto loc = op.getLoc();
|
||||
VectorType vecType;
|
||||
bool packReg = false;
|
||||
bool transpose = false;
|
||||
if constexpr (isLoad) {
|
||||
vecType = op.getRes().getType();
|
||||
packReg = op.getPackRegister();
|
||||
transpose = op.getTranspose();
|
||||
} else if constexpr (!isPrefetch) {
|
||||
vecType = op.getStoredVal().getType();
|
||||
}
|
||||
|
||||
auto i32Type = rewriter.getI32Type();
|
||||
Value byteCoord =
|
||||
rewriter.create<LLVM::UndefOp>(loc, VectorType::get(2, i32Type));
|
||||
Value zero = rewriter.create<LLVM::ConstantOp>(loc, i32Type, 0);
|
||||
Value one = rewriter.create<LLVM::ConstantOp>(loc, i32Type, 1);
|
||||
byteCoord = rewriter.create<LLVM::InsertElementOp>(
|
||||
loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero);
|
||||
byteCoord = rewriter.create<LLVM::InsertElementOp>(
|
||||
loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one);
|
||||
SmallVector<Value> args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(),
|
||||
op.getBasePitch(), byteCoord};
|
||||
SmallVector<Type> retTypes;
|
||||
Value spvLoadDstPtr;
|
||||
std::string funcName{"intel_sub_group_2d_block_"};
|
||||
std::string bitWidthId;
|
||||
LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
|
||||
SmallVector<std::pair<unsigned, StringRef>, 4> paramAttrs;
|
||||
if constexpr (isPrefetch) { // Prefetch
|
||||
funcName += "prefetch";
|
||||
paramAttrs = {std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName())};
|
||||
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
|
||||
/*other=*/LLVM::ModRefInfo::NoModRef,
|
||||
/*argMem=*/LLVM::ModRefInfo::Ref,
|
||||
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
|
||||
funcAttr = noUnwindAttrs;
|
||||
funcAttr.memEffectsAttr = memAttr;
|
||||
} else {
|
||||
auto vecElemType = vecType.getElementType();
|
||||
auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth();
|
||||
Value numElems = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, i32Type, vecType.getNumElements());
|
||||
auto dstOrSrcPtr = rewriter.create<LLVM::AllocaOp>(
|
||||
loc, LLVM::LLVMPointerType::get(rewriter.getContext()), vecElemType,
|
||||
numElems);
|
||||
args.push_back(dstOrSrcPtr);
|
||||
if constexpr (isLoad) { // Load
|
||||
funcName += "read";
|
||||
bitWidthId = getTypeMangling(vecElemType, /*isUnsigned=*/true);
|
||||
if (packReg)
|
||||
funcName += "_transform";
|
||||
else if (transpose)
|
||||
funcName += "_transpose";
|
||||
spvLoadDstPtr = dstOrSrcPtr;
|
||||
retTypes.push_back(vecType);
|
||||
paramAttrs = {
|
||||
std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
|
||||
std::make_pair(0, LLVM::LLVMDialect::getReadonlyAttrName()),
|
||||
std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
|
||||
std::make_pair(5, LLVM::LLVMDialect::getWriteOnlyAttrName()),
|
||||
};
|
||||
} else { // Store
|
||||
funcName += "write";
|
||||
bitWidthId = (vecElemBitWidth == 32)
|
||||
? "j"
|
||||
: ((vecElemBitWidth == 16) ? "t" : "h");
|
||||
rewriter.create<LLVM::StoreOp>(loc, op.getStoredVal(), dstOrSrcPtr);
|
||||
paramAttrs = {
|
||||
std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
|
||||
std::make_pair(0, LLVM::LLVMDialect::getWriteOnlyAttrName()),
|
||||
std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
|
||||
std::make_pair(5, LLVM::LLVMDialect::getReadonlyAttrName()),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
funcName =
|
||||
llvm::formatv("{0}_{1}b_{2}r{3}x{4}c", funcName, op.getElemSizeInBits(),
|
||||
op.getTileHeight(), op.getTileWidth(), op.getVBlocks())
|
||||
.str();
|
||||
funcName = llvm::formatv("_Z{0}{1}PU3AS1viiiDv2_i{2}{3}", funcName.size(),
|
||||
funcName, isPrefetch ? "" : "P", bitWidthId)
|
||||
.str();
|
||||
SmallVector<Type> argTypes;
|
||||
for (auto arg : args) {
|
||||
argTypes.push_back(arg.getType());
|
||||
}
|
||||
LLVM::CallOp call = createDeviceFunctionCall(
|
||||
rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()),
|
||||
argTypes, args, paramAttrs, funcAttr, op.getOperation());
|
||||
if (std::optional<ArrayAttr> optCacheControls =
|
||||
getCacheControlMetadata < isLoad || isPrefetch > (rewriter, op)) {
|
||||
call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
|
||||
}
|
||||
if constexpr (isLoad)
|
||||
rewriter.replaceOp(
|
||||
op, rewriter.create<LLVM::LoadOp>(loc, vecType, spvLoadDstPtr));
|
||||
else
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pass Definition
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct ConvertXeVMToLLVMPass
|
||||
: public impl::ConvertXeVMToLLVMPassBase<ConvertXeVMToLLVMPass> {
|
||||
using Base::Base;
|
||||
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<LLVM::LLVMDialect, XeVMDialect>();
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<LLVM::LLVMDialect>();
|
||||
target.addIllegalDialect<XeVMDialect>();
|
||||
RewritePatternSet patterns(&getContext());
|
||||
populateXeVMToLLVMConversionPatterns(patterns);
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConvertToLLVMPatternInterface implementation
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
/// Implement the interface to convert XeVM to LLVM.
|
||||
struct XeVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
|
||||
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
|
||||
void loadDependentDialects(MLIRContext *context) const final {
|
||||
context->loadDialect<LLVM::LLVMDialect>();
|
||||
}
|
||||
|
||||
/// Hook for derived dialect interface to provide conversion patterns
|
||||
/// and mark dialect legal for the conversion target.
|
||||
void populateConvertToLLVMConversionPatterns(
|
||||
ConversionTarget &target, LLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) const final {
|
||||
populateXeVMToLLVMConversionPatterns(patterns);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pattern Population
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void ::mlir::populateXeVMToLLVMConversionPatterns(RewritePatternSet &patterns) {
|
||||
patterns.add<LoadStorePrefetchToOCLPattern<BlockLoad2dOp>,
|
||||
LoadStorePrefetchToOCLPattern<BlockStore2dOp>,
|
||||
LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>,
|
||||
MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern>(
|
||||
patterns.getContext());
|
||||
}
|
||||
|
||||
void ::mlir::registerConvertXeVMToLLVMInterface(DialectRegistry ®istry) {
|
||||
registry.addExtension(+[](MLIRContext *ctx, XeVMDialect *dialect) {
|
||||
dialect->addInterfaces<XeVMToLLVMDialectInterface>();
|
||||
});
|
||||
}
|
@ -1,244 +0,0 @@
|
||||
// RUN: mlir-opt --convert-xevm-to-llvm --split-input-file %s | FileCheck %s
|
||||
|
||||
// Same below, but using the `ConvertToLLVMPatternInterface` entry point
|
||||
// and the generic `convert-to-llvm` pass.
|
||||
// RUN: mlir-opt --convert-to-llvm --split-input-file %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt(
|
||||
// CHECK-SAME: !llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>,
|
||||
// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return}
|
||||
// CHECK: llvm.func @blockload2d(%[[ARG0:.*]]: !llvm.ptr<1>,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32)
|
||||
llvm.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi16> {
|
||||
// CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32>
|
||||
// CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32
|
||||
// CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32
|
||||
// CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32>
|
||||
// CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32>
|
||||
// CHECK: %[[VAR5:.*]] = llvm.mlir.constant(8 : i32) : i32
|
||||
// CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i16 : (i32) -> !llvm.ptr
|
||||
// CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt(
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]])
|
||||
// CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i32, i32, i32, vector<2xi32>, ptr)>,
|
||||
// CHECK-SAME: linkage = #llvm.linkage<external>, no_unwind, sym_name =
|
||||
// CHECK-SAME: "_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt", visibility_ = 0 : i64,
|
||||
// CHECK-SAME: will_return} :
|
||||
// CHECK-SAME: (!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>,
|
||||
// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) -> ()
|
||||
// CHECK: %[[VAR7:.*]] = llvm.load %[[VAR6]] : !llvm.ptr -> vector<8xi16>
|
||||
%loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
|
||||
<{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=1 : i32, transpose=false,
|
||||
pack_register=false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
|
||||
llvm.return %loaded_a : vector<8xi16>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt(
|
||||
llvm.func @blockload2d_cache_control(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi16> {
|
||||
// CHECK: xevm.DecorationCacheControl =
|
||||
// CHECK-SAME: 6442 : i32, 0 : i32, 1 : i32, 0 : i32
|
||||
// CHECK-SAME: 6442 : i32, 1 : i32, 1 : i32, 0 : i32
|
||||
%loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
|
||||
<{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=1 : i32, transpose=false,
|
||||
pack_register=false, cache_control=#xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
|
||||
llvm.return %loaded_a : vector<8xi16>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt(
|
||||
// CHECK-SAME: !llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>,
|
||||
// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return}
|
||||
// CHECK: llvm.func @blockload2d_v_blocks(%[[ARG0:.*]]: !llvm.ptr<1>,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32)
|
||||
llvm.func @blockload2d_v_blocks(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<16xi16> {
|
||||
// CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32>
|
||||
// CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32
|
||||
// CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32
|
||||
// CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32>
|
||||
// CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32>
|
||||
// CHECK: %[[VAR5:.*]] = llvm.mlir.constant(16 : i32) : i32
|
||||
// CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i16 : (i32) -> !llvm.ptr
|
||||
// CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt(
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]])
|
||||
// CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i32, i32, i32, vector<2xi32>, ptr)>,
|
||||
// CHECK-SAME: linkage = #llvm.linkage<external>, no_unwind, sym_name =
|
||||
// CHECK-SAME: "_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt", visibility_ = 0 : i64,
|
||||
// CHECK-SAME: will_return}
|
||||
// CHECK-SAME: (!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>,
|
||||
// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) -> ()
|
||||
// CHECK: %[[VAR7:.*]] = llvm.load %[[VAR6]] : !llvm.ptr -> vector<16xi16>
|
||||
%loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
|
||||
<{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=2 : i32, transpose=false,
|
||||
pack_register=false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi16>
|
||||
llvm.return %loaded_a : vector<16xi16>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: llvm.func spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_16r16x1cPU3AS1viiiDv2_iPj(
|
||||
// CHECK-SAME: !llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>,
|
||||
// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return}
|
||||
// CHECK: llvm.func @blockload2d_pack_register(%[[ARG0:.*]]: !llvm.ptr<1>,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32)
|
||||
llvm.func @blockload2d_pack_register(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi32> {
|
||||
// CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32>
|
||||
// CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32
|
||||
// CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32
|
||||
// CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32>
|
||||
// CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32>
|
||||
// CHECK: %[[VAR5:.*]] = llvm.mlir.constant(8 : i32) : i32
|
||||
// CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i32 : (i32) -> !llvm.ptr
|
||||
// CHECK: llvm.call spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_16r16x1cPU3AS1viiiDv2_iPj(
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]])
|
||||
// CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i32, i32, i32, vector<2xi32>, ptr)>,
|
||||
// CHECK-SAME: linkage = #llvm.linkage<external>, no_unwind, sym_name =
|
||||
// CHECK-SAME: "_Z52intel_sub_group_2d_block_read_transform_16b_16r16x1cPU3AS1viiiDv2_iPj", visibility_ = 0 : i64,
|
||||
// CHECK-SAME: will_return} :
|
||||
// CHECK-SAME: (!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>,
|
||||
// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) -> ()
|
||||
// CHECK: %[[VAR7:.*]] = llvm.load %[[VAR6]] : !llvm.ptr -> vector<8xi32>
|
||||
%loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
|
||||
<{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=16 : i32, v_blocks=1 : i32, transpose=false,
|
||||
pack_register=true}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
|
||||
llvm.return %loaded_a : vector<8xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: llvm.func spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj(
|
||||
// CHECK-SAME: !llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>,
|
||||
// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return}
|
||||
// CHECK: llvm.func @blockload2d_transpose(%[[ARG0:.*]]: !llvm.ptr<1>,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32)
|
||||
llvm.func @blockload2d_transpose(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi32> {
|
||||
// CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32>
|
||||
// CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32
|
||||
// CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32
|
||||
// CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32>
|
||||
// CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32>
|
||||
// CHECK: %[[VAR5:.*]] = llvm.mlir.constant(8 : i32) : i32
|
||||
// CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i32 : (i32) -> !llvm.ptr
|
||||
// CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj(
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]])
|
||||
// CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i32, i32, i32, vector<2xi32>, ptr)>,
|
||||
// CHECK-SAME: linkage = #llvm.linkage<external>, no_unwind, sym_name =
|
||||
// CHECK-SAME: "_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj", visibility_ = 0 : i64,
|
||||
// CHECK-SAME: will_return}
|
||||
// CHECK-SAME: (!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>,
|
||||
// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) -> ()
|
||||
// CHECK: %[[VAR7:.*]] = llvm.load %[[VAR6]] : !llvm.ptr -> vector<8xi32>
|
||||
%loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
|
||||
<{elem_size_in_bits=32 : i32, tile_width=8 : i32, tile_height=16 : i32, v_blocks=1 : i32, transpose=true,
|
||||
pack_register=false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
|
||||
llvm.return %loaded_a : vector<8xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: llvm.func spir_funccc @_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj(
|
||||
// CHECK-SAME: !llvm.ptr<1> {llvm.nonnull, llvm.writeonly}, i32, i32, i32, vector<2xi32>,
|
||||
// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.readonly}) attributes {no_unwind, will_return}
|
||||
// CHECK: llvm.func @blockstore2d(%[[ARG0:.*]]: !llvm.ptr<1>,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32, %[[ARG6:.*]]: vector<8xi32>) {
|
||||
llvm.func @blockstore2d(%c: !llvm.ptr<1>, %base_width_c: i32, %base_height_c: i32, %base_pitch_c: i32, %x: i32, %y: i32, %c_result_casted: vector<8xi32>) {
|
||||
// CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32>
|
||||
// CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32
|
||||
// CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32
|
||||
// CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32>
|
||||
// CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32>
|
||||
// CHECK: %[[VAR5:.*]] = llvm.mlir.constant(8 : i32) : i32
|
||||
// CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i32 : (i32) -> !llvm.ptr
|
||||
// CHECK: llvm.store %[[ARG6]], %[[VAR6]] : vector<8xi32>, !llvm.ptr
|
||||
// CHECK: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj(
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]])
|
||||
// CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i32, i32, i32, vector<2xi32>, ptr)>,
|
||||
// CHECK-SAME: linkage = #llvm.linkage<external>, no_unwind, sym_name =
|
||||
// CHECK-SAME: "_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj", visibility_ = 0 : i64,
|
||||
// CHECK-SAME: will_return}
|
||||
// CHECK-SAME: : (!llvm.ptr<1> {llvm.nonnull, llvm.writeonly}, i32, i32, i32, vector<2xi32>,
|
||||
// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.readonly}) -> ()
|
||||
xevm.blockstore2d %c, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %c_result_casted
|
||||
<{elem_size_in_bits=32 : i32, tile_width=16 : i32, tile_height=8 : i32}>
|
||||
: (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: llvm.func spir_funccc @_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj(
|
||||
llvm.func @blockstore2d_cache_control(%c: !llvm.ptr<1>, %base_width_c: i32, %base_height_c: i32, %base_pitch_c: i32, %x: i32, %y: i32, %c_result_casted: vector<8xi32>) {
|
||||
// CHECK: xevm.DecorationCacheControl =
|
||||
// CHECK-SAME: 6443 : i32, 0 : i32, 2 : i32, 0 : i32
|
||||
// CHECK-SAME: 6443 : i32, 1 : i32, 2 : i32, 0 : i32
|
||||
xevm.blockstore2d %c, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %c_result_casted
|
||||
<{elem_size_in_bits=32 : i32, tile_width=16 : i32, tile_height=8 : i32, cache_control = #xevm.store_cache_control<L1wt_L2uc_L3wb>}>
|
||||
: (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: llvm.func spir_funccc @_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i(
|
||||
// CHECK-SAME: !llvm.ptr<1> {llvm.nonnull}, i32, i32, i32, vector<2xi32>) attributes
|
||||
// CHECK-SAME: {memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, no_unwind}
|
||||
// CHECK: llvm.func @blockprefetch2d(%[[ARG0:.*]]: !llvm.ptr<1>,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32) {
|
||||
llvm.func @blockprefetch2d(%ptr: !llvm.ptr<1>, %base_width: i32, %base_height: i32, %base_pitch: i32, %x: i32, %y: i32) {
|
||||
// CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32>
|
||||
// CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32
|
||||
// CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32
|
||||
// CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32>
|
||||
// CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32>
|
||||
// CHECK: llvm.call spir_funccc @_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i(
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]])
|
||||
// CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i32, i32, i32, vector<2xi32>)>, linkage = #llvm.linkage<external>,
|
||||
// CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, no_unwind,
|
||||
// CHECK-SAME: sym_name = "_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i", visibility_ = 0 : i64
|
||||
xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y
|
||||
<{elem_size_in_bits=8 : i32, tile_width=32 : i32, tile_height=8 : i32, v_blocks=1 : i32,
|
||||
cache_control=#xevm.load_cache_control<L1uc_L2uc_L3uc>}>
|
||||
: (!llvm.ptr<1>, i32, i32, i32, i32, i32)
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: llvm.func spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(
|
||||
// CHECK-SAME: vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> attributes
|
||||
// CHECK-SAME: {convergent, memory_effects = #llvm.memory_effects<other = none, argMem = none,
|
||||
// CHECK-SAME: inaccessibleMem = none>, no_unwind, will_return}
|
||||
// CHECK: llvm.func @mma(%[[ARG0:.*]]: vector<8xf32>, %[[ARG1:.*]]: vector<8xi16>, %[[ARG2:.*]]: vector<8xi32>) -> vector<8xf32> {
|
||||
llvm.func @mma(%loaded_c_casted: vector<8xf32>, %loaded_a: vector<8xi16>, %loaded_b_casted: vector<8xi32>) -> vector<8xf32> {
|
||||
// CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(
|
||||
// CHECK-SAME: %[[ARG1]], %[[ARG2]], %[[ARG0]]) {convergent, function_type =
|
||||
// CHECK-SAME: !llvm.func<vector<8xf32> (vector<8xi16>, vector<8xi32>, vector<8xf32>)>, linkage = #llvm.linkage<external>,
|
||||
// CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, no_unwind,
|
||||
// CHECK-SAME: sym_name = "_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f", visibility_ = 0 : i64, will_return}
|
||||
// CHECK-SAME: : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
|
||||
%c_result = xevm.mma %loaded_a, %loaded_b_casted, %loaded_c_casted
|
||||
{ shape=<m=8, n=16, k=16>, types=<d=f32, a=f16, b=f16, c=f32> }
|
||||
: (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
|
||||
llvm.return %c_result : vector<8xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: llvm.func spir_funccc @_Z22atomic_work_item_fenceiii(i32, i32, i32) attributes {no_unwind}
|
||||
llvm.func @memfence() {
|
||||
// CHECK: %[[VAR0:.*]] = llvm.mlir.constant(4 : i32) : i32
|
||||
// CHECK: %[[VAR1:.*]] = llvm.mlir.constant(1 : i32) : i32
|
||||
// CHECK: %[[VAR2:.*]] = llvm.mlir.constant(2 : i32) : i32
|
||||
// CHECK: llvm.call spir_funccc @_Z22atomic_work_item_fenceiii(%[[VAR2]], %[[VAR0]], %[[VAR1]])
|
||||
// CHECK-SAME: {function_type = !llvm.func<void (i32, i32, i32)>, linkage = #llvm.linkage<external>, no_unwind,
|
||||
// CHECK-SAME: sym_name = "_Z22atomic_work_item_fenceiii", visibility_ = 0 : i64} : (i32, i32, i32) -> ()
|
||||
xevm.memfence <{addrspace=#xevm.addr_space<global>, scope=#xevm.mem_scope<workgroup>}>
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: llvm.func spir_funccc @_Z8prefetchPU3AS1Kcm(!llvm.ptr<1>, i64) attributes
|
||||
// CHECK-SAME: {memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, no_unwind}
|
||||
// CHECK: llvm.func @prefetch(%[[ARG0:.*]]: !llvm.ptr<1>) {
|
||||
llvm.func @prefetch(%ptr: !llvm.ptr<1>) {
|
||||
// CHECK: %[[VAR0:.*]] = llvm.mlir.constant(1 : i64) : i64
|
||||
// CHECK: llvm.call spir_funccc @_Z8prefetchPU3AS1Kcm(%[[ARG0]], %[[VAR0]])
|
||||
// CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i64)>, linkage = #llvm.linkage<external>,
|
||||
// CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>,
|
||||
// CHECK-SAME: no_unwind, sym_name = "_Z8prefetchPU3AS1Kcm", visibility_ = 0 : i64
|
||||
xevm.prefetch %ptr <{cache_control = #xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr<1>)
|
||||
llvm.return
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user