Adam Siemieniuk e44fd05035
[mlir][x86] Move AMX dialect into X86 dialect (#183717)
Unifies the two dialects that define x86 operations into a single one.
The AMX dialect is moved into X86 in line with other x86 extensions.

Following the dialect renaming, X86 dialect is now a suitable home for
wider range of operations targeting specific hardware features. Moving
AMX definitions to X86 dialect creates a single, centralized hub for
defining all x86 intrinsic-like operations. The new grouping aims to
eliminate the need for new dialects as new hardware extensions become
available.

The two dialects are simply merged together. X86 dialect refactoring
will be addressed separately.

List of changes:
  - operations: 'amx.tile_*' => 'x86.amx.tile_*'
  - types: '!amx.tile' => '!x86.amx.tile'
  - namespace: 'mlir::amx' => 'mlir::x86::amx'
  - test define: 'MLIR_RUN_AMX_TESTS' => 'MLIR_RUN_X86_AMX_TESTS'
  - vector lowering: AMX is enabled by default together with X86

The MLIR AMX tests are now nested under X86 directory. To enable AMX
integration tests, 'MLIR_RUN_X86_TESTS' must also be defined.
2026-03-02 11:47:30 +01:00

159 lines
4.8 KiB
MLIR

// RUN: mlir-opt %s -split-input-file -verify-diagnostics
func.func @tile_row_height() {
// expected-error@+1 {{'x86.amx.tile_zero' op bad row height: 17}}
%0 = x86.amx.tile_zero : !x86.amx.tile<17x16xbf16>
return
}
// -----
func.func @tile_col_width() {
// expected-error@+1 {{'x86.amx.tile_zero' op bad column width: 65}}
%0 = x86.amx.tile_zero : !x86.amx.tile<16x65xi8>
return
}
// -----
func.func @tile_element_type() {
// expected-error@+1 {{failed to verify 'elementType'}}
%0 = x86.amx.tile_zero : !x86.amx.tile<8x8xi16>
return
}
// -----
func.func @tile_rank() {
// expected-error@+1 {{'x86.amx.tile_zero' op result #0 must be tile of}}
%0 = x86.amx.tile_zero : !x86.amx.tile<32xi8>
return
}
// -----
func.func @tile_col_4_byte_multiple() {
// expected-error@+1 {{'x86.amx.tile_zero' op bad column width: 5}}
%0 = x86.amx.tile_zero : !x86.amx.tile<16x5xi8>
return
}
// -----
func.func @load_base_tile_size(%arg0: memref<?x?xf32>) {
%0 = arith.constant 0 : index
// expected-error@+1 {{'x86.amx.tile_load' op bad column width: 68}}
%1 = x86.amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into !x86.amx.tile<16x17xf32>
return
}
// -----
func.func @store_base_tile_size(%arg0: memref<?x?xf32>, %arg1: !x86.amx.tile<16x17xf32>) {
%0 = arith.constant 0 : index
// expected-error@+1 {{'x86.amx.tile_store' op bad column width: 68}}
x86.amx.tile_store %arg0[%0, %0], %arg1 : memref<?x?xf32>, !x86.amx.tile<16x17xf32>
return
}
// -----
func.func @load_base_index_size(%arg0: memref<?x?xf32>) {
%0 = arith.constant 0 : index
// expected-error@+1 {{'x86.amx.tile_load' op requires 2 indices}}
%1 = x86.amx.tile_load %arg0[%0] : memref<?x?xf32> into !x86.amx.tile<16x16xf32>
return
}
// -----
func.func @store_base_index_size(%arg0: memref<?x?xf32>, %arg1: !x86.amx.tile<16x16xf32>) {
%0 = arith.constant 0 : index
// expected-error@+1 {{'x86.amx.tile_store' op requires 2 indices}}
x86.amx.tile_store %arg0[%0], %arg1 : memref<?x?xf32>, !x86.amx.tile<16x16xf32>
return
}
// -----
func.func @load_base_rank(%arg0: memref<?xf32>) {
%0 = arith.constant 0 : index
// expected-error@+1 {{'x86.amx.tile_load' op requires at least 2D memref}}
%1 = x86.amx.tile_load %arg0[%0] : memref<?xf32> into !x86.amx.tile<16x16xf32>
return
}
// -----
func.func @store_base_rank(%arg0: memref<?xf32>, %arg1: !x86.amx.tile<16x16xf32>) {
%0 = arith.constant 0 : index
// expected-error@+1 {{'x86.amx.tile_store' op requires at least 2D memref}}
x86.amx.tile_store %arg0[%0], %arg1 : memref<?xf32>, !x86.amx.tile<16x16xf32>
return
}
// -----
func.func @load_base_non_unit_stride(%arg0: memref<?x?xf32, strided<[?, ?]>>) {
%0 = arith.constant 0 : index
// expected-error@+1 {{'x86.amx.tile_load' op requires memref with unit innermost stride}}
%1 = x86.amx.tile_load %arg0[%0, %0]
: memref<?x?xf32, strided<[?, ?]>> into !x86.amx.tile<16x16xf32>
return
}
// -----
func.func @store_base_non_unit_stride(%arg0: memref<?x?xf32, strided<[?, ?]>>,
%arg1: !x86.amx.tile<16x16xf32>) {
%0 = arith.constant 0 : index
// expected-error@+1 {{'x86.amx.tile_store' op requires memref with unit innermost stride}}
x86.amx.tile_store %arg0[%0, %0], %arg1
: memref<?x?xf32, strided<[?, ?]>>, !x86.amx.tile<16x16xf32>
return
}
// -----
func.func @mulf_shape() {
%0 = x86.amx.tile_zero : !x86.amx.tile<8x8xbf16>
%1 = x86.amx.tile_zero : !x86.amx.tile<8x8xbf16>
%2 = x86.amx.tile_zero : !x86.amx.tile<4x4xf32>
// expected-error@+1 {{'x86.amx.tile_mulf' op bad mult shape: 4 x 4 x 4}}
%3 = x86.amx.tile_mulf %0, %1, %2 : !x86.amx.tile<8x8xbf16>, !x86.amx.tile<8x8xbf16>, !x86.amx.tile<4x4xf32>
return
}
// -----
func.func @mulf_type_combination() {
%0 = x86.amx.tile_zero : !x86.amx.tile<8x8xbf16>
%1 = x86.amx.tile_zero : !x86.amx.tile<4x8xf16>
%2 = x86.amx.tile_zero : !x86.amx.tile<8x4xf32>
// expected-error@+1 {{'x86.amx.tile_mulf' op unsupported type combination}}
%3 = x86.amx.tile_mulf %0, %1, %2 : !x86.amx.tile<8x8xbf16>, !x86.amx.tile<4x8xf16>, !x86.amx.tile<8x4xf32>
return
}
// -----
func.func @muli_shape() {
%0 = x86.amx.tile_zero : !x86.amx.tile<8x8xi8>
%1 = x86.amx.tile_zero : !x86.amx.tile<8x8xi8>
%2 = x86.amx.tile_zero : !x86.amx.tile<4x4xi32>
// expected-error@+1 {{'x86.amx.tile_muli' op bad mult shape: 4 x 4 x 2}}
%3 = x86.amx.tile_muli %0, %1, %2 : !x86.amx.tile<8x8xi8>, !x86.amx.tile<8x8xi8>, !x86.amx.tile<4x4xi32>
return
}
// -----
func.func @muli_type_combination() {
%0 = x86.amx.tile_zero : !x86.amx.tile<8x16xi8>
%1 = x86.amx.tile_zero : !x86.amx.tile<8x16xi32>
%2 = x86.amx.tile_zero : !x86.amx.tile<2x2xi32>
// expected-error@+1 {{'x86.amx.tile_muli' op operand #1 must be tile of 8-bit signless integer values}}
%3 = x86.amx.tile_muli %0, %1, %2 : !x86.amx.tile<8x16xi8>, !x86.amx.tile<8x16xi32>, !x86.amx.tile<2x2xi32>
return
}