//===------ WmmaOpsToSPIRV.cpp - WMMA LD/ST/Compute to SPIRV lowering------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file contains definitions of patterns to lower GPU Subgroup MMA ops to // SPIRV Dialect ops. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h" #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/IR/TypeUtilities.h" using namespace mlir; // See SPV_NV_cooperative_matrix for supported element wise ops. static void createElementWiseOp(ConversionPatternRewriter &builder, gpu::SubgroupMmaElementwiseOp op, spirv::CooperativeMatrixNVType coopType, ValueRange operands) { switch (op.getOpType()) { case gpu::MMAElementwiseOp::ADDF: builder.replaceOpWithNewOp(op, coopType, operands); return; case gpu::MMAElementwiseOp::ADDI: builder.replaceOpWithNewOp(op, coopType, operands); return; case gpu::MMAElementwiseOp::SUBF: builder.replaceOpWithNewOp(op, coopType, operands); return; case gpu::MMAElementwiseOp::SUBI: builder.replaceOpWithNewOp(op, coopType, operands); return; case gpu::MMAElementwiseOp::DIVF: builder.replaceOpWithNewOp(op, coopType, operands); return; case gpu::MMAElementwiseOp::DIVS: builder.replaceOpWithNewOp(op, coopType, operands); return; case gpu::MMAElementwiseOp::DIVU: builder.replaceOpWithNewOp(op, coopType, operands); return; case gpu::MMAElementwiseOp::NEGATEF: builder.replaceOpWithNewOp(op, coopType, operands); return; case gpu::MMAElementwiseOp::NEGATES: builder.replaceOpWithNewOp(op, coopType, operands); return; default: llvm_unreachable("unknown op"); } } namespace { /// This class implements the conversion of GPU MMA loadOp to /// CooperativeMatrixLoad op in the SPIRV dialect. struct WmmaLoadOpToSPIRVLowering : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = subgroupMmaLoadMatrixOp->getLoc(); gpu::MMAMatrixType retType = subgroupMmaLoadMatrixOp.getRes().getType().cast(); auto memrefType = subgroupMmaLoadMatrixOp.getSrcMemref().getType().cast(); Value bufferPtr = spirv::getElementPtr( *getTypeConverter(), memrefType, adaptor.getSrcMemref(), adaptor.getIndices(), loc, rewriter); auto coopType = convertMMAToSPIRVType(retType); int64_t stride = subgroupMmaLoadMatrixOp.getLeadDimension().getSExtValue(); auto i32Type = rewriter.getI32Type(); auto strideValue = rewriter.create( loc, i32Type, IntegerAttr::get(i32Type, stride)); auto coloumnMajor = rewriter.create( loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); rewriter.replaceOpWithNewOp( subgroupMmaLoadMatrixOp, coopType, bufferPtr, strideValue, coloumnMajor, spirv::MemoryAccessAttr()); return success(); } }; /// This class implements the conversion of GPU MMA StoreOp to /// CooperativeMatrixStore op in the SPIRV dialect. struct WmmaStoreOpToSPIRVLowering : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = subgroupMmaStoreMatrixOp->getLoc(); auto memrefType = subgroupMmaStoreMatrixOp.getDstMemref().getType().cast(); Value bufferPtr = spirv::getElementPtr( *getTypeConverter(), memrefType, adaptor.getDstMemref(), adaptor.getIndices(), loc, rewriter); int64_t stride = subgroupMmaStoreMatrixOp.getLeadDimension().getSExtValue(); auto i32Type = rewriter.getI32Type(); auto strideValue = rewriter.create( loc, i32Type, IntegerAttr::get(i32Type, stride)); auto coloumnMajor = rewriter.create( loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); rewriter.replaceOpWithNewOp( subgroupMmaStoreMatrixOp, bufferPtr, adaptor.getSrc(), strideValue, coloumnMajor, spirv::MemoryAccessAttr()); return success(); } }; /// This class implements the conversion of GPU MMA Compute to /// CooperativeMatrixMulAdd op in the SPIRV dialect. struct WmmaMmaOpToSPIRVLowering : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( subgroupMmaComputeOp, adaptor.getOpC().getType(), adaptor.getOpA(), adaptor.getOpB(), adaptor.getOpC()); return success(); } }; /// Convert GPU MMA ConstantMatrixOp to constant SPIR-V cooperative matrix ops. struct WmmaConstantOpToSPIRVLowering : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantMatrixOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value cst = adaptor.getOperands()[0]; auto coopType = convertMMAToSPIRVType( subgroupMmaConstantMatrixOp.getType().cast()); rewriter.replaceOpWithNewOp( subgroupMmaConstantMatrixOp, coopType, cst); return success(); } }; /// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops. struct WmmaElementwiseOpToSPIRVLowering : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(gpu::SubgroupMmaElementwiseOp subgroupMmaElementwiseOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // All operands should be of cooperative matrix types. for (Value operand : adaptor.getOperands()) { if (!operand.getType().isa()) return failure(); } auto coopType = convertMMAToSPIRVType( subgroupMmaElementwiseOp.getType().cast()); createElementWiseOp(rewriter, subgroupMmaElementwiseOp, coopType, adaptor.getOperands()); return success(); } }; } // namespace /// Return the LLVMStructureType corresponding to the MMAMatrixType `type`. mlir::spirv::CooperativeMatrixNVType mlir::convertMMAToSPIRVType(gpu::MMAMatrixType type) { ArrayRef retTypeShape = type.getShape(); Type elementType = type.getElementType(); return spirv::CooperativeMatrixNVType::get( elementType, spirv::Scope::Subgroup, retTypeShape[0], retTypeShape[1]); } void mlir::populateGpuWMMAToSPIRVConversionPatterns( SPIRVTypeConverter &converter, RewritePatternSet &patterns) { patterns.add(converter, patterns.getContext()); }