//===- SimplifyExtractStridedMetadata.cpp - Simplify this operation -------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // /// This pass simplifies extract_strided_metadata(other_op(memref) to /// extract_strided_metadata(memref) when it is possible to express the effect // of other_op using affine apply on the results of // extract_strided_metadata(memref). //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" namespace mlir { namespace memref { #define GEN_PASS_DEF_SIMPLIFYEXTRACTSTRIDEDMETADATA #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" } // namespace memref } // namespace mlir using namespace mlir; namespace { /// Replace `baseBuffer, offset, sizes, strides = /// extract_strided_metadata(subview(memref, subOffset, /// subSizes, subStrides))` /// With /// /// \verbatim /// baseBuffer, baseOffset, baseSizes, baseStrides = /// extract_strided_metadata(memref) /// strides#i = baseStrides#i * subSizes#i /// offset = baseOffset + sum(subOffset#i * strides#i) /// sizes = subSizes /// \endverbatim /// /// In other words, get rid of the subview in that expression and canonicalize /// on its effects on the offset, the sizes, and the strides using affine.apply. struct ExtractStridedMetadataOpSubviewFolder : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, PatternRewriter &rewriter) const override { auto subview = op.getSource().getDefiningOp(); if (!subview) return failure(); // Build a plain extract_strided_metadata(memref) from // extract_strided_metadata(subview(memref)). Location origLoc = op.getLoc(); IndexType indexType = rewriter.getIndexType(); Value source = subview.getSource(); auto sourceType = source.getType().cast(); unsigned sourceRank = sourceType.getRank(); SmallVector sizeStrideTypes(sourceRank, indexType); auto newExtractStridedMetadata = rewriter.create( origLoc, op.getBaseBuffer().getType(), indexType, sizeStrideTypes, sizeStrideTypes, source); SmallVector sourceStrides; int64_t sourceOffset; bool hasKnownStridesAndOffset = succeeded(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)); (void)hasKnownStridesAndOffset; assert(hasKnownStridesAndOffset && "getStridesAndOffset must work on valid subviews"); // Compute the new strides and offset from the base strides and offset: // newStride#i = baseStride#i * subStride#i // offset = baseOffset + sum(subOffsets#i * newStrides#i) SmallVector strides; SmallVector subStrides = subview.getMixedStrides(); auto origStrides = newExtractStridedMetadata.getStrides(); // Hold the affine symbols and values for the computation of the offset. SmallVector values(3 * sourceRank + 1); SmallVector symbols(3 * sourceRank + 1); detail::bindSymbolsList(rewriter.getContext(), symbols); AffineExpr expr = symbols.front(); values[0] = ShapedType::isDynamicStrideOrOffset(sourceOffset) ? getAsOpFoldResult(newExtractStridedMetadata.getOffset()) : rewriter.getIndexAttr(sourceOffset); SmallVector subOffsets = subview.getMixedOffsets(); AffineExpr s0 = rewriter.getAffineSymbolExpr(0); AffineExpr s1 = rewriter.getAffineSymbolExpr(1); for (unsigned i = 0; i < sourceRank; ++i) { // Compute the stride. OpFoldResult origStride = ShapedType::isDynamicStrideOrOffset(sourceStrides[i]) ? origStrides[i] : OpFoldResult(rewriter.getIndexAttr(sourceStrides[i])); strides.push_back(makeComposedFoldedAffineApply( rewriter, origLoc, s0 * s1, {subStrides[i], origStride})); // Build up the computation of the offset. unsigned baseIdxForDim = 1 + 3 * i; unsigned subOffsetForDim = baseIdxForDim; unsigned subStrideForDim = baseIdxForDim + 1; unsigned origStrideForDim = baseIdxForDim + 2; expr = expr + symbols[subOffsetForDim] * symbols[subStrideForDim] * symbols[origStrideForDim]; values[subOffsetForDim] = subOffsets[i]; values[subStrideForDim] = subStrides[i]; values[origStrideForDim] = origStride; } // Compute the offset. OpFoldResult finalOffset = makeComposedFoldedAffineApply(rewriter, origLoc, expr, values); SmallVector results; // The final result is . // Thus we need 1 + 1 + subview.getRank() + subview.getRank(), to hold all // the values. auto subType = subview.getType().cast(); unsigned subRank = subType.getRank(); // Properly size the array so that we can do random insertions // at the right indices. // We do that to populate the non-dropped sizes and strides in one go. results.resize_for_overwrite(subRank * 2 + 2); results[0] = newExtractStridedMetadata.getBaseBuffer(); results[1] = getValueOrCreateConstantIndexOp(rewriter, origLoc, finalOffset); // The sizes of the final type are defined directly by the input sizes of // the subview. // Moreover subviews can drop some dimensions, some strides and sizes may // not end up in the final value that we are // replacing. // Do the filtering here. SmallVector subSizes = subview.getMixedSizes(); const unsigned sizeStartIdx = 2; const unsigned strideStartIdx = sizeStartIdx + subRank; unsigned insertedDims = 0; llvm::SmallBitVector droppedDims = subview.getDroppedDims(); for (unsigned i = 0; i < sourceRank; ++i) { if (droppedDims.test(i)) continue; results[sizeStartIdx + insertedDims] = getValueOrCreateConstantIndexOp(rewriter, origLoc, subSizes[i]); results[strideStartIdx + insertedDims] = getValueOrCreateConstantIndexOp(rewriter, origLoc, strides[i]); ++insertedDims; } assert(insertedDims == subRank && "Should have populated all the values at this point"); rewriter.replaceOp(op, results); return success(); } }; /// Compute the expanded sizes of the given \p expandShape for the /// \p groupId-th reassociation group. /// \p origSizes hold the sizes of the source shape as values. /// This is used to compute the new sizes in cases of dynamic shapes. /// /// sizes#i = /// baseSizes#groupId / product(expandShapeSizes#j, /// for j in group excluding reassIdx#i) /// Where reassIdx#i is the reassociation index at index i in \p groupId. /// /// \post result.size() == expandShape.getReassociationIndices()[groupId].size() /// /// TODO: Move this utility function directly within ExpandShapeOp. For now, /// this is not possible because this function uses the Affine dialect and the /// MemRef dialect cannot depend on the Affine dialect. static SmallVector getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder, ArrayRef origSizes, unsigned groupId) { SmallVector reassocGroup = expandShape.getReassociationIndices()[groupId]; assert(!reassocGroup.empty() && "Reassociation group should have at least one dimension"); unsigned groupSize = reassocGroup.size(); SmallVector expandedSizes(groupSize); uint64_t productOfAllStaticSizes = 1; Optional dynSizeIdx; MemRefType expandShapeType = expandShape.getResultType(); // Fill up all the statically known sizes. for (unsigned i = 0; i < groupSize; ++i) { uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]); if (ShapedType::isDynamic(dimSize)) { assert(!dynSizeIdx && "There must be at most one dynamic size per group"); dynSizeIdx = i; continue; } productOfAllStaticSizes *= dimSize; expandedSizes[i] = builder.getIndexAttr(dimSize); } // Compute the dynamic size using the original size and all the other known // static sizes: // expandSize = origSize / productOfAllStaticSizes. if (dynSizeIdx) { AffineExpr s0 = builder.getAffineSymbolExpr(0); expandedSizes[*dynSizeIdx] = makeComposedFoldedAffineApply( builder, expandShape.getLoc(), s0.floorDiv(productOfAllStaticSizes), origSizes[groupId]); } return expandedSizes; } /// Compute the expanded strides of the given \p expandShape for the /// \p groupId-th reassociation group. /// \p origStrides and \p origSizes hold respectively the strides and sizes /// of the source shape as values. /// This is used to compute the strides in cases of dynamic shapes and/or /// dynamic stride for this reassociation group. /// /// strides#i = /// origStrides#reassDim * product(expandShapeSizes#j, for j in /// reassIdx#i+1..reassIdx#i+group.size-1) /// /// Where reassIdx#i is the reassociation index for at index i in \p groupId /// and expandShapeSizes#j is either: /// - The constant size at dimension j, derived directly from the result type of /// the expand_shape op, or /// - An affine expression: baseSizes#reassDim / product of all constant sizes /// in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic /// element.) /// /// \post result.size() == expandShape.getReassociationIndices()[groupId].size() /// /// TODO: Move this utility function directly within ExpandShapeOp. For now, /// this is not possible because this function uses the Affine dialect and the /// MemRef dialect cannot depend on the Affine dialect. SmallVector getExpandedStrides(memref::ExpandShapeOp expandShape, OpBuilder &builder, ArrayRef origSizes, ArrayRef origStrides, unsigned groupId) { SmallVector reassocGroup = expandShape.getReassociationIndices()[groupId]; assert(!reassocGroup.empty() && "Reassociation group should have at least one dimension"); unsigned groupSize = reassocGroup.size(); MemRefType expandShapeType = expandShape.getResultType(); Optional dynSizeIdx; // Fill up the expanded strides, with the information we can deduce from the // resulting shape. uint64_t currentStride = 1; SmallVector expandedStrides(groupSize); for (int i = groupSize - 1; i >= 0; --i) { expandedStrides[i] = builder.getIndexAttr(currentStride); uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]); if (ShapedType::isDynamic(dimSize)) { assert(!dynSizeIdx && "There must be at most one dynamic size per group"); dynSizeIdx = i; continue; } currentStride *= dimSize; } // Collect the statically known information about the original stride. Value source = expandShape.getSrc(); auto sourceType = source.getType().cast(); SmallVector strides; int64_t offset; bool hasKnownStridesAndOffset = succeeded(getStridesAndOffset(sourceType, strides, offset)); (void)hasKnownStridesAndOffset; assert(hasKnownStridesAndOffset && "getStridesAndOffset must work on valid expand_shape"); OpFoldResult origStride = ShapedType::isDynamicStrideOrOffset(strides[groupId]) ? origStrides[groupId] : builder.getIndexAttr(strides[groupId]); // Apply the original stride to all the strides. int64_t doneStrideIdx = 0; // If we saw a dynamic dimension, we need to fix-up all the strides up to // that dimension with the dynamic size. if (dynSizeIdx) { int64_t productOfAllStaticSizes = currentStride; assert(ShapedType::isDynamic(sourceType.getDimSize(groupId)) && "We shouldn't be able to change dynamicity"); OpFoldResult origSize = origSizes[groupId]; AffineExpr s0 = builder.getAffineSymbolExpr(0); AffineExpr s1 = builder.getAffineSymbolExpr(1); for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) { int64_t baseExpandedStride = expandedStrides[doneStrideIdx] .get() .cast() .getInt(); expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply( builder, expandShape.getLoc(), (s0 * baseExpandedStride).floorDiv(productOfAllStaticSizes) * s1, {origSize, origStride}); } } // Now apply the origStride to the remaining dimensions. AffineExpr s0 = builder.getAffineSymbolExpr(0); for (; doneStrideIdx < groupSize; ++doneStrideIdx) { int64_t baseExpandedStride = expandedStrides[doneStrideIdx] .get() .cast() .getInt(); expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply( builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride}); } return expandedStrides; } /// Produce an OpFoldResult object with \p builder at \p loc representing /// `prod(valueOrConstant#i, for i in {indices})`, /// where valueOrConstant#i is maybeConstant[i] when \p isDymamic is false, /// values[i] otherwise. /// /// \pre for all index in indices: index < values.size() /// \pre for all index in indices: index < maybeConstants.size() static OpFoldResult getProductOfValues(ArrayRef indices, OpBuilder &builder, Location loc, ArrayRef maybeConstants, ArrayRef values, llvm::function_ref isDynamic) { AffineExpr productOfValues = builder.getAffineConstantExpr(1); SmallVector inputValues; unsigned numberOfSymbols = 0; unsigned groupSize = indices.size(); for (unsigned i = 0; i < groupSize; ++i) { productOfValues = productOfValues * builder.getAffineSymbolExpr(numberOfSymbols++); unsigned srcIdx = indices[i]; int64_t maybeConstant = maybeConstants[srcIdx]; inputValues.push_back(isDynamic(maybeConstant) ? values[srcIdx] : builder.getIndexAttr(maybeConstant)); } return makeComposedFoldedAffineApply(builder, loc, productOfValues, inputValues); } /// Compute the collapsed size of the given \p collpaseShape for the /// \p groupId-th reassociation group. /// \p origSizes hold the sizes of the source shape as values. /// This is used to compute the new sizes in cases of dynamic shapes. /// /// Conceptually this helper function computes: /// `prod(origSizes#i, for i in {ressociationGroup[groupId]})`. /// /// \post result.size() == 1, in other words, each group collapse to one /// dimension. /// /// TODO: Move this utility function directly within CollapseShapeOp. For now, /// this is not possible because this function uses the Affine dialect and the /// MemRef dialect cannot depend on the Affine dialect. static SmallVector getCollapsedSize(memref::CollapseShapeOp collapseShape, OpBuilder &builder, ArrayRef origSizes, unsigned groupId) { SmallVector collapsedSize; MemRefType collapseShapeType = collapseShape.getResultType(); uint64_t size = collapseShapeType.getDimSize(groupId); if (!ShapedType::isDynamic(size)) { collapsedSize.push_back(builder.getIndexAttr(size)); return collapsedSize; } // We are dealing with a dynamic size. // Build the affine expr of the product of the original sizes involved in that // group. Value source = collapseShape.getSrc(); auto sourceType = source.getType().cast(); SmallVector reassocGroup = collapseShape.getReassociationIndices()[groupId]; collapsedSize.push_back(getProductOfValues( reassocGroup, builder, collapseShape.getLoc(), sourceType.getShape(), origSizes, ShapedType::isDynamic)); return collapsedSize; } /// Compute the collapsed stride of the given \p collpaseShape for the /// \p groupId-th reassociation group. /// \p origStrides and \p origSizes hold respectively the strides and sizes /// of the source shape as values. /// This is used to compute the strides in cases of dynamic shapes and/or /// dynamic stride for this reassociation group. /// /// Conceptually this helper function returns the stride of the inner most /// dimension of that group in the original shape. /// /// \post result.size() == 1, in other words, each group collapse to one /// dimension. static SmallVector getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder, ArrayRef origSizes, ArrayRef origStrides, unsigned groupId) { SmallVector reassocGroup = collapseShape.getReassociationIndices()[groupId]; assert(!reassocGroup.empty() && "Reassociation group should have at least one dimension"); Value source = collapseShape.getSrc(); auto sourceType = source.getType().cast(); SmallVector strides; int64_t offset; bool hasKnownStridesAndOffset = succeeded(getStridesAndOffset(sourceType, strides, offset)); (void)hasKnownStridesAndOffset; assert(hasKnownStridesAndOffset && "getStridesAndOffset must work on valid collapse_shape"); SmallVector collapsedStride; int64_t innerMostDimForGroup = reassocGroup.back(); int64_t innerMostStrideForGroup = strides[innerMostDimForGroup]; collapsedStride.push_back( ShapedType::isDynamicStrideOrOffset(innerMostStrideForGroup) ? origStrides[innerMostDimForGroup] : builder.getIndexAttr(innerMostStrideForGroup)); return collapsedStride; } /// Replace `baseBuffer, offset, sizes, strides = /// extract_strided_metadata(reshapeLike(memref))` /// With /// /// \verbatim /// baseBuffer, offset, baseSizes, baseStrides = /// extract_strided_metadata(memref) /// sizes = getReshapedSizes(reshapeLike) /// strides = getReshapedStrides(reshapeLike) /// \endverbatim /// /// /// Notice that `baseBuffer` and `offset` are unchanged. /// /// In other words, get rid of the expand_shape in that expression and /// materialize its effects on the sizes and the strides using affine apply. template (*getReshapedSizes)( ReassociativeReshapeLikeOp, OpBuilder &, ArrayRef /*origSizes*/, unsigned /*groupId*/), SmallVector (*getReshapedStrides)( ReassociativeReshapeLikeOp, OpBuilder &, ArrayRef /*origSizes*/, ArrayRef /*origStrides*/, unsigned /*groupId*/)> struct ExtractStridedMetadataOpReshapeFolder : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, PatternRewriter &rewriter) const override { auto reshape = op.getSource().getDefiningOp(); if (!reshape) return failure(); // Build a plain extract_strided_metadata(memref) from // extract_strided_metadata(reassociative_reshape_like(memref)). Location origLoc = op.getLoc(); IndexType indexType = rewriter.getIndexType(); Value source = reshape.getSrc(); auto sourceType = source.getType().cast(); unsigned sourceRank = sourceType.getRank(); SmallVector sizeStrideTypes(sourceRank, indexType); auto newExtractStridedMetadata = rewriter.create( origLoc, op.getBaseBuffer().getType(), indexType, sizeStrideTypes, sizeStrideTypes, source); // Collect statically known information. SmallVector strides; int64_t offset; bool hasKnownStridesAndOffset = succeeded(getStridesAndOffset(sourceType, strides, offset)); (void)hasKnownStridesAndOffset; assert(hasKnownStridesAndOffset && "getStridesAndOffset must work on valid reassociative_reshape_like"); MemRefType reshapeType = reshape.getResultType(); unsigned reshapeRank = reshapeType.getRank(); // The result value will start with the base_buffer and offset. unsigned baseIdxInResult = 2; SmallVector results(baseIdxInResult + reshapeRank * 2); results[0] = newExtractStridedMetadata.getBaseBuffer(); results[1] = ShapedType::isDynamicStrideOrOffset(offset) ? getAsOpFoldResult(newExtractStridedMetadata.getOffset()) : rewriter.getIndexAttr(offset); // Get the special case of 0-D out of the way. if (sourceRank == 0) { Value constantOne = getValueOrCreateConstantIndexOp( rewriter, origLoc, rewriter.getIndexAttr(1)); SmallVector resultValues(baseIdxInResult + reshapeRank * 2, constantOne); for (unsigned i = 0; i < baseIdxInResult; ++i) resultValues[i] = getValueOrCreateConstantIndexOp(rewriter, origLoc, results[i]); rewriter.replaceOp(op, resultValues); return success(); } // Compute the reshaped strides and sizes from the base strides and sizes. SmallVector origSizes = getAsOpFoldResult(newExtractStridedMetadata.getSizes()); SmallVector origStrides = getAsOpFoldResult(newExtractStridedMetadata.getStrides()); unsigned idx = 0, endIdx = reshape.getReassociationIndices().size(); for (; idx != endIdx; ++idx) { SmallVector reshapedSizes = getReshapedSizes(reshape, rewriter, origSizes, /*groupId=*/idx); SmallVector reshapedStrides = getReshapedStrides( reshape, rewriter, origSizes, origStrides, /*groupId=*/idx); unsigned groupSize = reshapedSizes.size(); const unsigned sizeStartIdx = baseIdxInResult; const unsigned strideStartIdx = sizeStartIdx + reshapeRank; for (unsigned i = 0; i < groupSize; ++i) { results[sizeStartIdx + i] = reshapedSizes[i]; results[strideStartIdx + i] = reshapedStrides[i]; } baseIdxInResult += groupSize; } assert(((isa(reshape) && idx == sourceRank) || (isa(reshape) && idx == reshapeRank)) && "We should have visited all the input dimensions"); assert(baseIdxInResult == reshapeRank + 2 && "We should have populated all the values"); rewriter.replaceOp( op, getValueOrCreateConstantIndexOp(rewriter, origLoc, results)); return success(); } }; /// Helper function to perform the replacement of all constant uses of `values` /// by a materialized constant extracted from `maybeConstants`. /// `values` and `maybeConstants` are expected to have the same size. template bool replaceConstantUsesOf(PatternRewriter &rewriter, Location loc, Container values, ArrayRef maybeConstants, llvm::function_ref isDynamic) { assert(values.size() == maybeConstants.size() && " expected values and maybeConstants of the same size"); bool atLeastOneReplacement = false; for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) { // Don't materialize a constant if there are no uses: this would indice // infinite loops in the driver. if (isDynamic(maybeConstant) || result.use_empty()) continue; Value constantVal = rewriter.create(loc, maybeConstant); for (Operation *op : llvm::make_early_inc_range(result.getUsers())) { rewriter.startRootUpdate(op); // updateRootInplace: lambda cannot capture structured bindings in C++17 // yet. op->replaceUsesOfWith(result, constantVal); rewriter.finalizeRootUpdate(op); atLeastOneReplacement = true; } } return atLeastOneReplacement; } // Forward propagate all constants information from an ExtractStridedMetadataOp. struct ForwardStaticMetadata : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp metadataOp, PatternRewriter &rewriter) const override { auto memrefType = metadataOp.getSource().getType().cast(); SmallVector strides; int64_t offset; LogicalResult res = getStridesAndOffset(memrefType, strides, offset); (void)res; assert(succeeded(res) && "must be a strided memref type"); bool atLeastOneReplacement = replaceConstantUsesOf( rewriter, metadataOp.getLoc(), ArrayRef>(metadataOp.getOffset()), ArrayRef(offset), ShapedType::isDynamicStrideOrOffset); atLeastOneReplacement |= replaceConstantUsesOf( rewriter, metadataOp.getLoc(), metadataOp.getSizes(), memrefType.getShape(), ShapedType::isDynamic); atLeastOneReplacement |= replaceConstantUsesOf( rewriter, metadataOp.getLoc(), metadataOp.getStrides(), strides, ShapedType::isDynamicStrideOrOffset); return success(atLeastOneReplacement); } }; /// Replace `base, offset, sizes, strides = /// extract_strided_metadata(allocLikeOp)` /// /// With /// /// ``` /// base = reinterpret_cast allocLikeOp(allocSizes) to a flat memref /// offset = 0 /// sizes = allocSizes /// strides#i = prod(allocSizes#j, for j in {i+1..rank-1}) /// ``` /// /// The transformation only applies if the allocLikeOp has been normalized. /// In other words, the affine_map must be an identity. template struct ExtractStridedMetadataOpAllocFolder : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, PatternRewriter &rewriter) const override { auto allocLikeOp = op.getSource().getDefiningOp(); if (!allocLikeOp) return failure(); auto memRefType = allocLikeOp.getResult().getType().template cast(); if (!memRefType.getLayout().isIdentity()) return rewriter.notifyMatchFailure( allocLikeOp, "alloc-like operations should have been normalized"); Location loc = op.getLoc(); int rank = memRefType.getRank(); // Collect the sizes. ValueRange dynamic = allocLikeOp.getDynamicSizes(); SmallVector sizes; sizes.reserve(rank); unsigned dynamicPos = 0; for (int64_t size : memRefType.getShape()) { if (ShapedType::isDynamic(size)) sizes.push_back(dynamic[dynamicPos++]); else sizes.push_back(rewriter.getIndexAttr(size)); } // Strides (just creates identity strides). SmallVector strides(rank, rewriter.getIndexAttr(1)); AffineExpr expr = rewriter.getAffineConstantExpr(1); unsigned symbolNumber = 0; for (int i = rank - 2; i >= 0; --i) { expr = expr * rewriter.getAffineSymbolExpr(symbolNumber++); assert(i + 1 + symbolNumber == sizes.size() && "The ArrayRef should encompass the last #symbolNumber sizes"); ArrayRef sizesInvolvedInStride(&sizes[i + 1], symbolNumber); strides[i] = makeComposedFoldedAffineApply(rewriter, loc, expr, sizesInvolvedInStride); } // Put all the values together to replace the results. SmallVector results; results.reserve(rank * 2 + 2); auto baseBufferType = op.getBaseBuffer().getType().cast(); int64_t offset = 0; if (allocLikeOp.getType() == baseBufferType) results.push_back(allocLikeOp); else results.push_back(rewriter.create( loc, baseBufferType, allocLikeOp, offset, /*sizes=*/ArrayRef(), /*strides=*/ArrayRef())); // Offset. results.push_back(rewriter.create(loc, offset)); for (OpFoldResult size : sizes) results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, size)); for (OpFoldResult stride : strides) results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, stride)); rewriter.replaceOp(op, results); return success(); } }; /// Rewrite memref.extract_aligned_pointer_as_index of a ViewLikeOp to the /// source of the ViewLikeOp. class RewriteExtractAlignedPointerAsIndexOfViewLikeOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp, PatternRewriter &rewriter) const override { auto viewLikeOp = extractOp.getSource().getDefiningOp(); if (!viewLikeOp) return rewriter.notifyMatchFailure(extractOp, "not a ViewLike source"); rewriter.updateRootInPlace(extractOp, [&]() { extractOp.getSourceMutable().assign(viewLikeOp.getViewSource()); }); return success(); } }; } // namespace void memref::populateSimplifyExtractStridedMetadataOpPatterns( RewritePatternSet &patterns) { patterns .add, ExtractStridedMetadataOpReshapeFolder< memref::CollapseShapeOp, getCollapsedSize, getCollapsedStride>, ForwardStaticMetadata, ExtractStridedMetadataOpAllocFolder, ExtractStridedMetadataOpAllocFolder, RewriteExtractAlignedPointerAsIndexOfViewLikeOp>( patterns.getContext()); } //===----------------------------------------------------------------------===// // Pass registration //===----------------------------------------------------------------------===// namespace { struct SimplifyExtractStridedMetadataPass final : public memref::impl::SimplifyExtractStridedMetadataBase< SimplifyExtractStridedMetadataPass> { void runOnOperation() override; }; } // namespace void SimplifyExtractStridedMetadataPass::runOnOperation() { RewritePatternSet patterns(&getContext()); memref::populateSimplifyExtractStridedMetadataOpPatterns(patterns); (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), std::move(patterns)); } std::unique_ptr memref::createSimplifyExtractStridedMetadataPass() { return std::make_unique(); }