From d698ede748e66f5519cb8481abc2df89a994a059 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Thu, 12 Jun 2025 13:45:19 +0200 Subject: [PATCH] [mlir][amx] Restore conversion interface for AMX (#143871) Restores mistakenly removed AMX interface which ensures that the custom tile type is converted to its LLVM equivalent within other operations such as control flow. Fix after #140559 --- mlir/include/mlir/Dialect/AMX/Transforms.h | 3 +++ mlir/include/mlir/InitAllExtensions.h | 2 ++ .../AMX/Transforms/LegalizeForLLVMExport.cpp | 19 ++++++++++++++++++ mlir/test/Target/LLVMIR/amx.mlir | 20 +++++++++++++++++++ 4 files changed, 44 insertions(+) diff --git a/mlir/include/mlir/Dialect/AMX/Transforms.h b/mlir/include/mlir/Dialect/AMX/Transforms.h index 4a751d99ceee..7391ec2ff6b1 100644 --- a/mlir/include/mlir/Dialect/AMX/Transforms.h +++ b/mlir/include/mlir/Dialect/AMX/Transforms.h @@ -25,6 +25,9 @@ void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter, /// intrinsics. void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target); +/// Register LLVM conversion interface for AMX dialect. +void registerConvertAMXToLLVMInterface(DialectRegistry ®istry); + } // namespace mlir #endif // MLIR_DIALECT_AMX_TRANSFORMS_H diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h index 7dcbabe8aafa..f356b91b1b6c 100644 --- a/mlir/include/mlir/InitAllExtensions.h +++ b/mlir/include/mlir/InitAllExtensions.h @@ -32,6 +32,7 @@ #include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h" #include "mlir/Conversion/UBToLLVM/UBToLLVM.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" +#include "mlir/Dialect/AMX/Transforms.h" #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h" #include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h" #include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h" @@ -85,6 +86,7 @@ inline void registerAllExtensions(DialectRegistry ®istry) { registerConvertOpenMPToLLVMInterface(registry); registerConvertSCFToEmitCInterface(registry); ub::registerConvertUBToLLVMInterface(registry); + registerConvertAMXToLLVMInterface(registry); gpu::registerConvertGpuToLLVMInterface(registry); NVVM::registerConvertGpuToNVVMInterface(registry); vector::registerConvertVectorToLLVMInterface(registry); diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp index 7471dc797e0f..37aebc9fab3e 100644 --- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp @@ -60,3 +60,22 @@ void mlir::populateAMXLegalizeForLLVMExportPatterns( void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) { target.addIllegalDialect(); } + +namespace { +/// Implement the interface to convert AMX to LLVM. +struct AMXToLLVMDialectInterface : public ConvertToLLVMPatternInterface { + using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; + + void populateConvertToLLVMConversionPatterns( + ConversionTarget &target, LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns) const final { + populateAMXLegalizeForLLVMExportPatterns(typeConverter, patterns); + } +}; +} // namespace + +void mlir::registerConvertAMXToLLVMInterface(DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, amx::AMXDialect *dialect) { + dialect->addInterfaces(); + }); +} diff --git a/mlir/test/Target/LLVMIR/amx.mlir b/mlir/test/Target/LLVMIR/amx.mlir index 094475040436..abdf2fe3bd53 100644 --- a/mlir/test/Target/LLVMIR/amx.mlir +++ b/mlir/test/Target/LLVMIR/amx.mlir @@ -88,3 +88,23 @@ func.func @amx_tile_muli(%matA: memref, %matB: memref, amx.tile_store %out[%c16, %c16], %res3 : memref, !amx.tile<16x16xi32> return } + +// CHECK-LABEL: define void @amx_tile_type_through_cf +func.func @amx_tile_type_through_cf(%src: memref, %out: memref, + %idx: index, %cond: i1) { + cf.cond_br %cond, ^bb1, ^bb2 +^bb1: // pred: ^bb0 + // CHECK: call x86_amx @llvm.x86.tileloadd64.internal + %0 = amx.tile_load %src[%idx, %idx] : memref into !amx.tile<16x64xi8> + cf.br ^bb3(%0 : !amx.tile<16x64xi8>) +^bb2: // pred: ^bb0 + // CHECK: call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64) + %1 = amx.tile_zero : !amx.tile<16x64xi8> + cf.br ^bb3(%1 : !amx.tile<16x64xi8>) +^bb3(%2: !amx.tile<16x64xi8>): // 2 preds: ^bb1, ^bb2 + cf.br ^bb4 +^bb4: // pred: ^bb3 + // CHECK: call void @llvm.x86.tilestored64.internal + amx.tile_store %out[%idx, %idx], %2 : memref, !amx.tile<16x64xi8> + return +}