[MLIR][XeGPU] Validate DPAS operand types against uArch in XeGPUToXeVM conversion (#185081)

The `DpasOp` would crash with `llvm_unreachable` with unsupported types
(like i16, or i32 in operand) when during lowering to the XeVM dialect.
This happens in both `encodePrecision` and `getNumOperandsPerDword`.

Per
https://github.com/llvm/llvm-project/issues/180107#issuecomment-4009160113,
we handle this in the `matchAndRewrite` by retrieving the uArch instance
and fetching the registered `SubgroupMatrixMultiplyAcc` instruction.
Then, we validate with `getSupportedTypes` and check `aTy`, `bTy`, and
`resultType` correctly with `notifyMatchError` for reporting and
graceful handling.

We add a failed conversion test for a simplified version of the
reproducible error in #180107
This commit is contained in:
Arjun Bhamra 2026-03-25 10:37:17 -04:00 committed by GitHub
parent a36b969294
commit e6cfdd01ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 50 additions and 1 deletions

View File

@ -23,6 +23,7 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
@ -902,6 +903,40 @@ class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
auto bTy = cast<VectorType>(op.getRhs().getType());
auto resultType = cast<VectorType>(op.getResultType());
// get the correct dpasInst by getting info from chip
auto chipStr = xegpu::getChipStr(op);
if (!chipStr)
return rewriter.notifyMatchFailure(op, "cannot determine target chip");
const auto *uArch = mlir::xegpu::uArch::getUArch(*chipStr);
if (!uArch)
return rewriter.notifyMatchFailure(op, "unsupported target uArch");
auto *dpasInst = const_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc *>(
llvm::dyn_cast_or_null<xegpu::uArch::SubgroupMatrixMultiplyAcc>(
uArch->getInstruction(
xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc)));
if (!dpasInst)
return rewriter.notifyMatchFailure(op,
"DPAS not supported by target uArch");
auto checkSupportedTypes = [&](VectorType vecTy,
xegpu::uArch::MMAOpndKind kind) -> bool {
auto supported = dpasInst->getSupportedTypes(*ctxt, kind);
return llvm::find(supported, vecTy.getElementType()) != supported.end();
};
if (!checkSupportedTypes(aTy, xegpu::uArch::MMAOpndKind::MatrixA))
return rewriter.notifyMatchFailure(
op, "A-matrix element type not supported by target uArch");
if (!checkSupportedTypes(bTy, xegpu::uArch::MMAOpndKind::MatrixB))
return rewriter.notifyMatchFailure(
op, "B-matrix element type not supported by target uArch");
// NOTE: Supported types for MatrixC and MatrixD are identical
if (!checkSupportedTypes(resultType, xegpu::uArch::MMAOpndKind::MatrixD))
return rewriter.notifyMatchFailure(
op, "result/accumulator element type not supported by target uArch");
auto encodePrecision = [&](Type type) -> xevm::ElemType {
if (type == rewriter.getBF16Type())
return xevm::ElemType::BF16;

View File

@ -1,6 +1,6 @@
// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
gpu.module @test_kernel {
gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
// CHECK-LABEL: func.func @dpas(
// CHECK-SAME: %[[ARG0:.*]]: vector<8xf16>, %[[ARG1:.*]]: vector<16xf16>, %[[ARG2:.*]]: vector<8xf32>
func.func @dpas(%a_loaded: vector<8xf16>, %b_loaded: vector<16xf16>, %c_loaded: vector<8xf32>) -> vector<8xf32> {

View File

@ -0,0 +1,14 @@
// RUN: mlir-opt --convert-xegpu-to-xevm %s -split-input-file -verify-diagnostics
// Verify that xegpu.dpas with unsupported element types (i16) is rejected
// during XeGPUToXeVM conversion rather than crashing.
gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
func.func @main() {
%0 = arith.constant dense<0> : vector<4xi16>
%1 = arith.constant dense<0> : vector<4xi32>
// expected-error@+1 {{failed to legalize operation 'xegpu.dpas' that was explicitly marked illegal}}
%2 = xegpu.dpas %0, %0, %1 : vector<4xi16>, vector<4xi16>, vector<4xi32> -> vector<4xi32>
return
}
}