
Removes the `(batch_)matmul_transpose_{a|b}` variants from OpDSL and replace it with `matmul affine_maps [...]` whenever appropriate. This is in line with the [plan](https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863), and can be done since #104783 merged. See: https://discourse.llvm.org/t/deprecate-batch-matmul-transpose-a-b-linalg-operations/87245 Issues investigated: * pad transform tests that could use `matmul` instead, so change to that. * ArmSME test using transpose actually needed it, so changed to `matmul` + affine maps. Arm tests validated by @banach-space (thanks!!).
6258 lines
246 KiB
C++
6258 lines
246 KiB
C++
//===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements the Linalg operations.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
|
|
#include "mlir/AsmParser/AsmParser.h"
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Arith/Utils/Utils.h"
|
|
#include "mlir/Dialect/Complex/IR/Complex.h"
|
|
#include "mlir/Dialect/Math/IR/Math.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
|
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
|
|
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
|
#include "mlir/IR/AffineMap.h"
|
|
#include "mlir/IR/Attributes.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
|
#include "mlir/IR/OpImplementation.h"
|
|
#include "mlir/IR/OperationSupport.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/IR/TypeUtilities.h"
|
|
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
|
|
|
#include "llvm/ADT/DenseMap.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/ADT/SetOperations.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
#include "llvm/ADT/StringSet.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
#include "llvm/Support/FormatVariadic.h"
|
|
#include "llvm/Support/InterleavedRange.h"
|
|
#include "llvm/Support/LogicalResult.h"
|
|
#include "llvm/Support/MathExtras.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
#include <cassert>
|
|
#include <optional>
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::linalg;
|
|
|
|
/// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`.
|
|
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v,
|
|
int64_t dim) {
|
|
auto type = cast<ShapedType>(v.getType());
|
|
if (!type.isDynamicDim(dim))
|
|
return builder.getIndexAttr(type.getDimSize(dim));
|
|
|
|
return getAsOpFoldResult(
|
|
TypeSwitch<Type, Value>(v.getType())
|
|
.Case<RankedTensorType>([&](RankedTensorType t) -> Value {
|
|
return tensor::DimOp::create(builder, loc, v, dim);
|
|
})
|
|
.Case<MemRefType>([&](MemRefType t) -> Value {
|
|
return memref::DimOp::create(builder, loc, v, dim);
|
|
}));
|
|
}
|
|
|
|
/// Returns a memref.subview or a tensor.extract_slice based on the type of the
|
|
/// `source`.
|
|
static Operation *getSlice(OpBuilder &b, Location loc, Value source,
|
|
ArrayRef<OpFoldResult> offsets,
|
|
ArrayRef<OpFoldResult> sizes,
|
|
ArrayRef<OpFoldResult> strides) {
|
|
return TypeSwitch<Type, Operation *>(source.getType())
|
|
.Case<RankedTensorType>([&](RankedTensorType t) -> Operation * {
|
|
return tensor::ExtractSliceOp::create(b, loc, source, offsets, sizes,
|
|
strides);
|
|
})
|
|
.Case<MemRefType>([&](MemRefType type) -> Operation * {
|
|
return memref::SubViewOp::create(b, loc, source, offsets, sizes,
|
|
strides);
|
|
})
|
|
.Default([&](Type t) -> Operation * { return nullptr; });
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Helper functions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
Value linalg::createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
|
|
int64_t dim) {
|
|
if (llvm::isa<UnrankedMemRefType, MemRefType>(source.getType()))
|
|
return b.createOrFold<memref::DimOp>(loc, source, dim);
|
|
if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.getType()))
|
|
return b.createOrFold<tensor::DimOp>(loc, source, dim);
|
|
llvm_unreachable("Expected MemRefType or TensorType");
|
|
}
|
|
|
|
OpFoldResult linalg::createFoldedDimOp(OpBuilder &b, Location loc, Value source,
|
|
int64_t dim) {
|
|
auto shapedType = llvm::cast<ShapedType>(source.getType());
|
|
if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
|
|
return createOrFoldDimOp(b, loc, source, dim);
|
|
return b.getIndexAttr(shapedType.getDimSize(dim));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Support for named Linalg ops defined in ods-gen.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
using RegionBuilderFn = llvm::function_ref<void(
|
|
ImplicitLocOpBuilder &, Block &, ArrayRef<NamedAttribute>,
|
|
function_ref<InFlightDiagnostic()>)>;
|
|
|
|
/// Fills the region of a structured operation using the provided
|
|
/// `regionBuilder`. The method is used by both named structured ops created by
|
|
/// ods-gen and by manually defined C++ ops. It is called by both builders and
|
|
/// parsers and creates a block with arguments corresponding to the elemental
|
|
/// types of `inputTypes` and `outputTypes`.
|
|
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion,
|
|
TypeRange inputTypes, TypeRange outputTypes,
|
|
ArrayRef<NamedAttribute> attrs,
|
|
function_ref<InFlightDiagnostic()> emitError,
|
|
RegionBuilderFn regionBuilder) {
|
|
SmallVector<Type, 8> argTypes;
|
|
SmallVector<Location, 8> argLocs;
|
|
for (auto containers : {inputTypes, outputTypes}) {
|
|
for (auto t : containers) {
|
|
argTypes.push_back(
|
|
isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
|
|
|
|
// TODO: Pass in a proper location here.
|
|
argLocs.push_back(opBuilder.getUnknownLoc());
|
|
}
|
|
}
|
|
|
|
// RAII.
|
|
OpBuilder::InsertionGuard guard(opBuilder);
|
|
Block *body =
|
|
opBuilder.createBlock(®ion, /*insertPt=*/{}, argTypes, argLocs);
|
|
|
|
opBuilder.setInsertionPointToStart(body);
|
|
ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
|
|
regionBuilder(b, *body, attrs, emitError);
|
|
|
|
// indexing_maps is an auto-generated method.
|
|
|
|
// iterator_types is an auto-generated method.
|
|
}
|
|
|
|
/// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
|
|
/// The result types are derived automatically if `resultTensorTypes` is none.
|
|
/// The body of the operation is filled using `regionBuilder`. All ods-gen
|
|
/// created structured operations use the method to implement their builders.
|
|
static void buildStructuredOp(OpBuilder &b, OperationState &state,
|
|
std::optional<TypeRange> resultTensorTypes,
|
|
ValueRange inputs, ValueRange outputs,
|
|
ArrayRef<NamedAttribute> attributes,
|
|
RegionBuilderFn regionBuilder) {
|
|
// Derive the result types if needed.
|
|
SmallVector<Type> derivedResultTypes =
|
|
resultTensorTypes.value_or(TypeRange());
|
|
if (!resultTensorTypes)
|
|
copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
|
|
llvm::IsaPred<RankedTensorType>);
|
|
|
|
state.addOperands(inputs);
|
|
state.addOperands(outputs);
|
|
state.addTypes(derivedResultTypes);
|
|
|
|
state.addAttributes(attributes);
|
|
state.addAttribute(
|
|
"operandSegmentSizes",
|
|
b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
|
|
static_cast<int32_t>(outputs.size())}));
|
|
|
|
// Create and fill the region of the structured operation.
|
|
Region ®ion = *state.addRegion();
|
|
fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
|
|
state.attributes.getAttrs(), /*emitError=*/{},
|
|
regionBuilder);
|
|
}
|
|
|
|
static void buildMatmulOp(OpBuilder &b, OperationState &state,
|
|
std::optional<TypeRange> resultTensorTypes,
|
|
ValueRange inputs, ValueRange outputs,
|
|
ArrayRef<NamedAttribute> attributes,
|
|
RegionBuilderFn regionBuilder,
|
|
ArrayRef<AffineMap> indexingMaps) {
|
|
// Initialize indexingMaps attribute, for MatmulOp.
|
|
SmallVector<Attribute, 3> indexingMapsAttrVal;
|
|
indexingMapsAttrVal =
|
|
llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
|
|
return AffineMapAttr::get(map);
|
|
});
|
|
state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
|
|
return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
|
|
attributes, regionBuilder);
|
|
}
|
|
|
|
static void buildBatchMatmulOp(OpBuilder &b, OperationState &state,
|
|
std::optional<TypeRange> resultTensorTypes,
|
|
ValueRange inputs, ValueRange outputs,
|
|
ArrayRef<NamedAttribute> attributes,
|
|
RegionBuilderFn regionBuilder,
|
|
ArrayRef<AffineMap> indexingMaps) {
|
|
// Initialize indexingMaps attribute, for BatchMatmulOp.
|
|
SmallVector<Attribute, 4> indexingMapsAttrVal;
|
|
indexingMapsAttrVal =
|
|
llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
|
|
return AffineMapAttr::get(map);
|
|
});
|
|
state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
|
|
return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
|
|
attributes, regionBuilder);
|
|
}
|
|
|
|
static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state,
|
|
std::optional<TypeRange> resultTensorTypes,
|
|
ValueRange inputs, ValueRange outputs,
|
|
ArrayRef<NamedAttribute> attributes,
|
|
RegionBuilderFn regionBuilder,
|
|
ArrayRef<AffineMap> indexingMaps) {
|
|
// Initialize indexingMaps attribute, for BatchReduceMatmulOp.
|
|
SmallVector<Attribute, 4> indexingMapsAttrVal;
|
|
indexingMapsAttrVal =
|
|
llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
|
|
return AffineMapAttr::get(map);
|
|
});
|
|
state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
|
|
return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
|
|
attributes, regionBuilder);
|
|
}
|
|
|
|
/// Common parsing used for both named structured ops created by ods-gen and by
|
|
/// manually defined C++ ops. Does not handle regions.
|
|
static ParseResult
|
|
parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
|
|
SmallVectorImpl<Type> &inputTypes,
|
|
SmallVectorImpl<Type> &outputTypes,
|
|
bool addOperandSegmentSizes = true) {
|
|
SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands,
|
|
outputsOperands;
|
|
|
|
if (succeeded(parser.parseOptionalLess())) {
|
|
if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater())
|
|
return failure();
|
|
}
|
|
attrsLoc = parser.getCurrentLocation();
|
|
if (parser.parseOptionalAttrDict(result.attributes))
|
|
return failure();
|
|
|
|
if (succeeded(parser.parseOptionalKeyword("ins"))) {
|
|
if (parser.parseLParen())
|
|
return failure();
|
|
|
|
inputsOperandsLoc = parser.getCurrentLocation();
|
|
if (parser.parseOperandList(inputsOperands) ||
|
|
parser.parseColonTypeList(inputTypes) || parser.parseRParen())
|
|
return failure();
|
|
}
|
|
|
|
if (succeeded(parser.parseOptionalKeyword("outs"))) {
|
|
outputsOperandsLoc = parser.getCurrentLocation();
|
|
if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
|
|
parser.parseColonTypeList(outputTypes) || parser.parseRParen())
|
|
return failure();
|
|
}
|
|
|
|
if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
|
|
result.operands) ||
|
|
parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
|
|
result.operands))
|
|
return failure();
|
|
|
|
if (addOperandSegmentSizes) {
|
|
// This is a bit complex because we're trying to be backward compatible with
|
|
// operation syntax that mix the inherent attributes and the discardable
|
|
// ones in the same dictionary. If the properties are used, we append the
|
|
// operandSegmentSizes there directly. Otherwise we append it to the
|
|
// discardable attributes dictionary where it is handled by the generic
|
|
// Operation::create(...) method.
|
|
if (result.propertiesAttr) {
|
|
NamedAttrList attrs = llvm::cast<DictionaryAttr>(result.propertiesAttr);
|
|
attrs.append("operandSegmentSizes",
|
|
parser.getBuilder().getDenseI32ArrayAttr(
|
|
{static_cast<int32_t>(inputsOperands.size()),
|
|
static_cast<int32_t>(outputsOperands.size())}));
|
|
result.propertiesAttr = attrs.getDictionary(parser.getContext());
|
|
} else {
|
|
result.addAttribute("operandSegmentSizes",
|
|
parser.getBuilder().getDenseI32ArrayAttr(
|
|
{static_cast<int32_t>(inputsOperands.size()),
|
|
static_cast<int32_t>(outputsOperands.size())}));
|
|
}
|
|
}
|
|
if (!result.propertiesAttr) {
|
|
std::optional<RegisteredOperationName> info =
|
|
result.name.getRegisteredInfo();
|
|
if (info) {
|
|
if (failed(info->verifyInherentAttrs(result.attributes, [&]() {
|
|
return parser.emitError(attrsLoc)
|
|
<< "'" << result.name.getStringRef() << "' op ";
|
|
})))
|
|
return failure();
|
|
}
|
|
}
|
|
return success();
|
|
}
|
|
|
|
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs,
|
|
ValueRange outputs) {
|
|
if (!inputs.empty())
|
|
p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
|
|
if (!outputs.empty())
|
|
p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Specific parsing and printing for named structured ops created by ods-gen.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static ParseResult parseNamedStructuredOpRegion(
|
|
OpAsmParser &parser, Region ®ion, unsigned numRegionArgs,
|
|
TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
|
|
RegionBuilderFn regionBuilder, SMLoc loc) {
|
|
if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
|
|
return parser.emitError(
|
|
parser.getCurrentLocation(),
|
|
llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
|
|
"region expects {0} args, got {1}",
|
|
numRegionArgs, inputTypes.size() + outputTypes.size()));
|
|
}
|
|
|
|
OpBuilder opBuilder(parser.getContext());
|
|
ParseResult result = success();
|
|
fillStructuredOpRegion(
|
|
opBuilder, region, inputTypes, outputTypes, attrs,
|
|
[&]() {
|
|
result = failure();
|
|
return parser.emitError(loc);
|
|
},
|
|
regionBuilder);
|
|
return result;
|
|
}
|
|
|
|
static ParseResult
|
|
parseNamedStructuredOpResults(OpAsmParser &parser,
|
|
SmallVectorImpl<Type> &resultTypes) {
|
|
if (parser.parseOptionalArrowTypeList(resultTypes))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
|
|
OperationState &result,
|
|
unsigned numRegionArgs,
|
|
RegionBuilderFn regionBuilder) {
|
|
// TODO: Enable when ods-gen supports captures.
|
|
SmallVector<Type, 1> inputTypes, outputTypes;
|
|
SMLoc loc = parser.getCurrentLocation();
|
|
if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
|
|
return failure();
|
|
|
|
// Parse optional attributes.
|
|
if (parser.parseOptionalAttrDict(result.attributes))
|
|
return failure();
|
|
|
|
// TODO: consider merging results parsing into region parsing.
|
|
// Need to wait for declarative assembly resolution to decide.
|
|
SmallVector<Type, 1> outputTensorsTypes;
|
|
if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
|
|
return failure();
|
|
result.addTypes(outputTensorsTypes);
|
|
|
|
std::unique_ptr<Region> region = std::make_unique<Region>();
|
|
if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
|
|
outputTypes, result.attributes.getAttrs(),
|
|
regionBuilder, loc))
|
|
return failure();
|
|
result.addRegion(std::move(region));
|
|
|
|
return success();
|
|
}
|
|
|
|
static void printNamedStructuredOpResults(OpAsmPrinter &p,
|
|
TypeRange resultTypes) {
|
|
if (resultTypes.empty())
|
|
return;
|
|
p.printOptionalArrowTypeList(resultTypes);
|
|
}
|
|
|
|
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op,
|
|
ValueRange inputs, ValueRange outputs,
|
|
ArrayRef<StringRef> elidedAttrs = {}) {
|
|
p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
|
|
|
|
// Printing is shared with generic ops, except for the region and
|
|
// attributes.
|
|
printCommonStructuredOpParts(p, inputs, outputs);
|
|
|
|
// Results printing.
|
|
printNamedStructuredOpResults(p, op->getResultTypes());
|
|
|
|
// Region is elided.
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Region builder helper.
|
|
// TODO: Move this to a utility library.
|
|
// The public methods on this class are referenced directly from generated code.
|
|
// Helper build the unary, binary, and type conversion functions defined by the
|
|
// DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
|
|
// class.
|
|
//
|
|
// Implementations of the math functions must be polymorphic over numeric types,
|
|
// internally performing necessary casts. If the function application makes no
|
|
// sense, then the only recourse is to assert and return nullptr. This can be
|
|
// extended later if it becomes possible to fail construction of the region. The
|
|
// invariant should be enforced at a higher level.
|
|
//
|
|
// TODO: These helpers are currently type polymorphic over the class of integer
|
|
// and floating point types, but they will not internally cast within bit
|
|
// widths of a class (mixed precision such as i8->i32) or across classes
|
|
// (i.e. mixed float and integer). Many such combinations are ambiguous or need
|
|
// to be handled with care and work is being considered to extend the op
|
|
// language to make such cases explicit. In the mean-time, violating this will
|
|
// fail verification, which is deemed acceptable.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
|
|
class RegionBuilderHelper {
|
|
public:
|
|
RegionBuilderHelper(OpBuilder &builder, Block &block)
|
|
: builder(builder), block(block) {}
|
|
|
|
// Build the unary functions defined by OpDSL.
|
|
Value buildUnaryFn(UnaryFn unaryFn, Value arg,
|
|
function_ref<InFlightDiagnostic()> emitError = {}) {
|
|
if (!isFloatingPoint(arg)) {
|
|
if (emitError) {
|
|
emitError() << "unsupported non numeric type";
|
|
return nullptr;
|
|
}
|
|
llvm_unreachable("unsupported non numeric type");
|
|
}
|
|
OpBuilder::InsertionGuard g(builder);
|
|
builder.setInsertionPointToEnd(&block);
|
|
switch (unaryFn) {
|
|
case UnaryFn::exp:
|
|
return math::ExpOp::create(builder, arg.getLoc(), arg);
|
|
case UnaryFn::log:
|
|
return math::LogOp::create(builder, arg.getLoc(), arg);
|
|
case UnaryFn::abs:
|
|
return math::AbsFOp::create(builder, arg.getLoc(), arg);
|
|
case UnaryFn::ceil:
|
|
return math::CeilOp::create(builder, arg.getLoc(), arg);
|
|
case UnaryFn::floor:
|
|
return math::FloorOp::create(builder, arg.getLoc(), arg);
|
|
case UnaryFn::negf:
|
|
return arith::NegFOp::create(builder, arg.getLoc(), arg);
|
|
case UnaryFn::reciprocal: {
|
|
Attribute oneAttr = builder.getOneAttr(arg.getType());
|
|
auto one = arith::ConstantOp::create(builder, arg.getLoc(),
|
|
::cast<TypedAttr>(oneAttr));
|
|
return arith::DivFOp::create(builder, arg.getLoc(), one, arg);
|
|
}
|
|
case UnaryFn::round:
|
|
return math::RoundOp::create(builder, arg.getLoc(), arg);
|
|
case UnaryFn::sqrt:
|
|
return math::SqrtOp::create(builder, arg.getLoc(), arg);
|
|
case UnaryFn::rsqrt:
|
|
return math::RsqrtOp::create(builder, arg.getLoc(), arg);
|
|
case UnaryFn::square:
|
|
return arith::MulFOp::create(builder, arg.getLoc(), arg, arg);
|
|
case UnaryFn::tanh:
|
|
return math::TanhOp::create(builder, arg.getLoc(), arg);
|
|
case UnaryFn::erf:
|
|
return math::ErfOp::create(builder, arg.getLoc(), arg);
|
|
}
|
|
if (emitError) {
|
|
emitError() << "unsupported unary function";
|
|
return nullptr;
|
|
}
|
|
llvm_unreachable("unsupported unary function");
|
|
}
|
|
|
|
// Build the binary functions defined by OpDSL.
|
|
// If emitError is provided, an error will be emitted if the operation is not
|
|
// supported and a nullptr will be returned, otherwise an assertion will be
|
|
// raised.
|
|
Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1,
|
|
function_ref<InFlightDiagnostic()> emitError = {}) {
|
|
bool allComplex = isComplex(arg0) && isComplex(arg1);
|
|
bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
|
|
bool allInteger = isInteger(arg0) && isInteger(arg1);
|
|
bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
|
|
arg1.getType().getIntOrFloatBitWidth() == 1;
|
|
if (!allComplex && !allFloatingPoint && !allInteger) {
|
|
if (emitError) {
|
|
emitError()
|
|
<< "Cannot build binary Linalg operation: expects allComplex, "
|
|
"allFloatingPoint, or allInteger, got "
|
|
<< arg0.getType() << " and " << arg1.getType();
|
|
return nullptr;
|
|
}
|
|
llvm_unreachable("unsupported non numeric type");
|
|
}
|
|
OpBuilder::InsertionGuard g(builder);
|
|
builder.setInsertionPointToEnd(&block);
|
|
switch (binaryFn) {
|
|
case BinaryFn::add:
|
|
if (allComplex)
|
|
return complex::AddOp::create(builder, arg0.getLoc(), arg0, arg1);
|
|
if (allFloatingPoint)
|
|
return arith::AddFOp::create(builder, arg0.getLoc(), arg0, arg1);
|
|
if (allBool)
|
|
return arith::OrIOp::create(builder, arg0.getLoc(), arg0, arg1);
|
|
return arith::AddIOp::create(builder, arg0.getLoc(), arg0, arg1);
|
|
case BinaryFn::sub:
|
|
if (allComplex)
|
|
return complex::SubOp::create(builder, arg0.getLoc(), arg0, arg1);
|
|
if (allFloatingPoint)
|
|
return arith::SubFOp::create(builder, arg0.getLoc(), arg0, arg1);
|
|
if (allBool) {
|
|
if (emitError) {
|
|
emitError() << "unsupported operation: sub with bools";
|
|
return nullptr;
|
|
}
|
|
llvm_unreachable("unsupported operation: sub with bools");
|
|
}
|
|
return arith::SubIOp::create(builder, arg0.getLoc(), arg0, arg1);
|
|
case BinaryFn::mul:
|
|
if (allComplex)
|
|
return complex::MulOp::create(builder, arg0.getLoc(), arg0, arg1);
|
|
if (allFloatingPoint)
|
|
return arith::MulFOp::create(builder, arg0.getLoc(), arg0, arg1);
|
|
if (allBool)
|
|
return arith::AndIOp::create(builder, arg0.getLoc(), arg0, arg1);
|
|
return arith::MulIOp::create(builder, arg0.getLoc(), arg0, arg1);
|
|
case BinaryFn::div:
|
|
if (allComplex)
|
|
return complex::DivOp::create(builder, arg0.getLoc(), arg0, arg1);
|
|
if (allFloatingPoint)
|
|
return arith::DivFOp::create(builder, arg0.getLoc(), arg0, arg1);
|
|
if (allBool) {
|
|
if (emitError) {
|
|
emitError() << "unsupported operation: div with bools";
|
|
return nullptr;
|
|
}
|
|
llvm_unreachable("unsupported operation: div with bools");
|
|
}
|
|
return arith::DivSIOp::create(builder, arg0.getLoc(), arg0, arg1);
|
|
case BinaryFn::div_unsigned:
|
|
if (!allInteger || allBool) {
|
|
if (emitError) {
|
|
emitError() << "unsupported operation: unsigned div not on uint";
|
|
return nullptr;
|
|
}
|
|
llvm_unreachable("unsupported operation: unsigned div not on uint");
|
|
}
|
|
return arith::DivUIOp::create(builder, arg0.getLoc(), arg0, arg1);
|
|
case BinaryFn::max_signed:
|
|
assert(!allComplex);
|
|
if (allFloatingPoint)
|
|
return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1);
|
|
return arith::MaxSIOp::create(builder, arg0.getLoc(), arg0, arg1);
|
|
case BinaryFn::min_signed:
|
|
assert(!allComplex);
|
|
if (allFloatingPoint)
|
|
return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1);
|
|
return arith::MinSIOp::create(builder, arg0.getLoc(), arg0, arg1);
|
|
case BinaryFn::max_unsigned:
|
|
assert(!allComplex);
|
|
if (allFloatingPoint)
|
|
return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1);
|
|
return arith::MaxUIOp::create(builder, arg0.getLoc(), arg0, arg1);
|
|
case BinaryFn::min_unsigned:
|
|
assert(!allComplex);
|
|
if (allFloatingPoint)
|
|
return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1);
|
|
return arith::MinUIOp::create(builder, arg0.getLoc(), arg0, arg1);
|
|
case BinaryFn::powf:
|
|
assert(allFloatingPoint);
|
|
return math::PowFOp::create(builder, arg0.getLoc(), arg0, arg1);
|
|
}
|
|
if (emitError) {
|
|
emitError() << "unsupported binary function";
|
|
return nullptr;
|
|
}
|
|
llvm_unreachable("unsupported binary function");
|
|
}
|
|
|
|
// Build the ternary functions defined by OpDSL.
|
|
Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2,
|
|
function_ref<InFlightDiagnostic()> emitError = {}) {
|
|
bool headBool =
|
|
isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
|
|
bool tailFloatingPoint =
|
|
isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
|
|
bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2);
|
|
OpBuilder::InsertionGuard g(builder);
|
|
builder.setInsertionPointToEnd(&block);
|
|
switch (ternaryFn) {
|
|
case TernaryFn::select:
|
|
if (!headBool && !(tailFloatingPoint || tailInteger))
|
|
llvm_unreachable("unsupported non numeric type");
|
|
return arith::SelectOp::create(builder, arg0.getLoc(), arg0, arg1, arg2);
|
|
}
|
|
if (emitError) {
|
|
emitError() << "unsupported ternary function";
|
|
return nullptr;
|
|
}
|
|
llvm_unreachable("unsupported ternary function");
|
|
}
|
|
|
|
// Build the type functions defined by OpDSL.
|
|
Value buildTypeFn(TypeFn typeFn, Type toType, Value operand,
|
|
function_ref<InFlightDiagnostic()> emitError = {}) {
|
|
switch (typeFn) {
|
|
case TypeFn::cast_signed:
|
|
return cast(toType, operand, false);
|
|
case TypeFn::cast_unsigned:
|
|
return cast(toType, operand, true);
|
|
}
|
|
if (emitError) {
|
|
emitError() << "unsupported type conversion function";
|
|
return nullptr;
|
|
}
|
|
llvm_unreachable("unsupported type conversion function");
|
|
}
|
|
|
|
void yieldOutputs(ValueRange values) {
|
|
OpBuilder::InsertionGuard g(builder);
|
|
builder.setInsertionPointToEnd(&block);
|
|
Location loc = builder.getUnknownLoc();
|
|
YieldOp::create(builder, loc, values);
|
|
}
|
|
|
|
Value constant(const std::string &value) {
|
|
OpBuilder::InsertionGuard g(builder);
|
|
builder.setInsertionPointToEnd(&block);
|
|
Location loc = builder.getUnknownLoc();
|
|
Attribute valueAttr = parseAttribute(value, builder.getContext());
|
|
return arith::ConstantOp::create(builder, loc,
|
|
::cast<TypedAttr>(valueAttr));
|
|
}
|
|
|
|
Value index(int64_t dim) {
|
|
OpBuilder::InsertionGuard g(builder);
|
|
builder.setInsertionPointToEnd(&block);
|
|
return IndexOp::create(builder, builder.getUnknownLoc(), dim);
|
|
}
|
|
|
|
Type getIntegerType(unsigned width) {
|
|
return IntegerType::get(builder.getContext(), width);
|
|
}
|
|
|
|
Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
|
|
Type getFloat64Type() { return Float64Type::get(builder.getContext()); }
|
|
|
|
private:
|
|
// Generates operations to cast the given operand to a specified type.
|
|
// If the cast cannot be performed, a warning will be issued and the
|
|
// operand returned as-is (which will presumably yield a verification
|
|
// issue downstream).
|
|
Value cast(Type toType, Value operand, bool isUnsignedCast) {
|
|
OpBuilder::InsertionGuard g(builder);
|
|
builder.setInsertionPointToEnd(&block);
|
|
auto loc = operand.getLoc();
|
|
if (isa<UnknownLoc>(loc)) {
|
|
if (operand.getDefiningOp())
|
|
loc = operand.getDefiningOp()->getLoc();
|
|
else if (operand.getParentBlock() &&
|
|
operand.getParentBlock()->getParentOp())
|
|
loc = operand.getParentBlock()->getParentOp()->getLoc();
|
|
}
|
|
return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
|
|
}
|
|
|
|
bool isComplex(Value value) {
|
|
return llvm::isa<ComplexType>(value.getType());
|
|
}
|
|
bool isFloatingPoint(Value value) {
|
|
return llvm::isa<FloatType>(value.getType());
|
|
}
|
|
bool isInteger(Value value) {
|
|
return llvm::isa<IntegerType>(value.getType());
|
|
}
|
|
|
|
OpBuilder &builder;
|
|
Block █
|
|
};
|
|
|
|
} // namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CopyOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
|
|
struct EraseSelfCopy : OpRewritePattern<CopyOp> {
|
|
using OpRewritePattern<CopyOp>::OpRewritePattern;
|
|
LogicalResult matchAndRewrite(CopyOp copyOp,
|
|
PatternRewriter &rewriter) const override {
|
|
if (copyOp.getInputs() != copyOp.getOutputs())
|
|
return rewriter.notifyMatchFailure(copyOp, "not a self copy");
|
|
if (copyOp.hasPureBufferSemantics())
|
|
rewriter.eraseOp(copyOp);
|
|
else
|
|
rewriter.replaceOp(copyOp, copyOp.getInputs());
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.add<EraseSelfCopy>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// FillOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
|
|
/// Fold linalg.fill -> tensor.expand/collapse_shape chain.
|
|
///
|
|
/// For such op chains, we can create new linalg.fill ops with the result
|
|
/// type of the tensor.expand/collapse_shape op.
|
|
template <typename TensorReshapeOp>
|
|
struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
|
|
using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
|
|
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
|
|
if (!oldFill)
|
|
return failure();
|
|
|
|
Location loc = oldFill.getLoc();
|
|
TensorReshapeOp newInit;
|
|
if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
|
|
|
|
newInit = TensorReshapeOp::create(
|
|
rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
|
|
reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
|
|
reshapeOp.getStaticOutputShape());
|
|
} else {
|
|
newInit = TensorReshapeOp::create(
|
|
rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
|
|
reshapeOp.getReassociation());
|
|
}
|
|
rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
|
|
ValueRange{newInit});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
|
|
/// filling value are the same.
|
|
struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tensor::PadOp padOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
|
|
if (!fillOp)
|
|
return failure();
|
|
|
|
// We can only fold if the padding value is the same as the original
|
|
// filling value.
|
|
Value padValue = padOp.getConstantPaddingValue();
|
|
if (!padValue || fillOp.value() != padValue)
|
|
return failure();
|
|
|
|
ReifiedRankedShapedTypeDims reifiedShape;
|
|
if (failed(reifyResultShapes(rewriter, padOp, reifiedShape)))
|
|
return rewriter.notifyMatchFailure(
|
|
padOp, "failed to reify tensor.pad op result shape");
|
|
|
|
auto emptyTensor =
|
|
tensor::EmptyOp::create(rewriter, padOp.getLoc(), reifiedShape.front(),
|
|
padOp.getResultType().getElementType());
|
|
Value replacement =
|
|
FillOp::create(rewriter, fillOp.getLoc(), ValueRange{padValue},
|
|
ValueRange{emptyTensor})
|
|
.getResult(0);
|
|
if (replacement.getType() != padOp.getResultType()) {
|
|
replacement = tensor::CastOp::create(rewriter, fillOp.getLoc(),
|
|
padOp.getResultType(), replacement);
|
|
}
|
|
rewriter.replaceOp(padOp, replacement);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
|
|
/// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
|
|
/// filling value are the same.
|
|
struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
|
|
if (!srcPadOp)
|
|
return failure();
|
|
|
|
if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
|
|
return failure();
|
|
|
|
// Walk back the tensor.insert_slice chain and find the first destination
|
|
// value at the start of the chain.
|
|
Value firstDest = insertOp.getDest();
|
|
while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
|
|
if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
|
|
return failure();
|
|
|
|
// Make sure the range of values accessed are disjoint. Without this, we
|
|
// cannot fold tensor.pad away.
|
|
bool disjoint = false;
|
|
for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
|
|
// If the dimension has dynamic offset/size, we cannot guarantee
|
|
// disjoint. So just skip it.
|
|
if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
|
|
insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
|
|
prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
|
|
continue;
|
|
|
|
// Get the range start and end, inclusively for both.
|
|
int64_t prevStart = prevOp.getStaticOffset(i);
|
|
int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
|
|
prevOp.getStaticStride(i);
|
|
int64_t nextStart = insertOp.getStaticOffset(i);
|
|
int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
|
|
insertOp.getStaticStride(i);
|
|
if (prevEnd < nextStart || nextEnd < prevStart) {
|
|
disjoint = true;
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (!disjoint)
|
|
break;
|
|
firstDest = prevOp.getDest();
|
|
}
|
|
|
|
// Check whether the first destination is a fill op. For overlapped cases,
|
|
// this also cannot be true.
|
|
auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
|
|
if (!dstFillOp)
|
|
return failure();
|
|
|
|
// We can only fold if the padding value is the same as the original
|
|
// filling value.
|
|
Value padValue = srcPadOp.getConstantPaddingValue();
|
|
if (!padValue || dstFillOp.value() != padValue)
|
|
return failure();
|
|
|
|
SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
|
|
SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
|
|
|
|
Location loc = insertOp.getLoc();
|
|
MLIRContext *context = getContext();
|
|
|
|
AffineExpr sym0, sym1;
|
|
bindSymbols(context, sym0, sym1);
|
|
auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context);
|
|
|
|
// Calculate the new offsets for the insert. It should be the old offsets
|
|
// plus low padding sizes.
|
|
SmallVector<OpFoldResult, 4> newOffsets;
|
|
for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
|
|
newOffsets.push_back(affine::makeComposedFoldedAffineApply(
|
|
rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
|
|
}
|
|
|
|
RankedTensorType srcPadType = srcPadOp.getSourceType();
|
|
SmallVector<OpFoldResult, 4> newSizes;
|
|
for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
|
|
if (srcPadType.isDynamicDim(i)) {
|
|
newSizes.push_back(
|
|
tensor::DimOp::create(rewriter, loc, srcPadOp.getSource(), i)
|
|
.getResult());
|
|
} else {
|
|
newSizes.push_back(rewriter.getIndexAttr(srcPadType.getDimSize(i)));
|
|
}
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
|
|
insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
|
|
newSizes, insertOp.getMixedStrides());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Fold tensor.extract(linalg.fill(<input>)) into <input>
|
|
struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
|
|
public:
|
|
using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
|
|
PatternRewriter &rewriter) const override {
|
|
// See if tensor input of tensor.extract op is the result of a linalg.fill
|
|
// op.
|
|
auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
|
|
if (!fillOp)
|
|
return failure();
|
|
|
|
// Get scalar input operand of linalg.fill op.
|
|
Value extractedScalar = fillOp.getInputs()[0];
|
|
|
|
// Replace tensor.extract op with scalar value used to fill the tensor.
|
|
rewriter.replaceOp(extractOp, extractedScalar);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Folds pack(fill) into a single fill op if
|
|
/// 1. The pack op does not have padding value, or
|
|
/// 2. The filled value and padding value are the same.
|
|
static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
|
|
linalg::PackOp packOp) {
|
|
auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
|
|
if (!fillOp)
|
|
return failure();
|
|
|
|
if (auto paddingValue = packOp.getPaddingValue())
|
|
if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
|
|
return failure();
|
|
|
|
Value packOpDest = packOp.getDest();
|
|
if (!packOpDest.hasOneUse())
|
|
return failure();
|
|
|
|
return linalg::FillOp::create(rewriter, packOp.getLoc(), fillOp.getInputs(),
|
|
packOp.getDest());
|
|
}
|
|
|
|
/// Wrapper pattern that applies foldFillPackIntoFillOp method.
|
|
struct FoldFillWithPack : public OpRewritePattern<linalg::PackOp> {
|
|
public:
|
|
FoldFillWithPack(MLIRContext *context)
|
|
: OpRewritePattern<linalg::PackOp>(context) {}
|
|
|
|
LogicalResult matchAndRewrite(linalg::PackOp packOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
|
|
if (failed(fillOp))
|
|
return failure();
|
|
rewriter.replaceOp(packOp, fillOp.value().result());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Fold fill with copy.
|
|
struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
|
|
using OpRewritePattern<linalg::CopyOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
|
|
PatternRewriter &rewriter) const override {
|
|
if (auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
|
|
rewriter.replaceOpWithNewOp<FillOp>(copyOp, copyOp.getResultTypes(),
|
|
fillOp.getInputs(),
|
|
copyOp.getOutputs());
|
|
return success();
|
|
}
|
|
if (auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
|
|
rewriter.replaceOpWithNewOp<linalg::CopyOp>(copyOp, copyOp.getInputs(),
|
|
fillOp.getOutputs());
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
};
|
|
|
|
/// Fold fill with transpose.
|
|
struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
|
|
using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
|
|
PatternRewriter &rewriter) const override {
|
|
if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
|
|
rewriter.replaceOpWithNewOp<FillOp>(
|
|
transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
|
|
transposeOp.getDpsInitOperand(0)->get());
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
};
|
|
|
|
/// Fold a concat with all elements being fills of the same value
|
|
/// into a fill of the concat result shape.
|
|
struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto concatOperands = concatOp.getInputs();
|
|
if (concatOperands.empty()) {
|
|
return failure();
|
|
}
|
|
|
|
auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
|
|
if (!firstFillOp) {
|
|
return failure();
|
|
}
|
|
// Prefetch the fill value.
|
|
OpFoldResult firstFillVal =
|
|
getAsOpFoldResult(firstFillOp.getDpsInputOperand(0)->get());
|
|
// Collect all the outs values for the fill operations.
|
|
SmallVector<Value> allOuts;
|
|
allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
|
|
|
|
auto isDefinedByCompatibleFillOp = [&](Value v) -> bool {
|
|
auto fillOp = v.getDefiningOp<linalg::FillOp>();
|
|
if (!fillOp) {
|
|
return false;
|
|
}
|
|
|
|
OpFoldResult fillVal =
|
|
getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get());
|
|
if (fillVal != firstFillVal)
|
|
return false;
|
|
|
|
allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
|
|
return true;
|
|
};
|
|
if (!llvm::all_of(concatOperands.drop_front(),
|
|
isDefinedByCompatibleFillOp)) {
|
|
return rewriter.notifyMatchFailure(
|
|
concatOp, "not all operands are defined by a compatible fill op");
|
|
}
|
|
|
|
Value outsConcat = tensor::ConcatOp::create(rewriter, concatOp.getLoc(),
|
|
concatOp.getDim(), allOuts);
|
|
rewriter.replaceOpWithNewOp<linalg::FillOp>(
|
|
concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
|
|
FoldFillWithPack, FoldFillWithPad,
|
|
FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
|
|
FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
|
|
FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// GenericOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static void buildGenericRegion(
|
|
OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs,
|
|
ValueRange outputs,
|
|
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
|
|
SmallVector<Type, 4> blockArgTypes;
|
|
SmallVector<Location, 4> blockArgLocs;
|
|
for (ValueRange container : {inputs, outputs}) {
|
|
for (Value v : container) {
|
|
Type t = v.getType();
|
|
blockArgTypes.push_back(
|
|
isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
|
|
blockArgLocs.push_back(v.getLoc());
|
|
}
|
|
}
|
|
|
|
OpBuilder::InsertionGuard guard(builder);
|
|
Block *bodyBlock =
|
|
builder.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs);
|
|
bodyBuild(builder, loc, bodyBlock->getArguments());
|
|
}
|
|
|
|
void GenericOp::getAsmBlockArgumentNames(Region ®ion,
|
|
OpAsmSetValueNameFn setNameFn) {
|
|
for (Value v : getRegionInputArgs())
|
|
setNameFn(v, "in");
|
|
for (Value v : getRegionOutputArgs())
|
|
setNameFn(v, "out");
|
|
}
|
|
|
|
void GenericOp::build(
|
|
OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
|
|
ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
|
|
ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
|
|
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
|
|
iteratorTypes, doc, libraryCall);
|
|
result.addAttributes(attributes);
|
|
if (bodyBuild)
|
|
buildGenericRegion(builder, result.location, *result.regions.front(),
|
|
inputs, outputs, bodyBuild);
|
|
}
|
|
|
|
void GenericOp::build(
|
|
OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
|
|
ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
|
|
ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
|
|
StringRef libraryCall,
|
|
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
build(builder, result, resultTensorTypes, inputs, outputs,
|
|
builder.getAffineMapArrayAttr(indexingMaps),
|
|
builder.getArrayAttr(llvm::to_vector(llvm::map_range(
|
|
iteratorTypes,
|
|
[&](utils::IteratorType iter) -> mlir::Attribute {
|
|
return IteratorTypeAttr::get(builder.getContext(), iter);
|
|
}))),
|
|
doc.empty() ? StringAttr() : builder.getStringAttr(doc),
|
|
libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
|
|
bodyBuild, attributes);
|
|
}
|
|
|
|
void GenericOp::build(
|
|
OpBuilder &builder, OperationState &result, ValueRange inputs,
|
|
ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
|
|
ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
|
|
StringRef libraryCall,
|
|
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
|
|
iteratorTypes, doc, libraryCall, bodyBuild, attributes);
|
|
}
|
|
|
|
void GenericOp::build(
|
|
OpBuilder &builder, OperationState &result, ValueRange inputs,
|
|
ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
|
|
ArrayRef<utils::IteratorType> iteratorTypes,
|
|
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
|
|
/*doc=*/"",
|
|
/*libraryCall=*/"", bodyBuild, attributes);
|
|
}
|
|
|
|
void GenericOp::build(
|
|
OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
|
|
ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
|
|
ArrayRef<utils::IteratorType> iteratorTypes,
|
|
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
|
|
iteratorTypes,
|
|
/*doc=*/"",
|
|
/*libraryCall=*/"", bodyBuild, attributes);
|
|
}
|
|
|
|
void GenericOp::print(OpAsmPrinter &p) {
|
|
p << " ";
|
|
|
|
// Print extra attributes.
|
|
auto genericAttrNames = linalgTraitAttrNames();
|
|
|
|
llvm::StringSet<> genericAttrNamesSet;
|
|
genericAttrNamesSet.insert_range(genericAttrNames);
|
|
SmallVector<NamedAttribute, 8> genericAttrs;
|
|
for (auto attr : (*this)->getAttrs()) {
|
|
if (attr.getName() == getIteratorTypesAttrName()) {
|
|
auto iteratorTypes =
|
|
llvm::cast<ArrayAttr>(attr.getValue())
|
|
.getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
|
|
// Convert IteratorType enums into the string representation. This is
|
|
// needed, because tests still use the old format when 'iterator_types'
|
|
// attribute is represented as an array of strings.
|
|
// TODO: Remove this conversion once tests are fixed.
|
|
SmallVector<Attribute> iteratorTypeNames =
|
|
llvm::to_vector(llvm::map_range(
|
|
iteratorTypes, [&](utils::IteratorType t) -> Attribute {
|
|
return StringAttr::get(getContext(), stringifyIteratorType(t));
|
|
}));
|
|
|
|
genericAttrs.emplace_back(
|
|
getIteratorTypesAttrName(),
|
|
ArrayAttr::get(getContext(), iteratorTypeNames));
|
|
} else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
|
|
genericAttrs.push_back(attr);
|
|
}
|
|
}
|
|
if (!genericAttrs.empty()) {
|
|
auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
|
|
p << genericDictAttr;
|
|
}
|
|
|
|
// Printing is shared with named ops, except for the region and attributes
|
|
printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
|
|
|
|
genericAttrNames.push_back("operandSegmentSizes");
|
|
genericAttrNamesSet.insert(genericAttrNames.back());
|
|
|
|
bool hasExtraAttrs = false;
|
|
for (NamedAttribute n : (*this)->getAttrs()) {
|
|
if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
|
|
break;
|
|
}
|
|
if (hasExtraAttrs) {
|
|
p << " attrs = ";
|
|
p.printOptionalAttrDict((*this)->getAttrs(),
|
|
/*elidedAttrs=*/genericAttrNames);
|
|
}
|
|
|
|
// Print region.
|
|
if (!getRegion().empty()) {
|
|
p << ' ';
|
|
p.printRegion(getRegion());
|
|
}
|
|
|
|
// Print results.
|
|
printNamedStructuredOpResults(p, getResultTensors().getTypes());
|
|
}
|
|
|
|
ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
DictionaryAttr dictAttr;
|
|
// Parse the core linalg traits that must check into a dictAttr.
|
|
// The name is unimportant as we will overwrite result.attributes.
|
|
// The core linalg traits must contain the information necessary to pass the
|
|
// verifier.
|
|
llvm::SMLoc attributeLocation = parser.getCurrentLocation();
|
|
if (parser.parseAttribute(dictAttr, "_", result.attributes))
|
|
return failure();
|
|
result.attributes.assign(dictAttr.getValue().begin(),
|
|
dictAttr.getValue().end());
|
|
|
|
// Convert array of string into an array of IteratorType enums. This is
|
|
// needed, because tests still use the old format when 'iterator_types'
|
|
// attribute is represented as an array of strings.
|
|
// TODO: Remove this conversion once tests are fixed.
|
|
auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
|
|
result.attributes.get(getIteratorTypesAttrName(result.name)));
|
|
if (!iteratorTypes) {
|
|
return parser.emitError(attributeLocation)
|
|
<< "expected " << getIteratorTypesAttrName(result.name)
|
|
<< " array attribute";
|
|
}
|
|
|
|
SmallVector<Attribute> iteratorTypeAttrs;
|
|
|
|
for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
|
|
auto maybeIteratorType = utils::symbolizeIteratorType(s);
|
|
if (!maybeIteratorType.has_value())
|
|
return parser.emitError(parser.getCurrentLocation())
|
|
<< "unexpected iterator_type (" << s << ")";
|
|
|
|
iteratorTypeAttrs.push_back(
|
|
IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
|
|
}
|
|
result.attributes.set(getIteratorTypesAttrName(result.name),
|
|
parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
|
|
|
|
// Parsing is shared with named ops, except for the region.
|
|
SmallVector<Type, 1> inputTypes, outputTypes;
|
|
if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
|
|
return failure();
|
|
|
|
// Optional attributes may be added.
|
|
if (succeeded(parser.parseOptionalKeyword("attrs")))
|
|
if (failed(parser.parseEqual()) ||
|
|
failed(parser.parseOptionalAttrDict(result.attributes)))
|
|
return failure();
|
|
|
|
std::unique_ptr<Region> region = std::make_unique<Region>();
|
|
if (parser.parseRegion(*region, {}))
|
|
return failure();
|
|
result.addRegion(std::move(region));
|
|
|
|
// Generic ops may specify that a subset of its outputs are tensors. Such
|
|
// outputs are specified in the result type.
|
|
// TODO: may need to move output parsing before region parsing.
|
|
// Need to wait for declarative assembly resolution to decide.
|
|
SmallVector<Type, 1> outputTensorsTypes;
|
|
if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
|
|
return failure();
|
|
result.addTypes(outputTensorsTypes);
|
|
|
|
return success();
|
|
}
|
|
|
|
static void getGenericEffectsImpl(
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
&effects,
|
|
LinalgOp linalgOp) {
|
|
for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
|
|
if (!llvm::isa<MemRefType>(operand.getType()))
|
|
continue;
|
|
effects.emplace_back(
|
|
MemoryEffects::Read::get(), &linalgOp->getOpOperand(index), /*stage=*/0,
|
|
/*effectOnFullRegion=*/true, SideEffects::DefaultResource::get());
|
|
}
|
|
|
|
for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
|
|
if (!llvm::isa<MemRefType>(operand.get().getType()))
|
|
continue;
|
|
if (linalgOp.payloadUsesValueFromOperand(&operand)) {
|
|
effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
|
|
/*effectOnFullRegion=*/true,
|
|
SideEffects::DefaultResource::get());
|
|
}
|
|
effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
|
|
/*effectOnFullRegion=*/true,
|
|
SideEffects::DefaultResource::get());
|
|
}
|
|
}
|
|
|
|
void GenericOp::getEffects(
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
&effects) {
|
|
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
static Speculation::Speculatability
|
|
getGenericSpeculatabilityImpl(LinalgOp linalgOp) {
|
|
// Operands with value semantics are speculatable, while operands with memory
|
|
// semantics are not.
|
|
if (!linalgOp.hasPureTensorSemantics())
|
|
return Speculation::NotSpeculatable;
|
|
// The body of the op can still have speculation in its region.
|
|
return Speculation::RecursivelySpeculatable;
|
|
}
|
|
|
|
Speculation::Speculatability GenericOp::getSpeculatability() {
|
|
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
LogicalResult GenericOp::verify() { return success(); }
|
|
|
|
namespace {
|
|
|
|
/// Remove linalg operations that are just copying the values from inputs to
|
|
/// results. In the memref case, the operation must be copying to and from the
|
|
/// same value. Requirements are:
|
|
/// 1) All iterator types are parallel
|
|
/// 2) The body contains just a yield operation with the yielded values being
|
|
/// the arguments corresponding to the operands.
|
|
template <typename OpTy>
|
|
struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
|
|
using OpRewritePattern<OpTy>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(OpTy linalgOp,
|
|
PatternRewriter &rewriter) const override {
|
|
// All indexing maps must be equal. It follows that they are permutations.
|
|
if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
|
|
return failure();
|
|
|
|
// Check that the body of the linalg operation is just a linalg.yield
|
|
// operation.
|
|
Block &body = linalgOp->getRegion(0).front();
|
|
if (!llvm::hasSingleElement(body))
|
|
return failure();
|
|
auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
|
|
if (!yieldOp)
|
|
return failure();
|
|
|
|
// In the buffer case, we need to check exact buffer equality.
|
|
if (linalgOp.hasPureBufferSemantics()) {
|
|
if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
|
|
linalgOp.getDpsInputOperand(0)->get() !=
|
|
linalgOp.getDpsInitOperand(0)->get()) {
|
|
return rewriter.notifyMatchFailure(
|
|
linalgOp, "expected single input and output to be the same value");
|
|
}
|
|
|
|
auto yieldArg = dyn_cast<BlockArgument>(yieldOp.getOperand(0));
|
|
if (!yieldArg || yieldArg.getOwner() != &body) {
|
|
return rewriter.notifyMatchFailure(linalgOp,
|
|
"cannot fold fill-like op");
|
|
}
|
|
|
|
rewriter.eraseOp(linalgOp);
|
|
return success();
|
|
}
|
|
|
|
if (!linalgOp.hasPureTensorSemantics()) {
|
|
return rewriter.notifyMatchFailure(
|
|
linalgOp, "mixed semantics is not supported yet");
|
|
}
|
|
|
|
// Get the argument number of the returned values. That is the operand
|
|
// number to use for replacing uses of this operation.
|
|
SmallVector<Value> returnedArgs;
|
|
for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
|
|
auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
|
|
if (!yieldArg || yieldArg.getOwner() != &body)
|
|
return failure();
|
|
unsigned argumentNumber = yieldArg.getArgNumber();
|
|
Value returnedArg = linalgOp->getOperand(argumentNumber);
|
|
Type resultType = linalgOp->getResult(yieldVal.index()).getType();
|
|
// The input can have a different type than the result, e.g. a dynamic
|
|
// input dimension can be turned into a static output dimension.
|
|
Type returnType = returnedArg.getType();
|
|
if (returnType != resultType) {
|
|
// Distinguish between sparse conversion or dense tensor casting.
|
|
// TODO: unify the two ops?
|
|
if (sparse_tensor::getSparseTensorEncoding(returnType) ||
|
|
sparse_tensor::getSparseTensorEncoding(resultType))
|
|
returnedArg = sparse_tensor::ConvertOp::create(
|
|
rewriter, linalgOp.getLoc(), resultType, returnedArg);
|
|
else {
|
|
if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
|
|
resultType))
|
|
return failure();
|
|
returnedArg = tensor::CastOp::create(rewriter, linalgOp.getLoc(),
|
|
resultType, returnedArg);
|
|
}
|
|
}
|
|
returnedArgs.push_back(returnedArg);
|
|
}
|
|
|
|
if (returnedArgs.size() != linalgOp->getNumResults())
|
|
return failure();
|
|
rewriter.replaceOp(linalgOp, returnedArgs);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.add<EraseIdentityLinalgOp<GenericOp>>(context);
|
|
}
|
|
|
|
LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
|
|
return memref::foldMemRefCast(*this);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MapOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static ParseResult parseDstStyleOp(
|
|
OpAsmParser &parser, OperationState &result,
|
|
function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
|
|
nullptr) {
|
|
// Parse `ins` and `outs`.
|
|
SmallVector<Type, 4> inputTypes, outputTypes;
|
|
if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
|
|
/*addOperandSegmentSizes=*/false))
|
|
return failure();
|
|
|
|
// Add result types.
|
|
for (Type outputType : outputTypes) {
|
|
if (llvm::isa<RankedTensorType>(outputType))
|
|
result.addTypes(outputType);
|
|
}
|
|
|
|
// Parse required attributes.
|
|
if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
|
|
return failure();
|
|
|
|
// Parse optional attributes.
|
|
if (parser.parseOptionalAttrDict(result.attributes))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
void MapOp::getAsmBlockArgumentNames(Region ®ion,
|
|
OpAsmSetValueNameFn setNameFn) {
|
|
for (Value v : getRegionInputArgs())
|
|
setNameFn(v, "in");
|
|
}
|
|
|
|
void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
|
|
if (!getResults().empty())
|
|
setNameFn(getResults().front(), "mapped");
|
|
}
|
|
|
|
void MapOp::build(
|
|
OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
|
|
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
build(builder, result, TypeRange{}, inputs, init);
|
|
result.addAttributes(attributes);
|
|
|
|
// Add output types for `RankedTensorType` output arguments.
|
|
Type initType = init.getType();
|
|
if (llvm::isa<RankedTensorType>(initType))
|
|
result.addTypes(initType);
|
|
|
|
if (bodyBuild)
|
|
buildGenericRegion(builder, result.location, *result.regions.front(),
|
|
inputs, /*outputs=*/{}, bodyBuild);
|
|
}
|
|
|
|
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
|
|
const OperationName &payloadOpName,
|
|
const NamedAttrList &payloadOpAttrs,
|
|
ArrayRef<Value> operands,
|
|
bool initFirst = false) {
|
|
OpBuilder b(parser.getContext());
|
|
Region *body = result.addRegion();
|
|
Block &block = body->emplaceBlock();
|
|
b.setInsertionPointToStart(&block);
|
|
for (auto &operand : operands) {
|
|
block.addArgument(
|
|
llvm::cast<ShapedType>(operand.getType()).getElementType(),
|
|
b.getUnknownLoc());
|
|
}
|
|
SmallVector<Value> payloadOpOperands;
|
|
// If initFirst flag is enabled, we consider init as the first position of
|
|
// payload operands.
|
|
if (initFirst) {
|
|
payloadOpOperands.push_back(block.getArguments().back());
|
|
for (const auto &arg : block.getArguments().drop_back())
|
|
payloadOpOperands.push_back(arg);
|
|
} else {
|
|
payloadOpOperands = {block.getArguments().begin(),
|
|
block.getArguments().end()};
|
|
}
|
|
|
|
Operation *payloadOp = b.create(
|
|
result.location, b.getStringAttr(payloadOpName.getStringRef()),
|
|
payloadOpOperands,
|
|
TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
|
|
.getElementType()},
|
|
payloadOpAttrs);
|
|
YieldOp::create(b, result.location, payloadOp->getResults());
|
|
}
|
|
|
|
ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
std::optional<OperationName> payloadOpName;
|
|
NamedAttrList payloadOpAttrs;
|
|
if (succeeded(parser.parseOptionalLBrace())) {
|
|
FailureOr<OperationName> operationName = parser.parseCustomOperationName();
|
|
if (failed(operationName))
|
|
return failure();
|
|
if (parser.parseOptionalAttrDict(payloadOpAttrs))
|
|
return failure();
|
|
payloadOpName = operationName.value();
|
|
if (parser.parseRBrace())
|
|
return failure();
|
|
}
|
|
|
|
if (parseDstStyleOp(parser, result))
|
|
return failure();
|
|
|
|
if (payloadOpName.has_value()) {
|
|
if (!result.operands.empty())
|
|
addBodyWithPayloadOp(parser, result, payloadOpName.value(),
|
|
payloadOpAttrs,
|
|
ArrayRef(result.operands).drop_back());
|
|
else
|
|
result.addRegion();
|
|
} else {
|
|
SmallVector<OpAsmParser::Argument> regionArgs;
|
|
if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
|
|
/*allowType=*/true, /*allowAttrs=*/true)) {
|
|
return failure();
|
|
}
|
|
Region *body = result.addRegion();
|
|
if (parser.parseRegion(*body, regionArgs))
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
// Retrieve the operation from the body, if it is the only one (except
|
|
// yield) and if it gets the same amount of arguments as the body does.
|
|
// If initFirst flag is enabled, we check that init takes the first position in
|
|
// operands of payload.
|
|
static Operation *findPayloadOp(Block *body, bool initFirst = false) {
|
|
if (body->getOperations().size() != 2)
|
|
return nullptr;
|
|
Operation &payload = body->getOperations().front();
|
|
assert(isa<YieldOp>(body->getOperations().back()));
|
|
|
|
if (payload.getNumOperands() == 0 ||
|
|
payload.getNumOperands() != body->getNumArguments())
|
|
return nullptr;
|
|
if (initFirst) {
|
|
// check init
|
|
if (payload.getOperands().back() != body->getArgument(0))
|
|
return nullptr;
|
|
// check rest
|
|
for (const auto &[operand, bbArg] :
|
|
llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
|
|
if (bbArg != operand)
|
|
return nullptr;
|
|
}
|
|
} else {
|
|
for (const auto &[operand, bbArg] :
|
|
llvm::zip(payload.getOperands(), body->getArguments())) {
|
|
if (bbArg != operand)
|
|
return nullptr;
|
|
}
|
|
}
|
|
return &payload;
|
|
}
|
|
|
|
void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
|
|
SmallVector<StringRef> elidedAttrs;
|
|
std::string attrToElide;
|
|
p << " { " << payloadOp->getName().getStringRef();
|
|
for (const auto &attr : payloadOp->getAttrs()) {
|
|
auto fastAttr =
|
|
llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
|
|
if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
|
|
attrToElide = attr.getName().str();
|
|
elidedAttrs.push_back(attrToElide);
|
|
break;
|
|
}
|
|
}
|
|
p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs);
|
|
p << " }";
|
|
}
|
|
|
|
void MapOp::print(OpAsmPrinter &p) {
|
|
Block *mapper = getBody();
|
|
Operation *payloadOp = findPayloadOp(mapper);
|
|
if (payloadOp) {
|
|
printShortForm(p, payloadOp);
|
|
}
|
|
|
|
printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
|
|
p.printOptionalAttrDict((*this)->getAttrs());
|
|
|
|
if (!payloadOp) {
|
|
// Print region if the payload op was not detected.
|
|
p.increaseIndent();
|
|
p.printNewline();
|
|
p << "(";
|
|
llvm::interleaveComma(mapper->getArguments(), p,
|
|
[&](auto arg) { p.printRegionArgument(arg); });
|
|
p << ") ";
|
|
|
|
p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
|
|
p.decreaseIndent();
|
|
}
|
|
}
|
|
|
|
LogicalResult MapOp::verify() {
|
|
auto *bodyBlock = getBody();
|
|
auto blockArgs = bodyBlock->getArguments();
|
|
|
|
// Checks if the number of `inputs` match the arity of the `mapper` region.
|
|
if (getInputs().size() != blockArgs.size())
|
|
return emitOpError() << "expects number of operands to match the arity of "
|
|
"mapper, but got: "
|
|
<< getInputs().size() << " and " << blockArgs.size();
|
|
|
|
// The parameters of mapper should all match the element type of inputs.
|
|
for (const auto &[bbArgType, inputArg] :
|
|
llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
|
|
auto inputElemType =
|
|
llvm::cast<ShapedType>(inputArg.getType()).getElementType();
|
|
if (bbArgType != inputElemType) {
|
|
return emitOpError() << "expected element type of input " << inputElemType
|
|
<< " to match bbArg type " << bbArgType;
|
|
}
|
|
}
|
|
|
|
// The shape of each input must match the shape of the output.
|
|
auto outputShape = getInit().getType().getShape();
|
|
for (Type inputArgType : TypeRange{getInputs()}) {
|
|
auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
|
|
if (inputElemShape != outputShape) {
|
|
return emitOpError() << "expected shape of input (" << inputElemShape
|
|
<< ") to match shape of output (" << outputShape
|
|
<< ")";
|
|
}
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
|
|
int64_t rank = getInit().getType().getRank();
|
|
return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
|
|
}
|
|
|
|
ArrayAttr MapOp::getIndexingMaps() {
|
|
Builder builder(getContext());
|
|
int64_t rank = getInit().getType().getRank();
|
|
int64_t numIndexingMaps = getOperands().size();
|
|
return builder.getAffineMapArrayAttr(SmallVector<AffineMap>(
|
|
numIndexingMaps, builder.getMultiDimIdentityMap(rank)));
|
|
}
|
|
|
|
void MapOp::getEffects(
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
&effects) {
|
|
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
Speculation::Speculatability MapOp::getSpeculatability() {
|
|
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ReduceOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void ReduceOp::getAsmBlockArgumentNames(Region ®ion,
|
|
OpAsmSetValueNameFn setNameFn) {
|
|
for (Value v : getRegionInputArgs())
|
|
setNameFn(v, "in");
|
|
for (Value v : getRegionOutputArgs())
|
|
setNameFn(v, "init");
|
|
}
|
|
|
|
void ReduceOp::getAsmResultNames(
|
|
function_ref<void(Value, StringRef)> setNameFn) {
|
|
if (!getResults().empty())
|
|
setNameFn(getResults().front(), "reduced");
|
|
}
|
|
|
|
void ReduceOp::build(
|
|
OpBuilder &builder, OperationState &result, ValueRange inputs,
|
|
ValueRange inits, ArrayRef<int64_t> dimensions,
|
|
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
build(builder, result, TypeRange{}, inputs, inits, dimensions);
|
|
result.addAttributes(attributes);
|
|
|
|
// Add output types for `RankedTensorType` output arguments.
|
|
for (Value init : inits) {
|
|
Type initType = init.getType();
|
|
if (llvm::isa<RankedTensorType>(initType))
|
|
result.addTypes(initType);
|
|
}
|
|
|
|
if (bodyBuild)
|
|
buildGenericRegion(builder, result.location, *result.regions.front(),
|
|
inputs, inits, bodyBuild);
|
|
}
|
|
|
|
SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
|
|
int64_t inputRank =
|
|
llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
|
|
SmallVector<utils::IteratorType> iteratorTypes(inputRank,
|
|
utils::IteratorType::parallel);
|
|
for (int64_t reductionDim : getDimensions())
|
|
iteratorTypes[reductionDim] = utils::IteratorType::reduction;
|
|
return iteratorTypes;
|
|
}
|
|
|
|
ArrayAttr ReduceOp::getIndexingMaps() {
|
|
int64_t inputRank =
|
|
llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
|
|
SmallVector<AffineMap> affineMaps(
|
|
getNumDpsInputs(),
|
|
AffineMap::getMultiDimIdentityMap(inputRank, getContext()));
|
|
AffineMap resultMap =
|
|
AffineMap::getMultiDimIdentityMap(inputRank, getContext())
|
|
.dropResults(getDimensions());
|
|
for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
|
|
affineMaps.push_back(resultMap);
|
|
return Builder(getContext()).getAffineMapArrayAttr(affineMaps);
|
|
}
|
|
|
|
void ReduceOp::getEffects(
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
&effects) {
|
|
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
Speculation::Speculatability ReduceOp::getSpeculatability() {
|
|
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
|
|
NamedAttrList &attributes,
|
|
StringRef attributeName) {
|
|
if (parser.parseKeyword(attributeName) || parser.parseEqual())
|
|
return failure();
|
|
|
|
attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{}));
|
|
return success();
|
|
}
|
|
|
|
ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
std::optional<OperationName> payloadOpName;
|
|
NamedAttrList payloadOpAttrs;
|
|
if (succeeded(parser.parseOptionalLBrace())) {
|
|
FailureOr<OperationName> operationName = parser.parseCustomOperationName();
|
|
if (failed(operationName))
|
|
return failure();
|
|
if (parser.parseOptionalAttrDict(payloadOpAttrs))
|
|
return failure();
|
|
payloadOpName = operationName.value();
|
|
if (parser.parseRBrace())
|
|
return failure();
|
|
}
|
|
|
|
if (parseDstStyleOp(
|
|
parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
|
|
return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
|
|
}))
|
|
return failure();
|
|
|
|
if (payloadOpName.has_value()) {
|
|
addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
|
|
ArrayRef(result.operands), /*initFirst=*/true);
|
|
} else {
|
|
SmallVector<OpAsmParser::Argument> regionArgs;
|
|
if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
|
|
/*allowType=*/true, /*allowAttrs=*/true)) {
|
|
return failure();
|
|
}
|
|
|
|
Region *body = result.addRegion();
|
|
if (parser.parseRegion(*body, regionArgs))
|
|
return failure();
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
|
|
ArrayRef<int64_t> attributeValue) {
|
|
p << ' ' << attributeName << " = [" << attributeValue << "] ";
|
|
}
|
|
|
|
void ReduceOp::print(OpAsmPrinter &p) {
|
|
Block *mapper = getBody();
|
|
Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true);
|
|
if (payloadOp) {
|
|
printShortForm(p, payloadOp);
|
|
}
|
|
|
|
printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
|
|
printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
|
|
p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
|
|
if (!payloadOp) {
|
|
// Print region if the payload op was not detected.
|
|
p.increaseIndent();
|
|
p.printNewline();
|
|
p << "(";
|
|
llvm::interleaveComma(mapper->getArguments(), p,
|
|
[&](auto arg) { p.printRegionArgument(arg); });
|
|
p << ") ";
|
|
|
|
p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
|
|
p.decreaseIndent();
|
|
}
|
|
}
|
|
|
|
LogicalResult ReduceOp::verify() {
|
|
ArrayRef<int64_t> dimensionsRef = getDimensions();
|
|
|
|
for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
|
|
if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
|
|
llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
|
|
return emitOpError() << "expects all inputs to have the same shapes. "
|
|
"Shape at input-index "
|
|
<< i
|
|
<< " is not equal to the shape at input-index 0.";
|
|
}
|
|
}
|
|
for (int64_t i = 1; i < getNumDpsInits(); ++i) {
|
|
if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() !=
|
|
llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) {
|
|
return emitOpError() << "expects all outputs to have the same shapes. "
|
|
"Shape at output-index "
|
|
<< i
|
|
<< " is not equal to the shape at output-index 0.";
|
|
}
|
|
}
|
|
auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType());
|
|
auto initType = llvm::cast<ShapedType>(getInits()[0].getType());
|
|
|
|
DenseSet<int64_t> dimensionsToReduce;
|
|
for (int64_t dimension : dimensionsRef) {
|
|
if (dimension < 0 || dimension >= inputType.getRank()) {
|
|
return emitOpError()
|
|
<< "dimensions for reduction should be in the range [0, "
|
|
<< inputType.getRank() - 1 << "].";
|
|
}
|
|
dimensionsToReduce.insert(dimension);
|
|
}
|
|
|
|
auto inputDims = inputType.getShape();
|
|
auto initDims = initType.getShape();
|
|
|
|
// Input dimensions that will be left after the reduction.
|
|
SmallVector<int64_t> reducedInputDims;
|
|
for (const auto &en : llvm::enumerate(inputDims)) {
|
|
if (!dimensionsToReduce.count(en.index()))
|
|
reducedInputDims.push_back(en.value());
|
|
}
|
|
|
|
if (reducedInputDims.size() != static_cast<size_t>(initType.getRank())) {
|
|
return emitOpError() << "number of dimensions after reduction "
|
|
<< reducedInputDims.size()
|
|
<< " doesn't match the init rank "
|
|
<< initType.getRank();
|
|
}
|
|
|
|
if (reducedInputDims != initDims)
|
|
return emitOpError() << "init dimensions [" << initDims
|
|
<< "] doesn't match input dimensions after reduction ["
|
|
<< reducedInputDims << "]";
|
|
|
|
Block *block = getBody();
|
|
if (block->getNumArguments() != this->getNumOperands())
|
|
return emitOpError()
|
|
<< "mismatching number of operands and block arguments";
|
|
|
|
// Check that the first block arguments match the element type of the inputs.
|
|
for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) {
|
|
Type inputElementType =
|
|
llvm::cast<ShapedType>(input.getType()).getElementType();
|
|
if (inputElementType != bbArg.getType())
|
|
return emitOpError()
|
|
<< "input element type " << inputElementType
|
|
<< " does not match corresponding block argument type "
|
|
<< bbArg.getType();
|
|
}
|
|
|
|
// Check that the last block arguments match the element type of the outputs.
|
|
for (auto [output, bbArg] : llvm::zip(
|
|
getDpsInits(), block->getArguments().take_back(getNumDpsInits()))) {
|
|
auto outputElementType =
|
|
llvm::cast<ShapedType>(output.getType()).getElementType();
|
|
if (outputElementType != bbArg.getType())
|
|
return emitOpError()
|
|
<< "output element type " << outputElementType
|
|
<< " does not match corresponding block argument type "
|
|
<< bbArg.getType();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TransposeOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static void buildIdentityRegion(OpBuilder &builder, Location loc,
|
|
Region ®ion, ValueRange inputs,
|
|
ValueRange outputs) {
|
|
buildGenericRegion(builder, loc, region, inputs, outputs,
|
|
[](OpBuilder &b, Location loc, ValueRange args) {
|
|
if (!args.empty())
|
|
linalg::YieldOp::create(b, loc, args[0]);
|
|
});
|
|
}
|
|
|
|
void TransposeOp::build(::mlir::OpBuilder &builder,
|
|
::mlir::OperationState &result, Value input, Value init,
|
|
DenseI64ArrayAttr permutation,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
result.addOperands(input);
|
|
result.addOperands(init);
|
|
result.addAttribute(getPermutationAttrName(result.name), permutation);
|
|
result.addAttributes(attributes);
|
|
|
|
// Add output types for `RankedTensorType` output arguments.
|
|
Type initType = init.getType();
|
|
if (llvm::isa<RankedTensorType>(initType))
|
|
result.addTypes(initType);
|
|
|
|
buildIdentityRegion(builder, result.location, *result.addRegion(), input,
|
|
init);
|
|
}
|
|
|
|
void TransposeOp::build(::mlir::OpBuilder &builder,
|
|
::mlir::OperationState &result, Value input, Value init,
|
|
ArrayRef<int64_t> permutation,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
|
|
attributes);
|
|
}
|
|
|
|
ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
if (failed(parseDstStyleOp(
|
|
parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
|
|
return parseDenseI64ArrayAttr(parser, attributes, "permutation");
|
|
})))
|
|
return failure();
|
|
|
|
OpBuilder builder(parser.getContext());
|
|
buildIdentityRegion(builder, result.location, *result.addRegion(),
|
|
/*inputs=*/result.operands,
|
|
/*outputs=*/{});
|
|
return success();
|
|
}
|
|
|
|
void TransposeOp::getAsmResultNames(
|
|
function_ref<void(Value, StringRef)> setNameFn) {
|
|
if (!getResults().empty())
|
|
setNameFn(getResults().front(), "transposed");
|
|
}
|
|
|
|
void TransposeOp::print(OpAsmPrinter &p) {
|
|
printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
|
|
printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
|
|
p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
|
|
}
|
|
|
|
LogicalResult TransposeOp::verify() {
|
|
ArrayRef<int64_t> permutationRef = getPermutation();
|
|
|
|
if (!isPermutationVector(permutationRef))
|
|
return emitOpError("permutation is not valid");
|
|
|
|
auto inputType = getInput().getType();
|
|
auto initType = getInit().getType();
|
|
|
|
int64_t rank = inputType.getRank();
|
|
|
|
if (rank != initType.getRank())
|
|
return emitOpError() << "input rank " << rank
|
|
<< " does not match init rank " << initType.getRank();
|
|
|
|
if (rank != static_cast<int64_t>(permutationRef.size()))
|
|
return emitOpError() << "size of permutation " << permutationRef.size()
|
|
<< " does not match the argument rank " << rank;
|
|
|
|
auto inputDims = inputType.getShape();
|
|
auto initDims = initType.getShape();
|
|
|
|
for (int64_t i = 0; i < rank; ++i) {
|
|
int64_t inputDim = inputDims[permutationRef[i]];
|
|
int64_t initDim = initDims[i];
|
|
|
|
if (inputDim != initDim) {
|
|
return emitOpError() << "dim(result, " << i << ") = " << initDim
|
|
<< " doesn't match dim(input, permutation[" << i
|
|
<< "]) = " << inputDim;
|
|
}
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
|
|
int64_t rank = getInit().getType().getRank();
|
|
return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
|
|
}
|
|
|
|
ArrayAttr TransposeOp::getIndexingMaps() {
|
|
Builder builder(getContext());
|
|
int64_t rank = getInit().getType().getRank();
|
|
return builder.getAffineMapArrayAttr(
|
|
{inversePermutation(AffineMap::getPermutationMap(
|
|
llvm::to_vector_of<unsigned>(getPermutation()), getContext())),
|
|
builder.getMultiDimIdentityMap(rank)});
|
|
}
|
|
|
|
void TransposeOp::getEffects(
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
&effects) {
|
|
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
Speculation::Speculatability TransposeOp::getSpeculatability() {
|
|
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
|
|
SmallVectorImpl<OpFoldResult> &result) {
|
|
// Only the tensor type is supported.
|
|
if (!isa<TensorType>(getInput().getType()))
|
|
return failure();
|
|
|
|
// Single dimension transpose.
|
|
if (getPermutation().size() == 0) {
|
|
result.push_back(getInput());
|
|
return success();
|
|
}
|
|
// Identity permutation.
|
|
if (isIdentityPermutation(getPermutation())) {
|
|
result.push_back(getInput());
|
|
return success();
|
|
}
|
|
|
|
return failure();
|
|
}
|
|
|
|
/// Fold transpose with transpose.
|
|
struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
|
|
using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
|
|
if (!defTransposeOp)
|
|
return failure();
|
|
ArrayRef<int64_t> defPerms = defTransposeOp.getPermutation();
|
|
ArrayRef<int64_t> perms = transposeOp.getPermutation();
|
|
SmallVector<int64_t> foldedPerms;
|
|
foldedPerms.reserve(perms.size());
|
|
for (int64_t perm : perms)
|
|
foldedPerms.push_back(defPerms[perm]);
|
|
|
|
rewriter.replaceOpWithNewOp<TransposeOp>(
|
|
transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
|
|
foldedPerms);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// This pattern canonicalize transpose by swapping the order of
|
|
/// broadcast and transpose:
|
|
/// transpose(broadcast(input)) -> broadcast(transpose(input))
|
|
struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
|
|
using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
|
|
PatternRewriter &rewriter) const override {
|
|
Value input = transposeOp.getInput();
|
|
BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>();
|
|
if (!input.hasOneUse() || !broadcastOp)
|
|
return failure();
|
|
|
|
ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
|
|
ArrayRef<int64_t> perms = transposeOp.getPermutation();
|
|
|
|
// Get new perms and new dimensions.
|
|
SmallVector<int64_t> resultPerms = dropDims(perms, dimensions);
|
|
SmallVector<int64_t> invertPerm = invertPermutationVector(perms);
|
|
SmallVector<int64_t> resultDimensions;
|
|
unsigned dimensionSize = dimensions.size();
|
|
for (unsigned i = 0; i < dimensionSize; ++i)
|
|
resultDimensions.push_back(invertPerm[dimensions[i]]);
|
|
|
|
// Create transpose result.
|
|
Value broadcastInput = broadcastOp.getInput();
|
|
Location loc = transposeOp.getLoc();
|
|
MLIRContext *ctx = transposeOp.getContext();
|
|
SmallVector<OpFoldResult> dims;
|
|
auto broadcastInputTy =
|
|
mlir::cast<RankedTensorType>(broadcastInput.getType());
|
|
unsigned inputRank = broadcastInputTy.getRank();
|
|
for (unsigned i = 0; i < inputRank; ++i) {
|
|
if (broadcastInputTy.isDynamicDim(i)) {
|
|
dims.push_back(tensor::DimOp::create(rewriter, loc, broadcastInput, i)
|
|
->getResult(0));
|
|
} else {
|
|
dims.push_back(IntegerAttr::get(IndexType::get(ctx),
|
|
broadcastInputTy.getDimSize(i)));
|
|
}
|
|
}
|
|
SmallVector<OpFoldResult> transposeResultShapes =
|
|
applyPermutation(dims, resultPerms);
|
|
Value transposeInit = tensor::EmptyOp::create(
|
|
rewriter, transposeOp.getLoc(), transposeResultShapes,
|
|
broadcastInputTy.getElementType());
|
|
|
|
// Create broadcast(transpose(input)).
|
|
Value transposeResult =
|
|
TransposeOp::create(rewriter, loc, broadcastOp.getInput(),
|
|
transposeInit, resultPerms)
|
|
->getResult(0);
|
|
rewriter.replaceOpWithNewOp<BroadcastOp>(
|
|
transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.add<FoldTransposeWithTranspose, SwapTransposeWithBroadcast>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// BroadcastOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void BroadcastOp::build(::mlir::OpBuilder &builder,
|
|
::mlir::OperationState &result, Value input, Value init,
|
|
DenseI64ArrayAttr dimensions,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
result.addOperands(input);
|
|
result.addOperands(init);
|
|
result.addAttribute(getDimensionsAttrName(result.name), dimensions);
|
|
result.addAttributes(attributes);
|
|
|
|
// Add output types for `RankedTensorType` output arguments.
|
|
Type initType = init.getType();
|
|
if (llvm::isa<RankedTensorType>(initType))
|
|
result.addTypes(initType);
|
|
|
|
buildIdentityRegion(builder, result.location, *result.addRegion(), input,
|
|
init);
|
|
}
|
|
|
|
void BroadcastOp::build(::mlir::OpBuilder &builder,
|
|
::mlir::OperationState &result, Value input, Value init,
|
|
ArrayRef<int64_t> dimensions,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions),
|
|
attributes);
|
|
}
|
|
|
|
ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
if (failed(parseDstStyleOp(
|
|
parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
|
|
return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
|
|
})))
|
|
return failure();
|
|
|
|
OpBuilder builder(parser.getContext());
|
|
buildIdentityRegion(builder, result.location, *result.addRegion(),
|
|
/*inputs=*/result.operands,
|
|
/*outputs=*/{});
|
|
return success();
|
|
}
|
|
|
|
void BroadcastOp::getAsmResultNames(
|
|
function_ref<void(Value, StringRef)> setNameFn) {
|
|
if (!getResults().empty())
|
|
setNameFn(getResults().front(), "broadcasted");
|
|
}
|
|
|
|
void BroadcastOp::print(OpAsmPrinter &p) {
|
|
printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
|
|
printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
|
|
p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
|
|
}
|
|
|
|
LogicalResult BroadcastOp::verify() {
|
|
ArrayRef<int64_t> dimensionsRef = getDimensions();
|
|
|
|
auto inputType = getInput().getType();
|
|
auto initType = getInit().getType();
|
|
|
|
int64_t inputRank = inputType.getRank();
|
|
int64_t initRank = initType.getRank();
|
|
|
|
auto inputShape = inputType.getShape();
|
|
auto initShape = initType.getShape();
|
|
|
|
if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank)
|
|
return emitOpError() << "input rank plus added dimensions does not "
|
|
"match init rank. input rank: "
|
|
<< inputRank
|
|
<< ", dimensions size: " << dimensionsRef.size()
|
|
<< ", init rank: " << initRank;
|
|
|
|
for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
|
|
if (dim < 0 || dim >= initRank)
|
|
return emitOpError() << "dimension " << idx
|
|
<< " is out of range. expected range: [0, "
|
|
<< initRank - 1 << "], got: " << dim;
|
|
}
|
|
|
|
// Mapping from input dims to init dims.
|
|
SmallVector<int64_t> dimMap;
|
|
for (auto dim : llvm::seq<int64_t>(0, initRank)) {
|
|
if (!llvm::is_contained(dimensionsRef, dim))
|
|
dimMap.push_back(dim);
|
|
}
|
|
|
|
for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
|
|
// This dimensions is mapped from the input. Init and input dims should
|
|
// match.
|
|
if (inputShape[inputDimIdx] != initShape[initDimIdx])
|
|
return emitOpError() << "input dim " << inputDimIdx
|
|
<< " should match init dim " << initDimIdx
|
|
<< ". input: " << inputShape[inputDimIdx]
|
|
<< ", init: " << initShape[initDimIdx];
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
|
|
int64_t rank = getInit().getType().getRank();
|
|
return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
|
|
}
|
|
|
|
ArrayAttr BroadcastOp::getIndexingMaps() {
|
|
Builder builder(getContext());
|
|
int64_t rank = getInit().getType().getRank();
|
|
return builder.getAffineMapArrayAttr(
|
|
{builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()),
|
|
builder.getMultiDimIdentityMap(rank)});
|
|
}
|
|
|
|
void BroadcastOp::getEffects(
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
&effects) {
|
|
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
Speculation::Speculatability BroadcastOp::getSpeculatability() {
|
|
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
/// Fold back-to-back broadcasts together.
|
|
struct FoldBroadcasts : OpRewritePattern<linalg::BroadcastOp> {
|
|
using OpRewritePattern<linalg::BroadcastOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto defBroadcastOp = broadcastOp.getInput().getDefiningOp<BroadcastOp>();
|
|
if (!defBroadcastOp)
|
|
return failure();
|
|
ArrayRef<int64_t> defDimensions = defBroadcastOp.getDimensions();
|
|
ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
|
|
SmallVector<int64_t> foldedDims(dimensions);
|
|
Value init = broadcastOp.getInit();
|
|
int64_t initRank = cast<ShapedType>(init.getType()).getRank();
|
|
// Mapping from input dims to init dims.
|
|
SmallVector<int64_t> dimMap;
|
|
for (auto dim : llvm::seq<int64_t>(0, initRank)) {
|
|
if (!llvm::is_contained(dimensions, dim))
|
|
dimMap.push_back(dim);
|
|
}
|
|
for (auto dim : defDimensions)
|
|
foldedDims.push_back(dimMap[dim]);
|
|
|
|
llvm::sort(foldedDims);
|
|
rewriter.replaceOpWithNewOp<BroadcastOp>(
|
|
broadcastOp, defBroadcastOp.getInput(), init, foldedDims);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcasts>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// YieldOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void linalg::YieldOp::print(OpAsmPrinter &p) {
|
|
if (getNumOperands() > 0)
|
|
p << ' ' << getOperands();
|
|
p.printOptionalAttrDict((*this)->getAttrs());
|
|
if (getNumOperands() > 0)
|
|
p << " : " << getOperandTypes();
|
|
}
|
|
|
|
ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 2> opInfo;
|
|
SmallVector<Type, 2> types;
|
|
SMLoc loc = parser.getCurrentLocation();
|
|
return failure(parser.parseOperandList(opInfo) ||
|
|
parser.parseOptionalAttrDict(result.attributes) ||
|
|
(!opInfo.empty() && parser.parseColonTypeList(types)) ||
|
|
parser.resolveOperands(opInfo, types, loc, result.operands));
|
|
}
|
|
|
|
// Check the operand number and types must match the element types of the
|
|
// LinalgOp interface's shaped operands.
|
|
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
|
|
if (op.getNumOperands() != linalgOp.getNumDpsInits())
|
|
return op.emitOpError("expected number of yield values (")
|
|
<< op.getNumOperands()
|
|
<< ") to match the number of inits / outs operands of the enclosing "
|
|
<< "LinalgOp (" << linalgOp.getNumDpsInits() << ")";
|
|
|
|
for (OpOperand &opOperand : op->getOpOperands()) {
|
|
OpOperand *outputOperand =
|
|
linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
|
|
Type elementType = outputOperand->get().getType();
|
|
if (isa<MemRefType, RankedTensorType>(elementType))
|
|
elementType = getElementTypeOrSelf(outputOperand->get().getType());
|
|
if (opOperand.get().getType() != elementType)
|
|
return op.emitOpError("type of yield operand ")
|
|
<< (opOperand.getOperandNumber() + 1) << " ("
|
|
<< opOperand.get().getType() << ") doesn't match "
|
|
<< "the element type of the enclosing linalg.generic op ("
|
|
<< elementType << ")";
|
|
}
|
|
return success();
|
|
}
|
|
|
|
LogicalResult linalg::YieldOp::verify() {
|
|
auto *parentOp = (*this)->getParentOp();
|
|
if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
|
|
return emitOpError("expected single non-empty parent region");
|
|
|
|
if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
|
|
return verifyYield(*this, linalgOp);
|
|
|
|
return emitOpError("expected parent op with LinalgOp interface");
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// IndexOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult IndexOp::verify() {
|
|
auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
|
|
if (!linalgOp)
|
|
return emitOpError("expected parent op with LinalgOp interface");
|
|
if (linalgOp.getNumLoops() <= getDim())
|
|
return emitOpError("expected dim (")
|
|
<< getDim() << ") to be lower than the number of loops ("
|
|
<< linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
|
|
return success();
|
|
}
|
|
|
|
OpFoldResult IndexOp::fold(FoldAdaptor adaptor) {
|
|
auto linalgOp = dyn_cast_or_null<LinalgOp>((*this)->getParentOp());
|
|
// Bail out if `linalg.index` does not have a proper parent yet at this
|
|
// point, e.g., when calling `createOrFold` during IR construction in
|
|
// `genericOp::build`.
|
|
if (!linalgOp)
|
|
return OpFoldResult{};
|
|
|
|
// Index of unit dims is always 0.
|
|
SmallVector<int64_t, 4> loopBounds = linalgOp.getStaticLoopRanges();
|
|
uint64_t dim = getDim();
|
|
assert(dim < loopBounds.size() && "Dim is out of bounds");
|
|
if (loopBounds[dim] == 1)
|
|
return IntegerAttr::get(IndexType::get(getContext()), 0);
|
|
|
|
return OpFoldResult{};
|
|
}
|
|
|
|
/////// Operations corresponding to library calls defined with Tablegen ////////
|
|
|
|
#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
|
|
#define GET_OP_CLASSES
|
|
#include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.cpp.inc"
|
|
|
|
AffineMap mlir::linalg::extractOrIdentityMap(std::optional<AffineMap> maybeMap,
|
|
unsigned rank,
|
|
MLIRContext *context) {
|
|
if (maybeMap)
|
|
return *maybeMap;
|
|
if (rank == 0)
|
|
return AffineMap::get(context);
|
|
return AffineMap::getMultiDimIdentityMap(rank, context);
|
|
}
|
|
|
|
SmallVector<AffineExpr, 4>
|
|
mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
|
|
MLIRContext *context) {
|
|
SmallVector<AffineExpr, 4> res;
|
|
res.reserve(num);
|
|
for (unsigned i = 0; i < num; ++i)
|
|
res.push_back(getAffineDimExpr(startIdx++, context));
|
|
return res;
|
|
}
|
|
|
|
SmallVector<AffineExpr, 4> mlir::linalg::concat(ArrayRef<AffineExpr> a,
|
|
ArrayRef<AffineExpr> b) {
|
|
auto rangeA = llvm::make_range(a.begin(), a.end());
|
|
auto rangeB = llvm::make_range(b.begin(), b.end());
|
|
auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
|
|
return llvm::to_vector<4>(concatRanges);
|
|
}
|
|
|
|
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) {
|
|
if (auto memref = llvm::dyn_cast<MemRefType>(t)) {
|
|
ss << "view";
|
|
for (auto size : memref.getShape())
|
|
if (size < 0)
|
|
ss << "sx";
|
|
else
|
|
ss << size << "x";
|
|
if (failed(appendMangledType(ss, memref.getElementType())))
|
|
return failure();
|
|
if (auto as = memref.getMemorySpace()) {
|
|
if (auto attr = llvm::dyn_cast<IntegerAttr>(as))
|
|
ss << "as" << attr.getInt();
|
|
else
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
if (auto vec = llvm::dyn_cast<VectorType>(t)) {
|
|
ss << "vector";
|
|
llvm::interleave(
|
|
vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
|
|
if (failed(appendMangledType(ss, vec.getElementType())))
|
|
return failure();
|
|
return success();
|
|
}
|
|
if (t.isSignlessIntOrIndexOrFloat()) {
|
|
ss << t;
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
std::string mlir::linalg::generateLibraryCallName(Operation *op) {
|
|
assert(isa<LinalgOp>(op));
|
|
std::string name(op->getName().getStringRef().str());
|
|
std::string fun = "";
|
|
for (NamedAttribute kv : op->getAttrs()) {
|
|
if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
|
|
fun = stringifyEnum(ufa.getValue()).str() + "_";
|
|
} else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
|
|
fun = stringifyEnum(bfa.getValue()).str() + "_";
|
|
}
|
|
}
|
|
name.reserve(128);
|
|
llvm::replace(name, '.', '_');
|
|
llvm::raw_string_ostream ss(name);
|
|
ss << "_" << fun;
|
|
for (Type t : op->getOperandTypes()) {
|
|
if (failed(appendMangledType(ss, t)))
|
|
return std::string();
|
|
ss << "_";
|
|
}
|
|
name.pop_back();
|
|
return name;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Canonicalizers and Folders.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
|
|
using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(LinalgOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
for (OpOperand &opOperand : op->getOpOperands()) {
|
|
// Linalg "inputs" may be either tensor or memref type.
|
|
// tensor<0xelt_type> is a convention that may not always mean
|
|
// "0 iterations". Only erase in cases we see memref<...x0x...>.
|
|
auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
|
|
if (!mt)
|
|
continue;
|
|
if (llvm::is_contained(op.getShape(&opOperand), 0)) {
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
}
|
|
return failure();
|
|
}
|
|
};
|
|
|
|
/// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
|
|
/// result that is more static than the linalg op.
|
|
struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
|
|
using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tensor::CastOp castOp,
|
|
PatternRewriter &rewriter) const override {
|
|
if (!tensor::canFoldIntoProducerOp(castOp))
|
|
return failure();
|
|
|
|
auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
|
|
if (!linalgOp)
|
|
return failure();
|
|
|
|
// Cast can be in conditionally reachable region, if which case folding will
|
|
// generate invalid code. Only conservatively fold ops in same block for
|
|
// now.
|
|
if (castOp->getBlock() != linalgOp->getBlock())
|
|
return failure();
|
|
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
rewriter.setInsertionPoint(linalgOp);
|
|
|
|
Location loc = linalgOp.getLoc();
|
|
OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
|
|
unsigned resultNumber = resultValue.getResultNumber();
|
|
auto resultType =
|
|
llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
|
|
// Replace the `outs` for the result with a `tensor.cast`. This cast is now
|
|
// going from a more dynamic shape to a less dynamic shape. If the producer
|
|
// for this cast, i.e. producer of the out operand, is also an operation
|
|
// that folds with tensor.cast consumer (like this pattern), the cast will
|
|
// continue to propagate as far up the stack as it can go.
|
|
OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
|
|
Value newOperand =
|
|
tensor::CastOp::create(rewriter, loc, resultType, outOperand->get());
|
|
SmallVector<Value> newOperands = linalgOp.getDpsInputs();
|
|
SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
|
|
linalgOp.getDpsInits().end());
|
|
outputOperands[resultNumber] = newOperand;
|
|
newOperands.append(outputOperands.begin(), outputOperands.end());
|
|
|
|
SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
|
|
linalgOp->result_type_end());
|
|
resultTypes[resultNumber] = resultType;
|
|
Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
|
|
|
|
// Create a tensor.cast operation back to the original type.
|
|
Value castBack = tensor::CastOp::create(
|
|
rewriter, loc, resultValue.getType(), newOp->getResult(resultNumber));
|
|
|
|
SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
|
|
results[resultNumber] = castBack;
|
|
rewriter.replaceOp(linalgOp, results);
|
|
rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// For each of the operand in `operands` this function maps the static sizes of
|
|
/// dimensions to their affine dim expressions.
|
|
static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
|
|
llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
|
|
for (OpOperand &opOperand : operands) {
|
|
if (linalgOp.isScalar(&opOperand))
|
|
continue;
|
|
Value src = opOperand.get();
|
|
auto sourceType = llvm::cast<RankedTensorType>(src.getType());
|
|
auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
|
|
|
|
// Get the `sourceShape` of the `sourceType`. If the operand is a result of
|
|
// `tensor.cast` operation and source of the cast operation has a static
|
|
// shape, then assign it to the `sourceShape`.
|
|
auto *parentOp = src.getDefiningOp();
|
|
ArrayRef<int64_t> sourceShape = sourceType.getShape();
|
|
if (parentOp) {
|
|
if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
|
|
Value castSource = castOp.getSource();
|
|
auto castSourceType =
|
|
llvm::dyn_cast<RankedTensorType>(castSource.getType());
|
|
if (castSourceType && castSourceType.hasStaticShape())
|
|
sourceShape = castSourceType.getShape();
|
|
}
|
|
}
|
|
|
|
// If the source shape's dimension has a static shape, map the affine dim
|
|
// expression to the known static size.
|
|
for (unsigned i = 0; i < sourceShape.size(); i++) {
|
|
if (sourceType.isDynamicDim(i))
|
|
continue;
|
|
if (auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
|
|
affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
|
|
/// mapped in `affineExprToSize`. New operands are created in `newOperands` and
|
|
/// their result types is stored in `resultTypes`. If `opOperand` requires no
|
|
/// change then `changeNeeded` is false and same operand is added in the
|
|
/// `newOperands` list.
|
|
static void createNewOperandWithStaticSizes(
|
|
Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
|
|
llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
|
|
SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
|
|
bool &changeNeeded) {
|
|
Value src = opOperand->get();
|
|
newOperands.push_back(src);
|
|
if (linalgOp.isScalar(opOperand))
|
|
return;
|
|
auto sourceType = llvm::cast<RankedTensorType>(src.getType());
|
|
Type resultType = sourceType;
|
|
if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
|
|
resultTypes.push_back(resultType);
|
|
return;
|
|
}
|
|
ArrayRef<int64_t> sourceShape = sourceType.getShape();
|
|
AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
|
|
SmallVector<int64_t> newShape;
|
|
// If operand is updated with new shape, `newOperandNeeded` will be
|
|
// true.
|
|
bool newOperandNeeded = false;
|
|
for (unsigned i = 0; i < sourceShape.size(); i++) {
|
|
int64_t dimShape = sourceShape[i];
|
|
AffineExpr dimExpr = sourceMap.getResult(i);
|
|
if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
|
|
newShape.push_back(dimShape);
|
|
continue;
|
|
}
|
|
// Dimension has a dynamic shape and corresponding affine dim
|
|
// expression is present in the map. So assign the size for the
|
|
// given affine dim expression to the dimension.
|
|
newShape.push_back(affineExprToSize[dimExpr]);
|
|
newOperandNeeded = true;
|
|
}
|
|
resultType = RankedTensorType::get(newShape, sourceType.getElementType(),
|
|
sourceType.getEncoding());
|
|
if (newOperandNeeded) {
|
|
changeNeeded = true;
|
|
// Get the new operand value given its size and element type by
|
|
// casting it.
|
|
Value newOperand = tensor::CastOp::create(rewriter, loc, resultType, src);
|
|
unsigned index = opOperand->getOperandNumber();
|
|
newOperands[index] = newOperand;
|
|
}
|
|
if (linalgOp.isDpsInit(opOperand))
|
|
resultTypes.push_back(resultType);
|
|
}
|
|
|
|
/// Static shapes for the operands can be inferred if any one of the operands
|
|
/// have a static shape. This can be done by referring to the affine dim
|
|
/// expressions for the operand.
|
|
struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
|
|
using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(LinalgOp linalgOp,
|
|
PatternRewriter &rewriter) const override {
|
|
if (!linalgOp.hasPureTensorSemantics())
|
|
return failure();
|
|
|
|
// Maps must be projected permutations.
|
|
if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
|
|
return !map.isProjectedPermutation();
|
|
}))
|
|
return failure();
|
|
|
|
// Maps affine dim expressions to the static size of that dimension.
|
|
llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
|
|
Location loc = linalgOp.getLoc();
|
|
|
|
// For each of the affine dim expression, check if the size is known. If
|
|
// known add that in the map.
|
|
populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
|
|
|
|
SmallVector<Value> newOperands;
|
|
SmallVector<Type> resultTypes;
|
|
|
|
// `changeNeeded` is `false` if the operands of `linalgOp` require no
|
|
// change in their types.
|
|
bool changeNeeded = false;
|
|
newOperands.reserve(linalgOp->getNumOperands());
|
|
resultTypes.reserve(linalgOp.getNumDpsInits());
|
|
|
|
// Iterate over all the operands and update the static sizes.
|
|
for (OpOperand &opOperand : linalgOp->getOpOperands()) {
|
|
createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
|
|
affineExprToSize, linalgOp, newOperands,
|
|
resultTypes, changeNeeded);
|
|
}
|
|
|
|
// If the generic op has all the required static information, no
|
|
// canonicalization needed.
|
|
if (!changeNeeded)
|
|
return failure();
|
|
|
|
// Clone op.
|
|
Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
|
|
SmallVector<Value> replacements;
|
|
replacements.reserve(newOp->getNumResults());
|
|
for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
|
|
Value newResult = std::get<1>(it);
|
|
Value oldResult = std::get<0>(it);
|
|
Type newType = newResult.getType();
|
|
Type oldType = oldResult.getType();
|
|
replacements.push_back(
|
|
(newType != oldType)
|
|
? tensor::CastOp::create(rewriter, loc, oldType, newResult)
|
|
: newResult);
|
|
}
|
|
rewriter.replaceOp(linalgOp, replacements);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
// All named ops canonicalizers and folders are auto-generated in the
|
|
// .cpp.inc.
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SoftmaxOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult SoftmaxOp::verify() {
|
|
ShapedType inputType = getInputOperandType();
|
|
ShapedType outputType = getOutputOperandType();
|
|
|
|
ArrayRef<int64_t> inputShape = inputType.getShape();
|
|
ArrayRef<int64_t> outputShape = outputType.getShape();
|
|
if (failed(verifyCompatibleShape(inputShape, outputShape)))
|
|
return emitOpError("incompatible output shape");
|
|
|
|
int64_t inputRank = getInputOperandRank();
|
|
int64_t dimension = getDimension();
|
|
if ((dimension < 0) || (dimension >= inputRank))
|
|
return emitOpError("incorrect dimension specified");
|
|
|
|
return success();
|
|
}
|
|
|
|
SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
|
|
int64_t operandRank = getInputOperandRank();
|
|
SmallVector<Range> loopBounds(operandRank);
|
|
Location loc = getLoc();
|
|
Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
|
|
Value one = arith::ConstantIndexOp::create(builder, loc, 1);
|
|
Value source = getInput();
|
|
for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
|
|
loopBounds[dim].offset = zero;
|
|
loopBounds[dim].size = getDimValue(builder, loc, source, dim);
|
|
loopBounds[dim].stride = one;
|
|
}
|
|
return loopBounds;
|
|
}
|
|
|
|
SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
|
|
SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
|
|
utils::IteratorType::parallel);
|
|
iteratorTypes[getDimension()] = utils::IteratorType::reduction;
|
|
return iteratorTypes;
|
|
}
|
|
|
|
FailureOr<TilingResult>
|
|
SoftmaxOp::getTiledImplementation(OpBuilder &builder,
|
|
ArrayRef<OpFoldResult> offsets,
|
|
ArrayRef<OpFoldResult> sizes) {
|
|
int64_t rank = getInputOperandRank();
|
|
auto oneAttr = builder.getI64IntegerAttr(1);
|
|
SmallVector<OpFoldResult> strides(rank, oneAttr);
|
|
SmallVector<Value> tiledOperands;
|
|
Operation *inputSlice =
|
|
getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
|
|
if (!inputSlice) {
|
|
return emitOpError("failed to compute input slice");
|
|
}
|
|
tiledOperands.emplace_back(inputSlice->getResult(0));
|
|
Operation *outputSlice =
|
|
getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
|
|
if (!outputSlice) {
|
|
return emitOpError("failed to compute output slice");
|
|
}
|
|
tiledOperands.emplace_back(outputSlice->getResult(0));
|
|
|
|
SmallVector<Type, 4> resultTypes;
|
|
if (hasPureTensorSemantics())
|
|
resultTypes.push_back(tiledOperands[1].getType());
|
|
Operation *tiledOp =
|
|
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
|
|
|
|
return TilingResult{
|
|
{tiledOp},
|
|
SmallVector<Value>(tiledOp->getResults()),
|
|
llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
|
|
}
|
|
|
|
LogicalResult SoftmaxOp::getResultTilePosition(
|
|
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
|
|
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
|
|
SmallVector<OpFoldResult> &resultSizes) {
|
|
if (resultNumber == 0) {
|
|
resultOffsets.assign(offsets.begin(), offsets.end());
|
|
resultSizes.assign(sizes.begin(), sizes.end());
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
// cast(dynamic) -> static.
|
|
LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
|
|
return memref::foldMemRefCast(*this);
|
|
}
|
|
|
|
LogicalResult
|
|
SoftmaxOp::reifyResultShapes(OpBuilder &b,
|
|
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
|
|
SmallVector<OpFoldResult> shapes;
|
|
Location loc = getOperation()->getLoc();
|
|
IRRewriter rewriter(b);
|
|
auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
|
|
auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
|
|
for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
|
|
if (!outputShapedType.isDynamicDim(dim)) {
|
|
// Static dim: Return IntegerAttr.
|
|
shapes.push_back(b.getIndexAttr(inputShapedType.getDimSize(dim)));
|
|
} else {
|
|
// Dynamic dim: Return Value.
|
|
OpFoldResult ofr = createOrFoldDimOp(b, loc, getInput(), dim);
|
|
shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
|
|
}
|
|
}
|
|
reifiedReturnShapes.emplace_back(std::move(shapes));
|
|
return success();
|
|
}
|
|
|
|
void SoftmaxOp::getEffects(
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
&effects) {
|
|
for (auto [index, operand] : llvm::enumerate(getDpsInputs())) {
|
|
if (!llvm::isa<MemRefType>(operand.getType()))
|
|
continue;
|
|
effects.emplace_back(MemoryEffects::Read::get(),
|
|
&getOperation()->getOpOperand(index), /*stage=*/0,
|
|
/*effectOnFullRegion=*/true,
|
|
SideEffects::DefaultResource::get());
|
|
}
|
|
|
|
for (OpOperand &operand : getDpsInitsMutable()) {
|
|
if (!llvm::isa<MemRefType>(operand.get().getType()))
|
|
continue;
|
|
effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
|
|
/*effectOnFullRegion=*/true,
|
|
SideEffects::DefaultResource::get());
|
|
effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
|
|
/*effectOnFullRegion=*/true,
|
|
SideEffects::DefaultResource::get());
|
|
}
|
|
}
|
|
|
|
// Helper functions for softmax decomposition.
|
|
// @{
|
|
|
|
// Helper function to produce the iterator types (reduction or parallel) and
|
|
// affine maps for the iterators used in the decomposition of softmax.
|
|
// This method creates:
|
|
// If allParallel == true:
|
|
// - iterator type: {parallel, ..., parallel}
|
|
// - affine maps:
|
|
// -- identity with inputRank dimensions.
|
|
// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
|
|
// where N == inputRank.
|
|
//
|
|
// If allParallel == false:
|
|
// - iterator type at dim(i) == parallel for i != \p dim and
|
|
// dim(dim) == reduction.
|
|
// - affine map:
|
|
// -- identity with inputRank dimensions.
|
|
// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
|
|
// where N == inputRank.
|
|
static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
|
|
computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank,
|
|
int64_t dim, bool allParallel = false) {
|
|
SmallVector<utils::IteratorType> iteratorTypes(inputRank,
|
|
utils::IteratorType::parallel);
|
|
if (!allParallel)
|
|
iteratorTypes[dim] = utils::IteratorType::reduction;
|
|
MLIRContext *ctxt = builder.getContext();
|
|
auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
|
|
SmallVector<AffineExpr, 2> affineExprs;
|
|
for (int i = 0; i < inputRank; i++) {
|
|
if (i != dim)
|
|
affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
|
|
}
|
|
auto reductionMap =
|
|
AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
|
|
SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
|
|
return std::make_tuple(iteratorTypes, indexingMaps);
|
|
}
|
|
|
|
// Helper function to produce a linalg.generic that computes a reduction on
|
|
// dimension \p dim with the operation type \p T.
|
|
template <typename T>
|
|
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
|
|
int64_t dim) {
|
|
auto inputType = cast<ShapedType>(input.getType());
|
|
ArrayRef<int64_t> inputShape = inputType.getShape();
|
|
int64_t inputRank = inputShape.size();
|
|
auto [iteratorTypes, indexingMaps] =
|
|
computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
|
|
assert(indexingMaps.size() == 2 &&
|
|
"We should have two maps: 1 for the input, 1 for the output");
|
|
assert(indexingMaps[0].isIdentity() && "input map should be identity");
|
|
|
|
auto genericOp = linalg::GenericOp::create(
|
|
builder, loc, output.getType(), input, output, indexingMaps,
|
|
iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
|
|
Value result = T::create(b, loc, args[0], args[1]);
|
|
linalg::YieldOp::create(b, loc, result);
|
|
});
|
|
return genericOp.getResult(0);
|
|
}
|
|
|
|
/// Produce a linalg generic that computes the second step of the softmax
|
|
/// decomposition: res = exp(input - max), where \p max is the max of \p input
|
|
/// on dimension \p dim.
|
|
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
|
|
Value max, Value output, int64_t dim) {
|
|
auto inputType = cast<ShapedType>(input.getType());
|
|
ArrayRef<int64_t> inputShape = inputType.getShape();
|
|
int64_t inputRank = inputShape.size();
|
|
auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
|
|
builder, inputRank, dim, /*allParallel=*/true);
|
|
assert(indexingMaps.size() == 2 && "We should have one map for each input");
|
|
assert(indexingMaps[0].isIdentity() && "input map should be identity");
|
|
// Add the affine map for the output argument.
|
|
indexingMaps.push_back(indexingMaps[0]);
|
|
auto genericOp = linalg::GenericOp::create(
|
|
builder, loc, input.getType(), ValueRange{input, max}, output,
|
|
indexingMaps, iteratorTypes,
|
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
|
Value diff = arith::SubFOp::create(b, loc, args[0], args[1]);
|
|
Value result = math::ExpOp::create(b, loc, diff);
|
|
linalg::YieldOp::create(b, loc, result);
|
|
});
|
|
return genericOp.getResult(0);
|
|
}
|
|
|
|
/// Produce a linalg generic that computes the final step of the softmax
|
|
/// decomposition.
|
|
/// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
|
|
/// yield n / d
|
|
/// }
|
|
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
|
|
Value denominator, Value output, int64_t dim) {
|
|
auto inputType = cast<ShapedType>(numerator.getType());
|
|
ArrayRef<int64_t> inputShape = inputType.getShape();
|
|
int64_t inputRank = inputShape.size();
|
|
auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
|
|
builder, inputRank, dim, /*allParallel=*/true);
|
|
assert(indexingMaps.size() == 2 &&
|
|
"We should have one map for each input (2)");
|
|
assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
|
|
// Add the affine map for the output tensor.
|
|
indexingMaps.push_back(indexingMaps[0]);
|
|
auto genericOp = linalg::GenericOp::create(
|
|
builder, loc, numerator.getType(), ValueRange{numerator, denominator},
|
|
output, indexingMaps, iteratorTypes,
|
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
|
Value result = arith::DivFOp::create(b, loc, args[0], args[1]);
|
|
linalg::YieldOp::create(b, loc, result);
|
|
});
|
|
return genericOp.getResult(0);
|
|
}
|
|
// @} End helper functions for softmax decomposition.
|
|
|
|
/// Given an N-dimensional tensor x, this method converts
|
|
/// softmax(x) to the following sequence of operations:
|
|
///
|
|
/// 1. Compute the max of x along dimension d. This results
|
|
/// in a N-1 dimensional tensor m.
|
|
/// m = max(x, dim = d)
|
|
///
|
|
/// 2. Subtract a broadcasted m from x and exponentiate. This results in
|
|
/// a N dimensional tensor z.
|
|
/// z = exp(x - m)
|
|
///
|
|
/// 3. Compute the sum of z along dimension d. This results in
|
|
/// a N-1 dimensional tensor l.
|
|
/// l = sum(z, dim = d)
|
|
///
|
|
/// 4. Divide z and l. This gives the N-dimensional softmax.
|
|
/// softmax = z / l
|
|
///
|
|
FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
|
|
OpBuilder::InsertionGuard guard(b);
|
|
b.setInsertionPoint(*this);
|
|
Location loc = getLoc();
|
|
Value input = getInput();
|
|
ShapedType inputType = getInputOperandType();
|
|
Type elementType = inputType.getElementType();
|
|
int64_t reductionDim = getDimension();
|
|
SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
|
|
Value output = getOutput();
|
|
dims.erase(dims.begin() + reductionDim);
|
|
// Step 1: Compute max along dim.
|
|
Value outputReduce = tensor::EmptyOp::create(b, loc, dims, elementType);
|
|
Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maxnumf,
|
|
elementType, b, loc,
|
|
/*useOnlyFiniteValue=*/true);
|
|
Value neutralForMaxFInit =
|
|
linalg::FillOp::create(b, loc, Value{neutralForMaxF}, outputReduce)
|
|
.result();
|
|
Value max =
|
|
reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
|
|
|
|
// Step 2: Subtract max from input and exponentiate.
|
|
Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
|
|
|
|
// Step 3: Compute sum along dim.
|
|
Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
|
|
b, loc, /*useOnlyFiniteValue=*/true);
|
|
Value zeroInit =
|
|
linalg::FillOp::create(b, loc, Value{zero}, outputReduce).result();
|
|
Value denominator =
|
|
reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
|
|
|
|
// Step 4: Compute softmax.
|
|
Value result =
|
|
buildDivOp(b, loc, numerator, denominator, output, reductionDim);
|
|
return SmallVector<Value>{result};
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// WinogradFilterTransformOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult WinogradFilterTransformOp::verify() {
|
|
auto filterType = cast<ShapedType>(getFilter().getType());
|
|
ArrayRef<int64_t> filterShape = filterType.getShape();
|
|
int64_t filterH = filterShape[getFilterHDim()];
|
|
int64_t filterW = filterShape[getFilterWDim()];
|
|
WinogradConv2DFmr fmr = getFmr();
|
|
int64_t m, r;
|
|
std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
|
|
|
|
if (filterH != r && filterH != 1)
|
|
return emitOpError("expect filter height either equals to r or 1");
|
|
if (filterW != r && filterW != 1)
|
|
return emitOpError("expect filter width either equals to r or 1");
|
|
if (filterH == 1 && filterW == 1)
|
|
return emitOpError("expect either filter height or width equals to r");
|
|
|
|
SmallVector<int64_t> expectedOutputShape;
|
|
expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
|
|
expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
|
|
expectedOutputShape.push_back(filterShape[getFilterCDim()]);
|
|
expectedOutputShape.push_back(filterShape[getFilterFDim()]);
|
|
|
|
auto outputType = cast<ShapedType>(getOutput().getType());
|
|
ArrayRef<int64_t> outputShape = outputType.getShape();
|
|
if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
|
|
return emitOpError("the output shape is not expected");
|
|
}
|
|
return success();
|
|
}
|
|
|
|
SmallVector<Range>
|
|
WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
|
|
Location loc = getLoc();
|
|
IntegerAttr zeroAttr = builder.getIndexAttr(0);
|
|
IntegerAttr oneAttr = builder.getIndexAttr(1);
|
|
Value filter = getFilter();
|
|
int64_t filterRank = getFilterOperandRank();
|
|
SmallVector<Range> loopBounds(filterRank);
|
|
for (unsigned dim = 0; dim < filterRank; ++dim) {
|
|
loopBounds[dim].offset = zeroAttr;
|
|
loopBounds[dim].size = getDimValue(builder, loc, filter, dim);
|
|
loopBounds[dim].stride = oneAttr;
|
|
}
|
|
return loopBounds;
|
|
}
|
|
|
|
SmallVector<utils::IteratorType>
|
|
WinogradFilterTransformOp::getLoopIteratorTypes() {
|
|
int64_t filterRank = getFilterOperandRank();
|
|
SmallVector<utils::IteratorType> iteratorTypes(filterRank,
|
|
utils::IteratorType::parallel);
|
|
return iteratorTypes;
|
|
}
|
|
|
|
LogicalResult WinogradFilterTransformOp::getResultTilePosition(
|
|
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
|
|
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
|
|
SmallVector<OpFoldResult> &resultSizes) {
|
|
IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
|
|
ShapedType filterType = getFilterOperandType();
|
|
ArrayRef<int64_t> filterShape = filterType.getShape();
|
|
int64_t filterH = filterShape[getFilterHDim()];
|
|
int64_t filterW = filterShape[getFilterWDim()];
|
|
WinogradConv2DFmr fmr = getFmr();
|
|
int64_t m, r;
|
|
std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
|
|
int64_t alpha = m + r - 1;
|
|
int64_t alphaH = filterH != 1 ? alpha : 1;
|
|
int64_t alphaW = filterW != 1 ? alpha : 1;
|
|
IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
|
|
IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
|
|
|
|
resultOffsets.append(
|
|
{zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
|
|
resultSizes.append(
|
|
{alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
|
|
|
|
return success();
|
|
}
|
|
|
|
/// Implement tiling for winograd_filter_transform
|
|
/// The input of winograd_filter_transform is (F, KH, KW, C).
|
|
/// The output of winograd_filter_transform is (alphaH, alphaW, C, F)
|
|
/// Users can specify the tile sizes of F and C.
|
|
/// `offsets` are the values for the offsets of F, KH, KW, C for one tile.
|
|
/// `sizes` are the values for the sizes of F, KH, KW, C for one tile.
|
|
FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
|
|
OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
|
|
ArrayRef<OpFoldResult> sizes) {
|
|
IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
|
|
IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
|
|
ShapedType filterType = getFilterOperandType();
|
|
ArrayRef<int64_t> filterShape = filterType.getShape();
|
|
int64_t filterH = filterShape[getFilterHDim()];
|
|
int64_t filterW = filterShape[getFilterWDim()];
|
|
IntegerAttr filterHAttr = builder.getI64IntegerAttr(filterH);
|
|
IntegerAttr filterWAttr = builder.getI64IntegerAttr(filterW);
|
|
SmallVector<Value> tiledOperands;
|
|
SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
|
|
|
|
sliceOffsets.append(
|
|
{offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
|
|
sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
|
|
sizes[getFilterCDim()]});
|
|
int64_t filterRank = getFilterOperandRank();
|
|
SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
|
|
Location loc = getLoc();
|
|
auto filterSlice = tensor::ExtractSliceOp::create(
|
|
builder, loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
|
|
tiledOperands.emplace_back(filterSlice);
|
|
|
|
SmallVector<OpFoldResult> resultOffsets, resultSizes;
|
|
if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
|
|
resultSizes)))
|
|
return failure();
|
|
|
|
int64_t outputRank = getOutputOperandRank();
|
|
SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
|
|
auto outputSlice = tensor::ExtractSliceOp::create(
|
|
builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
|
|
tiledOperands.emplace_back(outputSlice);
|
|
|
|
SmallVector<Type> resultTypes;
|
|
resultTypes.push_back(tiledOperands[1].getType());
|
|
Operation *tiledOp =
|
|
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
|
|
|
|
return TilingResult{
|
|
{tiledOp},
|
|
SmallVector<Value>(tiledOp->getResults()),
|
|
llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// WinogradInputTransformOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult WinogradInputTransformOp::verify() {
|
|
auto inputType = cast<ShapedType>(getInput().getType());
|
|
ArrayRef<int64_t> inputShape = inputType.getShape();
|
|
int64_t inputH = inputShape[getInputHDim()];
|
|
int64_t inputW = inputShape[getInputWDim()];
|
|
WinogradConv2DFmr fmr = getFmr();
|
|
int64_t m, r;
|
|
std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
|
|
int64_t tileSize = m + r - 1;
|
|
|
|
auto outputType = cast<ShapedType>(getOutput().getType());
|
|
ArrayRef<int64_t> outputShape = outputType.getShape();
|
|
bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
|
|
bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
|
|
|
|
SmallVector<int64_t> expectedOutputShape(6, inputH);
|
|
if (ShapedType::isDynamic(inputH)) {
|
|
expectedOutputShape[getOutputAlphaHDim()] = tileSize;
|
|
expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
|
|
} else {
|
|
expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
|
|
expectedOutputShape[getOutputTileHDim()] =
|
|
leftTransform ? (inputH - (r - 1)) / m : inputH;
|
|
}
|
|
if (ShapedType::isDynamic(inputW)) {
|
|
expectedOutputShape[getOutputAlphaWDim()] = tileSize;
|
|
expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
|
|
} else {
|
|
expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
|
|
expectedOutputShape[getOutputTileWDim()] =
|
|
rightTransform ? (inputW - (r - 1)) / m : inputW;
|
|
}
|
|
expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
|
|
expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
|
|
|
|
if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
|
|
return emitOpError("the output shape is not expected");
|
|
}
|
|
return success();
|
|
}
|
|
|
|
SmallVector<Range>
|
|
WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
|
|
Location loc = getLoc();
|
|
IntegerAttr zeroAttr = builder.getIndexAttr(0);
|
|
IntegerAttr oneAttr = builder.getIndexAttr(1);
|
|
Value output = getOutput();
|
|
int64_t outputRank = getOutputOperandRank();
|
|
SmallVector<Range> loopBounds(outputRank);
|
|
for (unsigned dim = 0; dim < outputRank; ++dim) {
|
|
loopBounds[dim].offset = zeroAttr;
|
|
// alphaH, alphaW, tileH, tileW, N, C
|
|
loopBounds[dim].size = getDimValue(builder, loc, output, dim);
|
|
loopBounds[dim].stride = oneAttr;
|
|
}
|
|
return loopBounds;
|
|
}
|
|
|
|
SmallVector<utils::IteratorType>
|
|
WinogradInputTransformOp::getLoopIteratorTypes() {
|
|
int64_t outputRank = getOutputOperandRank();
|
|
SmallVector<utils::IteratorType> iteratorTypes(outputRank,
|
|
utils::IteratorType::parallel);
|
|
return iteratorTypes;
|
|
}
|
|
|
|
LogicalResult WinogradInputTransformOp::getResultTilePosition(
|
|
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
|
|
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
|
|
SmallVector<OpFoldResult> &resultSizes) {
|
|
IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
|
|
ShapedType outputType = getOutputOperandType();
|
|
ArrayRef<int64_t> outputShape = outputType.getShape();
|
|
int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
|
|
int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
|
|
|
|
WinogradConv2DFmr fmr = getFmr();
|
|
int64_t m, r;
|
|
std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
|
|
int64_t alpha = m + r - 1;
|
|
int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
|
|
int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
|
|
|
|
IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
|
|
IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
|
|
|
|
resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
|
|
offsets[getOutputTileWDim()], offsets[getOutputNDim()],
|
|
offsets[getOutputCDim()]});
|
|
resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
|
|
sizes[getOutputTileWDim()], sizes[getOutputNDim()],
|
|
sizes[getOutputCDim()]});
|
|
|
|
return success();
|
|
}
|
|
|
|
/// Implement tiling for winograd_input_transform
|
|
/// The input of winograd_input_transform is (N, H, W, C).
|
|
/// The output of winograd_input_transform is (alphaH, alphaW, tileH, tileW, N,
|
|
/// C) Users can specify the tile sizes of tileH, tileW, N, and C. `offsets` are
|
|
/// the values for the offsets of tileH, tileW, N, C for one tile. `sizes` are
|
|
/// the values for the sizes of tileH, tileW, N, C for one tile.
|
|
FailureOr<TilingResult>
|
|
WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
|
|
ArrayRef<OpFoldResult> offsets,
|
|
ArrayRef<OpFoldResult> sizes) {
|
|
IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
|
|
WinogradConv2DFmr fmr = getFmr();
|
|
int64_t m, r;
|
|
std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
|
|
|
|
ShapedType outputType = getOutputOperandType();
|
|
ArrayRef<int64_t> outputShape = outputType.getShape();
|
|
int64_t alphaH = outputShape[getOutputAlphaHDim()];
|
|
int64_t alphaW = outputShape[getOutputAlphaWDim()];
|
|
|
|
Location loc = getLoc();
|
|
MLIRContext *context = builder.getContext();
|
|
auto identityAffineMap =
|
|
AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
|
|
auto offsetAffineMap =
|
|
AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
|
|
Value mappedOffsetH = affine::makeComposedAffineApply(
|
|
builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
|
|
offsets[getOutputTileHDim()]);
|
|
Value mappedOffsetW = affine::makeComposedAffineApply(
|
|
builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
|
|
offsets[getOutputTileWDim()]);
|
|
auto sizeAffineMap = AffineMap::get(
|
|
1, 0, {builder.getAffineDimExpr(0) * m + (r - 1)}, context);
|
|
Value mappedSizeH = affine::makeComposedAffineApply(
|
|
builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
|
|
Value mappedSizeW = affine::makeComposedAffineApply(
|
|
builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
|
|
|
|
SmallVector<Value> tiledOperands;
|
|
SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
|
|
|
|
OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
|
|
OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
|
|
sliceOffsets.append(
|
|
{offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
|
|
OpFoldResult sizeH =
|
|
alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
|
|
OpFoldResult sizeW =
|
|
alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
|
|
sliceSizes.append(
|
|
{sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
|
|
int64_t inputRank = getInputOperandRank();
|
|
SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
|
|
auto inputSlice = tensor::ExtractSliceOp::create(
|
|
builder, loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
|
|
tiledOperands.emplace_back(inputSlice);
|
|
|
|
SmallVector<OpFoldResult> resultOffsets, resultSizes;
|
|
if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
|
|
resultSizes)))
|
|
return failure();
|
|
|
|
int64_t outputRank = getOutputOperandRank();
|
|
SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
|
|
auto outputSlice = tensor::ExtractSliceOp::create(
|
|
builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
|
|
tiledOperands.emplace_back(outputSlice);
|
|
|
|
SmallVector<Type> resultTypes;
|
|
resultTypes.push_back(tiledOperands[1].getType());
|
|
Operation *tiledOp =
|
|
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
|
|
|
|
return TilingResult{
|
|
{tiledOp},
|
|
SmallVector<Value>(tiledOp->getResults()),
|
|
llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// WinogradOutputTransformOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult WinogradOutputTransformOp::verify() {
|
|
auto valueType = cast<ShapedType>(getValue().getType());
|
|
ArrayRef<int64_t> valueShape = valueType.getShape();
|
|
int64_t valueH = valueShape[getValueAlphaHDim()];
|
|
int64_t valueW = valueShape[getValueAlphaWDim()];
|
|
int64_t valueTileH = valueShape[getValueTileHDim()];
|
|
int64_t valueTileW = valueShape[getValueTileWDim()];
|
|
WinogradConv2DFmr fmr = getFmr();
|
|
int64_t m, r;
|
|
std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
|
|
bool leftTransform = valueH != 1;
|
|
bool rightTransform = valueW != 1;
|
|
|
|
int64_t outputRank = getOutputOperandRank();
|
|
SmallVector<int64_t> expectedOutputShape(outputRank, valueH);
|
|
if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
|
|
expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
|
|
} else {
|
|
if (valueH != (leftTransform ? m + r - 1 : 1))
|
|
return emitOpError("expect input height equals to input tile size");
|
|
expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
|
|
}
|
|
if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
|
|
expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
|
|
} else {
|
|
if (valueW != (rightTransform ? m + r - 1 : 1))
|
|
return emitOpError("expect input width equals to input tile size");
|
|
expectedOutputShape[getOutputWDim()] =
|
|
(rightTransform ? m : 1) * valueTileW;
|
|
}
|
|
expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
|
|
expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
|
|
|
|
auto outputType = cast<ShapedType>(getOutput().getType());
|
|
ArrayRef<int64_t> outputShape = outputType.getShape();
|
|
if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
|
|
return emitOpError("the output shape is not expected");
|
|
}
|
|
return success();
|
|
}
|
|
|
|
SmallVector<Range>
|
|
WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
|
|
Location loc = getLoc();
|
|
IntegerAttr zeroAttr = builder.getIndexAttr(0);
|
|
IntegerAttr oneAttr = builder.getIndexAttr(1);
|
|
Value value = getValue();
|
|
int64_t valueRank = getValueOperandRank();
|
|
SmallVector<Range> loopBounds(valueRank);
|
|
for (unsigned dim = 0; dim < valueRank; ++dim) {
|
|
loopBounds[dim].offset = zeroAttr;
|
|
// alphaH, alphaW, tileH, tileW, N, F
|
|
loopBounds[dim].size = getDimValue(builder, loc, value, dim);
|
|
loopBounds[dim].stride = oneAttr;
|
|
}
|
|
return loopBounds;
|
|
}
|
|
|
|
SmallVector<utils::IteratorType>
|
|
WinogradOutputTransformOp::getLoopIteratorTypes() {
|
|
int64_t valueRank = getValueOperandRank();
|
|
SmallVector<utils::IteratorType> iteratorTypes(valueRank,
|
|
utils::IteratorType::parallel);
|
|
return iteratorTypes;
|
|
}
|
|
|
|
LogicalResult WinogradOutputTransformOp::getResultTilePosition(
|
|
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
|
|
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
|
|
SmallVector<OpFoldResult> &resultSizes) {
|
|
WinogradConv2DFmr fmr = getFmr();
|
|
int64_t m, r;
|
|
std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
|
|
|
|
Location loc = getLoc();
|
|
MLIRContext *context = builder.getContext();
|
|
auto identityAffineMap =
|
|
AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
|
|
auto affineMap =
|
|
AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
|
|
|
|
ShapedType valueType = getValueOperandType();
|
|
ArrayRef<int64_t> valueShape = valueType.getShape();
|
|
int64_t valueH = valueShape[0];
|
|
int64_t valueW = valueShape[1];
|
|
Value mappedOffsetH = affine::makeComposedAffineApply(
|
|
builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
|
|
offsets[getValueTileHDim()]);
|
|
Value mappedOffsetW = affine::makeComposedAffineApply(
|
|
builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
|
|
offsets[getValueTileWDim()]);
|
|
Value mappedSizeH = affine::makeComposedAffineApply(
|
|
builder, loc, affineMap, sizes[getValueTileHDim()]);
|
|
Value mappedSizeW = affine::makeComposedAffineApply(
|
|
builder, loc, affineMap, sizes[getValueTileWDim()]);
|
|
|
|
IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
|
|
OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
|
|
OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
|
|
OpFoldResult sizeH =
|
|
valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
|
|
OpFoldResult sizeW =
|
|
valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
|
|
|
|
resultOffsets.append(
|
|
{offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
|
|
resultSizes.append(
|
|
{sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
|
|
return success();
|
|
}
|
|
|
|
/// Implement tiling for winograd_output_transform
|
|
/// The input of winograd_output_transform is (alphaH, alphaW, tileH, tileW, N,
|
|
/// F). The output of winograd_output_transform is (N, H, W, F) Users can
|
|
/// specify the tile sizes of tileH, tileW, N, and F. `offsets` are the values
|
|
/// for the offsets of tileH, tileW, N, F for one tile. `sizes` are the values
|
|
/// for the sizes of tileH, tileW, N, F for one tile.
|
|
FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
|
|
OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
|
|
ArrayRef<OpFoldResult> sizes) {
|
|
IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
|
|
IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
|
|
Location loc = getLoc();
|
|
SmallVector<Value> tiledOperands;
|
|
SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
|
|
|
|
ShapedType valueType = getValueOperandType();
|
|
ArrayRef<int64_t> valueShape = valueType.getShape();
|
|
int64_t alphaH = valueShape[getValueAlphaHDim()];
|
|
int64_t alphaW = valueShape[getValueAlphaWDim()];
|
|
IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
|
|
IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
|
|
|
|
sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
|
|
offsets[getValueTileWDim()], offsets[getValueNDim()],
|
|
offsets[getValueFDim()]});
|
|
sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
|
|
sizes[getValueTileWDim()], sizes[getValueNDim()],
|
|
sizes[getValueFDim()]});
|
|
int64_t valueRank = getValueOperandRank();
|
|
SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
|
|
auto valueSlice = tensor::ExtractSliceOp::create(
|
|
builder, loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
|
|
tiledOperands.emplace_back(valueSlice);
|
|
|
|
SmallVector<OpFoldResult> resultOffsets, resultSizes;
|
|
if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
|
|
resultSizes)))
|
|
return failure();
|
|
|
|
int64_t outputRank = getOutputOperandRank();
|
|
SmallVector<OpFoldResult> strides(outputRank, oneAttr);
|
|
auto outputSlice = tensor::ExtractSliceOp::create(
|
|
builder, loc, getOutput(), resultOffsets, resultSizes, strides);
|
|
tiledOperands.emplace_back(outputSlice);
|
|
|
|
SmallVector<Type> resultTypes;
|
|
resultTypes.push_back(tiledOperands[1].getType());
|
|
Operation *tiledOp =
|
|
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
|
|
|
|
return TilingResult{
|
|
{tiledOp},
|
|
SmallVector<Value>(tiledOp->getResults()),
|
|
llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// LinalgDialect
|
|
// TODO: Merge with the LinalgDialect block at the bottom
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Returns true if the result expression of `subMap` are a subset of `fullMap`.
|
|
static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap) {
|
|
auto explicitRange = subMap.getResults();
|
|
auto defaultRange = fullMap.getResults();
|
|
DenseSet<AffineExpr> explicitSet(explicitRange.begin(), explicitRange.end());
|
|
DenseSet<AffineExpr> defaultSet(defaultRange.begin(), defaultRange.end());
|
|
llvm::set_union(explicitSet, defaultSet);
|
|
return explicitSet == defaultSet;
|
|
}
|
|
|
|
/// Check if the user defined map is valid broadcast map. Here broadcast
|
|
/// indexing maps are defined in context of corresponding default indexing maps
|
|
/// for the given Op. This way the check becomes very simple i.e just check the
|
|
/// number of result dims.
|
|
/// Returns true if the explictMap is broadcasted with respect to the
|
|
/// defaultMap.
|
|
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) {
|
|
return explictMap.getNumResults() < defaultMap.getNumResults();
|
|
}
|
|
|
|
/// Verifies the broadcast and transpose semantic sepecified by the explicit
|
|
/// indexing map for the MatmulOp \p op for each operand specified by \p
|
|
/// opIndex.
|
|
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
|
|
unsigned opIndex) {
|
|
SmallVector<AffineMap, 3> opIndexingMaps = matmulOp.getIndexingMapsArray();
|
|
SmallVector<AffineMap, 3> defaultIndexingMaps =
|
|
matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
|
|
|
|
auto opIndexingMap = opIndexingMaps[opIndex];
|
|
auto defaultIndexingMap = defaultIndexingMaps[opIndex];
|
|
// Check general validity of indexing map results.
|
|
if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
|
|
return matmulOp->emitOpError()
|
|
<< "Unexpected dim expression in map result.";
|
|
|
|
if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
|
|
if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
|
|
return matmulOp->emitOpError()
|
|
<< "Invalid broadcast requested, should be (d2).";
|
|
}
|
|
return success();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
// Check general validity of input indexing map of
|
|
// BatchMatmulOp/BatchReduceMatmulOp.
|
|
template <typename OpTy>
|
|
static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp,
|
|
AffineMap opIndexingMap,
|
|
AffineMap defaultIndexingMap, bool isLHS) {
|
|
assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
|
|
isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
|
|
"Expected BatchMatmulOp or BatchReduceMatmulOp");
|
|
// Check the result dims are valid.
|
|
if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
|
|
return batchVariantMatmulOp->emitOpError()
|
|
<< "Unexpected result dim expression (outside the set of default "
|
|
"result dims).";
|
|
|
|
// Check for valid number of result dims of input maps.
|
|
if (opIndexingMap.getNumResults() > 3)
|
|
return batchVariantMatmulOp->emitOpError()
|
|
<< "no. of result dim expressions exceeds 3.";
|
|
|
|
auto hasValidBatchDim = [](AffineMap map) {
|
|
AffineExpr batchDim = map.getResult(0);
|
|
return batchDim.isFunctionOfDim(0);
|
|
};
|
|
|
|
// Check if the requested broadcast is valid.
|
|
if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
|
|
if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
|
|
return batchVariantMatmulOp->emitOpError()
|
|
<< "Invalid broadcast requested.";
|
|
} else if (!hasValidBatchDim(opIndexingMap)) {
|
|
return batchVariantMatmulOp->emitOpError()
|
|
<< "Invalid batch dimension expression.";
|
|
}
|
|
return success();
|
|
}
|
|
|
|
/// This function checks if the given AffineMap for the output of a
|
|
/// BatchMatmulOp/BatchReduceMatmulOp has exactly the desired number of result
|
|
/// dimensions and if the output map result dimensions are valid.
|
|
template <typename OpTy>
|
|
static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp,
|
|
AffineMap opIndexingMap) {
|
|
assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
|
|
isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
|
|
"Expected BatchMatmulOp or BatchReduceMatmulOp");
|
|
if (isa<BatchMatmulOp>(batchVariantMatmulOp) &&
|
|
opIndexingMap.getNumResults() != 3) {
|
|
|
|
return batchVariantMatmulOp->emitOpError()
|
|
<< "expects 3 dims, but got (" << opIndexingMap.getNumResults()
|
|
<< ").";
|
|
}
|
|
if (isa<BatchReduceMatmulOp>(batchVariantMatmulOp) &&
|
|
opIndexingMap.getNumResults() != 2) {
|
|
return batchVariantMatmulOp->emitOpError()
|
|
<< "expects 2 dims, but got (" << opIndexingMap.getNumResults()
|
|
<< ").";
|
|
}
|
|
|
|
auto areValidOutputResultDim = [&](AffineMap outputMap) {
|
|
return isa<BatchMatmulOp>(batchVariantMatmulOp)
|
|
? outputMap.getResult(0).isFunctionOfDim(0) &&
|
|
outputMap.getResult(1).isFunctionOfDim(1) &&
|
|
outputMap.getResult(2).isFunctionOfDim(2)
|
|
: outputMap.getResult(0).isFunctionOfDim(1) &&
|
|
outputMap.getResult(1).isFunctionOfDim(2);
|
|
};
|
|
|
|
if (!areValidOutputResultDim(opIndexingMap)) {
|
|
return batchVariantMatmulOp->emitOpError()
|
|
<< "Invalid output map result dimension.";
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
/// Verifies the broadcast and transpose semantic specified by the explicit
|
|
/// indexing map for the BatchMatmulOp/BatchReduceMatmulOp op for each operand
|
|
/// specified by opIndex.
|
|
template <typename OpTy>
|
|
static LogicalResult
|
|
verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp,
|
|
unsigned opIndex) {
|
|
SmallVector<AffineMap, 3> opIndexingMaps =
|
|
batchVariantMatmulOp.getIndexingMapsArray();
|
|
SmallVector<AffineMap, 3> defaultIndexingMaps =
|
|
batchVariantMatmulOp.getDefaultIndexingMaps(
|
|
batchVariantMatmulOp->getContext());
|
|
|
|
if (opIndexingMaps.size() != 3)
|
|
return batchVariantMatmulOp->emitOpError()
|
|
<< "Indexing_map attribute must have 3 affine maps.";
|
|
|
|
auto opIndexingMap = opIndexingMaps[opIndex];
|
|
auto defaultIndexingMap = defaultIndexingMaps[opIndex];
|
|
|
|
if (opIndex == 2 &&
|
|
failed(verifyOutputMap(batchVariantMatmulOp, opIndexingMap)))
|
|
return failure();
|
|
|
|
if (opIndex != 2 &&
|
|
failed(verifyInputMaps(batchVariantMatmulOp, opIndexingMap,
|
|
defaultIndexingMap, opIndex == 0)))
|
|
return failure();
|
|
|
|
return success();
|
|
}
|
|
|
|
namespace mlir {
|
|
namespace linalg {
|
|
|
|
std::optional<WinogradConv2DFmr> getWinogradConv2DFmr(int64_t m, int64_t r) {
|
|
if (m == 2 && r == 3)
|
|
return WinogradConv2DFmr::F_2_3;
|
|
if (m == 4 && r == 3)
|
|
return WinogradConv2DFmr::F_4_3;
|
|
if (m == 2 && r == 5)
|
|
return WinogradConv2DFmr::F_2_5;
|
|
return std::nullopt;
|
|
}
|
|
|
|
std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr) {
|
|
switch (fmr) {
|
|
case WinogradConv2DFmr::F_2_3:
|
|
return {2, 3};
|
|
case WinogradConv2DFmr::F_4_3:
|
|
return {4, 3};
|
|
case WinogradConv2DFmr::F_2_5:
|
|
return {2, 5};
|
|
}
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MatMulOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static FailureOr<SmallVector<SmallVector<int64_t>>>
|
|
getAffineResultPositions(ArrayAttr maps) {
|
|
SmallVector<SmallVector<int64_t>> positions;
|
|
for (auto map : maps) {
|
|
AffineMapAttr attr = dyn_cast<AffineMapAttr>(map);
|
|
if (!attr)
|
|
return failure();
|
|
SmallVector<int64_t> pos;
|
|
for (auto result : attr.getAffineMap().getResults()) {
|
|
auto dim = dyn_cast<AffineDimExpr>(result);
|
|
if (!dim)
|
|
return failure();
|
|
pos.push_back(dim.getPosition());
|
|
}
|
|
positions.push_back(pos);
|
|
}
|
|
return positions;
|
|
}
|
|
|
|
/// Returns a list of AffineMap with the typical matmul indexing charactristic.
|
|
SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
|
|
AffineExpr d0, d1, d2;
|
|
SmallVector<AffineMap> indexingMaps;
|
|
bindDims(context, d0, d1, d2);
|
|
indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
|
|
indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
|
|
indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
|
|
return indexingMaps;
|
|
}
|
|
|
|
bool MatmulOp::isDefaultIndexingMaps(Attribute attr) {
|
|
ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
|
|
if (!maps)
|
|
return false;
|
|
if (maps.size() != 3)
|
|
return false;
|
|
auto positions = getAffineResultPositions(maps);
|
|
if (failed(positions))
|
|
return false;
|
|
return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
|
|
(*positions)[1] == SmallVector<int64_t>{2, 1} &&
|
|
(*positions)[2] == SmallVector<int64_t>{0, 1};
|
|
}
|
|
|
|
SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
|
|
return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
|
|
utils::IteratorType::parallel,
|
|
utils::IteratorType::reduction};
|
|
}
|
|
|
|
unsigned MatmulOp::getNumRegionArgs() { return 3; }
|
|
|
|
std::string MatmulOp::getLibraryCallName() {
|
|
return generateLibraryCallName(getOperation());
|
|
}
|
|
|
|
bool MatmulOp::hasDynamicIndexingMaps() { return true; }
|
|
|
|
/// Check if the op has broadcast and/or transpose semantic. Returns true if
|
|
/// the user defined indexing maps are not equal to default map.
|
|
bool MatmulOp::hasUserDefinedMaps() {
|
|
SmallVector<AffineMap, 3> defaultMaps =
|
|
getDefaultIndexingMaps(this->getContext());
|
|
SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
|
|
return defaultMaps != explicitMaps;
|
|
}
|
|
|
|
/// Implements the block region builder for the MatmulOp. This is called by
|
|
/// 'fillStructuredOpRegion'.
|
|
void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
|
|
ArrayRef<NamedAttribute> attrs,
|
|
function_ref<InFlightDiagnostic()> emitError) {
|
|
if (emitError && block.getNumArguments() != 3) {
|
|
emitError() << "MatmulOp regionBuilder expects 3 args, got "
|
|
<< block.getNumArguments();
|
|
return;
|
|
}
|
|
assert(block.getNumArguments() == 3 &&
|
|
"MatmulOp regionBuilder expects 3 args");
|
|
RegionBuilderHelper helper(b, block);
|
|
SmallVector<Value> yields;
|
|
|
|
TypeFn castVal = TypeFn::cast_signed;
|
|
const auto *castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
|
|
return attr.getName() == "cast";
|
|
});
|
|
if (castIter != attrs.end()) {
|
|
if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
|
|
castVal = attr.getValue();
|
|
}
|
|
|
|
Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
|
|
block.getArgument(0));
|
|
Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
|
|
block.getArgument(1));
|
|
Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2, emitError);
|
|
if (!value3)
|
|
return;
|
|
Value value4 = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
|
|
value3, emitError);
|
|
if (!value4)
|
|
return;
|
|
yields.push_back(value4);
|
|
helper.yieldOutputs(yields);
|
|
}
|
|
|
|
/// Returns true if the given bcastMap map is a valid broadcast map. A valid
|
|
/// broadcast map must include K dimension.
|
|
/// TODO: Strict inclusion of K dimension in the broadcast map is not
|
|
/// necessary for both input matrices simultaneously. We can relax this
|
|
/// condition to have K dimension for one input matrix map and infer the K
|
|
/// dimension for other input matrix map from the one already having K
|
|
/// dimension.
|
|
bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
|
|
assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");
|
|
AffineExpr expr = bcastMap.getResult(0);
|
|
// Invalid map if the common dimension of matmul not found.
|
|
return expr.isFunctionOfDim(bcastMap.getNumDims() - 1);
|
|
}
|
|
|
|
FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) {
|
|
if (parser.parseOptionalKeyword("indexing_maps"))
|
|
return ArrayAttr{
|
|
nullptr}; // Success in case indexing_maps was not provided.
|
|
|
|
ArrayAttr arrayAttr;
|
|
if (parser.parseEqual() || parser.parseAttribute(arrayAttr))
|
|
return failure();
|
|
|
|
if (llvm::any_of(arrayAttr,
|
|
[](auto elt) { return !dyn_cast<AffineMapAttr>(elt); }))
|
|
return parser.emitError(parser.getCurrentLocation())
|
|
<< "element of indexing_maps array is not an affine_map";
|
|
|
|
return arrayAttr;
|
|
}
|
|
|
|
ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
|
|
if (failed(indexingMapsAttr))
|
|
return failure();
|
|
|
|
if (*indexingMapsAttr == nullptr) {
|
|
auto indexingMapAttrs = llvm::map_to_vector(
|
|
MatmulOp::getDefaultIndexingMaps(parser.getContext()),
|
|
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
|
|
indexingMapsAttr = parser.getBuilder().getArrayAttr(indexingMapAttrs);
|
|
}
|
|
|
|
result.addAttribute("indexing_maps", *indexingMapsAttr);
|
|
return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
|
|
MatmulOp::getRegionBuilder());
|
|
}
|
|
|
|
void MatmulOp::print(OpAsmPrinter &p) {
|
|
SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
|
|
MatmulOp::getDefaultIndexingMaps(getContext()),
|
|
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
|
|
if (!llvm::equal(getIndexingMaps(), indexingMaps))
|
|
p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
|
|
|
|
std::array<StringRef, 3> elidedAttrs = {
|
|
"operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
|
|
printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
|
|
elidedAttrs);
|
|
}
|
|
|
|
/// Verify the user defined indexing maps.
|
|
LogicalResult MatmulOp::verify() {
|
|
// Verification of pure matmul is handled by verifyStructuredOpInterface().
|
|
if (!hasUserDefinedMaps())
|
|
return success();
|
|
|
|
for (unsigned opIndex = 0; opIndex < 2; opIndex++) {
|
|
if (failed(verifyExtendedMatmulSemantic(*this, opIndex)))
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
|
|
return memref::foldMemRefCast(*this);
|
|
}
|
|
|
|
void MatmulOp::getEffects(
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
&effects) {
|
|
if (hasPureTensorSemantics())
|
|
return;
|
|
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
Speculation::Speculatability MatmulOp::getSpeculatability() {
|
|
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
SmallVector<AffineMap>
|
|
MatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
|
|
AffineExpr d0, d1, d2;
|
|
MLIRContext *context = builder.getContext();
|
|
bindDims(context, d0, d1, d2);
|
|
AffineMap mapLHS = AffineMap::get(3, 0, {d2, d0}, context);
|
|
AffineMap mapRHS = AffineMap::get(3, 0, {d2, d1}, context);
|
|
AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context);
|
|
return {mapLHS, mapRHS, mapOut};
|
|
}
|
|
|
|
bool MatmulTransposeAOp::isDefaultIndexingMaps(Attribute attr) {
|
|
ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
|
|
if (!maps)
|
|
return false;
|
|
if (maps.size() != 3)
|
|
return false;
|
|
auto positions = getAffineResultPositions(maps);
|
|
if (failed(positions))
|
|
return false;
|
|
return (*positions)[0] == SmallVector<int64_t>{2, 0} &&
|
|
(*positions)[1] == SmallVector<int64_t>{2, 1} &&
|
|
(*positions)[2] == SmallVector<int64_t>{0, 1};
|
|
}
|
|
|
|
void linalg::MatmulTransposeAOp::build(OpBuilder &builder,
|
|
OperationState &result,
|
|
ValueRange inputs, ValueRange outputs,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
|
|
MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
|
|
}
|
|
|
|
MatmulTransposeAOp
|
|
MatmulTransposeAOp::create(OpBuilder &builder, Location location,
|
|
ValueRange inputs, ValueRange outputs,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
OperationState state(location, getOperationName());
|
|
build(builder, state, inputs, outputs, attributes);
|
|
auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
|
|
assert(res && "builder didn't return the right type");
|
|
return res;
|
|
}
|
|
|
|
void linalg::MatmulTransposeAOp::build(OpBuilder &builder,
|
|
OperationState &result,
|
|
TypeRange resultTensorTypes,
|
|
ValueRange inputs, ValueRange outputs,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
|
|
MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
|
|
}
|
|
|
|
MatmulTransposeAOp
|
|
MatmulTransposeAOp::create(OpBuilder &builder, Location location,
|
|
TypeRange resultTensorTypes, ValueRange inputs,
|
|
ValueRange outputs,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
OperationState state(location, getOperationName());
|
|
build(builder, state, resultTensorTypes, inputs, outputs, attributes);
|
|
auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
|
|
assert(res && "builder didn't return the right type");
|
|
return res;
|
|
}
|
|
|
|
void linalg::MatmulTransposeAOp::build(OpBuilder &builder,
|
|
OperationState &result,
|
|
TypeRange resultTensorTypes,
|
|
ValueRange inputs, ValueRange outputs,
|
|
Attribute cast,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
result.addAttribute("cast", cast);
|
|
buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
|
|
MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
|
|
}
|
|
|
|
MatmulTransposeAOp
|
|
MatmulTransposeAOp::create(OpBuilder &builder, Location location,
|
|
TypeRange resultTensorTypes, ValueRange inputs,
|
|
ValueRange outputs, Attribute cast,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
OperationState state(location, getOperationName());
|
|
build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
|
|
auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
|
|
assert(res && "builder didn't return the right type");
|
|
return res;
|
|
}
|
|
|
|
bool MatmulTransposeAOp::classof(Operation *op) {
|
|
return dyn_cast_or_null<linalg::MatmulOp>(op) &&
|
|
MatmulTransposeAOp::isDefaultIndexingMaps(
|
|
op->getAttr("indexing_maps"));
|
|
}
|
|
|
|
SmallVector<AffineMap>
|
|
MatmulTransposeBOp::getDefaultIndexingMaps(OpBuilder &builder) {
|
|
AffineExpr d0, d1, d2;
|
|
MLIRContext *context = builder.getContext();
|
|
bindDims(context, d0, d1, d2);
|
|
AffineMap mapLHS = AffineMap::get(3, 0, {d0, d2}, context);
|
|
AffineMap mapRHS = AffineMap::get(3, 0, {d1, d2}, context);
|
|
AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context);
|
|
return {mapLHS, mapRHS, mapOut};
|
|
}
|
|
|
|
bool MatmulTransposeBOp::isDefaultIndexingMaps(Attribute attr) {
|
|
ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
|
|
if (!maps)
|
|
return false;
|
|
if (maps.size() != 3)
|
|
return false;
|
|
auto positions = getAffineResultPositions(maps);
|
|
if (failed(positions))
|
|
return false;
|
|
return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
|
|
(*positions)[1] == SmallVector<int64_t>{1, 2} &&
|
|
(*positions)[2] == SmallVector<int64_t>{0, 1};
|
|
}
|
|
|
|
void linalg::MatmulTransposeBOp::build(OpBuilder &builder,
|
|
OperationState &result,
|
|
ValueRange inputs, ValueRange outputs,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
|
|
MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
|
|
}
|
|
|
|
MatmulTransposeBOp
|
|
MatmulTransposeBOp::create(OpBuilder &builder, Location location,
|
|
ValueRange inputs, ValueRange outputs,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
OperationState state(location, getOperationName());
|
|
build(builder, state, inputs, outputs, attributes);
|
|
auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
|
|
assert(res && "builder didn't return the right type");
|
|
return res;
|
|
}
|
|
|
|
void linalg::MatmulTransposeBOp::build(OpBuilder &builder,
|
|
OperationState &result,
|
|
TypeRange resultTensorTypes,
|
|
ValueRange inputs, ValueRange outputs,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
|
|
MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
|
|
}
|
|
|
|
MatmulTransposeBOp
|
|
MatmulTransposeBOp::create(OpBuilder &builder, Location location,
|
|
TypeRange resultTensorTypes, ValueRange inputs,
|
|
ValueRange outputs,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
OperationState state(location, getOperationName());
|
|
build(builder, state, resultTensorTypes, inputs, outputs, attributes);
|
|
auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
|
|
assert(res && "builder didn't return the right type");
|
|
return res;
|
|
}
|
|
|
|
void linalg::MatmulTransposeBOp::build(OpBuilder &builder,
|
|
OperationState &result,
|
|
TypeRange resultTensorTypes,
|
|
ValueRange inputs, ValueRange outputs,
|
|
Attribute cast,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
result.addAttribute("cast", cast);
|
|
buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
|
|
MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
|
|
}
|
|
|
|
MatmulTransposeBOp
|
|
MatmulTransposeBOp::create(OpBuilder &builder, Location location,
|
|
TypeRange resultTensorTypes, ValueRange inputs,
|
|
ValueRange outputs, Attribute cast,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
OperationState state(location, getOperationName());
|
|
build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
|
|
auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
|
|
assert(res && "builder didn't return the right type");
|
|
return res;
|
|
}
|
|
|
|
bool MatmulTransposeBOp::classof(Operation *op) {
|
|
return dyn_cast_or_null<linalg::MatmulOp>(op) &&
|
|
MatmulTransposeBOp::isDefaultIndexingMaps(
|
|
op->getAttr("indexing_maps"));
|
|
}
|
|
|
|
SmallVector<AffineMap>
|
|
BatchMatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
|
|
AffineExpr d0, d1, d2, d3;
|
|
MLIRContext *context = builder.getContext();
|
|
bindDims(context, d0, d1, d2, d3);
|
|
AffineMap mapLHS = AffineMap::get(4, 0, {d0, d3, d1}, context);
|
|
AffineMap mapRHS = AffineMap::get(4, 0, {d0, d3, d2}, context);
|
|
AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
|
|
return {mapLHS, mapRHS, mapOut};
|
|
}
|
|
|
|
bool BatchMatmulTransposeAOp::isDefaultIndexingMaps(Attribute attr) {
|
|
ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
|
|
if (!maps)
|
|
return false;
|
|
if (maps.size() != 3)
|
|
return false;
|
|
auto positions = getAffineResultPositions(maps);
|
|
if (failed(positions))
|
|
return false;
|
|
return (*positions)[0] == SmallVector<int64_t>{0, 3, 1} &&
|
|
(*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
|
|
(*positions)[2] == SmallVector<int64_t>{0, 1, 2};
|
|
}
|
|
|
|
void linalg::BatchMatmulTransposeAOp::build(
|
|
OpBuilder &builder, OperationState &result, ValueRange inputs,
|
|
ValueRange outputs, ArrayRef<NamedAttribute> attributes) {
|
|
buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
|
|
BatchMatmulOp::getRegionBuilder(),
|
|
getDefaultIndexingMaps(builder));
|
|
}
|
|
|
|
BatchMatmulTransposeAOp
|
|
BatchMatmulTransposeAOp::create(OpBuilder &builder, Location location,
|
|
ValueRange inputs, ValueRange outputs,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
OperationState state(location, getOperationName());
|
|
build(builder, state, inputs, outputs, attributes);
|
|
auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
|
|
assert(res && "builder didn't return the right type");
|
|
return res;
|
|
}
|
|
|
|
void linalg::BatchMatmulTransposeAOp::build(
|
|
OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
|
|
ValueRange inputs, ValueRange outputs,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
|
|
BatchMatmulOp::getRegionBuilder(),
|
|
getDefaultIndexingMaps(builder));
|
|
}
|
|
|
|
BatchMatmulTransposeAOp
|
|
BatchMatmulTransposeAOp::create(OpBuilder &builder, Location location,
|
|
TypeRange resultTensorTypes, ValueRange inputs,
|
|
ValueRange outputs,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
OperationState state(location, getOperationName());
|
|
build(builder, state, resultTensorTypes, inputs, outputs, attributes);
|
|
auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
|
|
assert(res && "builder didn't return the right type");
|
|
return res;
|
|
}
|
|
|
|
void linalg::BatchMatmulTransposeAOp::build(
|
|
OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
|
|
ValueRange inputs, ValueRange outputs, Attribute cast,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
result.addAttribute("cast", cast);
|
|
buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
|
|
BatchMatmulOp::getRegionBuilder(),
|
|
getDefaultIndexingMaps(builder));
|
|
}
|
|
|
|
BatchMatmulTransposeAOp
|
|
BatchMatmulTransposeAOp::create(OpBuilder &builder, Location location,
|
|
TypeRange resultTensorTypes, ValueRange inputs,
|
|
ValueRange outputs, Attribute cast,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
OperationState state(location, getOperationName());
|
|
build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
|
|
auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
|
|
assert(res && "builder didn't return the right type");
|
|
return res;
|
|
}
|
|
|
|
bool BatchMatmulTransposeAOp::classof(Operation *op) {
|
|
return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
|
|
BatchMatmulTransposeAOp::isDefaultIndexingMaps(
|
|
op->getAttr("indexing_maps"));
|
|
}
|
|
|
|
SmallVector<AffineMap>
|
|
BatchMatmulTransposeBOp::getDefaultIndexingMaps(OpBuilder &builder) {
|
|
AffineExpr d0, d1, d2, d3;
|
|
MLIRContext *context = builder.getContext();
|
|
bindDims(context, d0, d1, d2, d3);
|
|
AffineMap mapLHS = AffineMap::get(4, 0, {d0, d1, d3}, context);
|
|
AffineMap mapRHS = AffineMap::get(4, 0, {d0, d2, d3}, context);
|
|
AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
|
|
return {mapLHS, mapRHS, mapOut};
|
|
}
|
|
|
|
bool BatchMatmulTransposeBOp::isDefaultIndexingMaps(Attribute attr) {
|
|
ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
|
|
if (!maps)
|
|
return false;
|
|
if (maps.size() != 3)
|
|
return false;
|
|
auto positions = getAffineResultPositions(maps);
|
|
if (failed(positions))
|
|
return false;
|
|
return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
|
|
(*positions)[1] == SmallVector<int64_t>{0, 2, 3} &&
|
|
(*positions)[2] == SmallVector<int64_t>{0, 1, 2};
|
|
}
|
|
|
|
void linalg::BatchMatmulTransposeBOp::build(
|
|
OpBuilder &builder, OperationState &result, ValueRange inputs,
|
|
ValueRange outputs, ArrayRef<NamedAttribute> attributes) {
|
|
buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
|
|
BatchMatmulOp::getRegionBuilder(),
|
|
getDefaultIndexingMaps(builder));
|
|
}
|
|
|
|
BatchMatmulTransposeBOp
|
|
BatchMatmulTransposeBOp::create(OpBuilder &builder, Location location,
|
|
ValueRange inputs, ValueRange outputs,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
OperationState state(location, getOperationName());
|
|
build(builder, state, inputs, outputs, attributes);
|
|
auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
|
|
assert(res && "builder didn't return the right type");
|
|
return res;
|
|
}
|
|
|
|
void linalg::BatchMatmulTransposeBOp::build(
|
|
OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
|
|
ValueRange inputs, ValueRange outputs,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
|
|
BatchMatmulOp::getRegionBuilder(),
|
|
getDefaultIndexingMaps(builder));
|
|
}
|
|
|
|
BatchMatmulTransposeBOp
|
|
BatchMatmulTransposeBOp::create(OpBuilder &builder, Location location,
|
|
TypeRange resultTensorTypes, ValueRange inputs,
|
|
ValueRange outputs,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
OperationState state(location, getOperationName());
|
|
build(builder, state, resultTensorTypes, inputs, outputs, attributes);
|
|
auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
|
|
assert(res && "builder didn't return the right type");
|
|
return res;
|
|
}
|
|
|
|
void linalg::BatchMatmulTransposeBOp::build(
|
|
OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
|
|
ValueRange inputs, ValueRange outputs, Attribute cast,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
result.addAttribute("cast", cast);
|
|
buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
|
|
BatchMatmulOp::getRegionBuilder(),
|
|
getDefaultIndexingMaps(builder));
|
|
}
|
|
|
|
BatchMatmulTransposeBOp
|
|
BatchMatmulTransposeBOp::create(OpBuilder &builder, Location location,
|
|
TypeRange resultTensorTypes, ValueRange inputs,
|
|
ValueRange outputs, Attribute cast,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
OperationState state(location, getOperationName());
|
|
build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
|
|
auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
|
|
assert(res && "builder didn't return the right type");
|
|
return res;
|
|
}
|
|
|
|
bool BatchMatmulTransposeBOp::classof(Operation *op) {
|
|
return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
|
|
BatchMatmulTransposeBOp::isDefaultIndexingMaps(
|
|
op->getAttr("indexing_maps"));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ContractOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
SmallVector<utils::IteratorType> ContractOp::getIteratorTypesArray() {
|
|
AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
|
|
// On well-formed IR, indexing_maps is non-empty, contained affine_maps'
|
|
// domains are all the same, and each implements a projected permutation.
|
|
// Each iteration space dim must occur for at least one operand and either
|
|
// takes part in a contraction/reduction or else has parallel iteration type.
|
|
// We have that a dim is a contraction/reduction dim if and only if the dim
|
|
// occurs for the output operand. We use this fact for fast inference:
|
|
// NB: In case we allow dims to occur solely for one input, the above still
|
|
// holds: per the einsum semantics, these are reduction dims as well.
|
|
SmallVector<bool> dimsInOutput(outAffineMap.getNumDims(), false);
|
|
for (auto result : outAffineMap.getResults()) {
|
|
auto dimExpr = dyn_cast<AffineDimExpr>(result);
|
|
assert(dimExpr && "affine_map is a projected permutation");
|
|
dimsInOutput[dimExpr.getPosition()] = true;
|
|
}
|
|
|
|
SmallVector<utils::IteratorType> iteratorTypes;
|
|
for (auto dimOccursInOutput : dimsInOutput)
|
|
iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
|
|
: utils::IteratorType::reduction);
|
|
|
|
return iteratorTypes;
|
|
}
|
|
|
|
unsigned ContractOp::getNumRegionArgs() { return 3; }
|
|
|
|
/// Implement block region builder, which is called by 'fillStructuredOpRegion'.
|
|
void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
|
|
ArrayRef<NamedAttribute> attrs,
|
|
function_ref<InFlightDiagnostic()> emitError) {
|
|
if (emitError && block.getNumArguments() != 3) {
|
|
emitError() << "ContractOp regionBuilder expects 3 args, got "
|
|
<< block.getNumArguments();
|
|
return;
|
|
}
|
|
assert(block.getNumArguments() == 3 &&
|
|
"ContractOp regionBuilder expects 3 args");
|
|
RegionBuilderHelper helper(b, block);
|
|
|
|
TypeFn castSignedness = TypeFn::cast_signed;
|
|
auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
|
|
return attr.getName() == "cast";
|
|
});
|
|
if (castIter != attrs.end()) {
|
|
if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
|
|
castSignedness = attr.getValue();
|
|
}
|
|
|
|
// TODO: Support fields with operators besides mult & add.
|
|
Type outType = block.getArgument(2).getType();
|
|
Value lhsAtOutType =
|
|
helper.buildTypeFn(castSignedness, outType, block.getArgument(0));
|
|
Value rhsAtOutType =
|
|
helper.buildTypeFn(castSignedness, outType, block.getArgument(1));
|
|
Value productAtOutType = helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType,
|
|
rhsAtOutType, emitError);
|
|
if (!productAtOutType)
|
|
return;
|
|
Value result = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
|
|
productAtOutType, emitError);
|
|
if (!result)
|
|
return;
|
|
helper.yieldOutputs({result});
|
|
}
|
|
|
|
ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
|
|
if (failed(indexingMapsAttr) || *indexingMapsAttr == nullptr)
|
|
return parser.emitError(parser.getCurrentLocation(),
|
|
"expected 'indexing_maps' attribute");
|
|
result.addAttribute("indexing_maps", *indexingMapsAttr);
|
|
|
|
return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
|
|
regionBuilder);
|
|
}
|
|
|
|
void ContractOp::print(OpAsmPrinter &p) {
|
|
p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
|
|
printNamedStructuredOp(
|
|
p, getOperation(), getInputs(), getOutputs(),
|
|
/*elidedAttrs=*/{"indexing_maps", "operandSegmentSizes"});
|
|
}
|
|
|
|
LogicalResult ContractOp::verify() {
|
|
int iterationSpaceDims = -1;
|
|
// Map iter space dims to #occurrences in inputs' and output's affine_maps:
|
|
// e.g., inOccurrences[0] will hold #times that dim (with index) 0 is used to
|
|
// access an input operand (so occurrence count can be at most 2) and
|
|
// outOccurrences[1] will indicate whether dim 1 occurred in the output, etc.
|
|
SmallVector<size_t> inOccurrences;
|
|
SmallVector<size_t> outOccurrences;
|
|
|
|
// A helper so that for each operand's affine_map and type we check that ...
|
|
auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
|
|
bool isInput) -> LogicalResult {
|
|
// ... the affine_map is a projected permutation;
|
|
if (!affineMap.isProjectedPermutation())
|
|
return emitError("provided affine_map is not a projected permutation");
|
|
|
|
// ... the rank of the affine_map's results and corresponding type match;
|
|
if (auto shapedType = dyn_cast<ShapedType>(operandType)) {
|
|
if (affineMap.getNumResults() != shapedType.getRank())
|
|
return emitError("ranks of shaped operand and results of corresponding "
|
|
"affine_map differ");
|
|
} else if (affineMap.getNumResults() != 0) {
|
|
return emitError("affine_map specifies shaped access while operand has "
|
|
"non-shaped type");
|
|
}
|
|
|
|
// ... the rank of the affine_map's domain is the same as those seen prior;
|
|
if (iterationSpaceDims == -1) {
|
|
iterationSpaceDims = affineMap.getNumDims();
|
|
inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
|
|
outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
|
|
} else if (iterationSpaceDims != (int)affineMap.getNumDims()) {
|
|
return emitError("iteration spaces of provided affine_maps differ");
|
|
}
|
|
|
|
// ... update counts of dims used to access either an input or the output.
|
|
for (AffineExpr affineExpr : affineMap.getResults()) {
|
|
auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
|
|
if (!affineDimExpr)
|
|
llvm_unreachable("affine_map is a projected permutation");
|
|
|
|
if (isInput)
|
|
inOccurrences[affineDimExpr.getPosition()] += 1;
|
|
else
|
|
outOccurrences[affineDimExpr.getPosition()] += 1;
|
|
}
|
|
|
|
return success();
|
|
};
|
|
|
|
for (auto &&[affineMap, operandType, isInput] :
|
|
llvm::zip(getIndexingMapsArray(), getOperandTypes(),
|
|
SmallVector<bool>{true, true, false})) {
|
|
if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
|
|
return failure(); // NB: checkAffineMapAndType will emit relevant error.
|
|
}
|
|
|
|
bool hasContractingDim = false;
|
|
for (size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
|
|
size_t inOccCount = inOccurrences[dimIndex];
|
|
size_t outOccCount = outOccurrences[dimIndex];
|
|
|
|
// We have a contracting dim if and only if ...
|
|
hasContractingDim |= inOccCount == 2 && outOccCount == 0;
|
|
|
|
if (inOccCount == 0 && outOccCount == 0)
|
|
return emitError() << "iteration space dim at index " << dimIndex
|
|
<< " not used to access any operand";
|
|
|
|
// NB: We disallow a dim which occurs for only one input operand and not
|
|
// for the output. In terms of einsum semantics such dims have a
|
|
// sensible meaning - namely an additional reduction per each such dim.
|
|
// By contrast, the ContractionOpInterface does not know about this
|
|
// iter type - cf. inferContractionDims' supported dim kinds. Similarly,
|
|
// while vector.contract's verifier accepts dims of this kind many of
|
|
// its lowerings give up on encountering these dims.
|
|
// TODO: Remove following once we have comprehensive support for input-only
|
|
// reduction dims, at both the linalg- and vector-dialect levels.
|
|
if (inOccCount == 1 && outOccCount != 1)
|
|
return emitError()
|
|
<< "iteration space dim at index " << dimIndex
|
|
<< " is neither a contracting dim nor of parallel iteration type";
|
|
}
|
|
|
|
if (!hasContractingDim)
|
|
return emitError("'indexing_maps' do not specify a contracting dimension");
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
|
|
return memref::foldMemRefCast(*this);
|
|
}
|
|
|
|
void ContractOp::getEffects(
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
&effects) {
|
|
if (hasPureTensorSemantics())
|
|
return;
|
|
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
Speculation::Speculatability ContractOp::getSpeculatability() {
|
|
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Implementation of BatchMatmulOp
|
|
//===----------------------------------------------------------------------===//
|
|
SmallVector<AffineMap>
|
|
BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
|
|
AffineExpr d0, d1, d2, d3;
|
|
SmallVector<AffineMap> indexingMaps;
|
|
bindDims(context, d0, d1, d2, d3);
|
|
indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
|
|
indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
|
|
indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d2}, context));
|
|
return indexingMaps;
|
|
}
|
|
|
|
bool BatchMatmulOp::isDefaultIndexingMaps(Attribute attr) {
|
|
ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
|
|
if (!maps)
|
|
return false;
|
|
if (maps.size() != 3)
|
|
return false;
|
|
auto positions = getAffineResultPositions(maps);
|
|
if (failed(positions))
|
|
return false;
|
|
return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
|
|
(*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
|
|
(*positions)[2] == SmallVector<int64_t>{0, 1, 2};
|
|
}
|
|
|
|
SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
|
|
return SmallVector<utils::IteratorType>{
|
|
utils::IteratorType::parallel, utils::IteratorType::parallel,
|
|
utils::IteratorType::parallel, utils::IteratorType::reduction};
|
|
}
|
|
|
|
unsigned BatchMatmulOp::getNumRegionArgs() { return 3; }
|
|
|
|
std::string BatchMatmulOp::getLibraryCallName() {
|
|
return generateLibraryCallName(getOperation());
|
|
}
|
|
|
|
/// Check if the op has broadcast and/or transpose semantic. Returns true if
|
|
/// the user defined indexing maps are not equal to default map.
|
|
bool BatchMatmulOp::hasUserDefinedMaps() {
|
|
SmallVector<AffineMap, 3> defaultMaps =
|
|
getDefaultIndexingMaps(this->getContext());
|
|
SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
|
|
return defaultMaps != explicitMaps;
|
|
}
|
|
|
|
/// Returns true if the given bcastMap map is a valid broadcast map. A valid
|
|
/// broadcast map must include K dimension.
|
|
/// TODO: Strict inclusion of K dimension in the broadcast map is not
|
|
/// necessary for both input matrices simultaneously. We can relax this
|
|
/// condition to have K dimension for one input matrix map and infer the K
|
|
/// dimension for other input matrix map from the one already having K
|
|
/// dimension.
|
|
bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
|
|
assert(bcastMap.getNumResults() < 3 &&
|
|
"Expected less than 3 result dim expr.");
|
|
bool isValid = false;
|
|
enum Indices { batchPos, mPos, nPos, kPos };
|
|
if (bcastMap.getNumResults() == 1) {
|
|
AffineExpr expr = bcastMap.getResult(0);
|
|
isValid = expr.isFunctionOfDim(kPos);
|
|
} else if (bcastMap.getNumResults() == 2) {
|
|
AffineExpr expr0 = bcastMap.getResult(0);
|
|
AffineExpr expr1 = bcastMap.getResult(1);
|
|
isValid =
|
|
isLHS ? ((expr0.isFunctionOfDim(batchPos) ||
|
|
expr0.isFunctionOfDim(mPos)) &&
|
|
expr1.isFunctionOfDim(kPos))
|
|
: ((expr0.isFunctionOfDim(batchPos) &&
|
|
expr1.isFunctionOfDim(kPos)) ||
|
|
(expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
|
|
}
|
|
return isValid;
|
|
}
|
|
|
|
void BatchMatmulOp::regionBuilder(
|
|
ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs,
|
|
function_ref<InFlightDiagnostic()> emitError) {
|
|
if (emitError && block.getNumArguments() != 3) {
|
|
emitError() << "BatchMatmulOp regionBuilder expects 3 args, got "
|
|
<< block.getNumArguments();
|
|
return;
|
|
}
|
|
assert(block.getNumArguments() == 3 &&
|
|
"BatchMatmulOp regionBuilder expects 3 args");
|
|
RegionBuilderHelper helper(b, block);
|
|
SmallVector<Value> yields;
|
|
|
|
TypeFn castVal = TypeFn::cast_signed;
|
|
auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
|
|
return attr.getName() == "cast";
|
|
});
|
|
if (castIter != attrs.end()) {
|
|
if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
|
|
castVal = attr.getValue();
|
|
}
|
|
|
|
auto toType = block.getArgument(2).getType();
|
|
Value castValA = helper.buildTypeFn(castVal, toType, block.getArgument(0));
|
|
Value castValB = helper.buildTypeFn(castVal, toType, block.getArgument(1));
|
|
Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
|
|
Value addVal =
|
|
helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
|
|
yields.push_back(addVal);
|
|
helper.yieldOutputs(yields);
|
|
}
|
|
|
|
ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
SmallVector<Attribute, 3> indexingMapsAttr;
|
|
Attribute mapAttr;
|
|
if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
|
|
if (parser.parseEqual())
|
|
return failure();
|
|
|
|
if (parser.parseLSquare())
|
|
return failure();
|
|
|
|
do {
|
|
if (parser.parseAttribute(mapAttr))
|
|
return failure();
|
|
if (!isa<AffineMapAttr>(mapAttr)) {
|
|
return parser.emitError(parser.getCurrentLocation(),
|
|
"expected affine map attribute");
|
|
}
|
|
indexingMapsAttr.push_back(mapAttr);
|
|
|
|
if (parser.parseOptionalComma())
|
|
break;
|
|
} while (true);
|
|
|
|
if (parser.parseRSquare())
|
|
return failure();
|
|
}
|
|
// Initialize indexingMaps, if not supplied explicitly.
|
|
if (indexingMapsAttr.empty()) {
|
|
indexingMapsAttr = llvm::map_to_vector(
|
|
BatchMatmulOp::getDefaultIndexingMaps(parser.getContext()),
|
|
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
|
|
}
|
|
result.addAttribute("indexing_maps",
|
|
parser.getBuilder().getArrayAttr(indexingMapsAttr));
|
|
|
|
return ::parseNamedStructuredOp(parser, result,
|
|
BatchMatmulOp::getNumRegionArgs(),
|
|
BatchMatmulOp::getRegionBuilder());
|
|
}
|
|
|
|
void BatchMatmulOp::print(OpAsmPrinter &p) {
|
|
SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
|
|
BatchMatmulOp::getDefaultIndexingMaps(getContext()),
|
|
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
|
|
if (!llvm::equal(getIndexingMaps(), indexingMaps))
|
|
p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
|
|
|
|
std::array<StringRef, 3> elidedAttrs = {
|
|
"operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
|
|
::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
|
|
elidedAttrs);
|
|
}
|
|
|
|
/// Verify the user defined indexing maps.
|
|
LogicalResult BatchMatmulOp::verify() {
|
|
// Verification of pure batch_matmul is handled by
|
|
// verifyStructuredOpInterface().
|
|
if (!hasUserDefinedMaps())
|
|
return success();
|
|
|
|
for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
|
|
if (failed(verifyExtendedBatchVariantMatmulSemantic(*this, opIndex)))
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
LogicalResult BatchMatmulOp::fold(FoldAdaptor,
|
|
SmallVectorImpl<OpFoldResult> &) {
|
|
return memref::foldMemRefCast(*this);
|
|
}
|
|
|
|
void BatchMatmulOp::getEffects(
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
&effects) {
|
|
if (hasPureTensorSemantics())
|
|
return;
|
|
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
Speculation::Speculatability BatchMatmulOp::getSpeculatability() {
|
|
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ElementwiseOp
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
namespace {
|
|
struct ArityGroupAndKind {
|
|
// The enum class {Unary, Binary, Ternary, ..}
|
|
ElementwiseArityGroup arityGroup;
|
|
|
|
// The kind (e.g. `exp` or `add`) belonging to the arity group.
|
|
union Kind {
|
|
UnaryFn unaryFn;
|
|
BinaryFn binaryFn;
|
|
TernaryFn ternaryFn;
|
|
} kind;
|
|
};
|
|
|
|
unsigned getArityGroupAsUInt(ElementwiseArityGroup arityGroup) {
|
|
return static_cast<unsigned>(arityGroup);
|
|
}
|
|
} // namespace
|
|
|
|
static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind) {
|
|
constexpr int lastUnary = static_cast<int>(ElementwiseCaseLimits::LastUnary);
|
|
constexpr int lastBinary =
|
|
static_cast<int>(ElementwiseCaseLimits::LastBinary);
|
|
constexpr int lastTernary =
|
|
static_cast<int>(ElementwiseCaseLimits::LastTernary);
|
|
|
|
int val = static_cast<int>(kind);
|
|
ArityGroupAndKind result;
|
|
|
|
if (val < lastUnary) {
|
|
result.arityGroup = ElementwiseArityGroup::Unary;
|
|
result.kind.unaryFn = static_cast<UnaryFn>(val);
|
|
return result;
|
|
}
|
|
if (val < lastBinary) {
|
|
result.arityGroup = ElementwiseArityGroup::Binary;
|
|
result.kind.binaryFn = static_cast<BinaryFn>(val - lastUnary);
|
|
return result;
|
|
}
|
|
if (val >= lastTernary) {
|
|
llvm_unreachable("unhandled ElementwiseFn");
|
|
}
|
|
result.arityGroup = ElementwiseArityGroup::Ternary;
|
|
result.kind.ternaryFn = static_cast<TernaryFn>(val - lastBinary);
|
|
return result;
|
|
}
|
|
|
|
SmallVector<utils::IteratorType> ElementwiseOp::getIteratorTypesArray() {
|
|
auto rank = getResultRank();
|
|
return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
|
|
}
|
|
|
|
SmallVector<AffineMap>
|
|
ElementwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims,
|
|
MLIRContext *context) {
|
|
auto map = AffineMap::getMultiDimIdentityMap(numDims, context);
|
|
return SmallVector<AffineMap>(numMaps, map);
|
|
}
|
|
|
|
ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
// Expect e.g. `kind = #linalg.elemwise_kind<add>`
|
|
Attribute attr;
|
|
mlir::linalg::ElementwiseKind elemwiseKindVal;
|
|
if (parser.parseKeyword("kind") || parser.parseEqual())
|
|
return failure();
|
|
|
|
if (succeeded(parser.parseAttribute(attr))) {
|
|
auto elemwiseKindAttr = dyn_cast<ElementwiseKindAttr>(attr);
|
|
if (!elemwiseKindAttr)
|
|
return parser.emitError(parser.getCurrentLocation(),
|
|
"expected ElementwiseKind attribute");
|
|
elemwiseKindVal = elemwiseKindAttr.getValue();
|
|
} else {
|
|
return parser.emitError(parser.getCurrentLocation(),
|
|
"expected operation 'kind' attribute");
|
|
}
|
|
result.addAttribute(
|
|
"kind", ElementwiseKindAttr::get(parser.getContext(), elemwiseKindVal));
|
|
|
|
// Parse optional `indexing_maps`
|
|
SmallVector<Attribute, 3> indexingMapsAttr;
|
|
Attribute mapAttr;
|
|
if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
|
|
if (parser.parseEqual())
|
|
return failure();
|
|
if (parser.parseLSquare())
|
|
return failure();
|
|
do {
|
|
if (parser.parseAttribute(mapAttr))
|
|
return failure();
|
|
if (!isa<AffineMapAttr>(mapAttr))
|
|
return parser.emitError(parser.getCurrentLocation(),
|
|
"expected affine map attribute");
|
|
indexingMapsAttr.push_back(mapAttr);
|
|
if (parser.parseOptionalComma())
|
|
break;
|
|
} while (true);
|
|
if (parser.parseRSquare())
|
|
return failure();
|
|
}
|
|
// At this stage of parsing the only way to infer number of region
|
|
// args is through op kind, as input output tensors are not parsed yet.
|
|
auto arityGroupAndKind = getArityGroupAndKind(elemwiseKindVal);
|
|
int numRegionArgs =
|
|
getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 /*output*/;
|
|
if (parseNamedStructuredOp(parser, result, numRegionArgs,
|
|
ElementwiseOp::getRegionBuilder())) {
|
|
return parser.emitError(parser.getCurrentLocation(),
|
|
"unable to parse elemwise op");
|
|
}
|
|
|
|
// Initialize indexingMaps, if not supplied explicitly.
|
|
if (indexingMapsAttr.empty()) {
|
|
// We need to infer the numDims of the indexing maps from the output
|
|
// type which is already parsed by now.
|
|
auto resultType = result.operands[result.operands.size() - 1].getType();
|
|
auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
|
|
if (!shapedType)
|
|
return parser.emitError(parser.getCurrentLocation(),
|
|
"return type needs to be shaped type");
|
|
auto numDims = shapedType.getRank();
|
|
indexingMapsAttr = llvm::map_to_vector(
|
|
ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,
|
|
parser.getContext()),
|
|
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
|
|
}
|
|
|
|
result.addAttribute("indexing_maps",
|
|
parser.getBuilder().getArrayAttr(indexingMapsAttr));
|
|
return success();
|
|
}
|
|
|
|
void ElementwiseOp::print(OpAsmPrinter &p) {
|
|
p << " kind=";
|
|
p.printAttribute(getKindAttr());
|
|
SmallVector<StringRef, 3> elidedAttrs = {"operandSegmentSizes", "kind",
|
|
"indexing_maps"};
|
|
unsigned arity =
|
|
getArityGroupAsUInt(getArityGroupAndKind(getKind()).arityGroup);
|
|
unsigned numDims = getResultRank();
|
|
|
|
SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
|
|
ElementwiseOp::getDefaultIndexingMaps(arity + 1 /*output*/, numDims,
|
|
getContext()),
|
|
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
|
|
|
|
if (!llvm::equal(getIndexingMaps(), indexingMaps))
|
|
p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
|
|
|
|
printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
|
|
elidedAttrs);
|
|
}
|
|
|
|
LogicalResult ElementwiseOp::verify() {
|
|
// All necessary checks are done either by
|
|
// - EnumAttr (e.g. unknown operation kind)
|
|
// - verifyStructuredOpInterface (incorrect map, sizes).
|
|
return success();
|
|
}
|
|
|
|
/// Implements the block region builder for the ElementwiseOp. This is called by
|
|
/// 'fillStructuredOpRegion'.
|
|
void ElementwiseOp::regionBuilder(
|
|
ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs,
|
|
function_ref<InFlightDiagnostic()> emitError) {
|
|
ElementwiseKind elemwiseKind;
|
|
for (auto attr : attrs) {
|
|
if (attr.getName() == b.getStringAttr("kind")) {
|
|
auto kindAttr = dyn_cast<ElementwiseKindAttr>(attr.getValue());
|
|
assert(kindAttr && "op kind attribute incorrectly set");
|
|
elemwiseKind = kindAttr.getValue();
|
|
break;
|
|
}
|
|
}
|
|
|
|
ArityGroupAndKind groupAndKind = getArityGroupAndKind(elemwiseKind);
|
|
auto arityGroup = groupAndKind.arityGroup;
|
|
auto kind = groupAndKind.kind;
|
|
if (emitError && block.getNumArguments() !=
|
|
getArityGroupAsUInt(arityGroup) + 1 /*output*/) {
|
|
emitError() << "Elementwise regionBuilder expects "
|
|
<< (getArityGroupAsUInt(arityGroup) + 1) << " args, got "
|
|
<< block.getNumArguments();
|
|
return;
|
|
}
|
|
assert(block.getNumArguments() ==
|
|
getArityGroupAsUInt(arityGroup) + 1 /*output*/
|
|
&& "Elementwise regionBuilder number of block args mismatch");
|
|
|
|
RegionBuilderHelper helper(b, block);
|
|
SmallVector<Value> yields;
|
|
Value result;
|
|
|
|
if (arityGroup == ElementwiseArityGroup::Unary) {
|
|
result = helper.buildUnaryFn(kind.unaryFn, block.getArgument(0));
|
|
|
|
} else if (arityGroup == ElementwiseArityGroup::Binary) {
|
|
result = helper.buildBinaryFn(kind.binaryFn, block.getArgument(0),
|
|
block.getArgument(1));
|
|
|
|
} else if (arityGroup == ElementwiseArityGroup::Ternary) {
|
|
result = helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0),
|
|
block.getArgument(1), block.getArgument(2));
|
|
|
|
} else {
|
|
assert(false && "found unhandled category in elemwise");
|
|
}
|
|
|
|
yields.push_back(result);
|
|
helper.yieldOutputs(yields);
|
|
}
|
|
|
|
LogicalResult ElementwiseOp::fold(FoldAdaptor,
|
|
SmallVectorImpl<OpFoldResult> &) {
|
|
return memref::foldMemRefCast(*this);
|
|
}
|
|
|
|
void ElementwiseOp::getEffects(
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
&effects) {
|
|
if (hasPureTensorSemantics())
|
|
return;
|
|
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
Speculation::Speculatability ElementwiseOp::getSpeculatability() {
|
|
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// PackOp/UnPackOp Common
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
template <typename OpTy, typename>
|
|
SmallVector<int64_t>
|
|
getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack) {
|
|
RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value)
|
|
? packOrUnPack.getDestType()
|
|
: packOrUnPack.getSourceType();
|
|
RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
|
|
? packOrUnPack.getSourceType()
|
|
: packOrUnPack.getDestType();
|
|
SmallVector<int64_t> result(
|
|
packedType.getShape().take_front(unpackedType.getRank()));
|
|
if (!packOrUnPack.getOuterDimsPerm().empty()) {
|
|
applyPermutationToVector(
|
|
result, invertPermutationVector(packOrUnPack.getOuterDimsPerm()));
|
|
}
|
|
return result;
|
|
}
|
|
template SmallVector<int64_t>
|
|
getPackedOuterShapeWithoutTransposition<PackOp>(PackOp);
|
|
template SmallVector<int64_t>
|
|
getPackedOuterShapeWithoutTransposition<UnPackOp>(UnPackOp);
|
|
|
|
// Given the (potentially) updated packed type, `newPackedTy`, generates an
|
|
// updated mixed-tile-sizes attribute. A tile size is updated only
|
|
// when:
|
|
// * a dim from newPackedTy is static, and
|
|
// * the corresponding size from mixedTiles is still dynamic.
|
|
// Otherwise, the original tile size is preserved.
|
|
// Note - packed-type-dim and mixed-tile-size should always match!
|
|
static SmallVector<OpFoldResult>
|
|
getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy,
|
|
SmallVector<OpFoldResult> mixedTiles) {
|
|
SmallVector<OpFoldResult> newMixedTileSizes;
|
|
for (auto it : llvm::zip(cast<ShapedType>(newPackedTy)
|
|
.getShape()
|
|
.take_back(mixedTiles.size()),
|
|
mixedTiles)) {
|
|
int64_t shape = std::get<0>(it);
|
|
if (shape == ShapedType::kDynamic) {
|
|
newMixedTileSizes.push_back(std::get<1>(it));
|
|
continue;
|
|
}
|
|
|
|
// If the current result dim is static, update the dynamic mixed-size
|
|
// (provided the original value is dynamic).
|
|
OpFoldResult tile = std::get<1>(it);
|
|
if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tile)) {
|
|
// Already a constant
|
|
newMixedTileSizes.push_back(tile);
|
|
} else {
|
|
assert(getConstantIntValue(tile).value() == shape &&
|
|
"tile size and dim size don't match!");
|
|
newMixedTileSizes.push_back(
|
|
(rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
|
|
}
|
|
}
|
|
|
|
return newMixedTileSizes;
|
|
}
|
|
|
|
template <typename OpTy>
|
|
static LogicalResult
|
|
reifyResultShapesImpl(OpTy op, OpBuilder &builder,
|
|
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
|
|
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
|
|
"applies to only pack or unpack operations");
|
|
int64_t destRank = op.getDestRank();
|
|
reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(destRank));
|
|
reifiedReturnShapes[0] =
|
|
tensor::getMixedSizes(builder, op.getLoc(), op.getDest());
|
|
return success();
|
|
}
|
|
|
|
template <typename OpTy>
|
|
static DenseMap<int64_t, OpFoldResult> getDimAndTileMappingImpl(OpTy op) {
|
|
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
|
|
"applies to only pack or unpack operations");
|
|
DenseMap<int64_t, OpFoldResult> dimAndTileMapping;
|
|
ArrayRef<int64_t> dimsToTile = op.getInnerDimsPos();
|
|
SmallVector<OpFoldResult> tiles = op.getMixedTiles();
|
|
assert(tiles.size() == dimsToTile.size() &&
|
|
"tiles must match indices of dimension to block");
|
|
// bind the dimension `i` with the tile factor.
|
|
for (auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
|
|
dimAndTileMapping[dimsToTile[i]] = tiles[i];
|
|
return dimAndTileMapping;
|
|
}
|
|
|
|
template <typename OpTy>
|
|
static SmallVector<OpFoldResult> getMixedTilesImpl(OpTy op) {
|
|
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
|
|
"applies to only pack or unpack operations");
|
|
Builder builder(op);
|
|
SmallVector<OpFoldResult> mixedInnerTiles;
|
|
unsigned dynamicValIndex = 0;
|
|
for (int64_t staticTile : op.getStaticInnerTiles()) {
|
|
if (ShapedType::isStatic(staticTile))
|
|
mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile));
|
|
else
|
|
mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
|
|
}
|
|
return mixedInnerTiles;
|
|
}
|
|
|
|
template <typename OpTy>
|
|
static SmallVector<int64_t> getStaticTilesImpl(OpTy op) {
|
|
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
|
|
"applies to only pack or unpack operations");
|
|
SmallVector<Value> dynamicTiles;
|
|
SmallVector<int64_t> staticTiles;
|
|
dispatchIndexOpFoldResults(op.getMixedTiles(), dynamicTiles, staticTiles);
|
|
return staticTiles;
|
|
}
|
|
|
|
/// Returns true if `dimsPos` is invalid. It is invalid when:
|
|
/// a) It contains duplicate.
|
|
/// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank).
|
|
/// c) The number of elements in `dimsPos` is > than `rank`.
|
|
static bool isInvalidPackingPosSpecification(ArrayRef<int64_t> dimsPos,
|
|
size_t rank) {
|
|
size_t dimsPosSize = dimsPos.size();
|
|
if (dimsPosSize > rank)
|
|
return true;
|
|
DenseSet<int64_t> uniqued(llvm::from_range, dimsPos);
|
|
if (dimsPosSize != uniqued.size())
|
|
return true;
|
|
return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
|
|
return dimPos < 0 || dimPos >= static_cast<int64_t>(rank);
|
|
});
|
|
}
|
|
|
|
template <typename OpTy>
|
|
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
|
|
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
|
|
"applies to only pack or unpack operations");
|
|
Operation *op = packOrUnPack.getOperation();
|
|
|
|
// Return true if we have a zero-value tile.
|
|
auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
|
|
return llvm::any_of(tiles, isZeroInteger);
|
|
};
|
|
|
|
// Verify tiles. Do not allow zero tiles.
|
|
SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
|
|
if (hasZeros(mixedTiles))
|
|
return op->emitError("invalid zero tile factor");
|
|
|
|
// Verify inner_dims_pos and outer_dims_perm.
|
|
RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
|
|
? packOrUnPack.getSourceType()
|
|
: packOrUnPack.getDestType();
|
|
size_t unpackedRank = unpackedType.getRank();
|
|
ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
|
|
ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
|
|
if (isInvalidPackingPosSpecification(innerDimsPos, unpackedRank))
|
|
return op->emitError("invalid inner_dims_pos vector");
|
|
if (isInvalidPackingPosSpecification(outerDimPerm, unpackedRank))
|
|
return op->emitError("invalid outer_dims_perm vector");
|
|
if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
|
|
return op->emitError("outer_dims_perm must be a permutation or empty");
|
|
|
|
// Tiling factors must be less than or equal to the input rank for pack (or
|
|
// output rank for unpack), and must match the number of `inner_dims_pos`.
|
|
if (mixedTiles.size() > unpackedRank) {
|
|
return op->emitError("tiling factors must be less than or equal to the "
|
|
"input rank for pack or output rank for unpack");
|
|
}
|
|
if (mixedTiles.size() != innerDimsPos.size()) {
|
|
return op->emitError(
|
|
"tiling factors must equal the number of dimensions to tile");
|
|
}
|
|
|
|
ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
|
|
? packOrUnPack.getDestType()
|
|
: packOrUnPack.getSourceType();
|
|
size_t packedRank = packedType.getRank();
|
|
// Require output rank to match input rank + number of blocking factors.
|
|
size_t expectedPackedRank = unpackedRank + mixedTiles.size();
|
|
if (expectedPackedRank != packedRank) {
|
|
return op->emitError(
|
|
"packed rank != (unpacked rank + num tiling factors), got ")
|
|
<< packedRank << " != " << expectedPackedRank;
|
|
}
|
|
|
|
// Verify result shape is greater than the minimum expected
|
|
// by the pack operation, and that the output shape
|
|
// represents full tiles.
|
|
RankedTensorType expectedPackedType = PackOp::inferPackedType(
|
|
unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
|
|
if (!llvm::all_of(
|
|
llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
|
|
mixedTiles),
|
|
[](std::tuple<int64_t, OpFoldResult> it) {
|
|
int64_t shape = std::get<0>(it);
|
|
if (Attribute attr =
|
|
llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
|
|
IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
|
|
int64_t staticTileSize = intAttr.getValue().getSExtValue();
|
|
return shape == staticTileSize;
|
|
}
|
|
return ShapedType::isDynamic(shape);
|
|
})) {
|
|
return op->emitError("mismatch in inner tile sizes specified and shaped of "
|
|
"tiled dimension in the packed type");
|
|
}
|
|
if (failed(verifyCompatibleShape(expectedPackedType.getShape(),
|
|
packedType.getShape()))) {
|
|
return op->emitError("expected ")
|
|
<< expectedPackedType << " for the packed domain value, got "
|
|
<< packedType;
|
|
}
|
|
return success();
|
|
}
|
|
|
|
namespace {
|
|
/// Subset of PackOp/UnPackOp fields used to compute the result of applying
|
|
/// various permutations to the op.
|
|
// TODO: Add linalg.transpose + pack/unpack folding patterns that just reuse
|
|
// these. These may or may not become true foldings / canonicalizations
|
|
// depending on how aggressive we want to be in automatically folding
|
|
// transposes.
|
|
struct PackOrUnPackTransposeResult {
|
|
SmallVector<int64_t> innerDimsPos;
|
|
SmallVector<OpFoldResult> innerTiles;
|
|
SmallVector<int64_t> outerDimsPerm;
|
|
};
|
|
} // namespace
|
|
|
|
template <typename OpTy>
|
|
static PackOrUnPackTransposeResult
|
|
commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp,
|
|
ArrayRef<int64_t> innerPermutation,
|
|
ArrayRef<int64_t> outerPermutation) {
|
|
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
|
|
"applies to only pack or unpack operations");
|
|
assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
|
|
"some permutation must be non-empty");
|
|
PackOrUnPackTransposeResult metadata;
|
|
metadata.innerDimsPos =
|
|
SmallVector<int64_t>(packOrUnPackOp.getInnerDimsPos());
|
|
metadata.innerTiles =
|
|
SmallVector<OpFoldResult>(packOrUnPackOp.getMixedTiles());
|
|
int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
|
|
? packOrUnPackOp.getSourceRank()
|
|
: packOrUnPackOp.getDestRank();
|
|
metadata.outerDimsPerm =
|
|
packOrUnPackOp.getOuterDimsPerm().empty()
|
|
? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
|
|
: SmallVector<int64_t>(packOrUnPackOp.getOuterDimsPerm());
|
|
if (!innerPermutation.empty()) {
|
|
assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
|
|
isPermutationVector(innerPermutation) &&
|
|
"invalid inner permutation");
|
|
applyPermutationToVector(metadata.innerDimsPos, innerPermutation);
|
|
applyPermutationToVector(metadata.innerTiles, innerPermutation);
|
|
}
|
|
if (!outerPermutation.empty()) {
|
|
assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
|
|
isPermutationVector(outerPermutation) &&
|
|
"invalid outer permutation");
|
|
applyPermutationToVector(metadata.outerDimsPerm, outerPermutation);
|
|
}
|
|
return metadata;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// PackOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void PackOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
|
|
setNameFn(getResult(), "pack");
|
|
}
|
|
|
|
void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
|
|
Value dest, ArrayRef<int64_t> innerDimsPos,
|
|
ArrayRef<OpFoldResult> innerTiles,
|
|
std::optional<Value> paddingValue,
|
|
ArrayRef<int64_t> outerDimsPerm) {
|
|
assert(innerDimsPos.size() == innerTiles.size() &&
|
|
"number of tile sizes specified must match the specified number of "
|
|
"original dimensions to be tiled");
|
|
SmallVector<int64_t> staticTileSizes;
|
|
SmallVector<Value> dynamicTileSizes;
|
|
dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
|
|
build(builder, state, dest.getType(), source, dest,
|
|
paddingValue ? *paddingValue : nullptr,
|
|
outerDimsPerm.empty() ? nullptr
|
|
: builder.getDenseI64ArrayAttr(outerDimsPerm),
|
|
builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
|
|
builder.getDenseI64ArrayAttr(staticTileSizes));
|
|
}
|
|
|
|
LogicalResult
|
|
PackOp::reifyResultShapes(OpBuilder &builder,
|
|
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
|
|
return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
|
|
}
|
|
|
|
DenseMap<int64_t, OpFoldResult> PackOp::getDimAndTileMapping() {
|
|
return getDimAndTileMappingImpl(*this);
|
|
}
|
|
|
|
SmallVector<OpFoldResult> PackOp::getMixedTiles() {
|
|
return getMixedTilesImpl(*this);
|
|
}
|
|
|
|
SmallVector<int64_t> PackOp::getStaticTiles() {
|
|
return getStaticTilesImpl(*this);
|
|
}
|
|
|
|
ArrayRef<int64_t> PackOp::getAllOuterDims() {
|
|
ShapedType inputType = getSourceType();
|
|
int64_t inputRank = inputType.getRank();
|
|
return getDestType().getShape().take_front(inputRank);
|
|
}
|
|
|
|
SmallVector<int64_t> PackOp::getTiledOuterDims() {
|
|
auto innerDimsPos = getInnerDimsPos();
|
|
auto packedShape = getDestType().getShape();
|
|
SmallVector<int64_t> res;
|
|
|
|
for (auto index : innerDimsPos)
|
|
res.push_back(packedShape[index]);
|
|
|
|
return res;
|
|
}
|
|
|
|
bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
|
|
ArrayRef<int64_t> innerDimsPos,
|
|
ArrayRef<int64_t> outputShape,
|
|
ArrayRef<int64_t> outerDimsPerm,
|
|
ArrayRef<OpFoldResult> innerTiles) {
|
|
SmallVector<int64_t> outputTileSizes(
|
|
outputShape.take_front(inputShape.size()));
|
|
if (!outerDimsPerm.empty()) {
|
|
assert(outerDimsPerm.size() == outputTileSizes.size() &&
|
|
"expected output and outer_dims_perm to have same size");
|
|
applyPermutationToVector(outputTileSizes,
|
|
invertPermutationVector(outerDimsPerm));
|
|
}
|
|
for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
|
|
if (ShapedType::isDynamic(inputShape[pos]))
|
|
continue;
|
|
std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
|
|
|
|
if (!constantTile) {
|
|
if (ShapedType::isStatic(outputTileSizes[pos]) &&
|
|
(inputShape[pos] % outputTileSizes[pos] != 0))
|
|
return true;
|
|
} else if (inputShape[pos] % (*constantTile) != 0) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
LogicalResult PackOp::verify() {
|
|
if (failed(commonVerifierPackAndUnPackOp(*this)))
|
|
return failure();
|
|
|
|
// Verify padding value, and bail out if the tile does not divide the
|
|
// dimension fully. In the case of dynamic tile factors or dimensions, having
|
|
// a partial tile is undefined behavior.
|
|
auto paddingValue = getPaddingValue();
|
|
if (paddingValue &&
|
|
paddingValue.getType() != getSourceType().getElementType()) {
|
|
return emitOpError("expected padding_value has ")
|
|
<< getSourceType().getElementType()
|
|
<< " but got: " << paddingValue.getType();
|
|
}
|
|
|
|
if (!paddingValue &&
|
|
requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
|
|
getDestType().getShape(), getOuterDimsPerm(),
|
|
getMixedTiles())) {
|
|
return emitOpError(
|
|
"invalid tile factor or output size provided. Only full tiles are "
|
|
"supported when padding_value is not set");
|
|
}
|
|
return success();
|
|
}
|
|
|
|
/// Converts OpFoldResults to int64_t shape entries, unconditionally mapping all
|
|
/// Value's to kDynamic, even if they are arith.constant values.
|
|
static SmallVector<int64_t>
|
|
asShapeWithAnyValueAsDynamic(ArrayRef<OpFoldResult> ofrs) {
|
|
SmallVector<int64_t> result;
|
|
for (auto o : ofrs) {
|
|
// Have to do this first, as getConstantIntValue special-cases constants.
|
|
if (llvm::dyn_cast_if_present<Value>(o))
|
|
result.push_back(ShapedType::kDynamic);
|
|
else
|
|
result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
|
|
}
|
|
return result;
|
|
}
|
|
|
|
/// Helper for PackOp::{getResultShape,inferPackedType}. Returns the shape of
|
|
/// the packed type. Having a shared helper helps implement these two methods in
|
|
/// a way that ensures that they agree on which dimensions are dynamic.
|
|
static SmallVector<int64_t> getPackOpResultTypeShape(
|
|
ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
|
|
ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
|
|
SmallVector<int64_t> resultShape = llvm::to_vector(sourceShape);
|
|
for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
|
|
if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
|
|
continue;
|
|
if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
|
|
resultShape[tiledDim.value()] = ShapedType::kDynamic;
|
|
continue;
|
|
}
|
|
resultShape[tiledDim.value()] = llvm::divideCeilSigned(
|
|
resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
|
|
}
|
|
|
|
// Swap tile loops if outer_dims_perm is available.
|
|
if (!outerDimsPerm.empty())
|
|
applyPermutationToVector(resultShape, outerDimsPerm);
|
|
|
|
// Append the inner tile dimensions.
|
|
resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
|
|
return resultShape;
|
|
}
|
|
|
|
SmallVector<OpFoldResult> PackOp::getResultShape(
|
|
OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
|
|
ArrayRef<OpFoldResult> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
|
|
ArrayRef<int64_t> outerDimsPerm) {
|
|
SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
|
|
|
|
AffineExpr s0, s1;
|
|
bindSymbols(builder.getContext(), s0, s1);
|
|
AffineExpr ceilDivExpr = s0.ceilDiv(s1);
|
|
for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
|
|
resultDims[tiledDim.value()] = affine::makeComposedFoldedAffineApply(
|
|
builder, loc, ceilDivExpr,
|
|
{resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
|
|
}
|
|
if (!outerDimsPerm.empty())
|
|
applyPermutationToVector(resultDims, outerDimsPerm);
|
|
resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
|
|
|
|
SmallVector<int64_t> resultTypeShape =
|
|
getPackOpResultTypeShape(asShapeWithAnyValueAsDynamic(sourceDims),
|
|
asShapeWithAnyValueAsDynamic(innerTileSizes),
|
|
innerDimsPos, outerDimsPerm);
|
|
|
|
// Fix-up `resultDims` to ensure that they are Value's if and only if the
|
|
// result type shape says it's a dynamic dim. This is needed as callers may
|
|
// use dispatchIndexOpFoldResults on the result, and rely on exact number of
|
|
// dynamic dims returned by that.
|
|
for (unsigned i = 0; i < resultDims.size(); ++i) {
|
|
if (ShapedType::isStatic(resultTypeShape[i]))
|
|
continue;
|
|
resultDims[i] =
|
|
getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]);
|
|
}
|
|
|
|
return resultDims;
|
|
}
|
|
|
|
/// Get the expected packed type based on source type, tile factors, position of
|
|
/// the inner tiles and permutation of the outer tiled loop.
|
|
RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
|
|
ArrayRef<int64_t> innerTileSizes,
|
|
ArrayRef<int64_t> innerDimsPos,
|
|
ArrayRef<int64_t> outerDimsPerm) {
|
|
SmallVector<int64_t> resultShape = getPackOpResultTypeShape(
|
|
sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
|
|
return RankedTensorType::get(resultShape, sourceType.getElementType());
|
|
}
|
|
|
|
Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
|
|
ArrayRef<OpFoldResult> innerTileSizes,
|
|
ArrayRef<int64_t> innerDimsPos,
|
|
ArrayRef<int64_t> outerDimsPerm) {
|
|
AffineExpr dim0, dim1;
|
|
bindDims(b.getContext(), dim0, dim1);
|
|
auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
|
|
return affine::makeComposedFoldedAffineApply(b, loc, dim0.ceilDiv(dim1),
|
|
{v1, v2});
|
|
};
|
|
|
|
SmallVector<OpFoldResult> mixedSizes;
|
|
for (auto [index, value] : llvm::enumerate(
|
|
llvm::cast<RankedTensorType>(source.getType()).getShape())) {
|
|
if (ShapedType::isDynamic(value))
|
|
mixedSizes.push_back(
|
|
tensor::DimOp::create(b, loc, source, index).getResult());
|
|
else
|
|
mixedSizes.push_back(b.getIndexAttr(value));
|
|
}
|
|
for (auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
|
|
int64_t dimPos = std::get<0>(it);
|
|
OpFoldResult tileSize = std::get<1>(it);
|
|
mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
|
|
}
|
|
if (!outerDimsPerm.empty())
|
|
applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
|
|
|
|
mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
|
|
auto elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
|
|
return tensor::EmptyOp::create(b, loc, mixedSizes, elemType);
|
|
}
|
|
|
|
PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
|
|
ArrayRef<int64_t> innerPermutation,
|
|
ArrayRef<int64_t> outerPermutation) {
|
|
PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
|
|
*this, innerPermutation, outerPermutation);
|
|
Value transposedDest =
|
|
createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
|
|
metadata.innerDimsPos, metadata.outerDimsPerm);
|
|
return PackOp::create(b, loc, getSource(), transposedDest,
|
|
metadata.innerDimsPos, metadata.innerTiles,
|
|
getPaddingValue(), metadata.outerDimsPerm);
|
|
}
|
|
|
|
/// Returns true if the tiles and the tiled dims are constant.
|
|
template <typename OpTy>
|
|
bool areTilesAndTiledDimsAllConstant(OpTy op) {
|
|
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
|
|
"applies to only pack or unpack operations");
|
|
ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
|
|
? op.getDestType()
|
|
: op.getSourceType();
|
|
SmallVector<OpFoldResult> mixedTiles = op.getMixedTiles();
|
|
for (auto [dimDest, tile] : llvm::zip(
|
|
packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
|
|
std::optional<int64_t> constTileSize = getConstantIntValue(tile);
|
|
if (!constTileSize || ShapedType::isDynamic(dimDest))
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
Speculation::Speculatability PackOp::getSpeculatability() {
|
|
if (getPaddingValue())
|
|
return Speculation::Speculatable;
|
|
|
|
// The verifier rejects already operations if we can statically prove that the
|
|
// sizes of the tiles do not divide perfectly the dimension; thus, check only
|
|
// to have constant tiles and tiled inner dimensions.
|
|
if (!areTilesAndTiledDimsAllConstant(*this))
|
|
return Speculation::NotSpeculatable;
|
|
|
|
return Speculation::Speculatable;
|
|
}
|
|
|
|
// Return true if `inner_dims_pos` and `outer_dims_perm` target the same
|
|
// dimensions for pack and unpack.
|
|
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
|
|
if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
|
|
return false;
|
|
if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
|
|
return true;
|
|
// Outer dims permutation is optional.
|
|
// To compare unbalanced pack-unpack pair, treat no permutation as equal to
|
|
// identity permutation.
|
|
return isIdentityPermutation(packOp.getOuterDimsPerm()) &&
|
|
isIdentityPermutation(unPackOp.getOuterDimsPerm());
|
|
}
|
|
|
|
// Return true if pack and unpack have the same tiles.
|
|
// Same SSA values or same integer constants.
|
|
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
|
|
auto packTiles = packOp.getMixedTiles();
|
|
auto unPackTiles = unPackOp.getMixedTiles();
|
|
if (packTiles.size() != unPackTiles.size())
|
|
return false;
|
|
for (size_t i = 0, e = packTiles.size(); i < e; i++) {
|
|
if (!isEqualConstantIntOrValue(packTiles[i], unPackTiles[i]))
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
/// Returns true if the pack op does not need a padding value.
|
|
static bool paddingIsNotNeeded(PackOp op) {
|
|
auto srcType = op.getSourceType();
|
|
if (llvm::any_of(op.getInnerDimsPos(),
|
|
[&](int64_t pos) { return srcType.isDynamicDim(pos); }))
|
|
return false;
|
|
if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
|
|
return false;
|
|
return !PackOp::requirePaddingValue(
|
|
srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
|
|
op.getOuterDimsPerm(), op.getMixedTiles());
|
|
}
|
|
|
|
/// Returns true if the `srcShape` or `destShape` is different from the one in
|
|
/// `packOp` and populates each with the inferred static shape.
|
|
static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
|
|
SmallVectorImpl<int64_t> &destShape) {
|
|
bool changeNeeded = false;
|
|
srcShape.assign(packOp.getSourceType().getShape().begin(),
|
|
packOp.getSourceType().getShape().end());
|
|
destShape.assign(packOp.getDestType().getShape().begin(),
|
|
packOp.getDestType().getShape().end());
|
|
llvm::SmallSetVector<int64_t, 4> innerDims;
|
|
innerDims.insert_range(packOp.getInnerDimsPos());
|
|
SmallVector<int64_t> inverseOuterDimsPerm;
|
|
if (!packOp.getOuterDimsPerm().empty())
|
|
inverseOuterDimsPerm = invertPermutationVector(packOp.getOuterDimsPerm());
|
|
int srcRank = packOp.getSourceRank();
|
|
for (auto i : llvm::seq<int64_t>(0, srcRank)) {
|
|
if (innerDims.contains(i))
|
|
continue;
|
|
int64_t srcPos = i;
|
|
int64_t destPos = i;
|
|
if (!inverseOuterDimsPerm.empty())
|
|
destPos = inverseOuterDimsPerm[srcPos];
|
|
if (ShapedType::isDynamic(srcShape[srcPos]) ==
|
|
ShapedType::isDynamic(destShape[destPos])) {
|
|
continue;
|
|
}
|
|
int64_t size = srcShape[srcPos];
|
|
if (ShapedType::isDynamic(size))
|
|
size = destShape[destPos];
|
|
srcShape[srcPos] = size;
|
|
destShape[destPos] = size;
|
|
changeNeeded = true;
|
|
}
|
|
return changeNeeded;
|
|
}
|
|
|
|
LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
|
|
// Fold an pack(unpack(x)) to x.
|
|
if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
|
|
if (unPackOp.getSourceType() != packOp.getDestType())
|
|
return failure();
|
|
if (packOp.getPaddingValue() ||
|
|
!hasSameInnerOuterAttribute(packOp, unPackOp) ||
|
|
!haveSameTiles(packOp, unPackOp))
|
|
return failure();
|
|
rewriter.replaceOp(packOp, unPackOp.getSource());
|
|
return success();
|
|
}
|
|
|
|
// Fold optional PaddingValue operand away if padding is not needed.
|
|
if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) {
|
|
rewriter.startOpModification(packOp);
|
|
packOp.getPaddingValueMutable().clear();
|
|
rewriter.finalizeOpModification(packOp);
|
|
return success();
|
|
}
|
|
|
|
// Insert tensor.cast ops if static shape inference is available..
|
|
SmallVector<int64_t> srcShape, destShape;
|
|
if (inferStaticShape(packOp, srcShape, destShape)) {
|
|
Location loc = packOp.getLoc();
|
|
Value source = packOp.getSource();
|
|
if (srcShape != packOp.getSourceType().getShape()) {
|
|
auto newSrcType = packOp.getSourceType().clone(srcShape);
|
|
source =
|
|
tensor::CastOp::create(rewriter, loc, newSrcType, packOp.getSource());
|
|
}
|
|
Value dest = packOp.getDest();
|
|
RankedTensorType originalResultType = packOp.getDestType();
|
|
bool needUpdateDestType = (destShape != originalResultType.getShape());
|
|
if (needUpdateDestType) {
|
|
auto newDestType = packOp.getDestType().clone(destShape);
|
|
dest =
|
|
tensor::CastOp::create(rewriter, loc, newDestType, packOp.getDest());
|
|
}
|
|
rewriter.modifyOpInPlace(packOp, [&] {
|
|
packOp.getSourceMutable().assign(source);
|
|
packOp.getDestMutable().assign(dest);
|
|
packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
|
|
});
|
|
// Insert a cast if needed
|
|
if (needUpdateDestType) {
|
|
rewriter.setInsertionPointAfter(packOp);
|
|
auto castOp =
|
|
tensor::CastOp::create(rewriter, loc, originalResultType, packOp);
|
|
rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
|
|
}
|
|
return success();
|
|
}
|
|
|
|
return failure();
|
|
}
|
|
|
|
template <typename PackOrUnpackOp>
|
|
static bool isLikePadUnPad(PackOrUnpackOp packOp,
|
|
RankedTensorType packedTensorType) {
|
|
static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
|
|
std::is_same<PackOrUnpackOp, UnPackOp>::value,
|
|
"Function meant for pack/unpack");
|
|
// This is a pad if packing only adds ones and we don't transpose dimensions.
|
|
|
|
// Check that we are not transposing any dimensions.
|
|
ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
|
|
int64_t numPackedDims = innerDimsPos.size();
|
|
auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
|
|
if (orderedDims != innerDimsPos) {
|
|
// Dimensions don't happen in order.
|
|
return false;
|
|
}
|
|
|
|
ArrayRef<int64_t> packedShape = packedTensorType.getShape();
|
|
int64_t packedRank = packedTensorType.getRank();
|
|
// At this point we know that we are taking numPackedDims outer
|
|
// dimensions and pushing them all the way as the inner most dimensions.
|
|
// What's left on the outer most dimensions is, in this order:
|
|
// - the factor of the packed dimensions, then
|
|
// - the untouched dimensions
|
|
// This shifting inward of dimensions is a no-op (as opposed to a transpose)
|
|
// if all the dimensions that bubble outerward are ones.
|
|
// Therefore check that all the dimensions but the numPackedDims inner most
|
|
// ones are ones.
|
|
return llvm::all_of(
|
|
llvm::seq<int64_t>(0, packedRank - numPackedDims),
|
|
[&packedShape](int64_t i) { return packedShape[i] == 1; });
|
|
}
|
|
|
|
bool PackOp::isLikePad() {
|
|
auto packedTensorType =
|
|
llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
|
|
return isLikePadUnPad(*this, packedTensorType);
|
|
}
|
|
|
|
OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
|
|
std::optional<Attribute> paddingValue;
|
|
if (auto pad = adaptor.getPaddingValue())
|
|
paddingValue = pad;
|
|
if (OpFoldResult reshapedSource = reshapeConstantSource(
|
|
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
|
|
getDestType(), paddingValue))
|
|
return reshapedSource;
|
|
return {};
|
|
}
|
|
|
|
/// Folds a tensor.cast op into a consuming PackOp op if the
|
|
/// `tensor.cast` has source that is more static than the consuming op.
|
|
///
|
|
/// Example:
|
|
/// ```mlir
|
|
/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
|
|
/// %2 = tensor.pack %1 ... : tensor<?x?xf32> ...
|
|
/// ```
|
|
///
|
|
/// folds into:
|
|
///
|
|
/// ```mlir
|
|
/// %2 = tensor.pack %0 ... : tensor<8x16xf32> ...
|
|
/// ```
|
|
struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
|
|
using OpRewritePattern<PackOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(PackOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
if (!tensor::hasFoldableTensorCastOperand(op))
|
|
return failure();
|
|
|
|
SmallVector<Type> newResultTypes(op->getResultTypes());
|
|
SmallVector<Value> newOperands =
|
|
tensor::getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
|
|
|
|
// Get the updated mixed-tile-sizes attribute.
|
|
SmallVector<OpFoldResult> newMixedTileSizes =
|
|
getNewMixedTileSizes(rewriter, newResultTypes[0], op.getMixedTiles());
|
|
|
|
// Clone op.
|
|
// TODO: Strictly speaking, discardable attributes should be _discarded_ at
|
|
// this point. However, in practice, we use them for things that we'd like
|
|
// to preserve. Implement a better abstraction.
|
|
PackOp newOp =
|
|
PackOp::create(rewriter, op.getLoc(), newOperands[0], newOperands[1],
|
|
op.getInnerDimsPos(), newMixedTileSizes,
|
|
op.getPaddingValue(), op.getOuterDimsPerm());
|
|
newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
|
|
|
|
// Replace op.
|
|
Value oldResult = op.getResult();
|
|
Value newResult = newOp.getResult();
|
|
Value replacement =
|
|
(newResult.getType() != oldResult.getType())
|
|
? tensor::CastOp::create(rewriter, op->getLoc(),
|
|
oldResult.getType(), newResult)
|
|
: newResult;
|
|
|
|
rewriter.replaceOp(op, {replacement});
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// UnPackOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void UnPackOp::getAsmResultNames(
|
|
function_ref<void(Value, StringRef)> setNameFn) {
|
|
setNameFn(getResult(), "unpack");
|
|
}
|
|
|
|
LogicalResult
|
|
UnPackOp::reifyResultShapes(OpBuilder &builder,
|
|
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
|
|
return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
|
|
}
|
|
|
|
DenseMap<int64_t, OpFoldResult> UnPackOp::getDimAndTileMapping() {
|
|
return getDimAndTileMappingImpl(*this);
|
|
}
|
|
|
|
SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
|
|
return getMixedTilesImpl(*this);
|
|
}
|
|
|
|
SmallVector<int64_t> UnPackOp::getStaticTiles() {
|
|
return getStaticTilesImpl(*this);
|
|
}
|
|
|
|
ArrayRef<int64_t> UnPackOp::getAllOuterDims() {
|
|
ShapedType destType = getDestType();
|
|
int64_t destRank = destType.getRank();
|
|
return getSourceType().getShape().take_front(destRank);
|
|
}
|
|
|
|
SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
|
|
auto innerDimsPos = getInnerDimsPos();
|
|
auto packedShape = getSourceType().getShape();
|
|
SmallVector<int64_t> res;
|
|
|
|
for (auto index : innerDimsPos)
|
|
res.push_back(packedShape[index]);
|
|
|
|
return res;
|
|
}
|
|
|
|
LogicalResult UnPackOp::verify() {
|
|
return commonVerifierPackAndUnPackOp(*this);
|
|
}
|
|
|
|
Speculation::Speculatability UnPackOp::getSpeculatability() {
|
|
// See PackOp::getSpeculatability.
|
|
if (!areTilesAndTiledDimsAllConstant(*this))
|
|
return Speculation::NotSpeculatable;
|
|
|
|
return Speculation::Speculatable;
|
|
}
|
|
|
|
void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
|
|
Value dest, ArrayRef<int64_t> innerDimsPos,
|
|
ArrayRef<OpFoldResult> innerTiles,
|
|
ArrayRef<int64_t> outerDimsPerm) {
|
|
assert(innerDimsPos.size() == innerTiles.size() &&
|
|
"number of tile sizes specified must match the specified number of "
|
|
"original dimensions to be tiled");
|
|
SmallVector<int64_t> staticTileSizes;
|
|
SmallVector<Value> dynamicTileSizes;
|
|
dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
|
|
build(builder, state, dest.getType(), source, dest,
|
|
outerDimsPerm.empty() ? nullptr
|
|
: builder.getDenseI64ArrayAttr(outerDimsPerm),
|
|
builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
|
|
builder.getDenseI64ArrayAttr(staticTileSizes));
|
|
}
|
|
|
|
Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
|
|
Value source,
|
|
ArrayRef<OpFoldResult> innerTileSizes,
|
|
ArrayRef<int64_t> innerDimsPos,
|
|
ArrayRef<int64_t> outerDimsPerm) {
|
|
AffineExpr sym0, sym1;
|
|
bindSymbols(b.getContext(), sym0, sym1);
|
|
auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
|
|
return affine::makeComposedFoldedAffineApply(b, loc, sym0 * sym1, {v1, v2});
|
|
};
|
|
|
|
SmallVector<OpFoldResult> mixedSizes;
|
|
auto srcType = llvm::cast<RankedTensorType>(source.getType());
|
|
for (auto i :
|
|
llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
|
|
if (srcType.isDynamicDim(i))
|
|
mixedSizes.push_back(
|
|
tensor::DimOp::create(b, loc, source, i).getResult());
|
|
else
|
|
mixedSizes.push_back(b.getIndexAttr(srcType.getDimSize(i)));
|
|
}
|
|
if (!outerDimsPerm.empty()) {
|
|
applyPermutationToVector<OpFoldResult>(
|
|
mixedSizes, invertPermutationVector(outerDimsPerm));
|
|
}
|
|
|
|
for (auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
|
|
mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
|
|
|
|
auto elemType = srcType.getElementType();
|
|
return tensor::EmptyOp::create(b, loc, mixedSizes, elemType);
|
|
}
|
|
|
|
UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
|
|
Value transposedSource,
|
|
ArrayRef<int64_t> innerPermutation,
|
|
ArrayRef<int64_t> outerPermutation) {
|
|
PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
|
|
*this, innerPermutation, outerPermutation);
|
|
return UnPackOp::create(b, loc, transposedSource, getDest(),
|
|
metadata.innerDimsPos, metadata.innerTiles,
|
|
metadata.outerDimsPerm);
|
|
}
|
|
|
|
/// Returns true if the `srcShape` or `destShape` is different from the one in
|
|
/// `op` and populates each with the inferred static shape.
|
|
static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
|
|
SmallVectorImpl<int64_t> &destShape) {
|
|
bool changeNeeded = false;
|
|
srcShape.assign(op.getSourceType().getShape().begin(),
|
|
op.getSourceType().getShape().end());
|
|
destShape.assign(op.getDestType().getShape().begin(),
|
|
op.getDestType().getShape().end());
|
|
llvm::SmallSetVector<int64_t, 4> innerDims;
|
|
innerDims.insert_range(op.getInnerDimsPos());
|
|
SmallVector<int64_t> inverseOuterDimsPerm;
|
|
if (!op.getOuterDimsPerm().empty())
|
|
inverseOuterDimsPerm = invertPermutationVector(op.getOuterDimsPerm());
|
|
int destRank = op.getDestRank();
|
|
for (auto i : llvm::seq<int64_t>(0, destRank)) {
|
|
if (innerDims.contains(i))
|
|
continue;
|
|
int64_t srcPos = i;
|
|
int64_t destPos = i;
|
|
if (!inverseOuterDimsPerm.empty())
|
|
srcPos = inverseOuterDimsPerm[destPos];
|
|
if (ShapedType::isDynamic(srcShape[srcPos]) ==
|
|
ShapedType::isDynamic(destShape[destPos])) {
|
|
continue;
|
|
}
|
|
int64_t size = srcShape[srcPos];
|
|
if (ShapedType::isDynamic(size))
|
|
size = destShape[destPos];
|
|
srcShape[srcPos] = size;
|
|
destShape[destPos] = size;
|
|
changeNeeded = true;
|
|
}
|
|
return changeNeeded;
|
|
}
|
|
|
|
LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
|
|
PatternRewriter &rewriter) {
|
|
/// unpack(pack(x)) -> x
|
|
if (PackOp packOp = unPackOp.getSource().getDefiningOp<PackOp>()) {
|
|
if (packOp.getSourceType() != unPackOp.getDestType())
|
|
return failure();
|
|
if (packOp.getPaddingValue() ||
|
|
!hasSameInnerOuterAttribute(packOp, unPackOp) ||
|
|
!haveSameTiles(packOp, unPackOp))
|
|
return failure();
|
|
rewriter.replaceOp(unPackOp, packOp.getSource());
|
|
return success();
|
|
}
|
|
/// unpack(destinationStyleOp(x)) -> unpack(x)
|
|
if (auto dstStyleOp =
|
|
unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
|
|
auto destValue = cast<OpResult>(unPackOp.getDest());
|
|
Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
|
|
rewriter.modifyOpInPlace(unPackOp,
|
|
[&]() { unPackOp.setDpsInitOperand(0, newDest); });
|
|
return success();
|
|
}
|
|
/// extract_slice(unpack(x into y)) -> unpack(x into extract_slice(y))
|
|
if (unPackOp->hasOneUse()) {
|
|
auto extractSliceUser =
|
|
dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
|
|
if (extractSliceUser && unPackOp.canFoldSliceOp(extractSliceUser)) {
|
|
OpBuilder::InsertionGuard g(rewriter);
|
|
rewriter.setInsertionPoint(unPackOp);
|
|
auto newDest = tensor::ExtractSliceOp::create(
|
|
rewriter, unPackOp->getLoc(), unPackOp.getDest(),
|
|
extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
|
|
extractSliceUser.getMixedStrides());
|
|
rewriter.modifyOpInPlace(unPackOp, [&]() {
|
|
unPackOp.setDpsInitOperand(0, newDest);
|
|
unPackOp.getResult().setType(newDest.getType());
|
|
});
|
|
rewriter.replaceOp(extractSliceUser, unPackOp);
|
|
return success();
|
|
}
|
|
}
|
|
|
|
// Insert tensor.cast ops if static shape inference is available..
|
|
SmallVector<int64_t> srcShape, destShape;
|
|
if (inferStaticShape(unPackOp, srcShape, destShape)) {
|
|
Location loc = unPackOp.getLoc();
|
|
Value source = unPackOp.getSource();
|
|
if (srcShape != unPackOp.getSourceType().getShape()) {
|
|
auto newSrcType = unPackOp.getSourceType().clone(srcShape);
|
|
source = tensor::CastOp::create(rewriter, loc, newSrcType,
|
|
unPackOp.getSource());
|
|
}
|
|
Value dest = unPackOp.getDest();
|
|
if (destShape != unPackOp.getDestType().getShape()) {
|
|
auto newDestType = unPackOp.getDestType().clone(destShape);
|
|
dest = tensor::CastOp::create(rewriter, loc, newDestType,
|
|
unPackOp.getDest());
|
|
}
|
|
Value newOp = UnPackOp::create(
|
|
rewriter, loc, source, dest, unPackOp.getInnerDimsPos(),
|
|
unPackOp.getMixedTiles(), unPackOp.getOuterDimsPerm());
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
|
unPackOp, unPackOp.getResult().getType(), newOp);
|
|
return success();
|
|
}
|
|
|
|
return failure();
|
|
}
|
|
|
|
bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
|
|
// Rank-reduced folding is not supported.
|
|
if (sliceOp.getResultType().getRank() != this->getDestType().getRank())
|
|
return false;
|
|
if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
|
|
!areAllConstantIntValue(sliceOp.getMixedStrides(), 1))
|
|
return false;
|
|
RankedTensorType unpackedTypeAfterFold = sliceOp.getResultType();
|
|
SmallVector<int64_t> outerShapeWithoutTranspose =
|
|
getPackedOuterShapeWithoutTransposition(*this);
|
|
for (auto [pos, tileSize] :
|
|
llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
|
|
if (unpackedTypeAfterFold.isDynamicDim(pos))
|
|
return false;
|
|
if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
|
|
return false;
|
|
if (ShapedType::isDynamic(tileSize))
|
|
return false;
|
|
int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
|
|
unpackedTypeAfterFold.getDimSize(pos);
|
|
if (paddingSize >= tileSize)
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool UnPackOp::isLikeUnPad() {
|
|
RankedTensorType packedTensorType = getSourceType();
|
|
return isLikePadUnPad(*this, packedTensorType);
|
|
}
|
|
|
|
OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
|
|
if (OpFoldResult reshapedSource = reshapeConstantSource(
|
|
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
|
|
getResult().getType()))
|
|
return reshapedSource;
|
|
return {};
|
|
}
|
|
|
|
/// Folds a tensor.cast op into a consuming UnPackOp op if the
|
|
/// `tensor.cast` has source that is more static than the consuming op.
|
|
///
|
|
/// Example:
|
|
/// ```mlir
|
|
/// %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
|
|
/// %2 = tensor.unpack %1 ... : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
|
|
/// ```
|
|
///
|
|
/// folds into:
|
|
///
|
|
/// ```mlir
|
|
/// %2 = tensor.unpack %0 ... tensor<1x1x8x1xi32> -> tensor<7x?xi32>
|
|
/// ```
|
|
struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
|
|
using OpRewritePattern<UnPackOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(UnPackOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
if (!tensor::hasFoldableTensorCastOperand(op))
|
|
return failure();
|
|
|
|
SmallVector<Type> newResultTypes(op->getResultTypes());
|
|
SmallVector<Value> newOperands =
|
|
tensor::getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
|
|
Value sourceTensor = newOperands[0];
|
|
|
|
// Get the updated mixed-tile-sizes attribute.
|
|
SmallVector<OpFoldResult> newMixedTileSizes = getNewMixedTileSizes(
|
|
rewriter, sourceTensor.getType(), op.getMixedTiles());
|
|
|
|
// Clone op.
|
|
// TODO: Strictly speaking, discardable attributes should be _discarded_ at
|
|
// this point. However, in practice, we use them for things that we'd like
|
|
// to preserve. Implement a better abstraction.
|
|
UnPackOp newOp = UnPackOp::create(rewriter, op.getLoc(), sourceTensor,
|
|
newOperands[1], op.getInnerDimsPos(),
|
|
newMixedTileSizes, op.getOuterDimsPerm());
|
|
newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
|
|
|
|
// Replace op.
|
|
Value oldResult = op.getResult();
|
|
Value newResult = newOp.getResult();
|
|
Value replacement =
|
|
(newResult.getType() != oldResult.getType())
|
|
? tensor::CastOp::create(rewriter, op->getLoc(),
|
|
oldResult.getType(), newResult)
|
|
: newResult;
|
|
|
|
rewriter.replaceOp(op, {replacement});
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// BatchReduceMatmulOp
|
|
//===----------------------------------------------------------------------===//
|
|
SmallVector<utils::IteratorType> BatchReduceMatmulOp::getIteratorTypesArray() {
|
|
return SmallVector<utils::IteratorType>{
|
|
utils::IteratorType::reduction, utils::IteratorType::parallel,
|
|
utils::IteratorType::parallel, utils::IteratorType::reduction};
|
|
}
|
|
|
|
SmallVector<AffineMap>
|
|
BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
|
|
AffineExpr d0, d1, d2, d3;
|
|
SmallVector<AffineMap> indexingMaps;
|
|
bindDims(context, d0, d1, d2, d3);
|
|
indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
|
|
indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
|
|
indexingMaps.push_back(AffineMap::get(4, 0, {d1, d2}, context));
|
|
return indexingMaps;
|
|
}
|
|
|
|
bool BatchReduceMatmulOp::isDefaultIndexingMaps(Attribute attr) {
|
|
ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
|
|
if (!maps)
|
|
return false;
|
|
if (maps.size() != 3)
|
|
return false;
|
|
auto positions = getAffineResultPositions(maps);
|
|
if (failed(positions))
|
|
return false;
|
|
return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
|
|
(*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
|
|
(*positions)[2] == SmallVector<int64_t>{1, 2};
|
|
}
|
|
unsigned BatchReduceMatmulOp::getNumRegionArgs() { return 3; }
|
|
|
|
std::string BatchReduceMatmulOp::getLibraryCallName() {
|
|
return generateLibraryCallName(getOperation());
|
|
}
|
|
|
|
/// Check if the op has broadcast and/or transpose semantic. Returns true if
|
|
/// the user defined indexing maps are not equal to default map.
|
|
bool BatchReduceMatmulOp::hasUserDefinedMaps() {
|
|
SmallVector<AffineMap, 3> defaultMaps =
|
|
getDefaultIndexingMaps(this->getContext());
|
|
SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
|
|
return defaultMaps != explicitMaps;
|
|
}
|
|
|
|
/// Returns true if the given bcastMap map is a valid broadcast map. A valid
|
|
/// broadcast map must include K dimension.
|
|
/// TODO: Strict inclusion of K dimension in the broadcast map is not
|
|
/// necessary for both input matrices simultaneously. We can relax this
|
|
/// condition to have K dimension for one input matrix map and infer the K
|
|
/// dimension for other input matrix map from the one already having K
|
|
/// dimension.
|
|
bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
|
|
bool isLHS) {
|
|
assert(bcastMap.getNumResults() < 3 &&
|
|
"Expected less than 3 result dim expr.");
|
|
bool isValid = false;
|
|
enum Indices { batchPos, mPos, nPos, kPos };
|
|
if (bcastMap.getNumResults() == 1) {
|
|
AffineExpr expr = bcastMap.getResult(0);
|
|
isValid = expr.isFunctionOfDim(kPos);
|
|
} else if (bcastMap.getNumResults() == 2) {
|
|
AffineExpr expr0 = bcastMap.getResult(0);
|
|
AffineExpr expr1 = bcastMap.getResult(1);
|
|
isValid =
|
|
isLHS ? ((expr0.isFunctionOfDim(batchPos) ||
|
|
expr0.isFunctionOfDim(mPos)) &&
|
|
expr1.isFunctionOfDim(kPos))
|
|
: ((expr0.isFunctionOfDim(batchPos) &&
|
|
expr1.isFunctionOfDim(kPos)) ||
|
|
(expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
|
|
}
|
|
return isValid;
|
|
}
|
|
|
|
void BatchReduceMatmulOp::regionBuilder(
|
|
ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs,
|
|
function_ref<InFlightDiagnostic()> emitError) {
|
|
if (emitError && block.getNumArguments() != 3) {
|
|
emitError() << "BatchReduceMatmulOp regionBuilder expects 3 args, got "
|
|
<< block.getNumArguments();
|
|
return;
|
|
}
|
|
assert(block.getNumArguments() == 3 &&
|
|
"BatchReduceMatmulOp regionBuilder expects 3 args");
|
|
RegionBuilderHelper helper(b, block);
|
|
SmallVector<Value> yields;
|
|
|
|
auto toType = block.getArgument(2).getType();
|
|
Value castValA =
|
|
helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
|
|
Value castValB =
|
|
helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
|
|
Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
|
|
Value addVal =
|
|
helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
|
|
yields.push_back(addVal);
|
|
helper.yieldOutputs(yields);
|
|
}
|
|
|
|
ParseResult BatchReduceMatmulOp::parse(OpAsmParser &parser,
|
|
OperationState &result) {
|
|
SmallVector<Attribute, 3> indexingMapsAttr;
|
|
Attribute mapAttr;
|
|
if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
|
|
if (parser.parseEqual())
|
|
return failure();
|
|
if (parser.parseLSquare())
|
|
return failure();
|
|
|
|
do {
|
|
if (parser.parseAttribute(mapAttr))
|
|
return failure();
|
|
if (!isa<AffineMapAttr>(mapAttr)) {
|
|
return parser.emitError(parser.getCurrentLocation(),
|
|
"expected affine map attribute");
|
|
}
|
|
indexingMapsAttr.push_back(mapAttr);
|
|
|
|
if (parser.parseOptionalComma())
|
|
break;
|
|
} while (true);
|
|
|
|
if (parser.parseRSquare())
|
|
return failure();
|
|
}
|
|
// Initialize indexingMaps, if not supplied explicitly.
|
|
if (indexingMapsAttr.empty()) {
|
|
indexingMapsAttr = llvm::map_to_vector(
|
|
BatchReduceMatmulOp::getDefaultIndexingMaps(parser.getContext()),
|
|
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
|
|
}
|
|
result.addAttribute("indexing_maps",
|
|
parser.getBuilder().getArrayAttr(indexingMapsAttr));
|
|
return ::parseNamedStructuredOp(parser, result,
|
|
BatchReduceMatmulOp::getNumRegionArgs(),
|
|
BatchReduceMatmulOp::getRegionBuilder());
|
|
}
|
|
|
|
void BatchReduceMatmulOp::print(OpAsmPrinter &p) {
|
|
SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
|
|
BatchReduceMatmulOp::getDefaultIndexingMaps(getContext()),
|
|
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
|
|
|
|
if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
|
|
p << " indexing_maps = [";
|
|
llvm::interleaveComma(getIndexingMaps(), p,
|
|
[&](Attribute attr) { p.printAttribute(attr); });
|
|
p << "]";
|
|
}
|
|
|
|
SmallVector<StringRef, 3> elidedAttrs = {
|
|
"operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
|
|
::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
|
|
elidedAttrs);
|
|
}
|
|
|
|
/// Verify the user defined indexing maps.
|
|
LogicalResult BatchReduceMatmulOp::verify() {
|
|
// Verification of pure batch_reduce_matmul is handled by
|
|
// verifyStructuredOpInterface().
|
|
if (!hasUserDefinedMaps())
|
|
return success();
|
|
|
|
for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
|
|
if (failed(verifyExtendedBatchVariantMatmulSemantic(*this, opIndex)))
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
LogicalResult BatchReduceMatmulOp::fold(FoldAdaptor,
|
|
SmallVectorImpl<OpFoldResult> &) {
|
|
return memref::foldMemRefCast(*this);
|
|
}
|
|
void BatchReduceMatmulOp::getEffects(
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
&effects) {
|
|
if (hasPureTensorSemantics())
|
|
return;
|
|
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
Speculation::Speculatability BatchReduceMatmulOp::getSpeculatability() {
|
|
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
} // namespace linalg
|
|
} // namespace mlir
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// LinalgDialect
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void LinalgDialect::getCanonicalizationPatterns(
|
|
RewritePatternSet &results) const {
|
|
results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp, FoldTensorCastPackOp,
|
|
FoldTensorCastUnPackOp, InferStaticShapeOfOperands>(getContext());
|
|
}
|
|
|
|
Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
|
|
Attribute value, Type type,
|
|
Location loc) {
|
|
return arith::ConstantOp::materialize(builder, value, type, loc);
|
|
}
|