//===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===/// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements the linalg dialect Fusion on tensors operations pass. // //===----------------------------------------------------------------------===// #include "PassDetail.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; using namespace mlir::linalg; /// Conditions for elementwise fusion of generic operations. static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer, OpOperand *consumerOpOperand) { // Producer and consumer must have tensor semantics. if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics()) return false; // Verify that // - the producer has all "parallel" iterator type. if (producer.getNumParallelLoops() != producer.getNumLoops()) return false; // Only allow fusing the producer of an input operand for now. // TODO: allow fusing the producer of an output operand. if (!consumer.isInputTensor(consumerOpOperand)) return false; // Get the consumer index map. The number of results of the consumer index // map must match the number of loops of the producer. AffineMap consumerIndexMap = consumer.getTiedIndexingMap(consumerOpOperand); if (consumerIndexMap.getNumResults() != producer.getNumLoops()) return false; // Currently support only operations with single result. if (producer.getNumOutputs() != 1) return false; // Finally the index_map for the result must be invertible. For now just // verify it is a permutation. AffineMap producerResultIndexMap = producer.getTiedIndexingMap(producer.getOutputOperand(0)); return producerResultIndexMap.isPermutation(); } /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of /// the `producer` to use in the fused operation given the indexing map of the /// result of the producer in the consumer. static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( OpOperand *producerOpOperand, AffineMap producerResultIndexMap, AffineMap fusedConsumerArgIndexMap) { // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map // from consumer loop -> consumer arg tensor index/producer result tensor // index. The fused loop is same as the consumer loop. For each producer arg // the indexing map to be computed is a map from consumer loop -> producer // arg tensor index. // producerResultIndexMap is a map from producer loop -> tensor index. // Compute the inverse to get map from tensor index -> producer loop. // The inverse is a map from producer result tensor index -> producer loop. AffineMap invProducerResultIndexMap = inversePermutation(producerResultIndexMap); assert(invProducerResultIndexMap && "expected producer result indexig map to be invertible"); LinalgOp producer = cast(producerOpOperand->getOwner()); // argMap is a map from producer loop -> producer arg tensor index. AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand); // Compose argMap with invProducerResultIndexMap to get a map from // producer result tensor index -> producer arg tensor index. AffineMap t1 = argMap.compose(invProducerResultIndexMap); // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from // consumer loop/ fused loop -> producer arg tensor index. return t1.compose(fusedConsumerArgIndexMap); } /// Generate the region of the fused tensor operation. The region of the fused /// op must be empty. static void generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp, AffineMap consumerToProducerLoopsMap, OpOperand *consumerOpOperand, unsigned nloops) { auto producer = cast(consumerOpOperand->get().getDefiningOp()); auto consumer = cast(consumerOpOperand->getOwner()); // Build the region of the fused op. Block &producerBlock = producer->getRegion(0).front(); Block &consumerBlock = consumer->getRegion(0).front(); Block *fusedBlock = new Block(); fusedOp.region().push_back(fusedBlock); BlockAndValueMapping mapper; OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(fusedBlock); // 2. Add an index operation for every fused loop dimension and use the // `consumerToProducerLoopsMap` to map the producer indices. if (producer.hasIndexSemantics()) { // Add an index operation for every fused loop dimension. unsigned numFusedOpLoops = std::max(producer.getNumLoops(), consumer.getNumLoops()); SmallVector fusedIndices; fusedIndices.reserve(numFusedOpLoops); llvm::transform(llvm::seq(0, numFusedOpLoops), std::back_inserter(fusedIndices), [&](uint64_t dim) { return rewriter.create(producer.getLoc(), dim); }); for (IndexOp indexOp : llvm::make_early_inc_range(producerBlock.getOps())) { Value newIndex = rewriter.create( producer.getLoc(), consumerToProducerLoopsMap.getSubMap(indexOp.dim()), fusedIndices); mapper.map(indexOp.getResult(), newIndex); } } // TODO: allow fusing the producer of an output operand. assert(consumer.isInputTensor(consumerOpOperand) && "expected producer of input operand"); // 3. Consumer input operands up to consumerIdx (exclusive). for (BlockArgument bbArg : consumerBlock.getArguments().take_front( consumerOpOperand->getOperandNumber())) // input assumption. mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType())); // Replacing consumerIdx requires getting the cloned, yielded, value from // the (cloned) producer block. This happens in step 9. // 4. Splice in producer's input operands. for (BlockArgument bbArg : producerBlock.getArguments().take_front(producer.getNumInputs())) mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType())); // 4.b. Producer output operand/map that is fused needs to be mapped to the // producer bbArg if it is an "initTensor" (i.e. its value is actually read). assert(producer->getNumResults() == 1 && "expected single result producer"); if (producer.isInitTensor(producer.getOutputOperand(0))) { BlockArgument bbArg = producerBlock.getArguments() .drop_front(producer.getNumInputs()) // TODO: bbArg index of .front(); mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType())); } // 5. Remaining consumer's input operands (drop past index `consumerIdx`). for (BlockArgument bbArg : consumerBlock.getArguments() .take_front(consumer.getNumInputs()) .drop_front(consumerOpOperand->getOperandNumber() + 1)) mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType())); // 6. All of consumer's output operands. for (BlockArgument bbArg : consumerBlock.getArguments().take_back(consumer.getNumOutputs())) mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType())); // 7. All of producer's output operands except the one fused. // TODO: allow fusion of multi-result producers. assert(producer->getNumResults() == 1 && "expected single result producer"); // 8. Clone all producer operations except for the yield and index operations // to the fused operation. for (auto &op : producerBlock.without_terminator()) { if (!isa(op)) rewriter.clone(op, mapper); } // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just // forward the yield operand. auto yieldOp = cast(producerBlock.getTerminator()); // TODO: allow fusion of multi-result producers. assert(producer->getNumResults() == 1 && "expected single result producer"); unsigned producerResultNumber = 0; Value replacement = mapper.lookupOrDefault(yieldOp.getOperand(producerResultNumber)); // Sanity checks, if replacement is not already in the mapper then it must be // produced outside. if (replacement == yieldOp.getOperand(producerResultNumber)) { if (auto bb = replacement.dyn_cast()) assert(bb.getOwner() != &producerBlock && "yielded block argument must have been mapped"); else assert(!producer->isAncestor(replacement.getDefiningOp()) && "yielded value must have been mapped"); } mapper.map(consumerBlock.getArgument(consumerOpOperand->getOperandNumber()), replacement); // 10. Clone operations from the consumer to the fused op. for (auto &op : consumerBlock.getOperations()) rewriter.clone(op, mapper); // Sanity checks. assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() && "Ill-formed GenericOp region"); } static Optional> fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand, const ControlElementwiseOpsFusionFn &controlFn, PatternRewriter &rewriter) { auto consumer = cast(consumerOpOperand->getOwner()); if (!areElementwiseOpsFusable(producer, consumer, consumerOpOperand) || !controlFn(producer->getResult(0), *consumerOpOperand)) return llvm::None; // TODO: allow fusing the producer of an output operand. assert(consumer.isInputTensor(consumerOpOperand) && "expected producer of input operand"); // Compute the fused operands list and indexing maps. SmallVector fusedOperands; SmallVector fusedIndexMaps; fusedOperands.reserve(producer->getNumOperands() + consumer->getNumOperands()); fusedIndexMaps.reserve(producer->getNumOperands() + consumer->getNumOperands()); // In the following, numbering matches that of `generateFusedTensorOpRegion`. // 3. Consumer input operands/maps up to consumerIdx (exclusive). SmallVector consumerInputs = consumer.getInputOperands(); SmallVector::iterator it = llvm::find(consumerInputs, consumerOpOperand); assert(it != consumerInputs.end() && "expected to find the consumer operand"); for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) { fusedOperands.push_back(opOperand->get()); fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); } // 4. Splice in producer's input operands/maps. assert(producer->getNumResults() == 1 && "expected single result producer"); AffineMap producerResultIndexMap = producer.getTiedIndexingMap(producer.getOutputOperand(0)); for (OpOperand *opOperand : producer.getInputOperands()) { fusedOperands.push_back(opOperand->get()); // Compute indexing maps for the producer args in the fused operation. AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( opOperand, producerResultIndexMap, consumer.getTiedIndexingMap(consumerOpOperand)); fusedIndexMaps.push_back(map); } // 4.b. Producer output operand/map that is fused needs to be passed if it is // an "initTensor" (i.e. its value is actually read). assert(producer->getNumResults() == 1 && "expected single result producer"); if (producer.isInitTensor(producer.getOutputOperand(0))) { fusedOperands.push_back(producer.getOutputOperand(0)->get()); // Compute indexing maps for the producer args in the fused operation. AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( producer.getOutputOperand(0), producerResultIndexMap, consumer.getTiedIndexingMap(consumerOpOperand)); fusedIndexMaps.push_back(map); } // 5. Remaining consumer's input operands/maps (drop past index // `consumerIdx`). for (OpOperand *opOperand : llvm::make_range(std::next(it), consumerInputs.end())) { fusedOperands.push_back(opOperand->get()); fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); } // 6. All of consumer's output operands (skip operands: added by the builder). for (OpOperand *opOperand : consumer.getOutputOperands()) fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); // 7. All of producer's output operands/maps except the one fused. // TODO: allow fusion of multi-result producers. assert(producer->getNumResults() == 1 && "expected single result producer"); // Generate the fused op. SmallVector consumerOutputs = consumer.getOutputOperands(); auto fusedOp = rewriter.create( consumer.getLoc(), consumer->getResultTypes(), /*inputs=*/fusedOperands, // TODO: handle outputs. consumerOutputs, rewriter.getAffineMapArrayAttr(fusedIndexMaps), consumer.iterator_types(), /*doc=*/nullptr, /*library_call=*/nullptr); // Construct an AffineMap from consumer loops to producer loops. // consumer loop -> tensor index AffineMap consumerResultIndexMap = consumer.getTiedIndexingMap(consumerOpOperand); // tensor index -> producer loop AffineMap invProducerResultIndexMap = inversePermutation(producerResultIndexMap); assert(invProducerResultIndexMap && "expected producer result indexig map to be invertible"); // consumer loop -> producer loop AffineMap consumerToProducerLoopsMap = invProducerResultIndexMap.compose(consumerResultIndexMap); generateFusedElementwiseOpRegion(rewriter, fusedOp, consumerToProducerLoopsMap, consumerOpOperand, consumer.getNumLoops()); return SmallVector(fusedOp->getResults()); } /// Linearize the expressions in `sourceMap` based on the `reassociationMaps` /// provided, given the shape of the source tensor that corresponds to the /// `sourceMap`. Note that this implicitly assumes that the tensors dimensions /// are "row-major" ordered logically. /// /// For example: /// /// %0 = op ... : tensor /// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>` /// /// and reshape: /// %1 = linalg.tensor_collapse_shape %0 [[0], [0, 1, 2]] : /// tensor into tensor /// /// would be rewritten into: /// %0 = op ... : tensor /// with output index_map /// `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>` template static AffineMap linearizeCollapsedDims(AffineMap sourceMap, TensorReshapeOp reshapeOp) { constexpr bool isExpanding = std::is_same::value; ArrayRef sourceShape = (isExpanding ? reshapeOp.getResultType().getShape() : reshapeOp.getSrcType().getShape()); SmallVector resultExprs; ArrayRef sourceExprs = sourceMap.getResults(); MLIRContext *context = sourceMap.getContext(); // Compute the result exprs based on the reassociation maps. for (auto &indices : reshapeOp.getReassociationIndices()) { // Assume that they are in-order and contiguous (already checked in // verifier). assert(!indices.empty()); SmallVector sizes; SmallVector dimExprs; for (auto en : llvm::zip(sourceShape.slice(indices[0], indices.size()), sourceExprs.slice(indices[0], indices.size()))) { if (std::get<0>(en) == 1) continue; sizes.push_back(std::get<0>(en)); dimExprs.push_back(std::get<1>(en)); } AffineExpr linearizedExpr = makeCanonicalStridedLayoutExpr(sizes, dimExprs, context); resultExprs.push_back(linearizedExpr); } return AffineMap::get(sourceMap.getNumDims(), sourceMap.getNumSymbols(), resultExprs, context); } // TensorExpandShapeOp is fusable with its consumer (i.e. reshape as a // producer). Fusing when operand has higher rank will require use of mods and // divs in the indexing maps of the fused op which would make it non-invertible. static bool isTensorReshapeOpFoldableByLinearization( TensorExpandShapeOp expandOp, AffineMap useIndexMap, bool asProducer) { if (!asProducer) return false; return useIndexMap.isPermutation(); } // TensorCollapseShapeOp is fusable with its producer (i.e. reshape as a // consumer). static bool isTensorReshapeOpFoldableByLinearization( TensorCollapseShapeOp collapseOp, AffineMap useIndexMap, bool asProducer) { if (asProducer) return false; return useIndexMap.isPermutation(); } /// Check if the reshape operation is only expansion into/collapsing of /// unit-dimension. template static bool isUnitDimExpansionOnly(TensorReshapeOp reshapeOp) { constexpr bool isExpanding = std::is_same::value; ArrayRef expandedShape = (isExpanding ? reshapeOp.getResultType().getShape() : reshapeOp.getSrcType().getShape()); for (auto &indices : reshapeOp.getReassociationIndices()) { unsigned numUnitDims = 0; for (int64_t position : indices) if (expandedShape[position] == 1) numUnitDims++; if (numUnitDims != indices.size() - 1) return false; } return true; } /// Conditions for folding a generic operation with a reshape op by expanding /// the iteration space dimensionality for tensor operations. These are /// preconditions assumed by `foldReshapeByDimExpansion` which implements the /// following fusion pattern. /// /// Consider /// /// %c = linalg.generic ins(%a, %b : memref, memref) /// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>, /// affine_map<(d0, d1, d2) -> (d1, d2)>, /// affine_map<(d0, d1, d2) -> (d0, d2, d1)>] /// %d = linalg.tensor_expand_shape %c [[0, 1], [2], [3, 4, 5]] /// : tensor into tensor /// /// The reshape can be folded into the `genericOp` if its loop dimensionality /// is increased to match the result (operand) of the tensor_expand_shape. /// The indexing_map of the fused tensor in the `genericOp` and the /// reassociation map helps compute the indexing maps of the modified op. /// For the above example, based on the reassociation map it /// can be concluded that /// /// - The loop used to access the first dimension of the fused tensor is split /// into two. /// - The loop used to access the second dimension of the fused tensor is kept /// as is. /// - The loop used to access the third dimension of the fused tensor is split /// into three. /// /// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified /// op, then /// /// d0 -> e0, e1 /// d1 -> e2, e3, e4 /// d2 -> e5 /// /// substituting this, the generic op can be rewritten as /// /// %d = linalg.generic ins(%0, %1 : ) /// indexing_maps = /// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>, /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>, /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>] /// /// Since operands to the linalg generic are now 5D, reshapes can be introduced /// to make it consistent /// /// %0 = linalg.tensor_expand_shape %a [[0, 1, 2], [3, 4], [5]] /// : tensor into tensor /// %1 = linalg.tensor_expand_shape %b [[0, 1, 2], [3]] /// : tensor into tensor /// /// The added reshapes are again expanding patterns, so they will get fused /// with its producers if possible. static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp, OpOperand *fusableOpOperand) { // Is fusable only if: // - All the indexing maps for operands and results are projected // permutations. // - The fused tensor is not a scalar. // - All the loops are parallel loops. return genericOp.hasTensorSemantics() && llvm::all_of(genericOp.indexing_maps().getValue(), [](Attribute attr) { return attr.cast() .getValue() .isProjectedPermutation(); }) && genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 && llvm::all_of(genericOp.iterator_types(), [](Attribute attr) { return attr.cast().getValue() == getParallelIteratorTypeName(); }); } namespace { /// Information needed to expand a generic operation to fold the reshape with /// it. class ExpansionInfo { public: // Computes the mapping from original dimensions of the op to the dimensions // of the expanded op given the `indexingMap` of the fused operand/result of // the generic op, the `reassocationMaps` of the reshape op and the shape of // the expanded op. LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand, ArrayRef reassociationMaps, ArrayRef expandedShape, PatternRewriter &rewriter); unsigned getOrigOpNumDims() const { return reassociation.size(); } unsigned getExpandedOpNumDims() const { return expandedOpNumDims; } ReassociationIndicesRef getExpandedDims(unsigned i) const { return reassociation[i]; } ArrayRef getExpandedShapeOfDim(unsigned i) const { return expandedShapeMap[i]; } private: /// Reassociation from the dimensions in the original operation to the /// dimension of the expanded operation. SmallVector reassociation; /// Mapping from extent of loops in the original operation, to the extent of /// loops in the expanded operation. SmallVector> expandedShapeMap; unsigned expandedOpNumDims; }; } // namespace LogicalResult ExpansionInfo::compute(LinalgOp linalgOp, OpOperand *fusableOpOperand, ArrayRef reassociationMaps, ArrayRef expandedShape, PatternRewriter &rewriter) { if (reassociationMaps.empty()) return failure(); AffineMap fusedIndexMap = linalgOp.getTiedIndexingMap(fusableOpOperand); Optional> originalLoopRange = linalgOp.getStaticLoopRanges(); if (!originalLoopRange) return rewriter.notifyMatchFailure(linalgOp, "unable to find loop range"); reassociation.clear(); expandedShapeMap.clear(); // Compute the number of dimension in the expanded op that correspond to each // dimension of the original op. SmallVector numExpandedDims(fusedIndexMap.getNumDims(), 1); expandedShapeMap.resize(fusedIndexMap.getNumDims()); for (auto resultExpr : llvm::enumerate(fusedIndexMap.getResults())) { unsigned pos = resultExpr.value().cast().getPosition(); AffineMap foldedDims = reassociationMaps[resultExpr.index()]; numExpandedDims[pos] = foldedDims.getNumResults(); ArrayRef shape = expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]); expandedShapeMap[pos].assign(shape.begin(), shape.end()); } // The remaining dimensions remain the same. for (unsigned i : llvm::seq(0, fusedIndexMap.getNumDims())) if (expandedShapeMap[i].empty()) expandedShapeMap[i] = {(*originalLoopRange)[i]}; // Compute reassociation map from the original op to the expanded op. unsigned sum = 0; reassociation.reserve(fusedIndexMap.getNumDims()); for (auto numFoldedDim : llvm::enumerate(numExpandedDims)) { auto seq = llvm::seq(sum, sum + numFoldedDim.value()); reassociation.emplace_back(seq.begin(), seq.end()); sum += numFoldedDim.value(); } expandedOpNumDims = sum; return success(); } /// Epanding the body of a linalg operation requires adaptations of the accessed /// loop indices. Specifically, access of indices in the original operation need /// to be replaced with linearizations of indices in the expanded op. That /// requires the shape of the expanded dimensions to be static (at least all but /// the most significant). For now check that these are all statically sized. /// Note that this could be extended to handle dynamic case, but the /// implementation below uses `affine.apply` which seems to have issues when the /// shapes are not static. LogicalResult isGenericOpExpandable(GenericOp genericOp, const ExpansionInfo &expansionInfo, PatternRewriter &rewriter) { if (!genericOp.hasIndexSemantics()) return success(); for (unsigned i : llvm::seq(0, expansionInfo.getOrigOpNumDims())) { ArrayRef expandedShape = expansionInfo.getExpandedShapeOfDim(i); if (expandedShape.size() == 1) continue; for (int64_t shape : expandedShape.drop_front()) { if (ShapedType::isDynamic(shape)) { return rewriter.notifyMatchFailure( genericOp, "cannot expand due to index semantics and dynamic dims"); } } } return success(); } /// Return the indexing map to use in the expanded op for a given the /// `indexingMap` of the original operation. static AffineMap getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, const ExpansionInfo &expansionInfo) { SmallVector newExprs; for (AffineExpr expr : indexingMap.getResults()) { unsigned pos = expr.cast().getPosition(); SmallVector expandedExprs = llvm::to_vector<4>( llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) { return builder.getAffineDimExpr(static_cast(v)); })); newExprs.append(expandedExprs.begin(), expandedExprs.end()); } return AffineMap::get(expansionInfo.getExpandedOpNumDims(), indexingMap.getNumSymbols(), newExprs, builder.getContext()); } /// Return the type of the operand/result to use in the expanded op given the /// type in the original op. static RankedTensorType getExpandedType(RankedTensorType originalType, AffineMap indexingMap, const ExpansionInfo &expansionInfo) { SmallVector expandedShape; for (AffineExpr expr : indexingMap.getResults()) { unsigned dim = expr.cast().getPosition(); auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim); expandedShape.append(dimExpansion.begin(), dimExpansion.end()); } return RankedTensorType::get(expandedShape, originalType.getElementType()); } /// Returns the reassociation maps to use in the `linalg.tensor_expand_shape` /// operation to convert the operands of the original operation to operands of /// the expanded operation. The same method is used to compute the /// `linalg.tensor_collapse_shape` used to collapse the result of the expanded /// op to get the value that can replace all uses of the results of the original /// op. static SmallVector getReassociationForExpansion(AffineMap indexingMap, const ExpansionInfo &expansionInfo) { SmallVector reassociation; unsigned numReshapeDims = 0; for (AffineExpr expr : indexingMap.getResults()) { unsigned dim = expr.cast().getPosition(); auto numExpandedDims = expansionInfo.getExpandedDims(dim).size(); SmallVector indices = llvm::to_vector<2>( llvm::seq(numReshapeDims, numReshapeDims + numExpandedDims)); reassociation.emplace_back(std::move(indices)); numReshapeDims += numExpandedDims; } return reassociation; } /// Update the body of an expanded linalg operation having index semantics. The /// indices of the original operation need to be recovered by linearizing the /// indices of the correspoding dimensions of the expanded operation. For now it /// is assumed that the shapes of the expanded operation needed for /// linearization are static. static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, Location loc, Region &fusedRegion, const ExpansionInfo &expansionInfo) { // Replace the original indices by the linearization of the expanded indices. for (IndexOp indexOp : llvm::make_early_inc_range(fusedRegion.front().getOps())) { ArrayRef expandedDims = expansionInfo.getExpandedDims(indexOp.dim()); assert(!expandedDims.empty() && "expected valid expansion info"); // Skip index operations that are not affected by the expansion. if (expandedDims.size() == 1 && expandedDims.front() == (int64_t)indexOp.dim()) continue; // Linearize the expanded indices of the original index dimension. OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointAfter(indexOp); ArrayRef expandedDimsShape = expansionInfo.getExpandedShapeOfDim(indexOp.dim()).drop_front(); SmallVector expandedIndices; expandedIndices.reserve(expandedDims.size() - 1); llvm::transform( expandedDims.drop_front(), std::back_inserter(expandedIndices), [&](int64_t dim) { return rewriter.create(loc, dim); }); Value newIndex = rewriter.create(loc, expandedDims.front()); for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) { assert(!ShapedType::isDynamic(std::get<0>(it))); AffineExpr idx, acc; bindDims(rewriter.getContext(), idx, acc); newIndex = rewriter.create( indexOp.getLoc(), idx + acc * std::get<0>(it), ValueRange{std::get<1>(it), newIndex}); } rewriter.replaceOp(indexOp, newIndex); } } /// Implements the fusion of a tensor_collapse_shape or a tensor_expand_shape op /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes /// that those conditions have been satisfied. static Optional> fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp, OpOperand *fusableOpOperand, PatternRewriter &rewriter) { assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) && "preconditions for fuse operation failed"); // Check if reshape is expanding or collapsing. auto expandingReshapeOp = dyn_cast(*reshapeOp); auto collapsingReshapeOp = dyn_cast(*reshapeOp); bool isExpanding = (expandingReshapeOp != nullptr); RankedTensorType expandedType = isExpanding ? expandingReshapeOp.getResultType() : collapsingReshapeOp.getSrcType(); ExpansionInfo expansionInfo; if (failed(expansionInfo.compute( genericOp, fusableOpOperand, isExpanding ? expandingReshapeOp.getReassociationMaps() : collapsingReshapeOp.getReassociationMaps(), expandedType.getShape(), rewriter))) return llvm::None; if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter))) return llvm::None; SmallVector expandedOpIndexingMaps = llvm::to_vector<4>( llvm::map_range(genericOp.getIndexingMaps(), [&](AffineMap m) { return getIndexingMapInExpandedOp(rewriter, m, expansionInfo); })); SmallVector expandedOpOperands; expandedOpOperands.reserve(genericOp.getNumInputs()); for (OpOperand *opOperand : genericOp.getInputOperands()) { if (opOperand == fusableOpOperand) { expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.src() : collapsingReshapeOp.src()); continue; } if (genericOp.isInputTensor(opOperand)) { AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); RankedTensorType expandedOperandType = getExpandedType(opOperand->get().getType().cast(), indexingMap, expansionInfo); if (expandedOperandType != opOperand->get().getType()) { // Reshape the operand to get the right type. SmallVector reassociation = getReassociationForExpansion(indexingMap, expansionInfo); expandedOpOperands.push_back(rewriter.create( genericOp.getLoc(), expandedOperandType, opOperand->get(), reassociation)); continue; } } expandedOpOperands.push_back(opOperand->get()); } Location loc = genericOp.getLoc(); SmallVector outputs; for (OpOperand *opOperand : genericOp.getOutputOperands()) { AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); RankedTensorType expandedOutputType = getExpandedType(opOperand->get().getType().cast(), indexingMap, expansionInfo); if (expandedOutputType != opOperand->get().getType()) { SmallVector reassociation = getReassociationForExpansion(indexingMap, expansionInfo); outputs.push_back(rewriter.create( genericOp.getLoc(), expandedOutputType, opOperand->get(), reassociation)); } } // The iterator types of the expanded op are all parallel. SmallVector iteratorTypes(expansionInfo.getExpandedOpNumDims(), getParallelIteratorTypeName()); TypeRange resultTypes = ValueRange(outputs).getTypes(); auto fusedOp = rewriter.create(genericOp.getLoc(), resultTypes, /*inputs=*/expandedOpOperands, outputs, expandedOpIndexingMaps, iteratorTypes); Region &fusedRegion = fusedOp->getRegion(0); Region &originalRegion = genericOp->getRegion(0); rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin()); // Update the index accesses after the expansion. updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo); // Reshape the result values to their original shape if this is a collapsing // reshape folded into its consumer. SmallVector resultVals; for (OpResult opResult : genericOp->getOpResults()) { int64_t resultNumber = opResult.getResultNumber(); if (!isExpanding && resultTypes[resultNumber] != opResult.getType()) { SmallVector reassociation = getReassociationForExpansion( genericOp.getTiedIndexingMap( genericOp.getOutputOperand(resultNumber)), expansionInfo); resultVals.push_back(rewriter.create( genericOp.getLoc(), opResult.getType(), fusedOp->getResult(resultNumber), reassociation)); } else { resultVals.push_back(fusedOp->getResult(resultNumber)); } } // Assuming a single result. return resultVals; } namespace { /// Pattern to fold tensor_expand_shape op with its consumer by using the source /// of the reshape op as the operand in the consumer (instead of the result of /// the tensor_collapse_shape). The corresponding index map in the consumer /// needs to be modified to linearize the folded dimension. /// /// For example, /// /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> /// %0 = linalg.tensor_expand_shape %arg0 [[0], [1, 2], [3]] /// tensor into tensor /// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... } /// ins(%0, %arg1 : tensor, tensor) ... /// -> tensor /// /// can be folded into /// /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)> /// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> /// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... } /// ins(%arg0, %arg1 : tensor, tensor) ... /// -> tensor template struct FoldProducerReshapeOpByLinearization : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { if (!genericOp.hasTensorSemantics()) return failure(); SmallVector inputOperands = genericOp.getInputOperands(); for (auto en : llvm::enumerate(inputOperands)) { auto reshapeOp = en.value()->get().getDefiningOp(); if (!reshapeOp) continue; if (!isTensorReshapeOpFoldableByLinearization( reshapeOp, genericOp.getTiedIndexingMap(en.value()), /*asProducer =*/true) || (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp))) continue; // Compute the fused operands list, SmallVector fusedOperands = genericOp.getInputOperands(); fusedOperands[en.index()] = reshapeOp.src(); SmallVector outputOperands = genericOp.getOutputOperands(); llvm::append_range(fusedOperands, outputOperands); // Compute indexing_maps for the fused operation. The indexing_maps for // the operands of the consumers that arent fused are the same. SmallVector fusedIndexMaps = genericOp.getIndexingMaps(); // Accepted consumer maps are either identity or permutation. auto invMap = inversePermutation(fusedIndexMaps[en.index()]); // Compute the indexing map to use for the result of the producer. AffineMap modifiedMap = linearizeCollapsedDims(invMap, reshapeOp); // The modified map cannot have symbols. if (modifiedMap.getNumSymbols()) return failure(); for (AffineExpr expr : modifiedMap.getResults()) { if (!expr.isPureAffine()) return failure(); } fusedIndexMaps[en.index()] = modifiedMap; // Further check that the resulting index maps can be fused and // inverted. Without this the resultant op is not legal. if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { return rewriter.notifyMatchFailure( genericOp, "fused op loop bound computation failed"); } rewriter.startRootUpdate(genericOp); genericOp->setOperands(fusedOperands); genericOp.indexing_mapsAttr( rewriter.getAffineMapArrayAttr(fusedIndexMaps)); rewriter.finalizeRootUpdate(genericOp); return success(); } return failure(); } }; static SmallVector getReassociationIndices(ArrayRef maps) { SmallVector reassociation; for (AffineMap map : maps) { ReassociationIndices indices; for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { unsigned pos = map.getResult(i).cast().getPosition(); indices.push_back(pos); } reassociation.push_back(indices); } return reassociation; } /// Pattern to move rank reducing reshape after an elementwise linalg generic /// op. This is useful to expose more fusion opportunities between named ops and /// generic ops. This can only be done if there is no broadcast or permuation /// within the dimensions we need to merge. /// /// For example, /// /// %0 = linalg.tensor_expand_shape %A [[0, 1], [2]] /// : tensor<12544x16xf32> into tensor<112x112x16xf32> /// %2 = linalg.generic {indexing_maps = [ /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, /// affine_map<(d0, d1, d2) -> (d2)>, /// affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = /// ["parallel", "parallel", "parallel"]} { /// } -> tensor<112x112x16xf32> /// /// into /// /// %2 = linalg.generic {indexing_maps = [ /// affine_map<(d0, d1) -> (d0, d1)>, /// affine_map<(d0, d1) -> (d1)>, /// affine_map<(d0, d1) -> (d0, d1)>], /// iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 /// : tensor<12544x16xf32>, tensor<16xf32>) outs(%1 : tensor<12544x16xf32>) { /// } -> tensor<12544x16xf32> /// %3 = linalg.tensor_expand_shape %2 [[0, 1], [2]] /// : tensor<12544x16xf32> into tensor<112x112x16xf32> struct PushExpandingReshape : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { // Only apply to elementwise linalg on tensor. if (!genericOp.hasTensorSemantics() || genericOp.getNumParallelLoops() != genericOp.getNumLoops()) return failure(); // Only support identity output maps. It could be extended to permuations if // needed. if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *opOperand) { return !genericOp.getTiedIndexingMap(opOperand).isIdentity(); })) return failure(); int64_t destRank = genericOp.getNumParallelLoops(); SmallVector newOperands = genericOp.getInputOperands(); TensorExpandShapeOp reshapeFound; // 1. Look for tensor_expand_shape operands and figure out save the // dimensions merged. SmallVector inputOperands = genericOp.getInputOperands(); for (auto en : llvm::enumerate(inputOperands)) { auto reshapeOp = en.value()->get().template getDefiningOp(); if (!reshapeOp) continue; // TODO: We could support non-identity map as long as the merged // dimensions are still contiguous. if (!genericOp.getTiedIndexingMap(en.value()).isIdentity()) continue; if (reshapeFound) { // Only support a second reshape op if it has the same reassociate maps. if (reshapeFound.getReassociationMaps() == reshapeOp.getReassociationMaps()) newOperands[en.index()] = reshapeOp.src(); continue; } reshapeFound = reshapeOp; newOperands[en.index()] = reshapeOp.src(); } if (!reshapeFound) return failure(); // Calculate the reassociation indices and rassociated reverse map. SmallVector reassociation = getReassociationIndices(reshapeFound.getReassociationMaps()); SmallVector remap(destRank); for (auto &indices : llvm::enumerate(reassociation)) { for (int64_t index : indices.value()) { remap[index] = indices.index(); } } // 2. Verify that we can merge the dimensions in the linalg and that we // don't need to create new reshapes operands. Inserting new reshape // operands would defeat the purpose of the transformation. for (auto en : llvm::enumerate(inputOperands)) { if (en.value()->get() == newOperands[en.index()]) { AffineMap map = genericOp.getTiedIndexingMap(en.value()); for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) { if (reassociation[remap[map.getDimPosition(i)]].size() > 1) return failure(); } } } // 3. Calculate the affine map remapping and the reassociation to apply to // output tensors. SmallVector newMaps; unsigned newRank = reassociation.size(); for (auto map : genericOp.getIndexingMaps()) { SmallVector newExprs; for (auto expr : map.getResults()) { unsigned position = expr.template cast().getPosition(); // Skip dimension merged except for the last of the group. if (reassociation[remap[position]].back() == position) { newExprs.push_back( getAffineDimExpr(remap[position], genericOp.getContext())); } } newMaps.push_back( AffineMap::get(newRank, 0, newExprs, genericOp.getContext())); } // 4. Reshape the output tensors. SmallVector newOutputs; SmallVector newOutputTypes; for (auto output : genericOp.outputs()) { auto newOutputType = RankedTensorType::get( reshapeFound.getSrcType().getShape(), output.getType().template cast().getElementType()); Value newOutput = rewriter.create( genericOp->getLoc(), newOutputType, output, reassociation); newOutputTypes.push_back(newOutputType); newOutputs.push_back(newOutput); } // 5. Create a new generic op with lowerer rank. SmallVector iteratorTypes(newRank, getParallelIteratorTypeName()); auto newOp = rewriter.create(genericOp->getLoc(), newOutputTypes, newOperands, newOutputs, newMaps, iteratorTypes); rewriter.inlineRegionBefore(genericOp.region(), newOp.region(), newOp.region().begin()); // 6. Reshape the so that the type matches the uses. SmallVector newResults; for (auto result : llvm::enumerate(newOp->getResults())) { newResults.push_back(rewriter.create( genericOp->getLoc(), genericOp.getOutputTensorTypes()[result.index()], result.value(), reassociation)); } rewriter.replaceOp(genericOp, newResults); return success(); } }; /// Pattern to fuse a tensor_collapse_shape op with its consumer generic op, /// when the reshape op is collapsing dimensions. The dimensionality of the loop /// in the consumer is expanded. class FoldWithProducerReshapeOpByExpansion : public OpRewritePattern { public: FoldWithProducerReshapeOpByExpansion( MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), controlFoldingReshapes(foldReshapes) {} LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { for (OpOperand *opOperand : genericOp.getInputTensorOperands()) { TensorCollapseShapeOp reshapeOp = opOperand->get().getDefiningOp(); if (!reshapeOp) continue; // Fold only if // - The tensor reshape op is folding. // - All constraints of fusing with reshape by expansion are met. if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) || (!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand))) continue; Optional> replacementValues = fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter); if (!replacementValues) return failure(); rewriter.replaceOp(genericOp, replacementValues.getValue()); return success(); } return failure(); } private: ControlElementwiseOpsFusionFn controlFoldingReshapes; }; /// Pattern to fold tensor_collapse_shape or tensor_expand_shape op with its /// producer. The corresponding index map in the consumer needs to be modified /// to linearize the folded dimension. template struct FoldConsumerReshapeOpByLinearization : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, PatternRewriter &rewriter) const override { GenericOp producer = reshapeOp.src().template getDefiningOp(); if (!producer || !producer.hasTensorSemantics() || producer.getNumOutputs() != 1 || !isTensorReshapeOpFoldableByLinearization( reshapeOp, producer.getTiedIndexingMap(producer.getOutputOperand(0)), /*asProducer =*/false) || (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp))) return failure(); // The indexing_maps for the operands of the fused operation are same as // those for the operands of the producer. SmallVector fusedIndexMaps = producer.getIndexingMaps(); auto invMap = inversePermutation( producer.getTiedIndexingMap(producer.getOutputOperand(0))); // Compute the indexing map to use for the operand of the producer. AffineMap modifiedMap = linearizeCollapsedDims(invMap, reshapeOp); for (AffineExpr expr : modifiedMap.getResults()) { if (!expr.isPureAffine()) { return rewriter.notifyMatchFailure( producer, "fused op indexing map is not affine"); } } fusedIndexMaps.back() = modifiedMap; // Further check that the resulting index maps can be fused and // inverted. Without this the resultant op is not legal. if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { return rewriter.notifyMatchFailure( producer, "fused op loop bound computation failed"); } Location loc = producer.getLoc(); SmallVector inputOperands = producer.getInputOperands(); Value output = rewriter.create( loc, producer.getOutputOperand(0)->get(), reshapeOp.getReassociationExprs()); auto fusedOp = rewriter.create( loc, reshapeOp.getResultType(), /*inputs=*/inputOperands, // TODO: handle outputs. /*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps), producer.iterator_types(), /*doc=*/nullptr, /*library_call=*/nullptr); auto &fusedRegion = fusedOp->getRegion(0); rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion, fusedRegion.begin()); rewriter.replaceOp(reshapeOp, fusedOp->getResults()); return success(); } }; /// Pattern to fold a tensor_expand_shape op with its producer generic op /// by expanding the dimensionality of the loop in the producer op. struct FoldReshapeWithGenericOpByExpansion : public OpRewritePattern { FoldReshapeWithGenericOpByExpansion( MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), controlFoldingReshapes(foldReshapes) {} LogicalResult matchAndRewrite(TensorExpandShapeOp reshapeOp, PatternRewriter &rewriter) const override { // Fold only if all constraints of fusing with reshape by expansion are met. GenericOp producer = reshapeOp.src().getDefiningOp(); if (!producer || producer.getNumOutputs() != 1 || !isFusableWithReshapeByDimExpansion(producer, producer.getOutputOperand(0)) || !controlFoldingReshapes(producer->getResult(0), reshapeOp->getOpOperand(0))) return failure(); Optional> replacementValues = fuseWithReshapeByExpansion( producer, reshapeOp, producer.getOutputOperand(0), rewriter); if (!replacementValues) return failure(); rewriter.replaceOp(reshapeOp, replacementValues.getValue()); return success(); } private: ControlElementwiseOpsFusionFn controlFoldingReshapes; }; /// Pattern to fold a generic op with a splat constant/scalar constant. Does not /// handle cases where the constant is not single-valued. class FoldScalarOrSplatConstant : public OpRewritePattern { public: FoldScalarOrSplatConstant(MLIRContext *context, ControlElementwiseOpsFusionFn &fun, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), controlFn(fun) {} LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { if (!genericOp.hasTensorSemantics()) return failure(); for (OpOperand *opOperand : genericOp.getInputOperands()) { Operation *def = opOperand->get().getDefiningOp(); Attribute constantAttr; auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool { { DenseElementsAttr splatAttr; if (matchPattern(def, m_Constant(&splatAttr)) && splatAttr.isSplat() && splatAttr.getType().getElementType().isIntOrFloat()) { constantAttr = splatAttr.getSplatValue(); return true; } } { IntegerAttr intAttr; if (matchPattern(def, m_Constant(&intAttr))) { constantAttr = intAttr; return true; } } { FloatAttr floatAttr; if (matchPattern(def, m_Constant(&floatAttr))) { constantAttr = floatAttr; return true; } } return false; }; auto resultValue = opOperand->get().dyn_cast(); if (!def || !resultValue || !isScalarOrSplatConstantOp(def) || !controlFn(resultValue, *opOperand)) continue; // The operands and the indexing_maps of the fused operation the same as // the operands and indexing_maps of the generic operations with the // values at the constant index dropped. SmallVector fusedIndexMaps; SmallVector fusedOperands; SmallVector fusedLocs{genericOp.getLoc()}; fusedIndexMaps.reserve(genericOp.getNumInputsAndOutputs()); fusedOperands.reserve(genericOp.getNumInputs()); fusedLocs.reserve(fusedLocs.size() + genericOp.getNumInputs()); for (OpOperand *inputOperand : genericOp.getInputOperands()) { if (inputOperand == opOperand) continue; Value inputValue = inputOperand->get(); fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(inputOperand)); fusedOperands.push_back(inputValue); fusedLocs.push_back(inputValue.getLoc()); } for (OpOperand *outputOperand : genericOp.getOutputOperands()) fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(outputOperand)); // Check if the operation shapes to loops map is computable. if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { return rewriter.notifyMatchFailure( genericOp, "fused op loop bound computation failed"); } // Create a constant scalar value from the splat constant. Value scalarConstant = rewriter.create( def->getLoc(), constantAttr, constantAttr.getType()); SmallVector outputOperands = genericOp.getOutputOperands(); auto fusedOp = rewriter.create( rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(), /*inputs=*/fusedOperands, /*outputs=*/outputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps), genericOp.iterator_types(), /*doc=*/nullptr, /*library_call=*/nullptr); // Map the block argument corresponding to the replaced argument with the // scalar constant. Region ®ion = genericOp->getRegion(0); Block &entryBlock = *region.begin(); BlockAndValueMapping mapping; mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()), scalarConstant); Region &fusedRegion = fusedOp->getRegion(0); rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(), mapping); rewriter.replaceOp(genericOp, fusedOp->getResults()); return success(); } return failure(); } private: ControlElementwiseOpsFusionFn controlFn; }; /// Base class for constant folding linalg.generic ops with N inputs, 1 output, /// and permutation indexing maps. /// /// `ConcreteType` should provide methods with signatures /// /// ```c++ /// bool matchIndexingMaps(GenericOp genericOp) const; /// RegionComputationFn getRegionComputeFn(GenericOp) const; /// ``` /// /// The latter inspects the region and returns the computation inside as a /// functor. The functor will be invoked with constant elements for all inputs /// and should return the corresponding computea constant element for output. template class FoldConstantBase : public OpRewritePattern { public: struct APIntOrFloat { Optional apInt; Optional apFloat; }; struct APIntOrFloatArray { SmallVector apInts; SmallVector apFloats; }; using RegionComputationFn = std::function; FoldConstantBase(MLIRContext *context, const ControlElementwiseOpsFusionFn &controlFn, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), controlFn(controlFn) {} LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { if (genericOp.hasBufferSemantics()) return failure(); // Only support ops generating one output for now. if (genericOp.getNumOutputs() != 1) return failure(); auto outputType = genericOp.getResultTypes().front().dyn_cast(); // Require the output types to be static give we are generating constants. if (!outputType || !outputType.hasStaticShape()) return failure(); if (!llvm::all_of(genericOp.getInputOperands(), [](OpOperand *operand) { return operand->get().getType().isa(); })) return failure(); // Make sure all element types are the same. auto getOperandElementType = [](OpOperand *operand) { return operand->get().getType().cast().getElementType(); }; if (!llvm::is_splat(llvm::map_range(genericOp.getInputAndOutputOperands(), getOperandElementType))) return failure(); // We can only handle the case where we have int/float elements. auto elementType = outputType.getElementType(); if (!elementType.isIntOrFloat()) return failure(); // Require all indexing maps to be permutations for now. This is common and // it simplifies input/output access greatly: we can do the data shuffling // entirely in the compiler, without needing to turn all indices into // Values, and then do affine apply on them, and then match back the // constant again. if (!llvm::all_of(genericOp.getIndexingMaps(), [](AffineMap map) { return map.isPermutation(); })) return failure(); for (OpOperand *operand : genericOp.getOutputOperands()) { if (genericOp.payloadUsesValueFromOperand(operand)) return failure(); } // Further check the indexing maps are okay for the ConcreteType. if (!static_cast(this)->matchIndexingMaps(genericOp)) return failure(); // Defer to the concrete type to check the region and discover the // computation inside. RegionComputationFn computeFn = static_cast(this)->getRegionComputeFn(genericOp); if (!computeFn) return failure(); // All inputs should be constants. int numInputs = genericOp.getNumInputs(); SmallVector inputValues(numInputs); for (auto operand : llvm::enumerate(genericOp.getInputOperands())) { if (!matchPattern(operand.value()->get(), m_Constant(&inputValues[operand.index()]))) return failure(); } // Identified this as a potential candidate for folding. Now check the // policy to see whether we are allowed to proceed. for (int i = 0; i < numInputs; ++i) { OpOperand *consumer = genericOp.getInputOperand(i); OpResult producer = consumer->get().cast(); if (!controlFn(producer, *consumer)) return failure(); } auto linalgOp = cast(genericOp.getOperation()); SmallVector loopBounds = linalgOp.computeStaticLoopSizes(); int64_t numElements = outputType.getNumElements(); // Use APInt/APFloat instead of Attribute here for constructing the output. // This helps to avoid blowing up compiler memory usage: Attributes would // unify the following cases but they have lifetime as the MLIRContext. SmallVector intOutputValues; SmallVector fpOutputValues; if (elementType.template isa()) fpOutputValues.resize(numElements, APFloat(0.f)); else intOutputValues.resize(numElements); // Return the constant dim positions from the given permutation map. auto getDimPositions = [](AffineMap map) { SmallVector dims; dims.reserve(map.getNumResults()); for (AffineExpr result : map.getResults()) { dims.push_back(result.cast().getPosition()); } return dims; }; SmallVector> inputDims; for (int i = 0; i < numInputs; ++i) inputDims.push_back(getDimPositions(genericOp.getIndexingMaps()[i])); auto outputDims = getDimPositions(genericOp.getIndexingMaps().back()); auto outputShape = outputType.getShape(); // Allocate small vectors for index delinearization. Initial values do not // matter here as they will be overwritten later. SmallVector indices(loopBounds.size(), 0); SmallVector dstIndices(loopBounds.size(), 0); SmallVector> srcIndices( numInputs, SmallVector(loopBounds.size(), 0)); SmallVector srcLinearIndices(numInputs, 0); uint64_t dstLinearIndex = 0; // Allocate spaces for compute function inputs. Initial values do not matter // here as they will be overwritten later. APIntOrFloatArray computeFnInputs; auto inputShapes = llvm::to_vector<4>( llvm::map_range(genericOp.getInputOperands(), [](OpOperand *operand) { return operand->get().getType().cast().getShape(); })); // Given a `linearIndex`, remap it to a linear index to access linalg op // inputs/ouputs. This mutates `indices`, `srcIndices`, `dstIndices`, // `srcLinearIndices`, `dstLinearIndex` in place. auto computeRemappedLinearIndex = [&](int linearIndex) { int totalCount = linearIndex; for (int dim = loopBounds.size() - 1; dim >= 0; --dim) { indices[dim] = totalCount % loopBounds[dim]; totalCount /= loopBounds[dim]; } for (int dim = loopBounds.size() - 1; dim >= 0; --dim) { for (int i = 0; i < numInputs; ++i) srcIndices[i][dim] = indices[inputDims[i][dim]]; dstIndices[dim] = indices[outputDims[dim]]; } dstLinearIndex = dstIndices.front(); for (int i = 0; i < numInputs; ++i) srcLinearIndices[i] = srcIndices[i].front(); for (int dim = 1; dim < outputType.getRank(); ++dim) { dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim]; for (int i = 0; i < numInputs; ++i) srcLinearIndices[i] = srcLinearIndices[i] * inputShapes[i][dim] + srcIndices[i][dim]; } }; bool isFloat = elementType.isa(); if (isFloat) { SmallVector> inFpRanges; for (int i = 0; i < numInputs; ++i) inFpRanges.push_back(inputValues[i].getValues()); computeFnInputs.apFloats.resize(numInputs, APFloat(0.f)); // Transpose the input constant. Because we don't know its rank in // advance, we need to loop over the range [0, element count) and // delinearize the index. for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) { computeRemappedLinearIndex(linearIndex); // Collect constant elements for all inputs at this loop iteration. for (int i = 0; i < numInputs; ++i) computeFnInputs.apFloats[i] = inFpRanges[i][srcLinearIndices[i]]; // Invoke the computation to get the corresponding constant output // element. fpOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apFloat; } } else { SmallVector> inIntRanges; for (int i = 0; i < numInputs; ++i) inIntRanges.push_back(inputValues[i].getValues()); computeFnInputs.apInts.resize(numInputs); // Transpose the input constant. Because we don't know its rank in // advance, we need to loop over the range [0, element count) and // delinearize the index. for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) { computeRemappedLinearIndex(linearIndex); // Collect constant elements for all inputs at this loop iteration. for (int i = 0; i < numInputs; ++i) computeFnInputs.apInts[i] = inIntRanges[i][srcLinearIndices[i]]; // Invoke the computation to get the corresponding constant output // element. intOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apInt; } } DenseElementsAttr outputAttr = isFloat ? DenseElementsAttr::get(outputType, fpOutputValues) : DenseElementsAttr::get(outputType, intOutputValues); rewriter.replaceOpWithNewOp(genericOp, outputAttr); return success(); } private: ControlElementwiseOpsFusionFn controlFn; }; // Folds linalg.generic ops that are actually transposes on constant values. struct FoldConstantTranspose : public FoldConstantBase { using FoldConstantBase::FoldConstantBase; bool matchIndexingMaps(GenericOp genericOp) const { // We should have one input and one output. return genericOp.getIndexingMaps().size() == 2; } RegionComputationFn getRegionComputeFn(GenericOp genericOp) const { // Make sure the region only contains a yield op. Block &body = genericOp.region().front(); if (!llvm::hasSingleElement(body)) return nullptr; auto yieldOp = dyn_cast(body.getTerminator()); if (!yieldOp) return nullptr; // The yield op should return the block argument corresponds to the input. for (Value yieldVal : yieldOp.values()) { auto yieldArg = yieldVal.dyn_cast(); if (!yieldArg || yieldArg.getOwner() != &body) return nullptr; if (yieldArg.getArgNumber() != 0) return nullptr; } // No computation; just return the orginal value. return [](const APIntOrFloatArray &inputs) { if (inputs.apFloats.empty()) return APIntOrFloat{inputs.apInts.front(), llvm::None}; return APIntOrFloat{llvm::None, inputs.apFloats.front()}; }; } ControlElementwiseOpsFusionFn controlFn; }; } // namespace static Optional> fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand, GenericOp producer, const ControlElementwiseOpsFusionFn &controlFn) { if (producer->getNumResults() != 1) return llvm::None; return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn, rewriter); } bool mlir::linalg::skipUnitDimReshape(const OpResult &producer, OpOperand &consumer) { if (auto producerCollapseOp = dyn_cast(producer.getOwner())) { return !isUnitDimExpansionOnly(producerCollapseOp); } if (auto consumerExpandOp = dyn_cast(consumer.getOwner())) { return !isUnitDimExpansionOnly(consumerExpandOp); } return true; } namespace { /// Patterns to fuse a generic op, with the producer of its operands. class FuseElementwiseOps : public OpRewritePattern { public: FuseElementwiseOps(MLIRContext *context, ControlElementwiseOpsFusionFn &fun, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), controlFn(fun) {} LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { // Find the first operand that is defined by another generic op on tensors. for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) { auto producer = dyn_cast_or_null(opOperand->get().getDefiningOp()); if (!producer || !producer.hasTensorSemantics()) continue; Optional> fusedOpResults = fuseElementwiseOps(rewriter, opOperand, producer, controlFn); if (fusedOpResults) { rewriter.replaceOp(genericOp, *fusedOpResults); return success(); } } return failure(); } private: ControlElementwiseOpsFusionFn controlFn; }; /// Pass that fuses generic ops on tensors. Used only for testing. struct LinalgElementwiseOpFusionPass : public LinalgElementwiseOpFusionBase { void runOnOperation() override { Operation *op = getOperation(); RewritePatternSet patterns(op->getContext()); ControlElementwiseOpsFusionFn allowFoldingFn = [](const OpResult &producer, const OpOperand &consumer) { return true; }; populateElementwiseOpsFusionPatterns( patterns, LinalgElementwiseFusionOptions().setControlFoldingReshapes( allowFoldingUnitDimReshapes ? allowFoldingFn : skipUnitDimReshape)); // Use TopDownTraversal for compile time reasons GreedyRewriteConfig grc; grc.useTopDownTraversal = true; (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns), grc); } }; /// Pass to test folding of reshape ops with generic ops by linearization. struct FoldReshapeOpsByLinearizationPass : public LinalgFoldReshapeOpsByLinearizationBase< FoldReshapeOpsByLinearizationPass> { void runOnOperation() override { Operation *op = getOperation(); RewritePatternSet patterns(op->getContext()); populateFoldReshapeOpsByLinearizationPatterns(patterns); if (allowFoldingUnitDimReshapes) { populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns); } (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns)); } }; /// Forces `outs` operands of linalg operations to use `linalg.init_tensor` if /// the value of the `outs` operand is not used within the op. This is only /// implemented for `linalg.generic` operations for now, but should hold for all /// linalg structured ops. struct RemoveOutsDependency : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(GenericOp op, PatternRewriter &rewriter) const override { rewriter.startRootUpdate(op); bool modifiedOutput = false; Location loc = op.getLoc(); for (OpOperand *opOperand : op.getOutputOperands()) { if (!op.payloadUsesValueFromOperand(opOperand)) { Value operandVal = opOperand->get(); auto operandType = operandVal.getType().dyn_cast(); if (!operandType) continue; // If outs is already an `init_tensor` operation, nothing to do. auto definingOp = operandVal.getDefiningOp(); if (definingOp) continue; modifiedOutput = true; SmallVector dynamicDims; for (auto dim : llvm::enumerate(operandType.getShape())) { if (dim.value() != ShapedType::kDynamicSize) continue; dynamicDims.push_back(rewriter.createOrFold( loc, operandVal, dim.index())); } Value initTensor = rewriter.create( loc, dynamicDims, operandType.getShape(), operandType.getElementType()); op->setOperand(opOperand->getOperandNumber(), initTensor); } } if (!modifiedOutput) { rewriter.cancelRootUpdate(op); return failure(); } rewriter.finalizeRootUpdate(op); return success(); } }; } // namespace void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns( RewritePatternSet &patterns) { patterns .add, FoldProducerReshapeOpByLinearization, FoldConsumerReshapeOpByLinearization, FoldConsumerReshapeOpByLinearization>( patterns.getContext()); } void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns( RewritePatternSet &patterns) { patterns .add, FoldProducerReshapeOpByLinearization, FoldConsumerReshapeOpByLinearization, FoldConsumerReshapeOpByLinearization>( patterns.getContext()); } void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns( RewritePatternSet &patterns, ControlElementwiseOpsFusionFn controlFoldingReshapes) { patterns.add(patterns.getContext(), controlFoldingReshapes); patterns.add(patterns.getContext(), controlFoldingReshapes); } void mlir::linalg::populateElementwiseOpsFusionPatterns( RewritePatternSet &patterns, LinalgElementwiseFusionOptions options) { auto *context = patterns.getContext(); patterns.add(context, options.controlElementwiseOpsFusionFn); patterns.add(context); populateFoldReshapeOpsByExpansionPatterns(patterns, options.controlFoldingReshapesFn); AffineApplyOp::getCanonicalizationPatterns(patterns, context); GenericOp::getCanonicalizationPatterns(patterns, context); TensorExpandShapeOp::getCanonicalizationPatterns(patterns, context); TensorCollapseShapeOp::getCanonicalizationPatterns(patterns, context); context->getLoadedDialect()->getCanonicalizationPatterns( patterns); } void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) { auto *context = patterns.getContext(); patterns.add(context); } std::unique_ptr mlir::createLinalgElementwiseOpFusionPass() { return std::make_unique(); } std::unique_ptr mlir::createFoldReshapeOpsByLinearizationPass() { return std::make_unique(); }