[MLIR][XeGPU] Validate DPAS operand types against uArch in XeGPUToXeVM conversion (#185081)
The `DpasOp` would crash with `llvm_unreachable` with unsupported types (like i16, or i32 in operand) when during lowering to the XeVM dialect. This happens in both `encodePrecision` and `getNumOperandsPerDword`. Per https://github.com/llvm/llvm-project/issues/180107#issuecomment-4009160113, we handle this in the `matchAndRewrite` by retrieving the uArch instance and fetching the registered `SubgroupMatrixMultiplyAcc` instruction. Then, we validate with `getSupportedTypes` and check `aTy`, `bTy`, and `resultType` correctly with `notifyMatchError` for reporting and graceful handling. We add a failed conversion test for a simplified version of the reproducible error in #180107
This commit is contained in:
parent
a36b969294
commit
e6cfdd01ae
@ -23,6 +23,7 @@
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
|
||||
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
|
||||
#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
@ -902,6 +903,40 @@ class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
|
||||
auto bTy = cast<VectorType>(op.getRhs().getType());
|
||||
auto resultType = cast<VectorType>(op.getResultType());
|
||||
|
||||
// get the correct dpasInst by getting info from chip
|
||||
auto chipStr = xegpu::getChipStr(op);
|
||||
if (!chipStr)
|
||||
return rewriter.notifyMatchFailure(op, "cannot determine target chip");
|
||||
|
||||
const auto *uArch = mlir::xegpu::uArch::getUArch(*chipStr);
|
||||
if (!uArch)
|
||||
return rewriter.notifyMatchFailure(op, "unsupported target uArch");
|
||||
|
||||
auto *dpasInst = const_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc *>(
|
||||
llvm::dyn_cast_or_null<xegpu::uArch::SubgroupMatrixMultiplyAcc>(
|
||||
uArch->getInstruction(
|
||||
xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc)));
|
||||
if (!dpasInst)
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"DPAS not supported by target uArch");
|
||||
|
||||
auto checkSupportedTypes = [&](VectorType vecTy,
|
||||
xegpu::uArch::MMAOpndKind kind) -> bool {
|
||||
auto supported = dpasInst->getSupportedTypes(*ctxt, kind);
|
||||
return llvm::find(supported, vecTy.getElementType()) != supported.end();
|
||||
};
|
||||
|
||||
if (!checkSupportedTypes(aTy, xegpu::uArch::MMAOpndKind::MatrixA))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "A-matrix element type not supported by target uArch");
|
||||
if (!checkSupportedTypes(bTy, xegpu::uArch::MMAOpndKind::MatrixB))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "B-matrix element type not supported by target uArch");
|
||||
// NOTE: Supported types for MatrixC and MatrixD are identical
|
||||
if (!checkSupportedTypes(resultType, xegpu::uArch::MMAOpndKind::MatrixD))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "result/accumulator element type not supported by target uArch");
|
||||
|
||||
auto encodePrecision = [&](Type type) -> xevm::ElemType {
|
||||
if (type == rewriter.getBF16Type())
|
||||
return xevm::ElemType::BF16;
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
|
||||
|
||||
gpu.module @test_kernel {
|
||||
gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
|
||||
// CHECK-LABEL: func.func @dpas(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: vector<8xf16>, %[[ARG1:.*]]: vector<16xf16>, %[[ARG2:.*]]: vector<8xf32>
|
||||
func.func @dpas(%a_loaded: vector<8xf16>, %b_loaded: vector<16xf16>, %c_loaded: vector<8xf32>) -> vector<8xf32> {
|
||||
|
||||
14
mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir
Normal file
14
mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir
Normal file
@ -0,0 +1,14 @@
|
||||
// RUN: mlir-opt --convert-xegpu-to-xevm %s -split-input-file -verify-diagnostics
|
||||
|
||||
// Verify that xegpu.dpas with unsupported element types (i16) is rejected
|
||||
// during XeGPUToXeVM conversion rather than crashing.
|
||||
|
||||
gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
|
||||
func.func @main() {
|
||||
%0 = arith.constant dense<0> : vector<4xi16>
|
||||
%1 = arith.constant dense<0> : vector<4xi32>
|
||||
// expected-error@+1 {{failed to legalize operation 'xegpu.dpas' that was explicitly marked illegal}}
|
||||
%2 = xegpu.dpas %0, %0, %1 : vector<4xi16>, vector<4xi16>, vector<4xi32> -> vector<4xi32>
|
||||
return
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user