Unifies the two dialects that define x86 operations into a single one. The AMX dialect is moved into X86 in line with other x86 extensions. Following the dialect renaming, X86 dialect is now a suitable home for wider range of operations targeting specific hardware features. Moving AMX definitions to X86 dialect creates a single, centralized hub for defining all x86 intrinsic-like operations. The new grouping aims to eliminate the need for new dialects as new hardware extensions become available. The two dialects are simply merged together. X86 dialect refactoring will be addressed separately. List of changes: - operations: 'amx.tile_*' => 'x86.amx.tile_*' - types: '!amx.tile' => '!x86.amx.tile' - namespace: 'mlir::amx' => 'mlir::x86::amx' - test define: 'MLIR_RUN_AMX_TESTS' => 'MLIR_RUN_X86_AMX_TESTS' - vector lowering: AMX is enabled by default together with X86 The MLIR AMX tests are now nested under X86 directory. To enable AMX integration tests, 'MLIR_RUN_X86_TESTS' must also be defined.
216 lines
8.8 KiB
C++
216 lines
8.8 KiB
C++
//=- Transforms.h - X86 Dialect Transformation Entrypoints --------*- C++ -*-=//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#ifndef MLIR_DIALECT_X86_TRANSFORMS_H
|
|
#define MLIR_DIALECT_X86_TRANSFORMS_H
|
|
|
|
#include "mlir/IR/Value.h"
|
|
|
|
namespace mlir {
|
|
|
|
class ImplicitLocOpBuilder;
|
|
class LLVMConversionTarget;
|
|
class LLVMTypeConverter;
|
|
class RewritePatternSet;
|
|
|
|
namespace x86 {
|
|
|
|
/// Helper class to factor out the creation and extraction of masks from nibs.
|
|
struct MaskHelper {
|
|
/// b0 captures the lowest bit, b7 captures the highest bit.
|
|
/// Meant to be used with instructions such as mm256BlendPs.
|
|
template <uint8_t b0, uint8_t b1, uint8_t b2, uint8_t b3, uint8_t b4,
|
|
uint8_t b5, uint8_t b6, uint8_t b7>
|
|
static uint8_t blend() {
|
|
static_assert(b0 <= 1 && b1 <= 1 && b2 <= 1 && b3 <= 1, "overflow");
|
|
static_assert(b4 <= 1 && b5 <= 1 && b6 <= 1 && b7 <= 1, "overflow");
|
|
return static_cast<uint8_t>((b7 << 7) | (b6 << 6) | (b5 << 5) | (b4 << 4) |
|
|
(b3 << 3) | (b2 << 2) | (b1 << 1) | b0);
|
|
}
|
|
/// b0 captures the lowest bit, b7 captures the highest bit.
|
|
/// Meant to be used with instructions such as mm256BlendPs.
|
|
static void extractBlend(uint8_t mask, uint8_t &b0, uint8_t &b1, uint8_t &b2,
|
|
uint8_t &b3, uint8_t &b4, uint8_t &b5, uint8_t &b6,
|
|
uint8_t &b7) {
|
|
b7 = mask & (1 << 7);
|
|
b6 = mask & (1 << 6);
|
|
b5 = mask & (1 << 5);
|
|
b4 = mask & (1 << 4);
|
|
b3 = mask & (1 << 3);
|
|
b2 = mask & (1 << 2);
|
|
b1 = mask & (1 << 1);
|
|
b0 = mask & 1;
|
|
}
|
|
/// b01 captures the lower 2 bits, b67 captures the higher 2 bits.
|
|
/// Meant to be used with instructions such as mm256ShufflePs.
|
|
template <unsigned b67, unsigned b45, unsigned b23, unsigned b01>
|
|
static uint8_t shuffle() {
|
|
static_assert(b01 <= 0x03, "overflow");
|
|
static_assert(b23 <= 0x03, "overflow");
|
|
static_assert(b45 <= 0x03, "overflow");
|
|
static_assert(b67 <= 0x03, "overflow");
|
|
return static_cast<uint8_t>((b67 << 6) | (b45 << 4) | (b23 << 2) | b01);
|
|
}
|
|
/// b01 captures the lower 2 bits, b67 captures the higher 2 bits.
|
|
static void extractShuffle(uint8_t mask, uint8_t &b01, uint8_t &b23,
|
|
uint8_t &b45, uint8_t &b67) {
|
|
b67 = (mask & (0x03 << 6)) >> 6;
|
|
b45 = (mask & (0x03 << 4)) >> 4;
|
|
b23 = (mask & (0x03 << 2)) >> 2;
|
|
b01 = mask & 0x03;
|
|
}
|
|
/// b03 captures the lower 4 bits, b47 captures the higher 4 bits.
|
|
/// Meant to be used with instructions such as mm256Permute2f128Ps.
|
|
template <unsigned b47, unsigned b03>
|
|
static uint8_t permute() {
|
|
static_assert(b03 <= 0x0f, "overflow");
|
|
static_assert(b47 <= 0x0f, "overflow");
|
|
return static_cast<uint8_t>((b47 << 4) + b03);
|
|
}
|
|
/// b03 captures the lower 4 bits, b47 captures the higher 4 bits.
|
|
static void extractPermute(uint8_t mask, uint8_t &b03, uint8_t &b47) {
|
|
b47 = (mask & (0x0f << 4)) >> 4;
|
|
b03 = mask & 0x0f;
|
|
}
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// A set of patterns for specialized lowering of vector contraction
|
|
// operation to vector fused multiply and add (FMA) operation.
|
|
void populateVectorContractToFMAPatterns(RewritePatternSet &patterns);
|
|
|
|
// A set of patterns for lowering 32-bit packed vector contraction operations
|
|
// to their corresponding packed-type dot-product operations, ultimately
|
|
// targeting the relevant x86 LLVM intrinsics (e.g., BF16 and Int8).
|
|
void populateVectorContractToPackedTypeDotProductPatterns(
|
|
RewritePatternSet &patterns);
|
|
|
|
// A set of patterns for lowering 32-bit packed BF16 vector contraction
|
|
// operations to vector fused multiply-add (FMA) operations, following
|
|
// the emulation-based approach using BF16 packed operations.
|
|
void populateVectorContractBF16ToFMAPatterns(RewritePatternSet &patterns);
|
|
|
|
// Performs forward scheduling of vector producer ops to minimize their live
|
|
// range by placing them at their earliest legal use site.
|
|
void populateSinkVectorProducerOpsPatterns(RewritePatternSet &patterns);
|
|
|
|
// Shuffles FMAs with x86 operations as operands such that FMAs are
|
|
// grouped with respect to odd/even packed index.
|
|
void populateShuffleVectorFMAOpsPatterns(RewritePatternSet &patterns);
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
/// Helpers extracted from:
|
|
/// - clang/lib/Headers/avxintrin.h
|
|
/// - clang/test/CodeGen/X86/avx-builtins.c
|
|
/// - clang/test/CodeGen/X86/avx2-builtins.c
|
|
/// - clang/test/CodeGen/X86/avx-shuffle-builtins.c
|
|
/// as well as the Intel Intrinsics Guide
|
|
/// (https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html)
|
|
/// make it easier to just implement known good lowerings.
|
|
/// All intrinsics correspond 1-1 to the Intel definition.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace avx2 {
|
|
|
|
namespace inline_asm {
|
|
//===----------------------------------------------------------------------===//
|
|
/// Methods in the inline_asm namespace emit calls to LLVM::InlineAsmOp.
|
|
//===----------------------------------------------------------------------===//
|
|
/// If bit i of `mask` is zero, take f32@i from v1 else take it from v2.
|
|
Value mm256BlendPsAsm(ImplicitLocOpBuilder &b, Value v1, Value v2,
|
|
uint8_t mask);
|
|
|
|
} // namespace inline_asm
|
|
|
|
namespace intrin {
|
|
//===----------------------------------------------------------------------===//
|
|
/// Methods in the intrin namespace emulate clang's impl. of X86 intrinsics.
|
|
//===----------------------------------------------------------------------===//
|
|
/// Lower to vector.shuffle v1, v2, [0, 8, 1, 9, 4, 12, 5, 13].
|
|
Value mm256UnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2);
|
|
|
|
/// Lower to vector.shuffle v1, v2, [0, 8, 1, 9, 4, 12, 5, 13].
|
|
Value mm256UnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2);
|
|
|
|
/// a a b b a a b b
|
|
/// Take an 8 bit mask, 2 bit for each position of a[0, 3) **and** b[0, 4):
|
|
/// 0:127 | 128:255
|
|
/// b01 b23 C8 D8 | b01+4 b23+4 C8+4 D8+4
|
|
Value mm256ShufflePs(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask);
|
|
|
|
// imm[0:1] out of imm[0:3] is:
|
|
// 0 1 2 3
|
|
// a[0:127] or a[128:255] or b[0:127] or b[128:255] |
|
|
// a[0:127] or a[128:255] or b[0:127] or b[128:255]
|
|
// 0 1 2 3
|
|
// imm[0:1] out of imm[4:7].
|
|
Value mm256Permute2f128Ps(ImplicitLocOpBuilder &b, Value v1, Value v2,
|
|
uint8_t mask);
|
|
|
|
/// If bit i of `mask` is zero, take f32@i from v1 else take it from v2.
|
|
Value mm256BlendPs(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask);
|
|
} // namespace intrin
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
/// Generic lowerings may either use intrin or inline_asm depending on needs.
|
|
//===----------------------------------------------------------------------===//
|
|
/// 4x8xf32-specific AVX2 transpose lowering.
|
|
void transpose4x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef<Value> vs);
|
|
|
|
/// 8x8xf32-specific AVX2 transpose lowering.
|
|
void transpose8x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef<Value> vs);
|
|
|
|
/// Structure to control the behavior of specialized AVX2 transpose lowering.
|
|
struct TransposeLoweringOptions {
|
|
bool lower4x8xf32_ = false;
|
|
TransposeLoweringOptions &lower4x8xf32(bool lower = true) {
|
|
lower4x8xf32_ = lower;
|
|
return *this;
|
|
}
|
|
bool lower8x8xf32_ = false;
|
|
TransposeLoweringOptions &lower8x8xf32(bool lower = true) {
|
|
lower8x8xf32_ = lower;
|
|
return *this;
|
|
}
|
|
};
|
|
|
|
/// Options for controlling specialized AVX2 lowerings.
|
|
struct LoweringOptions {
|
|
/// Configure specialized vector lowerings.
|
|
TransposeLoweringOptions transposeOptions;
|
|
LoweringOptions &setTransposeOptions(TransposeLoweringOptions options) {
|
|
transposeOptions = options;
|
|
return *this;
|
|
}
|
|
};
|
|
|
|
/// Insert specialized transpose lowering patterns.
|
|
void populateSpecializedTransposeLoweringPatterns(
|
|
RewritePatternSet &patterns, LoweringOptions options = LoweringOptions(),
|
|
int benefit = 10);
|
|
|
|
} // namespace avx2
|
|
} // namespace x86
|
|
|
|
/// Collect a set of patterns to lower X86 ops to ops that map to LLVM
|
|
/// intrinsics.
|
|
void populateX86LegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
|
|
RewritePatternSet &patterns);
|
|
|
|
/// Configure the target to support lowering X86 ops to ops that map to
|
|
/// LLVM intrinsics.
|
|
void configureX86LegalizeForExportTarget(LLVMConversionTarget &target);
|
|
|
|
/// Register LLVM conversion interface for X86 dialect.
|
|
void registerConvertX86ToLLVMInterface(DialectRegistry ®istry);
|
|
|
|
} // namespace mlir
|
|
|
|
#endif // MLIR_DIALECT_X86_TRANSFORMS_H
|