[mlir][Linalg] Improve region support in Linalg ops.

This revision takes advantage of the newly extended `ref` directive in assembly format
to allow better region handling for LinalgOps. Specifically, FillOp and CopyOp now build their regions explicitly which allows retiring older behavior that relied on specific op knowledge in both lowering to loops and vectorization.

Differential Revision: https://reviews.llvm.org/D96598
This commit is contained in:
Nicolas Vasilache 2021-02-12 13:50:10 +00:00
parent ee4dd0f876
commit 973e133b76
11 changed files with 317 additions and 281 deletions

View File

@ -1056,20 +1056,6 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
//===------------------------------------------------------------------===//
// Other static interface methods.
//===------------------------------------------------------------------===//
StaticInterfaceMethod<
/*desc=*/[{
Create an operation of the current type with the given location,
operands, and attributes.
}],
/*retTy=*/"Operation *",
/*methodName=*/"create",
(ins "OpBuilder &":$builder, "Location":$loc, "TypeRange":$resultTypes,
"ValueRange":$operands,
"ArrayRef<NamedAttribute>":$attributes), [{
return builder.create<ConcreteOp>(
loc, resultTypes, operands, attributes);
}]
>,
InterfaceMethod<
/*desc=*/[{
Clone the current operation with the given location and operands. This
@ -1082,14 +1068,13 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
(ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes,
"ValueRange":$operands),
[{
BlockAndValueMapping map;
unsigned numRegions = $_op->getNumRegions();
Operation *res = create(b, loc, resultTypes, operands, $_op->getAttrs());
assert(res->getNumRegions() == numRegions && "inconsistent # regions");
for (unsigned ridx = 0; ridx < numRegions; ++ridx)
$_op->getRegion(ridx).cloneInto(
&res->getRegion(ridx), map);
return res;
BlockAndValueMapping bvm;
OperationState state(
loc, ConcreteOp::getOperationName(), operands, resultTypes,
$_op->getAttrs());
for (Region &r : $_op->getRegions())
r.cloneInto(state.addRegion(), bvm);
return b.createOperation(state);
}]
>,
StaticInterfaceMethod<
@ -1098,7 +1083,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
Returns a null function if this named op does not define a region
builder.
}],
/*retTy=*/"std::function<void(Block &)>",
/*retTy=*/"std::function<void(Block &, ValueRange)>",
/*methodName=*/"getRegionBuilder",
(ins),
[{ return ConcreteOp::getRegionBuilder(); }]

View File

@ -110,14 +110,13 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> {
AnyStridedMemRef:$output,
OptionalAttr<AffineMapAttr>:$inputPermutation,
OptionalAttr<AffineMapAttr>:$outputPermutation);
let regions = (region AnyRegion:$region);
// TODO: this should go away once the usage of OptionalAttr triggers emission
// of builders with default arguments left unspecified.
let builders = [OpBuilderDAG<(ins "Value":$input, "Value":$output),
[{
return build(
$_builder, $_state, input, output, AffineMapAttr(), AffineMapAttr());
}]>];
let builders = [
OpBuilderDAG<(ins "Value":$input, "Value":$output,
CArg<"AffineMap", "AffineMap()">:$inputPermutation,
CArg<"AffineMap", "AffineMap()">:$outputPermutation,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>];
let extraClassDeclaration = structuredOpsDecls # [{
ValueRange inputs() { return getOperands().take_front(); }
@ -146,24 +145,31 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> {
Value getSource() { return input();}
Value getTarget() { return output(); }
static std::function<void(Block &)> getRegionBuilder() {
return nullptr;
static void regionBuilder(Block &block, ValueRange captures);
static std::function<void(Block &block, ValueRange captures)>
getRegionBuilder() {
return &regionBuilder;
}
static unsigned getNumRegionArgs() { return 2; }
}];
let verifier = [{ return ::verify(*this); }];
let assemblyFormat = [{
`(` operands `)` attr-dict `:` type(operands)
`(` $input `,` $output `)` attr-dict `:`
type($input) `,` type($output)
custom<CopyOpRegion>($region, ref(type($input)), ref(type($input)))
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
let skipDefaultBuilders = 1;
}
def FillOp : LinalgStructured_Op<"fill", []> {
let arguments = (ins AnyShaped:$output,
AnyTypeOf<[AnyFloat, AnySignlessInteger, AnyVector]>:$value);
let results = (outs Optional<AnyRankedTensor>:$result);
let regions = (region AnyRegion:$region);
let extraClassDeclaration = structuredOpsDecls # [{
ValueRange inputs() { return {}; }
ValueRange outputs() { return getOperands().take_front(); }
@ -183,13 +189,18 @@ def FillOp : LinalgStructured_Op<"fill", []> {
extractOrIdentityMap(llvm::None, getNumParallelLoops(), context)});
}
static std::function<void(Block &)> getRegionBuilder() {
return nullptr;
static void regionBuilder(Block &block, ValueRange captures);
static std::function<void(Block &block, ValueRange captures)>
getRegionBuilder() {
return &regionBuilder;
}
static unsigned getNumRegionArgs() { return 1; }
}];
let assemblyFormat = [{
`(` operands `)` attr-dict `:` type(operands) (`->` type($result)^)?
`(` $output `,` $value `)` attr-dict `:`
type($output) `,` type($value) (`->` type($result)^)?
custom<FillOpRegion>($region, ref(type($output)), ref($value))
}];
let builders = [
@ -268,7 +279,8 @@ class PoolingBase_Op<string mnemonic, list<OpTrait> props>
return padding().getValue().getValue<int64_t>({i, 1});
}
static std::function<void(Block &)> getRegionBuilder() {
static std::function<void(Block &, ValueRange captures)> getRegionBuilder()
{
return nullptr;
}
}];
@ -519,7 +531,7 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, [
library_call()->str() : "op_has_no_registered_library_name";
}
static std::function<void(Block &)> getRegionBuilder() {
static std::function<void(Block &, ValueRange)> getRegionBuilder() {
return nullptr;
}
}];

