diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 6df209438447..daee02990ee9 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -23,6 +23,7 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" +#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/STLExtras.h" @@ -902,6 +903,40 @@ class DpasToXeVMPattern : public OpConversionPattern { auto bTy = cast(op.getRhs().getType()); auto resultType = cast(op.getResultType()); + // get the correct dpasInst by getting info from chip + auto chipStr = xegpu::getChipStr(op); + if (!chipStr) + return rewriter.notifyMatchFailure(op, "cannot determine target chip"); + + const auto *uArch = mlir::xegpu::uArch::getUArch(*chipStr); + if (!uArch) + return rewriter.notifyMatchFailure(op, "unsupported target uArch"); + + auto *dpasInst = const_cast( + llvm::dyn_cast_or_null( + uArch->getInstruction( + xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc))); + if (!dpasInst) + return rewriter.notifyMatchFailure(op, + "DPAS not supported by target uArch"); + + auto checkSupportedTypes = [&](VectorType vecTy, + xegpu::uArch::MMAOpndKind kind) -> bool { + auto supported = dpasInst->getSupportedTypes(*ctxt, kind); + return llvm::find(supported, vecTy.getElementType()) != supported.end(); + }; + + if (!checkSupportedTypes(aTy, xegpu::uArch::MMAOpndKind::MatrixA)) + return rewriter.notifyMatchFailure( + op, "A-matrix element type not supported by target uArch"); + if (!checkSupportedTypes(bTy, xegpu::uArch::MMAOpndKind::MatrixB)) + return rewriter.notifyMatchFailure( + op, "B-matrix element type not supported by target uArch"); + // NOTE: Supported types for MatrixC and MatrixD are identical + if (!checkSupportedTypes(resultType, xegpu::uArch::MMAOpndKind::MatrixD)) + return rewriter.notifyMatchFailure( + op, "result/accumulator element type not supported by target uArch"); + auto encodePrecision = [&](Type type) -> xevm::ElemType { if (type == rewriter.getBF16Type()) return xevm::ElemType::BF16; diff --git a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir index a9ab0be00722..7cc59f4cfdcd 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir @@ -1,6 +1,6 @@ // RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s -gpu.module @test_kernel { +gpu.module @test_kernel [#xevm.target] { // CHECK-LABEL: func.func @dpas( // CHECK-SAME: %[[ARG0:.*]]: vector<8xf16>, %[[ARG1:.*]]: vector<16xf16>, %[[ARG2:.*]]: vector<8xf32> func.func @dpas(%a_loaded: vector<8xf16>, %b_loaded: vector<16xf16>, %c_loaded: vector<8xf32>) -> vector<8xf32> { diff --git a/mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir b/mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir new file mode 100644 index 000000000000..95211dcff250 --- /dev/null +++ b/mlir/test/Conversion/XeGPUToXeVM/failed_conversion.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-opt --convert-xegpu-to-xevm %s -split-input-file -verify-diagnostics + +// Verify that xegpu.dpas with unsupported element types (i16) is rejected +// during XeGPUToXeVM conversion rather than crashing. + +gpu.module @test_kernel [#xevm.target] { + func.func @main() { + %0 = arith.constant dense<0> : vector<4xi16> + %1 = arith.constant dense<0> : vector<4xi32> + // expected-error@+1 {{failed to legalize operation 'xegpu.dpas' that was explicitly marked illegal}} + %2 = xegpu.dpas %0, %0, %1 : vector<4xi16>, vector<4xi16>, vector<4xi32> -> vector<4xi32> + return + } +}