Vladislav Vinogradov ec03bbe8a7 [mlir] Fix bug in partial dialect conversion
The discussion on forum:
https://llvm.discourse.group/t/bug-in-partial-dialect-conversion/4115

The `applyPartialConversion` didn't handle the operations, that were
marked as illegal inside dynamic legality callback.
Instead of reporting error, if such operation was not converted to legal set,
the method just added it to `unconvertedSet` in the same way as unknown operations.

This patch fixes that and handle dynamically illegal operations as well.

The patch includes 2 fixes for existing passes:

* `tensor-bufferize` - explicitly mark `std.return` as legal.
* `convert-parallel-loops-to-gpu` - ugly fix with marking visited operations
  to avoid recursive legality checks.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D108505
2021-09-20 10:39:10 +03:00

191 lines
7.2 KiB
C++

//===- Bufferize.cpp - Bufferization for `tensor` dialect ops -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements bufferization of `tensor` dialect ops
//
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/Bufferize.h"
#include "PassDetail.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Passes.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
namespace {
class BufferizeCastOp : public OpConversionPattern<tensor::CastOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tensor::CastOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto resultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<memref::CastOp>(op, resultType, operands[0]);
return success();
}
};
} // namespace
namespace {
class BufferizeDimOp : public OpConversionPattern<tensor::DimOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tensor::DimOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
tensor::DimOp::Adaptor adaptor(operands);
rewriter.replaceOpWithNewOp<memref::DimOp>(op, adaptor.source(),
adaptor.index());
return success();
}
};
} // namespace
namespace {
class BufferizeExtractOp : public OpConversionPattern<tensor::ExtractOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tensor::ExtractOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
tensor::ExtractOp::Adaptor adaptor(operands);
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, adaptor.tensor(),
adaptor.indices());
return success();
}
};
} // namespace
namespace {
class BufferizeFromElementsOp
: public OpConversionPattern<tensor::FromElementsOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tensor::FromElementsOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
int numberOfElements = op.elements().size();
auto resultType = MemRefType::get(
{numberOfElements}, op.getType().cast<TensorType>().getElementType());
Value result = rewriter.create<memref::AllocOp>(op.getLoc(), resultType);
for (auto element : llvm::enumerate(op.elements())) {
Value index =
rewriter.create<ConstantIndexOp>(op.getLoc(), element.index());
rewriter.create<memref::StoreOp>(op.getLoc(), element.value(), result,
index);
}
rewriter.replaceOp(op, {result});
return success();
}
};
} // namespace
namespace {
class BufferizeGenerateOp : public OpConversionPattern<tensor::GenerateOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tensor::GenerateOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// Allocate memory.
Location loc = op.getLoc();
tensor::GenerateOp::Adaptor transformed(operands);
RankedTensorType tensorType = op.getType().cast<RankedTensorType>();
MemRefType memrefType =
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
Value result = rewriter.create<memref::AllocOp>(
loc, memrefType, transformed.dynamicExtents());
// Collect loop bounds.
int64_t rank = tensorType.getRank();
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
SmallVector<Value, 4> lowerBounds(rank, zero);
SmallVector<Value, 4> steps(rank, one);
SmallVector<Value, 4> upperBounds;
int nextDynamicIndex = 0;
for (int i = 0; i < rank; i++) {
Value upperBound =
tensorType.isDynamicDim(i)
? transformed.dynamicExtents()[nextDynamicIndex++]
: rewriter.create<ConstantIndexOp>(loc, memrefType.getDimSize(i));
upperBounds.push_back(upperBound);
}
// Generate tensor elements with a parallel loop that stores into
// each element of the resulting memref.
//
// This is a bit tricky. We cannot simply clone the ops because when an op
// is cloned, it must be legalized. However, we want to allow arbitrary ops
// in the body that we don't necessarily have legalization patterns for as
// part of this dialect conversion invocation.
//
// To accomplish this, we use mergeBlockBefore to "move" this op's body
// into the scf.parallel's body.
auto parallel =
rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
Block *parallelBody = parallel.getBody();
rewriter.mergeBlockBefore(op.getBody(), parallelBody->getTerminator(),
parallelBody->getArguments());
// Replace the inlined yield op with a store op. The scf.parallel's builder
// already populated an scf.yield at the end, so we don't need to worry
// about creating that.
Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
rewriter.setInsertionPointAfter(elementYield);
rewriter.replaceOpWithNewOp<memref::StoreOp>(
elementYield, elementYield->getOperands()[0], result,
parallelBody->getArguments());
rewriter.replaceOp(op, {result});
return success();
}
};
} // namespace
void mlir::populateTensorBufferizePatterns(
BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
patterns.add<BufferizeCastOp, BufferizeDimOp, BufferizeExtractOp,
BufferizeFromElementsOp, BufferizeGenerateOp>(
typeConverter, patterns.getContext());
}
namespace {
struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
void runOnFunction() override {
auto *context = &getContext();
BufferizeTypeConverter typeConverter;
RewritePatternSet patterns(context);
ConversionTarget target(*context);
populateBufferizeMaterializationLegality(target);
populateTensorBufferizePatterns(typeConverter, patterns);
target.addIllegalOp<tensor::CastOp, tensor::ExtractOp,
tensor::FromElementsOp, tensor::GenerateOp>();
target.addLegalDialect<memref::MemRefDialect>();
target.addDynamicallyLegalDialect<StandardOpsDialect>(
[&](Operation *op) { return typeConverter.isLegal(op); });
target.addLegalOp<ReturnOp>();
target.addLegalDialect<scf::SCFDialect>();
if (failed(
applyPartialConversion(getFunction(), target, std::move(patterns))))
signalPassFailure();
}
};
} // namespace
std::unique_ptr<Pass> mlir::createTensorBufferizePass() {
return std::make_unique<TensorBufferizePass>();
}