View File

@ -154,7 +154,13 @@ LogicalResult mlir::linalg::CopyTransposeRewrite::matchAndRewrite(
if (in == op.input() && out == op.output())
return failure();
rewriter.replaceOpWithNewOp<CopyOp>(op, in, out);
auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
if (!libraryCallName)
return failure();
rewriter.replaceOpWithNewOp<mlir::CallOp>(
op, libraryCallName.getValue(), TypeRange(),
createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), {in, out}));
return success();
}

View File

@ -27,8 +27,6 @@ Operation *mlir::edsc::makeGenericLinalgOp(
ArrayRef<StructuredIndexed> outputs, TypeRange resultTensorTypes,
function_ref<void(ValueRange)> regionBuilder, ArrayRef<Value> otherValues,
ArrayRef<Attribute> otherAttributes) {
OpBuilder &builder = edsc::ScopedContext::getBuilderRef();
// Build maps
SmallVector<SmallVector<AffineExpr, 4>, 4> exprsList;
exprsList.reserve(inputs.size() + outputs.size());
@ -54,13 +52,10 @@ Operation *mlir::edsc::makeGenericLinalgOp(
resultTensorTypes,
inputValues,
outputValues,
builder.getAffineMapArrayAttr(maps),
builder.getStrArrayAttr(iteratorStrTypes),
StringAttr() /*doc*/,
StringAttr() /*library_call*/,
ArrayAttr() /*sparse*/
/* TODO: other attributes in op */
)
maps,
iteratorStrTypes,
""/*doc*/,
""/*library_call*/)
.getOperation();
// clang-format on

View File

