//===- VectorToArmSME.cpp - Conversion from Vector to the ArmSME dialect --===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h" #include "mlir/Dialect/ArmSME/IR/ArmSME.h" #include "mlir/Dialect/ArmSME/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/BuiltinTypes.h" #include "llvm/Support/Casting.h" using namespace mlir; /// Returns true if 'val' is a splat of zero, false otherwise. static bool isSplatZero(Type elemType, DenseElementsAttr val) { if (llvm::isa(elemType)) return val && val.isSplat() && val.getSplatValue().isZero(); if (llvm::isa(elemType)) return val && val.isSplat() && val.getSplatValue().isZero(); return false; } /// Generates a for loop over ZA tile slices where the induction variable is /// the tile slice index. Sets the IR Builder insertion point as the loop body. /// Callers of this method are responsible for restoring it if needed. static scf::ForOp getLoopOverTileSlices(PatternRewriter &rewriter, Location loc, Type eltType) { auto step = rewriter.create(loc, 1); auto minTileSlices = rewriter.create( loc, arm_sme::getSMETileSliceMinNumElts(eltType)); auto vscale = rewriter.create(loc, rewriter.getIndexType()); auto lowerBound = rewriter.create(loc, 0); auto numTileSlices = rewriter.create(loc, minTileSlices, vscale); auto forOp = rewriter.create(loc, lowerBound, numTileSlices, step); rewriter.setInsertionPointToStart(forOp.getBody()); return forOp; } /// Returns a tile of the given vector type. static arm_sme::CastTileToVector getSMETileAndCastToVector(PatternRewriter &rewriter, Location loc, VectorType type) { unsigned tileElementWidth = type.getElementType().getIntOrFloatBitWidth(); // Create 'arm_sme.get_tile' op. auto tileId = rewriter.create( loc, rewriter.getIntegerType(tileElementWidth)); // Create `arm_sme.cast_tile_to_vector` to cast tile ID to a vector type. return rewriter.create(loc, type, tileId); } namespace { /// Conversion pattern for vector.transfer_read. /// /// --- /// /// Example 1: op with identity permutation map to horizontal /// arm_sme.tile_load: /// /// vector.transfer_read ... permutation_map: (d0, d1) -> (d0, d1) /// /// is converted to: /// /// arm_sme.tile_load ... /// /// --- /// /// Example 2: op with transpose permutation map to vertical arm_sme.tile_load /// (in-flight transpose): /// /// vector.transfer_read ... permutation_map: (d0, d1) -> (d1, d0) /// /// is converted to: /// /// arm_sme.tile_load ... layout struct TransferReadToArmSMELowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, PatternRewriter &rewriter) const final { // The permutation map must have two results. if (transferReadOp.getTransferRank() != 2) return rewriter.notifyMatchFailure(transferReadOp, "not a 2 result permutation map"); auto vectorType = transferReadOp.getVectorType(); if (!arm_sme::isValidSMETileVectorType(vectorType)) return rewriter.notifyMatchFailure(transferReadOp, "not a valid vector type for SME"); if (!llvm::isa(transferReadOp.getSource().getType())) return rewriter.notifyMatchFailure(transferReadOp, "not a memref source"); // Out-of-bounds dims are not supported. if (transferReadOp.hasOutOfBoundsDim()) return rewriter.notifyMatchFailure(transferReadOp, "not inbounds transfer read"); arm_sme::TileSliceLayout layout; AffineExpr d0, d1; bindDims(transferReadOp.getContext(), d0, d1); AffineMap map = transferReadOp.getPermutationMap(); if (map.isIdentity()) layout = arm_sme::TileSliceLayout::Horizontal; else if (map == AffineMap::get(map.getNumDims(), 0, {d1, d0}, transferReadOp.getContext())) layout = arm_sme::TileSliceLayout::Vertical; else return rewriter.notifyMatchFailure(transferReadOp, "unsupported permutation map"); // Padding isn't optional for transfer_read, but is only used in the case // of out-of-bounds accesses (not supported here) and/or masking. Mask is // optional, if it's not present don't pass padding. auto mask = transferReadOp.getMask(); auto padding = mask ? transferReadOp.getPadding() : nullptr; rewriter.replaceOpWithNewOp( transferReadOp, vectorType, transferReadOp.getSource(), transferReadOp.getIndices(), padding, mask, layout); return success(); } }; /// Conversion pattern for vector.transfer_write. /// /// --- /// /// Example 1: op with identity permutation map to horizontal /// arm_sme.tile_store: /// /// vector.transfer_write %vector, %source[%c0, %c0] /// {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref /// /// is converted to: /// /// arm_sme.tile_store %vector, %source[%c0, %c0] : memref, /// vector<[16]x[16]xi8> /// --- /// /// Example 2: op with transpose permutation map to vertical arm_sme.tile_store /// (in-flight transpose): /// /// vector.transfer_write %vector, %source[%c0, %c0] /// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, /// in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref /// /// is converted to: /// /// arm_sme.tile_store %vector, %source[%c0, %c0] layout /// : memref, vector<[16]x[16]xi8> struct TransferWriteToArmSMELowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, PatternRewriter &rewriter) const final { auto vType = writeOp.getVectorType(); if (!arm_sme::isValidSMETileVectorType(vType)) return failure(); if (!llvm::isa(writeOp.getSource().getType())) return failure(); // Out-of-bounds dims are not supported. if (writeOp.hasOutOfBoundsDim()) return rewriter.notifyMatchFailure(writeOp, "not inbounds transfer write"); AffineExpr d0, d1; bindDims(writeOp.getContext(), d0, d1); AffineMap map = writeOp.getPermutationMap(); bool isTranspose = (map == AffineMap::get(map.getNumDims(), 0, {d1, d0}, writeOp.getContext())); if (!map.isIdentity() && !isTranspose) return rewriter.notifyMatchFailure(writeOp, "unsupported permutation map"); arm_sme::TileSliceLayout layout = isTranspose ? arm_sme::TileSliceLayout::Vertical : arm_sme::TileSliceLayout::Horizontal; rewriter.replaceOpWithNewOp( writeOp, writeOp.getVector(), writeOp.getSource(), writeOp.getIndices(), writeOp.getMask(), layout); return success(); } }; /// Conversion pattern for vector.load. struct VectorLoadToArmSMELowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::LoadOp load, PatternRewriter &rewriter) const override { if (!arm_sme::isValidSMETileVectorType(load.getVectorType())) return failure(); rewriter.replaceOpWithNewOp( load, load.getVectorType(), load.getBase(), load.getIndices()); return success(); } }; /// Conversion pattern for vector.store. struct VectorStoreToArmSMELowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::StoreOp store, PatternRewriter &rewriter) const override { if (!arm_sme::isValidSMETileVectorType(store.getVectorType())) return failure(); rewriter.replaceOpWithNewOp( store, store.getValueToStore(), store.getBase(), store.getIndices()); return success(); } }; /// Conversion pattern for dense arith.constant. struct ConstantOpToArmSMELowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(arith::ConstantOp constantOp, PatternRewriter &rewriter) const final { auto tileType = dyn_cast(constantOp.getType()); if (!tileType || !arm_sme::isValidSMETileVectorType(tileType)) return failure(); auto denseAttr = dyn_cast(constantOp.getValueAttr()); if (!denseAttr || !denseAttr.isSplat()) return failure(); auto tileElementType = tileType.getElementType(); // Lower 'arith.constant dense<0>' to 'arm_sme.zero' op. if (isSplatZero(tileElementType, denseAttr)) { rewriter.replaceOpWithNewOp(constantOp, tileType); return success(); } // Lower non-zero constants to a loop of 'arm_sme.move_vector_to_tile_slice' // ops that broadcast the constant to each tile slice. OpBuilder::InsertionGuard g(rewriter); auto loc = constantOp.getLoc(); // Unpack 1-d vector type from 2-d vector type. auto tileSliceType = VectorType::get(tileType.getShape().drop_front(), tileElementType, /*scalableDims=*/{true}); auto denseAttr1D = DenseElementsAttr::get( tileSliceType, denseAttr.getSplatValue()); auto constantOp1D = rewriter.create(loc, denseAttr1D); arm_sme::CastTileToVector tile = getSMETileAndCastToVector(rewriter, loc, tileType); auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType); auto tileSliceIndex = forOp.getInductionVar(); // Create 'arm_sme.move_vector_to_tile_slice' to write vector to tile slice. rewriter.create( loc, tileType, constantOp1D, tile, tileSliceIndex); rewriter.replaceOp(constantOp, tile); return success(); } }; /// Conversion pattern for vector.broadcast. /// /// Example: /// /// %broadcast_to_tile = vector.broadcast %src : i32 to vector<[4]x[4]xi32> /// /// is converted to: /// /// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32> /// scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 { /// arm_sme.move_vector_to_tile_slice %broadcast_to_1d, %tile, /// %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32> /// } /// /// Supports scalar, 0-d vector, and 1-d vector broadcasts. struct BroadcastOpToArmSMELowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp, PatternRewriter &rewriter) const final { auto tileType = broadcastOp.getResultVectorType(); if (!tileType || !arm_sme::isValidSMETileVectorType(tileType)) return failure(); OpBuilder::InsertionGuard g(rewriter); auto loc = broadcastOp.getLoc(); auto srcType = broadcastOp.getSourceType(); auto srcVectorType = dyn_cast(srcType); auto tileElementType = tileType.getElementType(); Value broadcastOp1D; if (srcType.isIntOrFloat() || (srcVectorType && (srcVectorType.getRank() == 0))) { // Broadcast scalar or 0-d vector to 1-d vector. auto tileSliceType = VectorType::get(tileType.getShape().drop_front(), tileElementType, /*scalableDims=*/{true}); broadcastOp1D = rewriter.create( loc, tileSliceType, broadcastOp.getSource()); } else if (srcVectorType && (srcVectorType.getRank() == 1)) // Value to broadcast is already a 1-d vector, nothing to do. broadcastOp1D = broadcastOp.getSource(); else return failure(); arm_sme::CastTileToVector tile = getSMETileAndCastToVector(rewriter, loc, tileType); // Create a loop over ZA tile slices. auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType); auto tileSliceIndex = forOp.getInductionVar(); // Create 'arm_sme.move_vector_to_tile_slice' to broadcast the value to each // tile slice. rewriter.create( loc, tileType, broadcastOp1D, tile, tileSliceIndex); rewriter.replaceOp(broadcastOp, tile); return success(); } }; /// Conversion pattern for vector.splat. /// /// Example: /// /// %splat_to_tile = vector.splat %src : i32 to vector<[4]x[4]xi32> /// /// is converted to: /// /// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32> /// scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 { /// arm_sme.move_vector_to_tile_slice %broadcast_to_1d, %tile, /// %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32> /// } /// /// This is identical to vector.broadcast of a scalar. struct SplatOpToArmSMELowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::SplatOp splatOp, PatternRewriter &rewriter) const final { auto tileType = splatOp.getResult().getType(); if (!tileType || !arm_sme::isValidSMETileVectorType(tileType)) return failure(); OpBuilder::InsertionGuard g(rewriter); auto loc = splatOp.getLoc(); auto srcType = splatOp.getOperand().getType(); auto tileElementType = tileType.getElementType(); assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat"); // Avoid unused-variable warning when building without assertions. (void)srcType; // First, broadcast the scalar to a 1-d vector. VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0); Value broadcastOp1D = rewriter.create( loc, tileSliceType, splatOp.getInput()); arm_sme::CastTileToVector tile = getSMETileAndCastToVector(rewriter, loc, tileType); // Next, create a loop over ZA tile slices and "move" the generated 1-d // vector to each slice. auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType); auto tileSliceIndex = forOp.getInductionVar(); rewriter.create( loc, tileType, broadcastOp1D, tile, tileSliceIndex); rewriter.replaceOp(splatOp, tile); return success(); } }; /// Conversion pattern for vector.transpose. /// /// Stores the input tile to memory and reloads vertically. /// /// Example: /// /// %transposed_src = vector.transpose %src, [1, 0] /// : vector<[4]x[4]xi32> to vector<[4]x[4]xi32> /// /// is converted to: /// /// %alloca = memref.alloca(%svl_s, %svl_s) : memref /// %arm_sme.tile_store %src, , %alloca[%c0, %c0] /// : memref, vector<[4]x[4]xi32> /// %transposed_src = arm_sme.tile_load %alloca[%c0, %c0] /// layout : memref, vector<[4]x[4]xi32> /// /// NOTE: Tranposing via memory is obviously expensive, the current intention /// is to avoid the transpose if possible, this is therefore intended as a /// fallback and to provide base support for Vector ops. If it turns out /// transposes can't be avoided then this should be replaced with a more optimal /// implementation, perhaps with tile <-> vector (MOVA) ops. struct TransposeOpToArmSMELowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, PatternRewriter &rewriter) const final { auto tileType = transposeOp.getResultVectorType(); if (!tileType || !arm_sme::isValidSMETileVectorType(tileType)) return failure(); SmallVector transp; for (auto attr : transposeOp.getTransp()) transp.push_back(cast(attr).getInt()); // Bail unless this is a true 2-D matrix transpose. if (transp[0] != 1 || transp[1] != 0) return failure(); OpBuilder::InsertionGuard g(rewriter); auto loc = transposeOp.getLoc(); // Allocate buffer to store input tile to. Value vscale = rewriter.create(loc, rewriter.getIndexType()); Value minTileSlices = rewriter.create( loc, rewriter.getIndexAttr(tileType.getDimSize(0))); Value c0 = rewriter.create(loc, rewriter.getIndexAttr(0)); Value numTileSlices = rewriter.create(loc, vscale, minTileSlices); auto bufferType = MemRefType::get({ShapedType::kDynamic, ShapedType::kDynamic}, tileType.getElementType()); auto buffer = rewriter.create( loc, bufferType, ValueRange{numTileSlices, numTileSlices}); Value input = transposeOp.getVector(); // Store input tile. auto tileStoreOp = rewriter.create( loc, input, buffer, ValueRange{c0, c0}); // Reload input tile vertically. rewriter.replaceOpWithNewOp( transposeOp, tileType, tileStoreOp.getBase(), tileStoreOp.getIndices(), arm_sme::TileSliceLayout::Vertical); return success(); } }; /// Conversion pattern for vector.outerproduct. /// /// If the vector.outerproduct is masked (and the mask is from a /// vector.create_mask), then the mask is decomposed into two 1-D masks for the /// operands. /// /// Example: /// /// %mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1> /// %result = vector.mask %mask { /// vector.outerproduct %vecA, %vecB /// : vector<[4]xf32>, vector<[4]xf32> /// } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32> /// /// is converted to: /// /// %maskA = vector.create_mask %dimA : vector<[4]xi1> /// %maskB = vector.create_mask %dimB : vector<[4]xi1> /// %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB) /// : vector<[4]xf32>, vector<[4]xf32> /// /// Unmasked outerproducts can be directly replaced with the arm_sme op. /// /// Example: /// /// %result = vector.outerproduct %vecA, %vecB /// : vector<[4]xf32>, vector<[4]xf32> /// /// is converted to: /// /// %result = arm_sme.outerproduct %vecA, %vecB /// : vector<[4]xf32>, vector<[4]xf32> /// struct VectorOuterProductToArmSMELowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::OuterProductOp outerProductOp, PatternRewriter &rewriter) const override { // We don't yet support lowering AXPY operations to SME. These could be // lowered by masking out all but the first element of the LHS. if (!isa(outerProductOp.getOperandTypeRHS())) return outerProductOp.emitError("AXPY operations not supported"); if (!arm_sme::isValidSMETileVectorType( outerProductOp.getResultVectorType())) return outerProductOp.emitError( "outer product does not fit into SME tile"); auto kind = outerProductOp.getKind(); if (kind != vector::CombiningKind::ADD) return outerProductOp.emitError( "unsupported kind (lowering to SME only supports ADD at the moment)"); Value lhsMask = {}; Value rhsMask = {}; Operation *rootOp = outerProductOp; auto loc = outerProductOp.getLoc(); if (outerProductOp.isMasked()) { auto maskOp = outerProductOp.getMaskingOp(); rewriter.setInsertionPoint(maskOp); rootOp = maskOp; auto operandMasks = decomposeResultMask(loc, maskOp.getMask(), rewriter); if (failed(operandMasks)) return failure(); std::tie(lhsMask, rhsMask) = *operandMasks; } rewriter.replaceOpWithNewOp( rootOp, outerProductOp.getResultVectorType(), outerProductOp.getLhs(), outerProductOp.getRhs(), lhsMask, rhsMask, outerProductOp.getAcc()); return success(); } static FailureOr> decomposeResultMask(Location loc, Value mask, PatternRewriter &rewriter) { // Attempt to extract masks from vector.create_mask. // TODO: Add support for other mask sources. auto createMaskOp = mask.getDefiningOp(); if (!createMaskOp) return failure(); auto maskType = createMaskOp.getVectorType(); Value lhsMaskDim = createMaskOp.getOperand(0); Value rhsMaskDim = createMaskOp.getOperand(1); VectorType operandMaskType = VectorType::Builder(maskType).dropDim(0); Value lhsMask = rewriter.create(loc, operandMaskType, lhsMaskDim); Value rhsMask = rewriter.create(loc, operandMaskType, rhsMaskDim); return std::make_pair(lhsMask, rhsMask); } }; } // namespace void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns, MLIRContext &ctx) { patterns.add(&ctx); }