[mlir][amx] Restore conversion interface for AMX (#143871)
Restores mistakenly removed AMX interface which ensures that the custom tile type is converted to its LLVM equivalent within other operations such as control flow. Fix after #140559
This commit is contained in:
parent
013034cd0f
commit
d698ede748
@ -25,6 +25,9 @@ void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
|
||||
/// intrinsics.
|
||||
void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target);
|
||||
|
||||
/// Register LLVM conversion interface for AMX dialect.
|
||||
void registerConvertAMXToLLVMInterface(DialectRegistry ®istry);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_AMX_TRANSFORMS_H
|
||||
|
||||
@ -32,6 +32,7 @@
|
||||
#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
|
||||
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
|
||||
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
|
||||
#include "mlir/Dialect/AMX/Transforms.h"
|
||||
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
|
||||
#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h"
|
||||
#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h"
|
||||
@ -85,6 +86,7 @@ inline void registerAllExtensions(DialectRegistry ®istry) {
|
||||
registerConvertOpenMPToLLVMInterface(registry);
|
||||
registerConvertSCFToEmitCInterface(registry);
|
||||
ub::registerConvertUBToLLVMInterface(registry);
|
||||
registerConvertAMXToLLVMInterface(registry);
|
||||
gpu::registerConvertGpuToLLVMInterface(registry);
|
||||
NVVM::registerConvertGpuToNVVMInterface(registry);
|
||||
vector::registerConvertVectorToLLVMInterface(registry);
|
||||
|
||||
@ -60,3 +60,22 @@ void mlir::populateAMXLegalizeForLLVMExportPatterns(
|
||||
void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) {
|
||||
target.addIllegalDialect<AMXDialect>();
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Implement the interface to convert AMX to LLVM.
|
||||
struct AMXToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
|
||||
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
|
||||
|
||||
void populateConvertToLLVMConversionPatterns(
|
||||
ConversionTarget &target, LLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) const final {
|
||||
populateAMXLegalizeForLLVMExportPatterns(typeConverter, patterns);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::registerConvertAMXToLLVMInterface(DialectRegistry ®istry) {
|
||||
registry.addExtension(+[](MLIRContext *ctx, amx::AMXDialect *dialect) {
|
||||
dialect->addInterfaces<AMXToLLVMDialectInterface>();
|
||||
});
|
||||
}
|
||||
|
||||
@ -88,3 +88,23 @@ func.func @amx_tile_muli(%matA: memref<?x?xi8>, %matB: memref<?x?xi8>,
|
||||
amx.tile_store %out[%c16, %c16], %res3 : memref<?x?xi8>, !amx.tile<16x16xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define void @amx_tile_type_through_cf
|
||||
func.func @amx_tile_type_through_cf(%src: memref<?x?xi8>, %out: memref<?x?xi8>,
|
||||
%idx: index, %cond: i1) {
|
||||
cf.cond_br %cond, ^bb1, ^bb2
|
||||
^bb1: // pred: ^bb0
|
||||
// CHECK: call x86_amx @llvm.x86.tileloadd64.internal
|
||||
%0 = amx.tile_load %src[%idx, %idx] : memref<?x?xi8> into !amx.tile<16x64xi8>
|
||||
cf.br ^bb3(%0 : !amx.tile<16x64xi8>)
|
||||
^bb2: // pred: ^bb0
|
||||
// CHECK: call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64)
|
||||
%1 = amx.tile_zero : !amx.tile<16x64xi8>
|
||||
cf.br ^bb3(%1 : !amx.tile<16x64xi8>)
|
||||
^bb3(%2: !amx.tile<16x64xi8>): // 2 preds: ^bb1, ^bb2
|
||||
cf.br ^bb4
|
||||
^bb4: // pred: ^bb3
|
||||
// CHECK: call void @llvm.x86.tilestored64.internal
|
||||
amx.tile_store %out[%idx, %idx], %2 : memref<?x?xi8>, !amx.tile<16x64xi8>
|
||||
return
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user