@ -33,32 +33,53 @@ using namespace mlir;
using namespace mlir::linalg;
/// Forward declarations.
template <typename NamedStructuredOpType>
static void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder,
OperationState &result,
TypeRange inputTypes,
TypeRange outputTypes);
/// Generic entry point to create the block for the region of a LinalgOp.
/// This is used by both named structured ops created by ods-gen and by manually
/// defined C++ ops.
/// This is used by both builders and parsers.
/// This function creates the block in the region with arguments corresponding
/// to the elemental types of `inputTypes` and `outputTypes`, which are asserted
/// to be ShapedType.
template <typename NamedStructuredOpType>
static void fillStructuredOpRegion(
OpBuilder &opBuilder, Region &region, TypeRange inputTypes,
TypeRange outputTypes, ValueRange captures = {},
std::function<void(unsigned, unsigned)> errorHandler = [](unsigned,
unsigned) {});
/// Generic entry point to create both the region and the block of a LinalgOp.
template <typename NamedStructuredOpType>
static void
createAndFillStructuredOpRegion(OpBuilder &opBuilder, OperationState &result,
TypeRange inputTypes, TypeRange outputTypes,
ValueRange captures = {});
/// Common parsing and printing used for both named structured ops created by
/// ods-gen and by manually defined C++ ops. Does not handle regions.
static ParseResult
parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
SmallVectorImpl<Type> &inputTypes,
SmallVectorImpl<Type> &outputTypes);
template <typename NamedStructuredOpType>
static void printCommonStructuredOpParts(OpAsmPrinter &p,
NamedStructuredOpType op);
/// Specific parsing and printing for named structured ops created by ods-gen.
template <typename NamedStructuredOpType>
static ParseResult
parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region,
TypeRange inputTypes, TypeRange outputTypes);
TypeRange inputTypes, TypeRange outputTypes,
ArrayRef<OpAsmParser::OperandType> captures = {});
static ParseResult
parseNamedStructuredOpResults(OpAsmParser &parser,
SmallVectorImpl<Type> &resultTypes);
template <typename NamedStructuredOpType>
static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
OperationState &result);
template <typename NamedStructuredOpType>
static void printCommonStructuredOpParts(OpAsmPrinter &p,
NamedStructuredOpType op);
static ParseResult
parseNamedStructuredOp(OpAsmParser &parser, OperationState &result,
ArrayRef<OpAsmParser::OperandType> captures = {});
static void printNamedStructuredOpResults(OpAsmPrinter &p,
TypeRange resultTypes);
@ -83,14 +104,136 @@ static LogicalResult foldMemRefCast(Operation *op) {
return success(folded);
}
//===----------------------------------------------------------------------===//
// CopyOp
//===----------------------------------------------------------------------===//
void CopyOp::regionBuilder(Block &block, ValueRange captures) {
using namespace edsc::intrinsics;
assert(block.getNumArguments() == 2 && "CopyOp regionBuilder expects 2 args");
(linalg_yield(block.getArgument(0)));
}
void CopyOp::build(OpBuilder &builder, OperationState &result, Value input,
Value output, AffineMap inputPermutation,
AffineMap outputPermutation,
ArrayRef<NamedAttribute> namedAttrs) {
result.addOperands({input, output});
result.addAttributes(namedAttrs);
if (inputPermutation)
result.addAttribute("inputPermutation",
AffineMapAttr::get(inputPermutation));
if (outputPermutation)
result.addAttribute("outputPermutation",
AffineMapAttr::get(outputPermutation));
result.addRegion();
fillStructuredOpRegion<CopyOp>(builder, *result.regions.front(),
TypeRange{input.getType()},
TypeRange{output.getType()});
}
ParseResult parseCopyOpRegion(OpAsmParser &parser, Region &r, Type inputType,
Type outputType) {
OpBuilder opBuilder(parser.getBuilder().getContext());
fillStructuredOpRegion<CopyOp>(opBuilder, r, TypeRange{inputType},
TypeRange{outputType});
return success();
}
/// CopyOp region is elided when printing.
void printCopyOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type) {}
static LogicalResult verify(CopyOp op) {
auto outputViewType = op.getOutputShapedType(0);
auto inputViewType = op.getInputShapedType(0);
if (inputViewType.getElementType() != outputViewType.getElementType())
return op.emitOpError("expects views of the same type");
if (inputViewType.getRank() != outputViewType.getRank())
return op.emitOpError("expects views of the same rank");
auto rank = op.getNumParallelLoops();
auto inputPermutationMap = op.inputPermutation();
if (inputPermutationMap) {
if (inputPermutationMap->getNumInputs() != rank)
return op.emitOpError("expects optional input_permutation map of rank ")
<< rank;
if (!inputPermutationMap->isPermutation())
return op.emitOpError(
"expects optional input_permutation map to be a permutation");
}
auto outputPermutationMap = op.outputPermutation();
if (outputPermutationMap) {
if (outputPermutationMap->getNumInputs() != rank)
return op.emitOpError("expects optional output_permutation map of rank ")
<< rank;
if (!outputPermutationMap->isPermutation())
return op.emitOpError(
"expects optional output_permutation map to be a permutation");
}
if (rank == 0 && inputPermutationMap)
return op.emitOpError("expected no input permutation when rank == 0");
if (rank == 0 && outputPermutationMap)
return op.emitOpError("expected no output permutation when rank == 0");
return success();
}
void CopyOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
effects.emplace_back(MemoryEffects::Read::get(), input(),
SideEffects::DefaultResource::get());
effects.emplace_back(MemoryEffects::Write::get(), output(),
SideEffects::DefaultResource::get());
}
//===----------------------------------------------------------------------===//
// FillOp
//===----------------------------------------------------------------------===//
void FillOp::regionBuilder(Block &block, ValueRange captures) {
using namespace edsc::intrinsics;
assert(captures.size() == 1 && "FillOp regionBuilder expects 1 capture");
(linalg_yield(captures));
}
void FillOp::build(OpBuilder &builder, OperationState &result, Value output,
Value value) {
build(builder, result, output.getType().dyn_cast<RankedTensorType>(), output,
value);
fillStructuredOpRegion<FillOp>(builder, *result.regions.front(), TypeRange{},
TypeRange{output.getType()}, value);
}
ParseResult parseFillOpRegion(OpAsmParser &parser, Region &r, Type outputType,
OpAsmParser::OperandType valueRef) {
OpBuilder opBuilder(parser.getBuilder().getContext());
// Resolve `valueRef` into `value` at parse time so we can build the region
// with captures.
SmallVector<Value> value;
parser.resolveOperand(valueRef, getElementTypeOrSelf(outputType), value);
fillStructuredOpRegion<FillOp>(opBuilder, r, TypeRange{},
TypeRange{outputType}, value);
return success();
}
/// FillOp region is elided when printing.
void printFillOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Value) {}
static LogicalResult verify(FillOp op) {
auto viewType = op.getOutputShapedType(0);
auto fillType = op.value().getType();
if (viewType.getElementType() != fillType)
return op.emitOpError("expects fill type to match view elemental type");
if (!op.getNumResults() && !viewType.isa<MemRefType>()) {
return op.emitOpError(
"expected fill op with no result value to use memref type");
}
return success();
}
void FillOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
if (output().getType().isa<MemRefType>())
effects.emplace_back(MemoryEffects::Write::get(), output(),
SideEffects::DefaultResource::get());
}
//===----------------------------------------------------------------------===//
@ -397,7 +540,6 @@ static LogicalResult verify(IndexedGenericOp op) { return verifyGenericOp(op); }
// InitTensorOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(InitTensorOp op) {
RankedTensorType resultType = op.getType();
SmallVector<int64_t, 4> staticSizes = llvm::to_vector<4>(llvm::map_range(
@ -1396,68 +1538,6 @@ static LogicalResult verify(linalg::YieldOp op) {
/////// Operations corresponding to library calls defined with Tablegen ////////
void FillOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
if (output().getType().isa<MemRefType>())
effects.emplace_back(MemoryEffects::Write::get(), output(),
SideEffects::DefaultResource::get());
}
static LogicalResult verify(FillOp op) {
auto viewType = op.getOutputShapedType(0);
auto fillType = op.value().getType();
if (viewType.getElementType() != fillType)
return op.emitOpError("expects fill type to match view elemental type");
if (!op.getNumResults() && !viewType.isa<MemRefType>()) {
return op.emitOpError(
"expected fill op with no result value to use memref type");
}
return success();
}
void CopyOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
effects.emplace_back(MemoryEffects::Read::get(), input(),
SideEffects::DefaultResource::get());
effects.emplace_back(MemoryEffects::Write::get(), output(),
SideEffects::DefaultResource::get());
}
static LogicalResult verify(CopyOp op) {
auto outputViewType = op.getOutputShapedType(0);
auto inputViewType = op.getInputShapedType(0);
if (inputViewType.getElementType() != outputViewType.getElementType())
return op.emitOpError("expects views of the same type");
if (inputViewType.getRank() != outputViewType.getRank())
return op.emitOpError("expects views of the same rank");
auto rank = op.getNumParallelLoops();
auto inputPermutationMap = op.inputPermutation();
if (inputPermutationMap) {
if (inputPermutationMap->getNumInputs() != rank)
return op.emitOpError("expects optional input_permutation map of rank ")
<< rank;
if (!inputPermutationMap->isPermutation())
return op.emitOpError(
"expects optional input_permutation map to be a permutation");
}
auto outputPermutationMap = op.outputPermutation();
if (outputPermutationMap) {
if (outputPermutationMap->getNumInputs() != rank)
return op.emitOpError("expects optional output_permutation map of rank ")
<< rank;
if (!outputPermutationMap->isPermutation())
return op.emitOpError(
"expects optional output_permutation map to be a permutation");
}
if (rank == 0 && inputPermutationMap)
return op.emitOpError("expected no input permutation when rank == 0");
if (rank == 0 && outputPermutationMap)
return op.emitOpError("expected no output permutation when rank == 0");
return success();
}
template <typename LinalgPoolingOp>
static LogicalResult verifyStrideOrDilation(LinalgPoolingOp op,
ArrayRef<Attribute> attrs,
@ -1690,14 +1770,25 @@ OpFoldResult TensorReshapeOp::fold(ArrayRef<Attribute> operands) {
}
//===----------------------------------------------------------------------===//
// Auto-generated Linalg named ops.
// Support for named Linalg ops defined in ods-gen.
//===----------------------------------------------------------------------===//
/// Generic entry point to create the block for the region of a LinalgOp.
/// This is used by both named structured ops created by ods-gen and by manually
/// defined C++ ops.
/// This is used by both builders and parsers.
/// This function creates the block in the region with arguments corresponding
/// to the elemental types of `inputTypes` and `outputTypes`, which are asserted
/// to be ShapedType.
template <typename NamedStructuredOpType>
static void buildNamedStructuredOpRegionAndAttributesImpl(
OpBuilder &opBuilder, Region &region, TypeRange inputTypes,
TypeRange outputTypes,
std::function<void(unsigned, unsigned)> errorHandler) {
static void
fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
TypeRange inputTypes, TypeRange outputTypes,
ValueRange captures,
std::function<void(unsigned, unsigned)> errorHandler) {
assert(llvm::all_of(inputTypes, [](Type t) { return t.isa<ShapedType>(); }));
assert(llvm::all_of(outputTypes, [](Type t) { return t.isa<ShapedType>(); }));
// TODO: atm all operands go through getElementTypeOrSelf,
// reconsider when we have evidence we need to.
SmallVector<Type, 8> argTypes;
@ -1707,7 +1798,7 @@ static void buildNamedStructuredOpRegionAndAttributesImpl(
// RAII.
OpBuilder::InsertionGuard guard(opBuilder);
Block *body = opBuilder.createBlock(&region, {}, argTypes);
Block *body = opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes);
unsigned actual = body->getNumArguments();
unsigned expected = NamedStructuredOpType::getNumRegionArgs();
if (expected != actual)
@ -1715,53 +1806,30 @@ static void buildNamedStructuredOpRegionAndAttributesImpl(
opBuilder.setInsertionPointToStart(body);
mlir::edsc::ScopedContext scope(opBuilder, opBuilder.getUnknownLoc());
NamedStructuredOpType::regionBuilder(*body);
NamedStructuredOpType::regionBuilder(*body, captures);
// indexing_maps is an auto-generated method.
// iterator_types is an auto-generated method.
}
/// Generic entry point to create both the region and the block of a LinalgOp.
template <typename NamedStructuredOpType>
void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder,
OperationState &result,
TypeRange inputTypes,
TypeRange outputTypes) {
void createAndFillStructuredOpRegion(OpBuilder &opBuilder,
OperationState &result,
TypeRange inputTypes,
TypeRange outputTypes,
ValueRange captures) {
Region &region = *result.addRegion();
buildNamedStructuredOpRegionAndAttributesImpl<NamedStructuredOpType>(
opBuilder, region, inputTypes, outputTypes,
fillStructuredOpRegion<NamedStructuredOpType>(
opBuilder, region, inputTypes, outputTypes, captures,
[&](unsigned expected, unsigned actual) {
llvm::errs() << "region expects " << expected << " args, got "
<< actual;
assert(expected != actual && "incorrect number of arguments");
});
}
template <typename NamedStructuredOpType>
static ParseResult
parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region,
TypeRange inputTypes, TypeRange outputTypes) {
ParseResult res = success();
OpBuilder opBuilder(parser.getBuilder().getContext());
buildNamedStructuredOpRegionAndAttributesImpl<NamedStructuredOpType>(
opBuilder, region, inputTypes, outputTypes,
[&](unsigned expected, unsigned actual) {
res = parser.emitError(parser.getCurrentLocation(),
llvm::formatv("region expects {0} args, got {1}",
expected, actual));
});
return res;
}
static ParseResult
parseNamedStructuredOpResults(OpAsmParser &parser,
SmallVectorImpl<Type> &resultTypes) {
if (succeeded(parser.parseOptionalArrow()))
if (parser.parseTypeList(resultTypes))
return failure();
return success();
}
/// Common parsing used for both named structured ops created by ods-gen and by
/// manually defined C++ ops. Does not handle regions.
static ParseResult
parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
SmallVectorImpl<Type> &inputTypes,
@ -1802,8 +1870,56 @@ parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
}
template <typename NamedStructuredOpType>
static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
OperationState &result) {
static void printCommonStructuredOpParts(OpAsmPrinter &p,
NamedStructuredOpType op) {
if (!op.inputs().empty())
p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")";
if (!op.outputs().empty())
p << " outs(" << op.outputs() << " : " << op.outputs().getTypes() << ")";
}
//===----------------------------------------------------------------------===//
// Specific parsing and printing for named structured ops created by ods-gen.
//===----------------------------------------------------------------------===//
template <typename NamedStructuredOpType>
static ParseResult
parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region,
TypeRange inputTypes, TypeRange outputTypes,
ArrayRef<OpAsmParser::OperandType> captures) {
ParseResult res = success();
OpBuilder opBuilder(parser.getBuilder().getContext());
// Resolve `captures` into `capturedValues` at parse time so we can build the
// region with captures.
SmallVector<Value> capturedValues;
fillStructuredOpRegion<NamedStructuredOpType>(
opBuilder, region, inputTypes, outputTypes, capturedValues,
[&](unsigned expected, unsigned actual) {
res = parser.emitError(
parser.getCurrentLocation(),
llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
"region expects {0} args, got {1}",
expected, actual));
region.front().dump();
});
return res;
}
static ParseResult
parseNamedStructuredOpResults(OpAsmParser &parser,
SmallVectorImpl<Type> &resultTypes) {
if (succeeded(parser.parseOptionalArrow()))
if (parser.parseTypeList(resultTypes))
return failure();
return success();
}
template <typename NamedStructuredOpType>
static ParseResult
parseNamedStructuredOp(OpAsmParser &parser, OperationState &result,
ArrayRef<OpAsmParser::OperandType> captures) {
// TODO: Enable when ods-gen supports captures.
assert(captures.empty() && "unexpected captures for named structured ops");
SmallVector<Type, 1> inputTypes, outputTypes;
if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
return failure();
@ -1817,7 +1933,7 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
std::unique_ptr<Region> region = std::make_unique<Region>();
if (parseNamedStructuredOpRegion<NamedStructuredOpType>(
parser, *region, inputTypes, outputTypes))
parser, *region, inputTypes, outputTypes, captures))
return failure();
result.addRegion(std::move(region));
@ -1831,15 +1947,6 @@ static void printNamedStructuredOpResults(OpAsmPrinter &p,
p.printOptionalArrowTypeList(resultTypes);
}
template <typename NamedStructuredOpType>
static void printCommonStructuredOpParts(OpAsmPrinter &p,
NamedStructuredOpType op) {
if (!op.inputs().empty())
p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")";
if (!op.outputs().empty())
p << " outs(" << op.outputs() << " : " << op.outputs().getTypes() << ")";
}
template <typename NamedStructuredOpType>
static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) {
p << op.getOperationName();
@ -1861,6 +1968,10 @@ static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op) {
return verifyGenericOp<NamedStructuredOpType>(op);
}
//===----------------------------------------------------------------------===//
// Canonicalizers and Folders.
//===----------------------------------------------------------------------===//
namespace {
struct EraseDeadLinalgOp : public RewritePattern {
EraseDeadLinalgOp(PatternBenefit benefit = 1)

View File

@ -49,7 +49,7 @@ static linalg::GenericOp createGenericOpFromNamedOp(linalg::LinalgOp namedOp,
indexingMaps, iterators,
[&regionBuilder](OpBuilder &bodyBuilder, Location loc, ValueRange) {
edsc::ScopedContext scope(bodyBuilder, loc);
regionBuilder(*bodyBuilder.getBlock());
regionBuilder(*bodyBuilder.getBlock(), /*captures=*/{});
});
}

View File

@ -52,14 +52,6 @@ static SmallVector<Value, 8> makeCanonicalAffineApplies(OpBuilder &b,
return res;
}
static SmallVector<Value, 4> permuteIvs(ArrayRef<Value> ivs,
Optional<AffineMap> permutation) {
return permutation ? applyMapToValues(ScopedContext::getBuilderRef(),
ScopedContext::getLocation(),
permutation.getValue(), ivs)
: SmallVector<Value, 4>(ivs.begin(), ivs.end());
}
template <typename IndexedValueType, typename OpType>
static void inlineRegionAndEmitStore(OpType op, ArrayRef<Value> indexedValues,
ArrayRef<SmallVector<Value, 8>> indexing,
@ -178,40 +170,6 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs,
outputBuffers);
}
template <typename IndexedValueType>
static void emitScalarImplementation(ArrayRef<Value> allIvs, CopyOp copyOp) {
assert(copyOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
auto nPar = copyOp.getNumParallelLoops();
assert(nPar == allIvs.size());
auto inputIvs =
permuteIvs(allIvs.take_front(nPar), copyOp.inputPermutation());
auto outputIvs =
permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation());
SmallVector<Value, 8> iivs(inputIvs.begin(), inputIvs.end());
SmallVector<Value, 8> oivs(outputIvs.begin(), outputIvs.end());
IndexedValueType O(copyOp.getOutputBuffer(0)), I(copyOp.getInput(0));
// Emit the proper scalar assignment, whether we are dealing with a 0-D or
// an n-D loop nest; with or without permutations.
// clang-format off
nPar > 0 ? O(oivs) = I(iivs) :
O() = I();
// clang-format on
}
template <typename IndexedValueType>
static void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) {
assert(fillOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
auto nPar = fillOp.getNumParallelLoops();
assert(nPar == allIvs.size());
auto ivs = SmallVector<Value, 4>(allIvs.begin(), allIvs.begin() + nPar);
IndexedValueType O(fillOp.getOutputBuffer(0));
// Emit the proper scalar assignment, whether we are dealing with a 0-D or
// an n-D loop nest; with or without permutations.
nPar > 0 ? O(ivs) = fillOp.value() : O() = fillOp.value();
}
// Create a padded view into the given `input` tensor using the 'indices'
// to access the tensor. `skipPadding` lists the dimensions for which no padding
// is needed e.g. the non-spatial dimensions for convolutions.
@ -533,8 +491,8 @@ linalgOpToLoopsImpl(Operation *op, OpBuilder &builder,
assert(iterArgs.empty() && "unexpected iterArgs");
allIvs.append(ivs.begin(), ivs.end());
llvm::TypeSwitch<Operation *>(op)
.Case<CopyOp, FillOp, ConvOp, PoolingMaxOp, PoolingMinOp,
PoolingSumOp, IndexedGenericOp, LinalgOp>([&](auto op) {
.Case<ConvOp, PoolingMaxOp, PoolingMinOp, PoolingSumOp,
IndexedGenericOp, LinalgOp>([&](auto op) {
emitScalarImplementation<IndexedValueTy>(allIvs, op);
})
.Default([&](Operation *op) { assert(false && "unexpected op"); });

View File

@ -267,7 +267,7 @@ static Optional<VectorizedLinalgOp> vectorizeAsLinalgGeneric(
llvm::map_range(linalgOp.getShapedOperandTypes(),
[](ShapedType t) { return t.getElementType(); }));
block->addArguments(elementTypes);
linalgOp.getRegionBuilder()(*block);
linalgOp.getRegionBuilder()(*block, /*captures=*/{});
}
Block *block = &region->front();
@ -333,24 +333,26 @@ static bool hasOnlyScalarElementwiseOp(Region &r) {
// Return true if the op is an element-wise linalg op.
static bool isElementwise(Operation *op) {
auto genericOp = dyn_cast<linalg::GenericOp>(op);
if (!genericOp)
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
if (!linalgOp)
return false;
if (genericOp.getNumLoops() != genericOp.getNumParallelLoops())
if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
return false;
// TODO: relax the restrictions on indexing map.
for (unsigned i = 0, e = genericOp.getNumOutputs(); i < e; i++) {
if (!genericOp.getOutputIndexingMap(i).isIdentity())
for (unsigned i = 0, e = linalgOp.getNumOutputs(); i < e; i++) {
if (!linalgOp.getOutputIndexingMap(i).isIdentity())
return false;
}
// Currently bound the input indexing map to minor identity as other
// permutations might require adding transpose ops to convert the vector read
// to the right shape.
for (unsigned i = 0, e = genericOp.getNumInputs(); i < e; i++) {
if (!genericOp.getInputIndexingMap(i).isMinorIdentity())
for (unsigned i = 0, e = linalgOp.getNumInputs(); i < e; i++) {
if (!linalgOp.getInputIndexingMap(i).isMinorIdentity())
return false;
}
return hasOnlyScalarElementwiseOp(genericOp.getRegion());
if (linalgOp->getNumRegions() != 1)
return false;
return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0));
}
static Optional<VectorizedLinalgOp> vectorizeContraction(OpBuilder &builder,
@ -393,9 +395,6 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
for (Type outputTensorType : linalgOp.getOutputTensorTypes())
if (!outputTensorType.cast<ShapedType>().hasStaticShape())
return failure();
if (isa<linalg::FillOp, linalg::CopyOp>(op))
return success();
if (isElementwise(op))
return success();
return success(isaContractionOpInterface(linalgOp));
@ -407,43 +406,12 @@ Optional<VectorizedLinalgOp> mlir::linalg::vectorizeLinalgOp(OpBuilder &builder,
return llvm::None;
edsc::ScopedContext scope(builder, op->getLoc());
// In the case of 0-D memrefs, return null and special case to scalar load or
// store later.
if (auto fillOp = dyn_cast<linalg::FillOp>(op)) {
// Vectorize fill as a vector.broadcast.
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
<< "Rewrite linalg.fill as vector.broadcast: " << *op);
VectorizedLinalgOp res;
if (Value v = buildVectorWrite(builder, fillOp.value(), fillOp.output()))
res.tensorResults.push_back(v);
return res;
}
if (auto copyOp = dyn_cast<linalg::CopyOp>(op)) {
// Vectorize copy as a vector.transfer_read+vector.transfer_write.
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
<< "Rewrite linalg.copy as vector.transfer_read + "
"vector.transfer_write: "
<< *op);
Value vector = buildVectorRead(builder, copyOp.input());
VectorizedLinalgOp res;
if (Value v = buildVectorWrite(builder, vector, copyOp.output()))
res.tensorResults.push_back(v);
return res;
}
if (isElementwise(op)) {
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
<< "Vectorize linalg op as a generic: " << *op);
return vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op));
}
// TODO: as soon as Copy and FillOp. get a region builder, replace all the
// above by:
// if (isa<FillOp, CopyOp>(op) || isElementwise(op)) {
// LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
// << "Vectorize linalg op as a generic: " << *op);
// return vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op));
// }
return vectorizeContraction(builder, cast<LinalgOp>(op));
}

