Revert "[MLIR][Conversion] Add convert-xevm-to-llvm pass." (#148081)

Reverts llvm/llvm-project#147375
This commit is contained in:
Charitha Saumya 2025-07-10 16:21:11 -07:00 committed by GitHub
parent eb97422e00
commit da608271ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 0 additions and 938 deletions

View File

@ -80,7 +80,6 @@
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRVPass.h"
#include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h"
#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
namespace mlir {

View File

@ -1495,13 +1495,4 @@ def ConvertVectorToXeGPU : Pass<"convert-vector-to-xegpu"> {
];
}
//===----------------------------------------------------------------------===//
// XeVMToLLVM
//===----------------------------------------------------------------------===//
def ConvertXeVMToLLVMPass : Pass<"convert-xevm-to-llvm"> {
let summary = "Convert XeVM to LLVM dialect";
let dependentDialects = ["LLVM::LLVMDialect"];
}
#endif // MLIR_CONVERSION_PASSES

View File

@ -1,27 +0,0 @@
//===-- XeVMToLLVM.h - Convert XeVM to LLVM dialect -------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_XEVMTOLLVM_XEVMTOLLVMPASS_H_
#define MLIR_CONVERSION_XEVMTOLLVM_XEVMTOLLVMPASS_H_
#include <memory>
namespace mlir {
class DialectRegistry;
class LLVMTypeConverter;
class RewritePatternSet;
class Pass;
#define GEN_PASS_DECL_CONVERTXEVMTOLLVMPASS
#include "mlir/Conversion/Passes.h.inc"
void populateXeVMToLLVMConversionPatterns(RewritePatternSet &patterns);
void registerConvertXeVMToLLVMInterface(DialectRegistry &registry);
} // namespace mlir
#endif // MLIR_CONVERSION_XEVMTOLLVM_XEVMTOLLVMPASS_H_

View File

@ -32,7 +32,6 @@
#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h"
@ -92,7 +91,6 @@ inline void registerAllExtensions(DialectRegistry &registry) {
gpu::registerConvertGpuToLLVMInterface(registry);
NVVM::registerConvertGpuToNVVMInterface(registry);
vector::registerConvertVectorToLLVMInterface(registry);
registerConvertXeVMToLLVMInterface(registry);
// Register all transform dialect extensions.
affine::registerTransformDialectExtension(registry);

View File

@ -73,4 +73,3 @@ add_subdirectory(VectorToLLVM)
add_subdirectory(VectorToSCF)
add_subdirectory(VectorToSPIRV)
add_subdirectory(VectorToXeGPU)
add_subdirectory(XeVMToLLVM)

View File

@ -1,21 +0,0 @@
add_mlir_conversion_library(MLIRXeVMToLLVM
XeVMToLLVM.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/XeVMToLLVM
DEPENDS
MLIRConversionPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRFuncDialect
MLIRGPUDialect
MLIRLLVMCommonConversion
MLIRLLVMDialect
MLIRXeVMDialect
MLIRPass
MLIRTransforms
)

View File

@ -1,633 +0,0 @@
//===-- XeVMToLLVM.cpp - XeVM to LLVM dialect conversion --------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/TypeSwitch.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTXEVMTOLLVMPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir
using namespace mlir;
using namespace xevm;
namespace {
struct LLVMFuncAttributeOptions {
bool isConvergent = false;
bool isNoUnwind = false;
bool isWillReturn = false;
LLVM::MemoryEffectsAttr memEffectsAttr{};
};
static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
false, true, false, {}};
static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
false, true, true, {}};
static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = {
true, true, true, {}};
std::string getTypeMangling(Type ty, bool isUnsigned = false) {
return TypeSwitch<Type, std::string>(ty)
.Case([isUnsigned](VectorType ty) -> std::string {
return "Dv" + std::to_string(ty.getNumElements()) + "_" +
getTypeMangling(ty.getElementType(), isUnsigned);
})
.Case([](Float16Type) -> std::string { return "Dh"; })
.Case([](Float32Type) -> std::string { return "f"; })
.Case([](Float64Type) -> std::string { return "d"; })
.Case([isUnsigned](IntegerType ty) -> std::string {
switch (ty.getWidth()) {
case 8:
return isUnsigned ? "h" : "c";
case 16:
return isUnsigned ? "t" : "s";
case 32:
return isUnsigned ? "j" : "i";
case 64:
return isUnsigned ? "m" : "l";
default:
llvm_unreachable("unhandled integer type");
}
})
.Default([](Type) -> std::string {
llvm_unreachable("unhandled type for mangling");
});
}
std::string mangle(StringRef baseName, ArrayRef<Type> types,
ArrayRef<bool> isUnsigned = {}) {
assert((isUnsigned.empty() || isUnsigned.size() == types.size()) &&
"Signedness info doesn't match");
std::string s;
llvm::raw_string_ostream os(s);
llvm::SmallDenseMap<Type, unsigned> substitutions;
os << "_Z" << baseName.size() << baseName;
for (auto [idx, type] : llvm::enumerate(types)) {
auto it = substitutions.find(type);
if (it != substitutions.end()) {
os << "S";
// First substitution is `S_`, second is `S0_`, and so on.
if (unsigned firstIdx = it->getSecond(); firstIdx > 0)
os << firstIdx - 1;
os << "_";
} else {
if (!type.isIntOrFloat())
substitutions[type] = substitutions.size();
os << getTypeMangling(type, isUnsigned.empty() ? false : isUnsigned[idx]);
}
}
return os.str();
}
template <bool isLoad, typename OpType>
int32_t getL1CacheControl(OpType op) {
int32_t control = 0;
if constexpr (isLoad) {
switch (*op.getCacheControl()) {
case LoadCacheControl::L1UC_L2UC_L3UC:
case LoadCacheControl::L1UC_L2UC_L3C:
case LoadCacheControl::L1UC_L2C_L3UC:
case LoadCacheControl::L1UC_L2C_L3C:
control = 1;
break;
case LoadCacheControl::L1C_L2UC_L3UC:
case LoadCacheControl::L1C_L2UC_L3C:
case LoadCacheControl::L1C_L2C_L3UC:
case LoadCacheControl::L1C_L2C_L3C:
control = 2;
break;
case LoadCacheControl::L1S_L2UC_L3UC:
case LoadCacheControl::L1S_L2UC_L3C:
case LoadCacheControl::L1S_L2C_L3UC:
case LoadCacheControl::L1S_L2C_L3C:
control = 3;
break;
case LoadCacheControl::INVALIDATE_READ:
control = 4;
break;
}
} else {
switch (*op.getCacheControl()) {
case StoreCacheControl::L1UC_L2UC_L3UC:
case StoreCacheControl::L1UC_L2UC_L3WB:
case StoreCacheControl::L1UC_L2WB_L3UC:
case StoreCacheControl::L1UC_L2WB_L3WB:
control = 1;
break;
case StoreCacheControl::L1WT_L2UC_L3UC:
case StoreCacheControl::L1WT_L2UC_L3WB:
case StoreCacheControl::L1WT_L2WB_L3UC:
case StoreCacheControl::L1WT_L2WB_L3WB:
control = 2;
break;
case StoreCacheControl::L1S_L2UC_L3UC:
case StoreCacheControl::L1S_L2UC_L3WB:
case StoreCacheControl::L1S_L2WB_L3UC:
case StoreCacheControl::L1S_L2WB_L3WB:
control = 3;
break;
case StoreCacheControl::L1WB_L2UC_L3UC:
case StoreCacheControl::L1WB_L2WB_L3UC:
case StoreCacheControl::L1WB_L2UC_L3WB:
control = 4;
break;
}
}
return control;
}
template <bool isLoad, typename OpType>
int32_t getL3CacheControl(OpType op) {
int32_t control = 0;
if constexpr (isLoad) {
switch (*op.getCacheControl()) {
case LoadCacheControl::L1UC_L2UC_L3UC:
case LoadCacheControl::L1UC_L2C_L3UC:
case LoadCacheControl::L1C_L2UC_L3UC:
case LoadCacheControl::L1C_L2C_L3UC:
case LoadCacheControl::L1S_L2UC_L3UC:
case LoadCacheControl::L1S_L2C_L3UC:
control = 1;
break;
case LoadCacheControl::L1UC_L2UC_L3C:
case LoadCacheControl::L1UC_L2C_L3C:
case LoadCacheControl::L1C_L2UC_L3C:
case LoadCacheControl::L1C_L2C_L3C:
case LoadCacheControl::L1S_L2UC_L3C:
case LoadCacheControl::L1S_L2C_L3C:
control = 2;
break;
case LoadCacheControl::INVALIDATE_READ:
control = 4;
break;
}
} else {
switch (*op.getCacheControl()) {
case StoreCacheControl::L1UC_L2UC_L3UC:
case StoreCacheControl::L1UC_L2WB_L3UC:
case StoreCacheControl::L1WT_L2UC_L3UC:
case StoreCacheControl::L1WT_L2WB_L3UC:
case StoreCacheControl::L1S_L2UC_L3UC:
case StoreCacheControl::L1S_L2WB_L3UC:
case StoreCacheControl::L1WB_L2UC_L3UC:
case StoreCacheControl::L1WB_L2WB_L3UC:
control = 1;
break;
case StoreCacheControl::L1UC_L2UC_L3WB:
case StoreCacheControl::L1UC_L2WB_L3WB:
case StoreCacheControl::L1WT_L2UC_L3WB:
case StoreCacheControl::L1WT_L2WB_L3WB:
case StoreCacheControl::L1S_L2UC_L3WB:
case StoreCacheControl::L1S_L2WB_L3WB:
case StoreCacheControl::L1WB_L2UC_L3WB:
control = 2;
break;
}
}
return control;
}
template <bool isLoad, typename OpType>
static std::optional<ArrayAttr>
getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
if (!op.getCacheControl())
return {};
constexpr int32_t decorationCacheControlArity{4};
constexpr int32_t loadCacheControlKey{6442};
constexpr int32_t storeCacheControlKey{6443};
const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
SmallVector<int32_t, decorationCacheControlArity> decorationsL1{
controlKey, 0, getL1CacheControl<isLoad, OpType>(op), 0};
SmallVector<int32_t, decorationCacheControlArity> decorationsL3{
controlKey, 1, getL3CacheControl<isLoad, OpType>(op), 0};
auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1);
auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3);
SmallVector<Attribute, 2> combinedAttrs = {arrayAttrL1, arrayAttrL3};
return rewriter.getArrayAttr(combinedAttrs);
}
static LLVM::CallOp createDeviceFunctionCall(
ConversionPatternRewriter &rewriter, StringRef funcName, Type retType,
ArrayRef<Type> argTypes, ArrayRef<Value> args,
mlir::ArrayRef<std::pair<unsigned, mlir::StringRef>> paramAttrs,
LLVMFuncAttributeOptions funcAttributeOptions, Operation *op) {
auto moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
assert(moduleOp && "Expecting module");
Location loc = op->getLoc();
auto funcOpRes =
LLVM::lookupOrCreateFn(rewriter, moduleOp, funcName, argTypes, retType);
assert(!failed(funcOpRes));
LLVM::LLVMFuncOp funcOp = funcOpRes.value();
funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
funcOp.setConvergent(funcAttributeOptions.isConvergent);
funcOp.setNoUnwind(funcAttributeOptions.isNoUnwind);
funcOp.setWillReturn(funcAttributeOptions.isWillReturn);
if (funcAttributeOptions.memEffectsAttr)
funcOp.setMemoryEffectsAttr(funcAttributeOptions.memEffectsAttr);
for (auto [idx, attrName] : paramAttrs)
funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr());
auto callOp = rewriter.create<LLVM::CallOp>(loc, funcOp, args);
callOp->setAttrs(funcOp->getAttrs());
return callOp;
}
class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(xevm::MMAOp op, xevm::MMAOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!op.getC()) {
return rewriter.notifyMatchFailure(op, "OCL requires C operand");
}
auto precisionA = op.getTypes().getA();
auto precisionB = op.getTypes().getB();
auto precisionC = op.getTypes().getC();
auto precisionD = op.getTypes().getD();
if (precisionC != precisionD) {
return rewriter.notifyMatchFailure(op, "type of C and D need to match");
}
if (precisionC != xevm::ElemType::S32 &&
precisionC != xevm::ElemType::F32 &&
precisionC != xevm::ElemType::F16 &&
precisionC != xevm::ElemType::BF16) {
return rewriter.notifyMatchFailure(
op, "type of C and D must be S32, F32, F16 or BF16");
}
if (precisionA == xevm::ElemType::S32 ||
precisionA == xevm::ElemType::F32) {
return rewriter.notifyMatchFailure(op, "type of A cannot be S32 or F32");
}
if (precisionB == xevm::ElemType::S32 ||
precisionB == xevm::ElemType::F32) {
return rewriter.notifyMatchFailure(op, "type of B cannot be S32 or F32");
}
constexpr uint32_t bitWidthPackedA{16};
constexpr uint32_t bitWidthPackedB{32};
auto loc = op.getLoc();
auto castIfNeeded = [&](Value val, Type packedType) -> Value {
VectorType origTy = cast<VectorType>(val.getType());
const uint32_t vecBitSize =
origTy.getNumElements() *
origTy.getElementType().getIntOrFloatBitWidth();
VectorType newTy = VectorType::get(
vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
if (origTy != newTy)
val = rewriter.create<LLVM::BitcastOp>(loc, newTy, val);
return val;
};
Value a = op.getA();
Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32)
? cast<Type>(rewriter.getF32Type())
: rewriter.getIntegerType(bitWidthPackedA);
a = castIfNeeded(a, packedAType);
Value b = op.getB();
Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32)
? cast<Type>(rewriter.getF32Type())
: rewriter.getIntegerType(bitWidthPackedB);
b = castIfNeeded(b, packedBType);
Value c = op.getC();
VectorType cOrigTy = cast<VectorType>(c.getType());
VectorType resOrigTy = cast<VectorType>(op->getResultTypes()[0]);
assert(cOrigTy == resOrigTy && "Accumulator and result type mismatch");
// OCL builtins encode bfloat16 as int16
VectorType cTy =
cOrigTy.getElementType().isBF16()
? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16))
: cOrigTy;
VectorType resTy = cTy;
if (cOrigTy != cTy)
c = rewriter.create<LLVM::BitcastOp>(loc, cTy, c);
constexpr int32_t systolicDepth{8};
std::string fnName =
llvm::formatv("intel_sub_group_{0}_{1}_matrix_mad_k{2}",
stringifyElemType(op.getTypes().getA()).str(),
stringifyElemType(op.getTypes().getB()).str(),
systolicDepth *
getNumOperandsPerDword(op.getTypes().getA()))
.str();
SmallVector<Type> argTypes{a.getType(), b.getType(), cTy};
fnName = mangle(fnName, argTypes);
SmallVector<Value> args{a, b, c};
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
/*other=*/LLVM::ModRefInfo::NoModRef,
/*argMem=*/LLVM::ModRefInfo::NoModRef,
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
auto funcAttrs = convergentNoUnwindWillReturnAttrs;
funcAttrs.memEffectsAttr = memAttr;
Value result =
createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args, {},
funcAttrs, op.getOperation())
->getResult(0);
if (resOrigTy != resTy)
result = rewriter.create<LLVM::BitcastOp>(loc, resOrigTy, result);
rewriter.replaceOp(op, result);
return success();
}
private:
static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
switch (pTy) {
case xevm::ElemType::TF32:
return 1;
case xevm::ElemType::BF16:
case xevm::ElemType::F16:
return 2;
case xevm::ElemType::U8:
case xevm::ElemType::S8:
return 4;
default:
llvm_unreachable("unsupported xevm::ElemType");
}
}
};
class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(PrefetchOp op, PrefetchOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
const std::string fnName{"_Z8prefetchPU3AS1Kcm"};
Value one =
rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(), 1);
SmallVector<Value> args{op.getPtr(), one};
SmallVector<Type> argTypes;
for (auto arg : args)
argTypes.push_back(arg.getType());
auto funcAttr = noUnwindAttrs;
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
/*other=*/LLVM::ModRefInfo::NoModRef,
/*argMem=*/LLVM::ModRefInfo::Ref,
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
funcAttr.memEffectsAttr = memAttr;
LLVM::CallOp call = createDeviceFunctionCall(
rewriter, fnName, LLVM::LLVMVoidType::get(rewriter.getContext()),
argTypes, args, {}, funcAttr, op.getOperation());
if (std::optional<ArrayAttr> optCacheControls =
getCacheControlMetadata<true>(rewriter, op))
call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
rewriter.eraseOp(op);
return success();
}
};
class MemfenceToOCLPattern : public OpConversionPattern<MemfenceOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(MemfenceOp op, MemfenceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
const std::string fnName{"atomic_work_item_fence"};
int memScope, addrSpace;
switch (op.getAddrspace()) {
case xevm::AddrSpace::SHARED:
addrSpace = 1; // CLK_LOCAL_MEM_FENCE
break;
case xevm::AddrSpace::GLOBAL:
addrSpace = 2; // CLK_GLOBAL_MEM_FENCE
break;
default:
// GENERIC is not supported in OpenCL
return rewriter.notifyMatchFailure(
op, "Fence only supports global and shared address spaces.");
}
switch (op.getScope()) {
case xevm::MemScope::WORKGROUP:
memScope = 1;
break;
case xevm::MemScope::DEVICE:
memScope = 2;
break;
default:
// CLUSTER and SYSTEM are not supported in OpenCL
return rewriter.notifyMatchFailure(
op, "Fence only supports workgroup and device memory scopes.");
}
Type i32Type = rewriter.getI32Type();
Value acqRel = rewriter.create<LLVM::ConstantOp>(loc, i32Type, 4);
Value memScopeConst =
rewriter.create<LLVM::ConstantOp>(loc, i32Type, memScope);
Value addrSpaceConst =
rewriter.create<LLVM::ConstantOp>(loc, i32Type, addrSpace);
SmallVector<Value> args{addrSpaceConst, acqRel, memScopeConst};
SmallVector<Type> argTypes{3, i32Type};
createDeviceFunctionCall(rewriter, mangle(fnName, argTypes),
LLVM::LLVMVoidType::get(rewriter.getContext()),
argTypes, args, {}, noUnwindAttrs,
op.getOperation());
rewriter.eraseOp(op);
return success();
}
};
template <typename OpType>
class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
using OpConversionPattern<OpType>::OpConversionPattern;
LogicalResult
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp>;
constexpr bool isPrefetch = std::is_same_v<OpType, BlockPrefetch2dOp>;
auto loc = op.getLoc();
VectorType vecType;
bool packReg = false;
bool transpose = false;
if constexpr (isLoad) {
vecType = op.getRes().getType();
packReg = op.getPackRegister();
transpose = op.getTranspose();
} else if constexpr (!isPrefetch) {
vecType = op.getStoredVal().getType();
}
auto i32Type = rewriter.getI32Type();
Value byteCoord =
rewriter.create<LLVM::UndefOp>(loc, VectorType::get(2, i32Type));
Value zero = rewriter.create<LLVM::ConstantOp>(loc, i32Type, 0);
Value one = rewriter.create<LLVM::ConstantOp>(loc, i32Type, 1);
byteCoord = rewriter.create<LLVM::InsertElementOp>(
loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero);
byteCoord = rewriter.create<LLVM::InsertElementOp>(
loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one);
SmallVector<Value> args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(),
op.getBasePitch(), byteCoord};
SmallVector<Type> retTypes;
Value spvLoadDstPtr;
std::string funcName{"intel_sub_group_2d_block_"};
std::string bitWidthId;
LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
SmallVector<std::pair<unsigned, StringRef>, 4> paramAttrs;
if constexpr (isPrefetch) { // Prefetch
funcName += "prefetch";
paramAttrs = {std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName())};
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
/*other=*/LLVM::ModRefInfo::NoModRef,
/*argMem=*/LLVM::ModRefInfo::Ref,
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
funcAttr = noUnwindAttrs;
funcAttr.memEffectsAttr = memAttr;
} else {
auto vecElemType = vecType.getElementType();
auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth();
Value numElems = rewriter.create<LLVM::ConstantOp>(
loc, i32Type, vecType.getNumElements());
auto dstOrSrcPtr = rewriter.create<LLVM::AllocaOp>(
loc, LLVM::LLVMPointerType::get(rewriter.getContext()), vecElemType,
numElems);
args.push_back(dstOrSrcPtr);
if constexpr (isLoad) { // Load
funcName += "read";
bitWidthId = getTypeMangling(vecElemType, /*isUnsigned=*/true);
if (packReg)
funcName += "_transform";
else if (transpose)
funcName += "_transpose";
spvLoadDstPtr = dstOrSrcPtr;
retTypes.push_back(vecType);
paramAttrs = {
std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
std::make_pair(0, LLVM::LLVMDialect::getReadonlyAttrName()),
std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
std::make_pair(5, LLVM::LLVMDialect::getWriteOnlyAttrName()),
};
} else { // Store
funcName += "write";
bitWidthId = (vecElemBitWidth == 32)
? "j"
: ((vecElemBitWidth == 16) ? "t" : "h");
rewriter.create<LLVM::StoreOp>(loc, op.getStoredVal(), dstOrSrcPtr);
paramAttrs = {
std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
std::make_pair(0, LLVM::LLVMDialect::getWriteOnlyAttrName()),
std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
std::make_pair(5, LLVM::LLVMDialect::getReadonlyAttrName()),
};
}
}
funcName =
llvm::formatv("{0}_{1}b_{2}r{3}x{4}c", funcName, op.getElemSizeInBits(),
op.getTileHeight(), op.getTileWidth(), op.getVBlocks())
.str();
funcName = llvm::formatv("_Z{0}{1}PU3AS1viiiDv2_i{2}{3}", funcName.size(),
funcName, isPrefetch ? "" : "P", bitWidthId)
.str();
SmallVector<Type> argTypes;
for (auto arg : args) {
argTypes.push_back(arg.getType());
}
LLVM::CallOp call = createDeviceFunctionCall(
rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()),
argTypes, args, paramAttrs, funcAttr, op.getOperation());
if (std::optional<ArrayAttr> optCacheControls =
getCacheControlMetadata < isLoad || isPrefetch > (rewriter, op)) {
call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
}
if constexpr (isLoad)
rewriter.replaceOp(
op, rewriter.create<LLVM::LoadOp>(loc, vecType, spvLoadDstPtr));
else
rewriter.eraseOp(op);
return success();
}
};
//===----------------------------------------------------------------------===//
// Pass Definition
//===----------------------------------------------------------------------===//
struct ConvertXeVMToLLVMPass
: public impl::ConvertXeVMToLLVMPassBase<ConvertXeVMToLLVMPass> {
using Base::Base;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<LLVM::LLVMDialect, XeVMDialect>();
}
void runOnOperation() override {
ConversionTarget target(getContext());
target.addLegalDialect<LLVM::LLVMDialect>();
target.addIllegalDialect<XeVMDialect>();
RewritePatternSet patterns(&getContext());
populateXeVMToLLVMConversionPatterns(patterns);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// ConvertToLLVMPatternInterface implementation
//===----------------------------------------------------------------------===//
namespace {
/// Implement the interface to convert XeVM to LLVM.
struct XeVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
void loadDependentDialects(MLIRContext *context) const final {
context->loadDialect<LLVM::LLVMDialect>();
}
/// Hook for derived dialect interface to provide conversion patterns
/// and mark dialect legal for the conversion target.
void populateConvertToLLVMConversionPatterns(
ConversionTarget &target, LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns) const final {
populateXeVMToLLVMConversionPatterns(patterns);
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Pattern Population
//===----------------------------------------------------------------------===//
void ::mlir::populateXeVMToLLVMConversionPatterns(RewritePatternSet &patterns) {
patterns.add<LoadStorePrefetchToOCLPattern<BlockLoad2dOp>,
LoadStorePrefetchToOCLPattern<BlockStore2dOp>,
LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>,
MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern>(
patterns.getContext());
}
void ::mlir::registerConvertXeVMToLLVMInterface(DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, XeVMDialect *dialect) {
dialect->addInterfaces<XeVMToLLVMDialectInterface>();
});
}

View File

@ -1,244 +0,0 @@
// RUN: mlir-opt --convert-xevm-to-llvm --split-input-file %s | FileCheck %s
// Same below, but using the `ConvertToLLVMPatternInterface` entry point
// and the generic `convert-to-llvm` pass.
// RUN: mlir-opt --convert-to-llvm --split-input-file %s | FileCheck %s
// CHECK-LABEL: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt(
// CHECK-SAME: !llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>,
// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return}
// CHECK: llvm.func @blockload2d(%[[ARG0:.*]]: !llvm.ptr<1>,
// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32)
llvm.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi16> {
// CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32>
// CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32>
// CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32>
// CHECK: %[[VAR5:.*]] = llvm.mlir.constant(8 : i32) : i32
// CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i16 : (i32) -> !llvm.ptr
// CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt(
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]])
// CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i32, i32, i32, vector<2xi32>, ptr)>,
// CHECK-SAME: linkage = #llvm.linkage<external>, no_unwind, sym_name =
// CHECK-SAME: "_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt", visibility_ = 0 : i64,
// CHECK-SAME: will_return} :
// CHECK-SAME: (!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>,
// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) -> ()
// CHECK: %[[VAR7:.*]] = llvm.load %[[VAR6]] : !llvm.ptr -> vector<8xi16>
%loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
<{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=1 : i32, transpose=false,
pack_register=false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
llvm.return %loaded_a : vector<8xi16>
}
// -----
// CHECK-LABEL: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt(
llvm.func @blockload2d_cache_control(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi16> {
// CHECK: xevm.DecorationCacheControl =
// CHECK-SAME: 6442 : i32, 0 : i32, 1 : i32, 0 : i32
// CHECK-SAME: 6442 : i32, 1 : i32, 1 : i32, 0 : i32
%loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
<{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=1 : i32, transpose=false,
pack_register=false, cache_control=#xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
llvm.return %loaded_a : vector<8xi16>
}
// -----
// CHECK-LABEL: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt(
// CHECK-SAME: !llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>,
// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return}
// CHECK: llvm.func @blockload2d_v_blocks(%[[ARG0:.*]]: !llvm.ptr<1>,
// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32)
llvm.func @blockload2d_v_blocks(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<16xi16> {
// CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32>
// CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32>
// CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32>
// CHECK: %[[VAR5:.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i16 : (i32) -> !llvm.ptr
// CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt(
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]])
// CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i32, i32, i32, vector<2xi32>, ptr)>,
// CHECK-SAME: linkage = #llvm.linkage<external>, no_unwind, sym_name =
// CHECK-SAME: "_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt", visibility_ = 0 : i64,
// CHECK-SAME: will_return}
// CHECK-SAME: (!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>,
// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) -> ()
// CHECK: %[[VAR7:.*]] = llvm.load %[[VAR6]] : !llvm.ptr -> vector<16xi16>
%loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
<{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=2 : i32, transpose=false,
pack_register=false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi16>
llvm.return %loaded_a : vector<16xi16>
}
// -----
// CHECK-LABEL: llvm.func spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_16r16x1cPU3AS1viiiDv2_iPj(
// CHECK-SAME: !llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>,
// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return}
// CHECK: llvm.func @blockload2d_pack_register(%[[ARG0:.*]]: !llvm.ptr<1>,
// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32)
llvm.func @blockload2d_pack_register(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi32> {
// CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32>
// CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32>
// CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32>
// CHECK: %[[VAR5:.*]] = llvm.mlir.constant(8 : i32) : i32
// CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i32 : (i32) -> !llvm.ptr
// CHECK: llvm.call spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_16r16x1cPU3AS1viiiDv2_iPj(
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]])
// CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i32, i32, i32, vector<2xi32>, ptr)>,
// CHECK-SAME: linkage = #llvm.linkage<external>, no_unwind, sym_name =
// CHECK-SAME: "_Z52intel_sub_group_2d_block_read_transform_16b_16r16x1cPU3AS1viiiDv2_iPj", visibility_ = 0 : i64,
// CHECK-SAME: will_return} :
// CHECK-SAME: (!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>,
// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) -> ()
// CHECK: %[[VAR7:.*]] = llvm.load %[[VAR6]] : !llvm.ptr -> vector<8xi32>
%loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
<{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=16 : i32, v_blocks=1 : i32, transpose=false,
pack_register=true}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
llvm.return %loaded_a : vector<8xi32>
}
// -----
// CHECK-LABEL: llvm.func spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj(
// CHECK-SAME: !llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>,
// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return}
// CHECK: llvm.func @blockload2d_transpose(%[[ARG0:.*]]: !llvm.ptr<1>,
// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32)
llvm.func @blockload2d_transpose(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi32> {
// CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32>
// CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32>
// CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32>
// CHECK: %[[VAR5:.*]] = llvm.mlir.constant(8 : i32) : i32
// CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i32 : (i32) -> !llvm.ptr
// CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj(
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]])
// CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i32, i32, i32, vector<2xi32>, ptr)>,
// CHECK-SAME: linkage = #llvm.linkage<external>, no_unwind, sym_name =
// CHECK-SAME: "_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj", visibility_ = 0 : i64,
// CHECK-SAME: will_return}
// CHECK-SAME: (!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>,
// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) -> ()
// CHECK: %[[VAR7:.*]] = llvm.load %[[VAR6]] : !llvm.ptr -> vector<8xi32>
%loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
<{elem_size_in_bits=32 : i32, tile_width=8 : i32, tile_height=16 : i32, v_blocks=1 : i32, transpose=true,
pack_register=false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
llvm.return %loaded_a : vector<8xi32>
}
// -----
// CHECK-LABEL: llvm.func spir_funccc @_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj(
// CHECK-SAME: !llvm.ptr<1> {llvm.nonnull, llvm.writeonly}, i32, i32, i32, vector<2xi32>,
// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.readonly}) attributes {no_unwind, will_return}
// CHECK: llvm.func @blockstore2d(%[[ARG0:.*]]: !llvm.ptr<1>,
// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32, %[[ARG6:.*]]: vector<8xi32>) {
llvm.func @blockstore2d(%c: !llvm.ptr<1>, %base_width_c: i32, %base_height_c: i32, %base_pitch_c: i32, %x: i32, %y: i32, %c_result_casted: vector<8xi32>) {
// CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32>
// CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32>
// CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32>
// CHECK: %[[VAR5:.*]] = llvm.mlir.constant(8 : i32) : i32
// CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i32 : (i32) -> !llvm.ptr
// CHECK: llvm.store %[[ARG6]], %[[VAR6]] : vector<8xi32>, !llvm.ptr
// CHECK: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj(
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]])
// CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i32, i32, i32, vector<2xi32>, ptr)>,
// CHECK-SAME: linkage = #llvm.linkage<external>, no_unwind, sym_name =
// CHECK-SAME: "_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj", visibility_ = 0 : i64,
// CHECK-SAME: will_return}
// CHECK-SAME: : (!llvm.ptr<1> {llvm.nonnull, llvm.writeonly}, i32, i32, i32, vector<2xi32>,
// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.readonly}) -> ()
xevm.blockstore2d %c, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %c_result_casted
<{elem_size_in_bits=32 : i32, tile_width=16 : i32, tile_height=8 : i32}>
: (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
llvm.return
}
// -----
// CHECK-LABEL: llvm.func spir_funccc @_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj(
llvm.func @blockstore2d_cache_control(%c: !llvm.ptr<1>, %base_width_c: i32, %base_height_c: i32, %base_pitch_c: i32, %x: i32, %y: i32, %c_result_casted: vector<8xi32>) {
// CHECK: xevm.DecorationCacheControl =
// CHECK-SAME: 6443 : i32, 0 : i32, 2 : i32, 0 : i32
// CHECK-SAME: 6443 : i32, 1 : i32, 2 : i32, 0 : i32
xevm.blockstore2d %c, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %c_result_casted
<{elem_size_in_bits=32 : i32, tile_width=16 : i32, tile_height=8 : i32, cache_control = #xevm.store_cache_control<L1wt_L2uc_L3wb>}>
: (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
llvm.return
}
// -----
// CHECK-LABEL: llvm.func spir_funccc @_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i(
// CHECK-SAME: !llvm.ptr<1> {llvm.nonnull}, i32, i32, i32, vector<2xi32>) attributes
// CHECK-SAME: {memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, no_unwind}
// CHECK: llvm.func @blockprefetch2d(%[[ARG0:.*]]: !llvm.ptr<1>,
// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32) {
llvm.func @blockprefetch2d(%ptr: !llvm.ptr<1>, %base_width: i32, %base_height: i32, %base_pitch: i32, %x: i32, %y: i32) {
// CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32>
// CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32>
// CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32>
// CHECK: llvm.call spir_funccc @_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i(
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]])
// CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i32, i32, i32, vector<2xi32>)>, linkage = #llvm.linkage<external>,
// CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, no_unwind,
// CHECK-SAME: sym_name = "_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i", visibility_ = 0 : i64
xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y
<{elem_size_in_bits=8 : i32, tile_width=32 : i32, tile_height=8 : i32, v_blocks=1 : i32,
cache_control=#xevm.load_cache_control<L1uc_L2uc_L3uc>}>
: (!llvm.ptr<1>, i32, i32, i32, i32, i32)
llvm.return
}
// -----
// CHECK-LABEL: llvm.func spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(
// CHECK-SAME: vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> attributes
// CHECK-SAME: {convergent, memory_effects = #llvm.memory_effects<other = none, argMem = none,
// CHECK-SAME: inaccessibleMem = none>, no_unwind, will_return}
// CHECK: llvm.func @mma(%[[ARG0:.*]]: vector<8xf32>, %[[ARG1:.*]]: vector<8xi16>, %[[ARG2:.*]]: vector<8xi32>) -> vector<8xf32> {
llvm.func @mma(%loaded_c_casted: vector<8xf32>, %loaded_a: vector<8xi16>, %loaded_b_casted: vector<8xi32>) -> vector<8xf32> {
// CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(
// CHECK-SAME: %[[ARG1]], %[[ARG2]], %[[ARG0]]) {convergent, function_type =
// CHECK-SAME: !llvm.func<vector<8xf32> (vector<8xi16>, vector<8xi32>, vector<8xf32>)>, linkage = #llvm.linkage<external>,
// CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, no_unwind,
// CHECK-SAME: sym_name = "_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f", visibility_ = 0 : i64, will_return}
// CHECK-SAME: : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
%c_result = xevm.mma %loaded_a, %loaded_b_casted, %loaded_c_casted
{ shape=<m=8, n=16, k=16>, types=<d=f32, a=f16, b=f16, c=f32> }
: (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
llvm.return %c_result : vector<8xf32>
}
// -----
// CHECK-LABEL: llvm.func spir_funccc @_Z22atomic_work_item_fenceiii(i32, i32, i32) attributes {no_unwind}
llvm.func @memfence() {
// CHECK: %[[VAR0:.*]] = llvm.mlir.constant(4 : i32) : i32
// CHECK: %[[VAR1:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[VAR2:.*]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z22atomic_work_item_fenceiii(%[[VAR2]], %[[VAR0]], %[[VAR1]])
// CHECK-SAME: {function_type = !llvm.func<void (i32, i32, i32)>, linkage = #llvm.linkage<external>, no_unwind,
// CHECK-SAME: sym_name = "_Z22atomic_work_item_fenceiii", visibility_ = 0 : i64} : (i32, i32, i32) -> ()
xevm.memfence <{addrspace=#xevm.addr_space<global>, scope=#xevm.mem_scope<workgroup>}>
llvm.return
}
// -----
// CHECK-LABEL: llvm.func spir_funccc @_Z8prefetchPU3AS1Kcm(!llvm.ptr<1>, i64) attributes
// CHECK-SAME: {memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, no_unwind}
// CHECK: llvm.func @prefetch(%[[ARG0:.*]]: !llvm.ptr<1>) {
llvm.func @prefetch(%ptr: !llvm.ptr<1>) {
// CHECK: %[[VAR0:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: llvm.call spir_funccc @_Z8prefetchPU3AS1Kcm(%[[ARG0]], %[[VAR0]])
// CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i64)>, linkage = #llvm.linkage<external>,
// CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>,
// CHECK-SAME: no_unwind, sym_name = "_Z8prefetchPU3AS1Kcm", visibility_ = 0 : i64
xevm.prefetch %ptr <{cache_control = #xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr<1>)
llvm.return
}