[mlir][amx] Restore conversion interface for AMX (#143871)

Restores mistakenly removed AMX interface which ensures that the custom
tile type is converted to its LLVM equivalent within other operations
such as control flow.

Fix after #140559
This commit is contained in:
Adam Siemieniuk 2025-06-12 13:45:19 +02:00 committed by GitHub
parent 013034cd0f
commit d698ede748
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 44 additions and 0 deletions

View File

@ -25,6 +25,9 @@ void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
/// intrinsics.
void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target);
/// Register LLVM conversion interface for AMX dialect.
void registerConvertAMXToLLVMInterface(DialectRegistry &registry);
} // namespace mlir
#endif // MLIR_DIALECT_AMX_TRANSFORMS_H

View File

@ -32,6 +32,7 @@
#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h"
#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h"
@ -85,6 +86,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
registerConvertOpenMPToLLVMInterface(registry);
registerConvertSCFToEmitCInterface(registry);
ub::registerConvertUBToLLVMInterface(registry);
registerConvertAMXToLLVMInterface(registry);
gpu::registerConvertGpuToLLVMInterface(registry);
NVVM::registerConvertGpuToNVVMInterface(registry);
vector::registerConvertVectorToLLVMInterface(registry);

View File

@ -60,3 +60,22 @@ void mlir::populateAMXLegalizeForLLVMExportPatterns(
void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) {
target.addIllegalDialect<AMXDialect>();
}
namespace {
/// Implement the interface to convert AMX to LLVM.
struct AMXToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
void populateConvertToLLVMConversionPatterns(
ConversionTarget &target, LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns) const final {
populateAMXLegalizeForLLVMExportPatterns(typeConverter, patterns);
}
};
} // namespace
void mlir::registerConvertAMXToLLVMInterface(DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, amx::AMXDialect *dialect) {
dialect->addInterfaces<AMXToLLVMDialectInterface>();
});
}

View File

@ -88,3 +88,23 @@ func.func @amx_tile_muli(%matA: memref<?x?xi8>, %matB: memref<?x?xi8>,
amx.tile_store %out[%c16, %c16], %res3 : memref<?x?xi8>, !amx.tile<16x16xi32>
return
}
// CHECK-LABEL: define void @amx_tile_type_through_cf
func.func @amx_tile_type_through_cf(%src: memref<?x?xi8>, %out: memref<?x?xi8>,
%idx: index, %cond: i1) {
cf.cond_br %cond, ^bb1, ^bb2
^bb1: // pred: ^bb0
// CHECK: call x86_amx @llvm.x86.tileloadd64.internal
%0 = amx.tile_load %src[%idx, %idx] : memref<?x?xi8> into !amx.tile<16x64xi8>
cf.br ^bb3(%0 : !amx.tile<16x64xi8>)
^bb2: // pred: ^bb0
// CHECK: call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64)
%1 = amx.tile_zero : !amx.tile<16x64xi8>
cf.br ^bb3(%1 : !amx.tile<16x64xi8>)
^bb3(%2: !amx.tile<16x64xi8>): // 2 preds: ^bb1, ^bb2
cf.br ^bb4
^bb4: // pred: ^bb3
// CHECK: call void @llvm.x86.tilestored64.internal
amx.tile_store %out[%idx, %idx], %2 : memref<?x?xi8>, !amx.tile<16x64xi8>
return
}