View File

@ -1,5 +1,4 @@
// RUN: mlir-opt -copy-removal -split-input-file %s
//| FileCheck %s
// RUN: mlir-opt -copy-removal -split-input-file %s | FileCheck %s
// All linalg copies except the linalg.copy(%1, %9) must be removed since the
// defining operation of %1 and its DeallocOp have been defined in another block.
@ -256,7 +255,7 @@ func @test_ReuseCopyTargetAsSource(%arg0: memref<2xf32>, %result: memref<2xf32>)
%tmp2 = math.exp %gen2_arg0 : f32
linalg.yield %tmp2 : f32
}
"linalg.copy"(%temp, %result) : (memref<2xf32>, memref<2xf32>) -> ()
linalg.copy(%temp, %result) : memref<2xf32>, memref<2xf32>
dealloc %temp : memref<2xf32>
// CHECK: return
return
@ -292,7 +291,7 @@ func @test_ReuseCopyTargetAsSource(%arg0: memref<2xf32>){
linalg.yield %tmp2 : f32
}
// CHECK: linalg.copy
"linalg.copy"(%temp, %to) : (memref<2xf32>, memref<2xf32>) -> ()
linalg.copy(%temp, %to) : memref<2xf32>, memref<2xf32>
dealloc %temp : memref<2xf32>
return
}
@ -355,7 +354,7 @@ func @check_with_affine_dialect(%arg0: memref<4xf32>, %arg1: memref<4xf32>, %arg
}
// CHECK-NOT: linalg.copy
// CHECK-NOT: dealloc
"linalg.copy"(%0, %arg2) : (memref<4xf32>, memref<4xf32>) -> ()
linalg.copy(%0, %arg2) : memref<4xf32>, memref<4xf32>
dealloc %0 : memref<4xf32>
//CHECK: return
return

View File

@ -23,7 +23,7 @@
// IMPL-NEXT: map2 = simplifyAffineMap(map2);
// IMPL-NEXT: return {{.+}}.getAffineMapArrayAttr({ map0, map1, map2 });
//
// IMPL: void Test1Op::regionBuilder(Block &block) {
// IMPL: void Test1Op::regionBuilder(Block &block, ValueRange captures) {
// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
// IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]);
// IMPL: Value [[e:.*]] = std_addf([[c]], [[d]]);
@ -47,7 +47,7 @@ def test1(A: f32(M, K), B: f32(K)) -> (C: f32(M)) {
// IMPL: AffineMap::get(3, 3, {d2, d1}, context)
// IMPL: AffineMap::get(3, 3, {d0, d1}, context)
//
// IMPL: Test2Op::regionBuilder(Block &block) {
// IMPL: Test2Op::regionBuilder(Block &block, ValueRange captures) {
// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
// IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]);
// IMPL: Value [[e:.*]] = std_addf([[c]], [[d]]);
@ -71,7 +71,7 @@ def test2(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
// IMPL: AffineMap::get(4, 4, {d3, d2}, context)
// IMPL: AffineMap::get(4, 4, {d0, d1, d2}, context)
//
// IMPL: Test3Op::regionBuilder(Block &block) {
// IMPL: Test3Op::regionBuilder(Block &block, ValueRange captures) {
// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
// IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]);
// IMPL: Value [[e:.*]] = std_addf([[c]], [[d]]);

View File

@ -1871,11 +1871,11 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
$_builder.getI32VectorAttr({{
static_cast<int32_t>(inputs.size()),
static_cast<int32_t>(outputs.size())}));
buildNamedStructuredOpRegionAndAttributes<{0}>(
createAndFillStructuredOpRegion<{0}>(
$_builder,
$_state,
TypeRange(inputs),
TypeRange(outputs));
TypeRange(outputs)/*, TODO: support captures*/);
}]>,
OpBuilderDAG<
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
@ -1889,11 +1889,11 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
$_builder.getI32VectorAttr({{
static_cast<int32_t>(inputs.size()),
static_cast<int32_t>(outputs.size())}));
buildNamedStructuredOpRegionAndAttributes<{0}>(
createAndFillStructuredOpRegion<{0}>(
$_builder,
$_state,
TypeRange(inputs),
TypeRange(outputs));
TypeRange(outputs)/*, TODO: support captures*/);
}]>,
OpBuilderDAG<
(ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
@ -1907,7 +1907,9 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
{6}
];
let printer = [{{ return ::printNamedStructuredOp(p, *this); }];
let parser = [{{ return ::parseNamedStructuredOp<{0}>(parser, result); }];
let parser = [{{
return ::parseNamedStructuredOp<{0}>(parser, result/*TODO:, captures*/);
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
@ -1915,8 +1917,8 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
// Auto-generated.
ArrayAttr iterator_types();
ArrayAttr indexing_maps();
static void regionBuilder(Block &block);
static std::function<void(Block &)> getRegionBuilder() {{
static void regionBuilder(Block &block, ValueRange captures);
static std::function<void(Block &, ValueRange)> getRegionBuilder() {{
return regionBuilder;
}
@ -1980,11 +1982,11 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
$_builder.getI32VectorAttr({{
static_cast<int32_t>(inputs.size()),
static_cast<int32_t>(outputs.size())}));
buildNamedStructuredOpRegionAndAttributes<{0}>(
createAndFillStructuredOpRegion<{0}>(
$_builder,
$_state,
TypeRange(inputs),
TypeRange(outputs));
TypeRange(outputs)/*, TODO: support captures*/);
{2}
}]>
)FMT";
@ -2311,7 +2313,7 @@ void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName,
};
const char *regionBuilderFmt = R"FMT(
void {0}::regionBuilder(Block &block) {
void {0}::regionBuilder(Block &block, ValueRange captures) {
using namespace edsc;
using namespace intrinsics;
auto args = block.getArguments();