
The revision adds isOneInteger helper, and simplifies the existing code with the two methods. It removes some lambda, which makes code cleaner. For downstream users, you can update the code with the below script. ```bash sed -i "s/isZeroIndex/isZeroInteger/g" **/*.h sed -i "s/isZeroIndex/isZeroInteger/g" **/*.cpp ``` --------- Signed-off-by: hanhanW <hanhan0912@gmail.com>
5617 lines
221 KiB
C++
5617 lines
221 KiB
C++
//===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements the Linalg operations.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
|
|
#include "mlir/AsmParser/AsmParser.h"
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Arith/Utils/Utils.h"
|
|
#include "mlir/Dialect/Complex/IR/Complex.h"
|
|
#include "mlir/Dialect/Math/IR/Math.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Dialect/Tensor/Utils/Utils.h"
|
|
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
|
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
|
|
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
|
#include "mlir/IR/AffineExprVisitor.h"
|
|
#include "mlir/IR/AffineMap.h"
|
|
#include "mlir/IR/Attributes.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/IR/OpImplementation.h"
|
|
#include "mlir/IR/OperationSupport.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
|
|
|
#include "llvm/ADT/DenseMap.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/ADT/SetOperations.h"
|
|
#include "llvm/ADT/SmallSet.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
#include "llvm/ADT/StringSet.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
#include "llvm/Support/FormatVariadic.h"
|
|
#include "llvm/Support/InterleavedRange.h"
|
|
#include "llvm/Support/LogicalResult.h"
|
|
#include "llvm/Support/MathExtras.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
#include <cassert>
|
|
#include <optional>
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::linalg;
|
|
|
|
/// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`.
|
|
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v,
|
|
int64_t dim) {
|
|
auto type = cast<ShapedType>(v.getType());
|
|
if (!type.isDynamicDim(dim))
|
|
return builder.getIndexAttr(type.getDimSize(dim));
|
|
|
|
return getAsOpFoldResult(
|
|
TypeSwitch<Type, Value>(v.getType())
|
|
.Case<RankedTensorType>([&](RankedTensorType t) -> Value {
|
|
return builder.create<tensor::DimOp>(loc, v, dim);
|
|
})
|
|
.Case<MemRefType>([&](MemRefType t) -> Value {
|
|
return builder.create<memref::DimOp>(loc, v, dim);
|
|
}));
|
|
}
|
|
|
|
/// Returns a memref.subview or a tensor.extract_slice based on the type of the
|
|
/// `source`.
|
|
static Operation *getSlice(OpBuilder &b, Location loc, Value source,
|
|
ArrayRef<OpFoldResult> offsets,
|
|
ArrayRef<OpFoldResult> sizes,
|
|
ArrayRef<OpFoldResult> strides) {
|
|
return TypeSwitch<Type, Operation *>(source.getType())
|
|
.Case<RankedTensorType>([&](RankedTensorType t) -> Operation * {
|
|
return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
|
|
strides);
|
|
})
|
|
.Case<MemRefType>([&](MemRefType type) -> Operation * {
|
|
return b.create<memref::SubViewOp>(loc, source, offsets, sizes,
|
|
strides);
|
|
})
|
|
.Default([&](Type t) -> Operation * { return nullptr; });
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Helper functions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
Value linalg::createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
|
|
int64_t dim) {
|
|
if (llvm::isa<UnrankedMemRefType, MemRefType>(source.getType()))
|
|
return b.createOrFold<memref::DimOp>(loc, source, dim);
|
|
if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.getType()))
|
|
return b.createOrFold<tensor::DimOp>(loc, source, dim);
|
|
llvm_unreachable("Expected MemRefType or TensorType");
|
|
}
|
|
|
|
OpFoldResult linalg::createFoldedDimOp(OpBuilder &b, Location loc, Value source,
|
|
int64_t dim) {
|
|
auto shapedType = llvm::cast<ShapedType>(source.getType());
|
|
if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
|
|
return createOrFoldDimOp(b, loc, source, dim);
|
|
return b.getIndexAttr(shapedType.getDimSize(dim));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Support for named Linalg ops defined in ods-gen.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
using RegionBuilderFn = llvm::function_ref<void(ImplicitLocOpBuilder &, Block &,
|
|
ArrayRef<NamedAttribute>)>;
|
|
|
|
/// Fills the region of a structured operation using the provided
|
|
/// `regionBuilder`. The method is used by both named structured ops created by
|
|
/// ods-gen and by manually defined C++ ops. It is called by both builders and
|
|
/// parsers and creates a block with arguments corresponding to the elemental
|
|
/// types of `inputTypes` and `outputTypes`.
|
|
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion,
|
|
TypeRange inputTypes, TypeRange outputTypes,
|
|
ArrayRef<NamedAttribute> attrs,
|
|
RegionBuilderFn regionBuilder) {
|
|
SmallVector<Type, 8> argTypes;
|
|
SmallVector<Location, 8> argLocs;
|
|
for (auto containers : {inputTypes, outputTypes}) {
|
|
for (auto t : containers) {
|
|
argTypes.push_back(
|
|
isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
|
|
|
|
// TODO: Pass in a proper location here.
|
|
argLocs.push_back(opBuilder.getUnknownLoc());
|
|
}
|
|
}
|
|
|
|
// RAII.
|
|
OpBuilder::InsertionGuard guard(opBuilder);
|
|
Block *body =
|
|
opBuilder.createBlock(®ion, /*insertPt=*/{}, argTypes, argLocs);
|
|
|
|
opBuilder.setInsertionPointToStart(body);
|
|
ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
|
|
regionBuilder(b, *body, attrs);
|
|
|
|
// indexing_maps is an auto-generated method.
|
|
|
|
// iterator_types is an auto-generated method.
|
|
}
|
|
|
|
/// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
|
|
/// The result types are derived automatically if `resultTensorTypes` is none.
|
|
/// The body of the operation is filled using `regionBuilder`. All ods-gen
|
|
/// created structured operations use the method to implement their builders.
|
|
static void buildStructuredOp(OpBuilder &b, OperationState &state,
|
|
std::optional<TypeRange> resultTensorTypes,
|
|
ValueRange inputs, ValueRange outputs,
|
|
ArrayRef<NamedAttribute> attributes,
|
|
RegionBuilderFn regionBuilder) {
|
|
// Derive the result types if needed.
|
|
SmallVector<Type> derivedResultTypes =
|
|
resultTensorTypes.value_or(TypeRange());
|
|
if (!resultTensorTypes)
|
|
copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
|
|
llvm::IsaPred<RankedTensorType>);
|
|
|
|
state.addOperands(inputs);
|
|
state.addOperands(outputs);
|
|
state.addTypes(derivedResultTypes);
|
|
|
|
state.addAttributes(attributes);
|
|
state.addAttribute(
|
|
"operandSegmentSizes",
|
|
b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
|
|
static_cast<int32_t>(outputs.size())}));
|
|
|
|
// Create and fill the region of the structured operation.
|
|
Region ®ion = *state.addRegion();
|
|
fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
|
|
state.attributes.getAttrs(), regionBuilder);
|
|
}
|
|
|
|
static void buildMatmulOp(OpBuilder &b, OperationState &state,
|
|
std::optional<TypeRange> resultTensorTypes,
|
|
ValueRange inputs, ValueRange outputs,
|
|
ArrayRef<NamedAttribute> attributes,
|
|
RegionBuilderFn regionBuilder,
|
|
ArrayRef<AffineMap> indexingMaps) {
|
|
// Initialize indexingMaps attribute, for MatmulOp.
|
|
SmallVector<Attribute, 3> indexingMapsAttrVal;
|
|
indexingMapsAttrVal = llvm::map_to_vector(
|
|
MatmulOp::getDefaultIndexingMaps(b.getContext()),
|
|
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
|
|
state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
|
|
return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
|
|
attributes, regionBuilder);
|
|
}
|
|
|
|
static void buildBatchMatmulOp(OpBuilder &b, OperationState &state,
|
|
std::optional<TypeRange> resultTensorTypes,
|
|
ValueRange inputs, ValueRange outputs,
|
|
ArrayRef<NamedAttribute> attributes,
|
|
RegionBuilderFn regionBuilder,
|
|
ArrayRef<AffineMap> indexingMaps) {
|
|
// Initialize indexingMaps attribute, for BatchMatmulOp.
|
|
SmallVector<Attribute, 4> indexingMapsAttrVal;
|
|
indexingMapsAttrVal =
|
|
llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
|
|
return AffineMapAttr::get(map);
|
|
});
|
|
state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
|
|
return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
|
|
attributes, regionBuilder);
|
|
}
|
|
|
|
static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state,
|
|
std::optional<TypeRange> resultTensorTypes,
|
|
ValueRange inputs, ValueRange outputs,
|
|
ArrayRef<NamedAttribute> attributes,
|
|
RegionBuilderFn regionBuilder,
|
|
ArrayRef<AffineMap> indexingMaps) {
|
|
// Initialize indexingMaps attribute, for BatchReduceMatmulOp.
|
|
SmallVector<Attribute, 4> indexingMapsAttrVal;
|
|
indexingMapsAttrVal =
|
|
llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
|
|
return AffineMapAttr::get(map);
|
|
});
|
|
state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
|
|
return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
|
|
attributes, regionBuilder);
|
|
}
|
|
|
|
/// Common parsing used for both named structured ops created by ods-gen and by
|
|
/// manually defined C++ ops. Does not handle regions.
|
|
static ParseResult
|
|
parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
|
|
SmallVectorImpl<Type> &inputTypes,
|
|
SmallVectorImpl<Type> &outputTypes,
|
|
bool addOperandSegmentSizes = true) {
|
|
SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands,
|
|
outputsOperands;
|
|
|
|
if (succeeded(parser.parseOptionalLess())) {
|
|
if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater())
|
|
return failure();
|
|
}
|
|
attrsLoc = parser.getCurrentLocation();
|
|
if (parser.parseOptionalAttrDict(result.attributes))
|
|
return failure();
|
|
|
|
if (succeeded(parser.parseOptionalKeyword("ins"))) {
|
|
if (parser.parseLParen())
|
|
return failure();
|
|
|
|
inputsOperandsLoc = parser.getCurrentLocation();
|
|
if (parser.parseOperandList(inputsOperands) ||
|
|
parser.parseColonTypeList(inputTypes) || parser.parseRParen())
|
|
return failure();
|
|
}
|
|
|
|
if (succeeded(parser.parseOptionalKeyword("outs"))) {
|
|
outputsOperandsLoc = parser.getCurrentLocation();
|
|
if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
|
|
parser.parseColonTypeList(outputTypes) || parser.parseRParen())
|
|
return failure();
|
|
}
|
|
|
|
if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
|
|
result.operands) ||
|
|
parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
|
|
result.operands))
|
|
return failure();
|
|
|
|
if (addOperandSegmentSizes) {
|
|
// This is a bit complex because we're trying to be backward compatible with
|
|
// operation syntax that mix the inherent attributes and the discardable
|
|
// ones in the same dictionary. If the properties are used, we append the
|
|
// operandSegmentSizes there directly. Otherwise we append it to the
|
|
// discardable attributes dictionary where it is handled by the generic
|
|
// Operation::create(...) method.
|
|
if (result.propertiesAttr) {
|
|
NamedAttrList attrs = llvm::cast<DictionaryAttr>(result.propertiesAttr);
|
|
attrs.append("operandSegmentSizes",
|
|
parser.getBuilder().getDenseI32ArrayAttr(
|
|
{static_cast<int32_t>(inputsOperands.size()),
|
|
static_cast<int32_t>(outputsOperands.size())}));
|
|
result.propertiesAttr = attrs.getDictionary(parser.getContext());
|
|
} else {
|
|
result.addAttribute("operandSegmentSizes",
|
|
parser.getBuilder().getDenseI32ArrayAttr(
|
|
{static_cast<int32_t>(inputsOperands.size()),
|
|
static_cast<int32_t>(outputsOperands.size())}));
|
|
}
|
|
}
|
|
if (!result.propertiesAttr) {
|
|
std::optional<RegisteredOperationName> info =
|
|
result.name.getRegisteredInfo();
|
|
if (info) {
|
|
if (failed(info->verifyInherentAttrs(result.attributes, [&]() {
|
|
return parser.emitError(attrsLoc)
|
|
<< "'" << result.name.getStringRef() << "' op ";
|
|
})))
|
|
return failure();
|
|
}
|
|
}
|
|
return success();
|
|
}
|
|
|
|
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs,
|
|
ValueRange outputs) {
|
|
if (!inputs.empty())
|
|
p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
|
|
if (!outputs.empty())
|
|
p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Specific parsing and printing for named structured ops created by ods-gen.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static ParseResult parseNamedStructuredOpRegion(
|
|
OpAsmParser &parser, Region ®ion, unsigned numRegionArgs,
|
|
TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
|
|
RegionBuilderFn regionBuilder) {
|
|
if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
|
|
return parser.emitError(
|
|
parser.getCurrentLocation(),
|
|
llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
|
|
"region expects {0} args, got {1}",
|
|
numRegionArgs, inputTypes.size() + outputTypes.size()));
|
|
}
|
|
|
|
OpBuilder opBuilder(parser.getContext());
|
|
fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, attrs,
|
|
regionBuilder);
|
|
return success();
|
|
}
|
|
|
|
static ParseResult
|
|
parseNamedStructuredOpResults(OpAsmParser &parser,
|
|
SmallVectorImpl<Type> &resultTypes) {
|
|
if (parser.parseOptionalArrowTypeList(resultTypes))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
|
|
OperationState &result,
|
|
unsigned numRegionArgs,
|
|
RegionBuilderFn regionBuilder) {
|
|
// TODO: Enable when ods-gen supports captures.
|
|
SmallVector<Type, 1> inputTypes, outputTypes;
|
|
if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
|
|
return failure();
|
|
|
|
// Parse optional attributes.
|
|
if (parser.parseOptionalAttrDict(result.attributes))
|
|
return failure();
|
|
|
|
// TODO: consider merging results parsing into region parsing.
|
|
// Need to wait for declarative assembly resolution to decide.
|
|
SmallVector<Type, 1> outputTensorsTypes;
|
|
if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
|
|
return failure();
|
|
result.addTypes(outputTensorsTypes);
|
|
|
|
std::unique_ptr<Region> region = std::make_unique<Region>();
|
|
if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
|
|
outputTypes, result.attributes.getAttrs(),
|
|
regionBuilder))
|
|
return failure();
|
|
result.addRegion(std::move(region));
|
|
|
|
return success();
|
|
}
|
|
|
|
static void printNamedStructuredOpResults(OpAsmPrinter &p,
|
|
TypeRange resultTypes) {
|
|
if (resultTypes.empty())
|
|
return;
|
|
p.printOptionalArrowTypeList(resultTypes);
|
|
}
|
|
|
|
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op,
|
|
ValueRange inputs, ValueRange outputs,
|
|
ArrayRef<StringRef> elidedAttrs = {}) {
|
|
p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
|
|
|
|
// Printing is shared with generic ops, except for the region and
|
|
// attributes.
|
|
printCommonStructuredOpParts(p, inputs, outputs);
|
|
|
|
// Results printing.
|
|
printNamedStructuredOpResults(p, op->getResultTypes());
|
|
|
|
// Region is elided.
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Region builder helper.
|
|
// TODO: Move this to a utility library.
|
|
// The public methods on this class are referenced directly from generated code.
|
|
// Helper build the unary, binary, and type conversion functions defined by the
|
|
// DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
|
|
// class.
|
|
//
|
|
// Implementations of the math functions must be polymorphic over numeric types,
|
|
// internally performing necessary casts. If the function application makes no
|
|
// sense, then the only recourse is to assert and return nullptr. This can be
|
|
// extended later if it becomes possible to fail construction of the region. The
|
|
// invariant should be enforced at a higher level.
|
|
//
|
|
// TODO: These helpers are currently type polymorphic over the class of integer
|
|
// and floating point types, but they will not internally cast within bit
|
|
// widths of a class (mixed precision such as i8->i32) or across classes
|
|
// (i.e. mixed float and integer). Many such combinations are ambiguous or need
|
|
// to be handled with care and work is being considered to extend the op
|
|
// language to make such cases explicit. In the mean-time, violating this will
|
|
// fail verification, which is deemed acceptable.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
|
|
class RegionBuilderHelper {
|
|
public:
|
|
RegionBuilderHelper(OpBuilder &builder, Block &block)
|
|
: builder(builder), block(block) {}
|
|
|
|
// Build the unary functions defined by OpDSL.
|
|
Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
|
|
if (!isFloatingPoint(arg))
|
|
llvm_unreachable("unsupported non numeric type");
|
|
OpBuilder::InsertionGuard g(builder);
|
|
builder.setInsertionPointToEnd(&block);
|
|
switch (unaryFn) {
|
|
case UnaryFn::exp:
|
|
return builder.create<math::ExpOp>(arg.getLoc(), arg);
|
|
case UnaryFn::log:
|
|
return builder.create<math::LogOp>(arg.getLoc(), arg);
|
|
case UnaryFn::abs:
|
|
return builder.create<math::AbsFOp>(arg.getLoc(), arg);
|
|
case UnaryFn::ceil:
|
|
return builder.create<math::CeilOp>(arg.getLoc(), arg);
|
|
case UnaryFn::floor:
|
|
return builder.create<math::FloorOp>(arg.getLoc(), arg);
|
|
case UnaryFn::negf:
|
|
return builder.create<arith::NegFOp>(arg.getLoc(), arg);
|
|
case UnaryFn::reciprocal: {
|
|
Attribute oneAttr = builder.getOneAttr(arg.getType());
|
|
auto one = builder.create<arith::ConstantOp>(arg.getLoc(),
|
|
::cast<TypedAttr>(oneAttr));
|
|
return builder.create<arith::DivFOp>(arg.getLoc(), one, arg);
|
|
}
|
|
case UnaryFn::round:
|
|
return builder.create<math::RoundOp>(arg.getLoc(), arg);
|
|
case UnaryFn::sqrt:
|
|
return builder.create<math::SqrtOp>(arg.getLoc(), arg);
|
|
case UnaryFn::rsqrt:
|
|
return builder.create<math::RsqrtOp>(arg.getLoc(), arg);
|
|
case UnaryFn::square:
|
|
return builder.create<arith::MulFOp>(arg.getLoc(), arg, arg);
|
|
case UnaryFn::tanh:
|
|
return builder.create<math::TanhOp>(arg.getLoc(), arg);
|
|
case UnaryFn::erf:
|
|
return builder.create<math::ErfOp>(arg.getLoc(), arg);
|
|
}
|
|
llvm_unreachable("unsupported unary function");
|
|
}
|
|
|
|
// Build the binary functions defined by OpDSL.
|
|
Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) {
|
|
bool allComplex = isComplex(arg0) && isComplex(arg1);
|
|
bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
|
|
bool allInteger = isInteger(arg0) && isInteger(arg1);
|
|
bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
|
|
arg1.getType().getIntOrFloatBitWidth() == 1;
|
|
if (!allComplex && !allFloatingPoint && !allInteger)
|
|
llvm_unreachable("unsupported non numeric type");
|
|
OpBuilder::InsertionGuard g(builder);
|
|
builder.setInsertionPointToEnd(&block);
|
|
switch (binaryFn) {
|
|
case BinaryFn::add:
|
|
if (allComplex)
|
|
return builder.create<complex::AddOp>(arg0.getLoc(), arg0, arg1);
|
|
if (allFloatingPoint)
|
|
return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
|
|
if (allBool)
|
|
return builder.create<arith::OrIOp>(arg0.getLoc(), arg0, arg1);
|
|
return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
|
|
case BinaryFn::sub:
|
|
if (allComplex)
|
|
return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
|
|
if (allFloatingPoint)
|
|
return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
|
|
if (allBool)
|
|
llvm_unreachable("unsupported operation: sub with bools");
|
|
return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
|
|
case BinaryFn::mul:
|
|
if (allComplex)
|
|
return builder.create<complex::MulOp>(arg0.getLoc(), arg0, arg1);
|
|
if (allFloatingPoint)
|
|
return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
|
|
if (allBool)
|
|
return builder.create<arith::AndIOp>(arg0.getLoc(), arg0, arg1);
|
|
return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
|
|
case BinaryFn::div:
|
|
if (allComplex)
|
|
return builder.create<complex::DivOp>(arg0.getLoc(), arg0, arg1);
|
|
if (allFloatingPoint)
|
|
return builder.create<arith::DivFOp>(arg0.getLoc(), arg0, arg1);
|
|
if (allBool)
|
|
llvm_unreachable("unsupported operation: div with bools");
|
|
return builder.create<arith::DivSIOp>(arg0.getLoc(), arg0, arg1);
|
|
case BinaryFn::div_unsigned:
|
|
if (!allInteger || allBool)
|
|
llvm_unreachable("unsupported operation: unsigned div not on uint");
|
|
return builder.create<arith::DivUIOp>(arg0.getLoc(), arg0, arg1);
|
|
case BinaryFn::max_signed:
|
|
assert(!allComplex);
|
|
if (allFloatingPoint)
|
|
return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
|
|
return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
|
|
case BinaryFn::min_signed:
|
|
assert(!allComplex);
|
|
if (allFloatingPoint)
|
|
return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
|
|
return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
|
|
case BinaryFn::max_unsigned:
|
|
assert(!allComplex);
|
|
if (allFloatingPoint)
|
|
return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
|
|
return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1);
|
|
case BinaryFn::min_unsigned:
|
|
assert(!allComplex);
|
|
if (allFloatingPoint)
|
|
return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
|
|
return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1);
|
|
case BinaryFn::powf:
|
|
assert(allFloatingPoint);
|
|
return builder.create<math::PowFOp>(arg0.getLoc(), arg0, arg1);
|
|
}
|
|
llvm_unreachable("unsupported binary function");
|
|
}
|
|
|
|
// Build the ternary functions defined by OpDSL.
|
|
Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1,
|
|
Value arg2) {
|
|
bool headBool =
|
|
isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
|
|
bool tailFloatingPoint =
|
|
isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
|
|
bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2);
|
|
OpBuilder::InsertionGuard g(builder);
|
|
builder.setInsertionPointToEnd(&block);
|
|
switch (ternaryFn) {
|
|
case TernaryFn::select:
|
|
if (!headBool && !(tailFloatingPoint || tailInteger))
|
|
llvm_unreachable("unsupported non numeric type");
|
|
return builder.create<arith::SelectOp>(arg0.getLoc(), arg0, arg1, arg2);
|
|
}
|
|
llvm_unreachable("unsupported ternary function");
|
|
}
|
|
|
|
// Build the type functions defined by OpDSL.
|
|
Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
|
|
switch (typeFn) {
|
|
case TypeFn::cast_signed:
|
|
return cast(toType, operand, false);
|
|
case TypeFn::cast_unsigned:
|
|
return cast(toType, operand, true);
|
|
}
|
|
llvm_unreachable("unsupported type conversion function");
|
|
}
|
|
|
|
void yieldOutputs(ValueRange values) {
|
|
OpBuilder::InsertionGuard g(builder);
|
|
builder.setInsertionPointToEnd(&block);
|
|
Location loc = builder.getUnknownLoc();
|
|
builder.create<YieldOp>(loc, values);
|
|
}
|
|
|
|
Value constant(const std::string &value) {
|
|
OpBuilder::InsertionGuard g(builder);
|
|
builder.setInsertionPointToEnd(&block);
|
|
Location loc = builder.getUnknownLoc();
|
|
Attribute valueAttr = parseAttribute(value, builder.getContext());
|
|
return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
|
|
}
|
|
|
|
Value index(int64_t dim) {
|
|
OpBuilder::InsertionGuard g(builder);
|
|
builder.setInsertionPointToEnd(&block);
|
|
return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
|
|
}
|
|
|
|
Type getIntegerType(unsigned width) {
|
|
return IntegerType::get(builder.getContext(), width);
|
|
}
|
|
|
|
Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
|
|
Type getFloat64Type() { return Float64Type::get(builder.getContext()); }
|
|
|
|
private:
|
|
// Generates operations to cast the given operand to a specified type.
|
|
// If the cast cannot be performed, a warning will be issued and the
|
|
// operand returned as-is (which will presumably yield a verification
|
|
// issue downstream).
|
|
Value cast(Type toType, Value operand, bool isUnsignedCast) {
|
|
OpBuilder::InsertionGuard g(builder);
|
|
builder.setInsertionPointToEnd(&block);
|
|
auto loc = operand.getLoc();
|
|
return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
|
|
}
|
|
|
|
bool isComplex(Value value) {
|
|
return llvm::isa<ComplexType>(value.getType());
|
|
}
|
|
bool isFloatingPoint(Value value) {
|
|
return llvm::isa<FloatType>(value.getType());
|
|
}
|
|
bool isInteger(Value value) {
|
|
return llvm::isa<IntegerType>(value.getType());
|
|
}
|
|
|
|
OpBuilder &builder;
|
|
Block █
|
|
};
|
|
|
|
} // namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CopyOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
|
|
struct EraseSelfCopy : OpRewritePattern<CopyOp> {
|
|
using OpRewritePattern<CopyOp>::OpRewritePattern;
|
|
LogicalResult matchAndRewrite(CopyOp copyOp,
|
|
PatternRewriter &rewriter) const override {
|
|
if (copyOp.getInputs() != copyOp.getOutputs())
|
|
return rewriter.notifyMatchFailure(copyOp, "not a self copy");
|
|
if (copyOp.hasPureBufferSemantics())
|
|
rewriter.eraseOp(copyOp);
|
|
else
|
|
rewriter.replaceOp(copyOp, copyOp.getInputs());
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.add<EraseSelfCopy>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// FillOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
|
|
/// Fold linalg.fill -> tensor.expand/collapse_shape chain.
|
|
///
|
|
/// For such op chains, we can create new linalg.fill ops with the result
|
|
/// type of the tensor.expand/collapse_shape op.
|
|
template <typename TensorReshapeOp>
|
|
struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
|
|
using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
|
|
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
|
|
if (!oldFill)
|
|
return failure();
|
|
|
|
Location loc = oldFill.getLoc();
|
|
TensorReshapeOp newInit;
|
|
if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
|
|
|
|
newInit = rewriter.create<TensorReshapeOp>(
|
|
loc, reshapeOp.getResultType(), oldFill.output(),
|
|
reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
|
|
reshapeOp.getStaticOutputShape());
|
|
} else {
|
|
newInit = rewriter.create<TensorReshapeOp>(loc, reshapeOp.getResultType(),
|
|
oldFill.output(),
|
|
reshapeOp.getReassociation());
|
|
}
|
|
rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
|
|
ValueRange{newInit});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
|
|
/// filling value are the same.
|
|
struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tensor::PadOp padOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
|
|
if (!fillOp)
|
|
return failure();
|
|
|
|
// We can only fold if the padding value is the same as the original
|
|
// filling value.
|
|
Value padValue = padOp.getConstantPaddingValue();
|
|
if (!padValue || fillOp.value() != padValue)
|
|
return failure();
|
|
|
|
ReifiedRankedShapedTypeDims reifiedShape;
|
|
if (failed(reifyResultShapes(rewriter, padOp, reifiedShape)))
|
|
return rewriter.notifyMatchFailure(
|
|
padOp, "failed to reify tensor.pad op result shape");
|
|
|
|
auto emptyTensor = rewriter.create<tensor::EmptyOp>(
|
|
padOp.getLoc(), reifiedShape.front(),
|
|
padOp.getResultType().getElementType());
|
|
Value replacement =
|
|
rewriter
|
|
.create<FillOp>(fillOp.getLoc(), ValueRange{padValue},
|
|
ValueRange{emptyTensor})
|
|
.getResult(0);
|
|
if (replacement.getType() != padOp.getResultType()) {
|
|
replacement = rewriter.create<tensor::CastOp>(
|
|
fillOp.getLoc(), padOp.getResultType(), replacement);
|
|
}
|
|
rewriter.replaceOp(padOp, replacement);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
|
|
/// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
|
|
/// filling value are the same.
|
|
struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
|
|
if (!srcPadOp)
|
|
return failure();
|
|
|
|
if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
|
|
return failure();
|
|
|
|
// Walk back the tensor.insert_slice chain and find the first destination
|
|
// value at the start of the chain.
|
|
Value firstDest = insertOp.getDest();
|
|
while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
|
|
if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
|
|
return failure();
|
|
|
|
// Make sure the range of values accessed are disjoint. Without this, we
|
|
// cannot fold tensor.pad away.
|
|
bool disjoint = false;
|
|
for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
|
|
// If the dimension has dynamic offset/size, we cannot guarantee
|
|
// disjoint. So just skip it.
|
|
if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
|
|
insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
|
|
prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
|
|
continue;
|
|
|
|
// Get the range start and end, inclusively for both.
|
|
int64_t prevStart = prevOp.getStaticOffset(i);
|
|
int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
|
|
prevOp.getStaticStride(i);
|
|
int64_t nextStart = insertOp.getStaticOffset(i);
|
|
int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
|
|
insertOp.getStaticStride(i);
|
|
if (prevEnd < nextStart || nextEnd < prevStart) {
|
|
disjoint = true;
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (!disjoint)
|
|
break;
|
|
firstDest = prevOp.getDest();
|
|
}
|
|
|
|
// Check whether the first destination is a fill op. For overlapped cases,
|
|
// this also cannot be true.
|
|
auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
|
|
if (!dstFillOp)
|
|
return failure();
|
|
|
|
// We can only fold if the padding value is the same as the original
|
|
// filling value.
|
|
Value padValue = srcPadOp.getConstantPaddingValue();
|
|
if (!padValue || dstFillOp.value() != padValue)
|
|
return failure();
|
|
|
|
SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
|
|
SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
|
|
|
|
Location loc = insertOp.getLoc();
|
|
MLIRContext *context = getContext();
|
|
|
|
AffineExpr sym0, sym1;
|
|
bindSymbols(context, sym0, sym1);
|
|
auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context);
|
|
|
|
// Calculate the new offsets for the insert. It should be the old offsets
|
|
// plus low padding sizes.
|
|
SmallVector<OpFoldResult, 4> newOffsets;
|
|
for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
|
|
newOffsets.push_back(affine::makeComposedFoldedAffineApply(
|
|
rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
|
|
}
|
|
|
|
RankedTensorType srcPadType = srcPadOp.getSourceType();
|
|
SmallVector<OpFoldResult, 4> newSizes;
|
|
for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
|
|
if (srcPadType.isDynamicDim(i)) {
|
|
newSizes.push_back(
|
|
rewriter.create<tensor::DimOp>(loc, srcPadOp.getSource(), i)
|
|
.getResult());
|
|
} else {
|
|
newSizes.push_back(rewriter.getIndexAttr(srcPadType.getDimSize(i)));
|
|
}
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
|
|
insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
|
|
newSizes, insertOp.getMixedStrides());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Fold tensor.extract(linalg.fill(<input>)) into <input>
|
|
struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
|
|
public:
|
|
using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
|
|
PatternRewriter &rewriter) const override {
|
|
// See if tensor input of tensor.extract op is the result of a linalg.fill
|
|
// op.
|
|
auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
|
|
if (!fillOp)
|
|
return failure();
|
|
|
|
// Get scalar input operand of linalg.fill op.
|
|
Value extractedScalar = fillOp.getInputs()[0];
|
|
|
|
// Replace tensor.extract op with scalar value used to fill the tensor.
|
|
rewriter.replaceOp(extractOp, extractedScalar);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Folds pack(fill) into a single fill op if
|
|
/// 1. The pack op does not have padding value, or
|
|
/// 2. The filled value and padding value are the same.
|
|
static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
|
|
linalg::PackOp packOp) {
|
|
auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
|
|
if (!fillOp)
|
|
return failure();
|
|
|
|
if (auto paddingValue = packOp.getPaddingValue())
|
|
if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
|
|
return failure();
|
|
|
|
Value packOpDest = packOp.getDest();
|
|
if (!packOpDest.hasOneUse())
|
|
return failure();
|
|
|
|
return rewriter.create<linalg::FillOp>(packOp.getLoc(), fillOp.getInputs(),
|
|
packOp.getDest());
|
|
}
|
|
|
|
/// Wrapper pattern that applies foldFillPackIntoFillOp method.
|
|
struct FoldFillWithPack : public OpRewritePattern<linalg::PackOp> {
|
|
public:
|
|
FoldFillWithPack(MLIRContext *context)
|
|
: OpRewritePattern<linalg::PackOp>(context) {}
|
|
|
|
LogicalResult matchAndRewrite(linalg::PackOp packOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
|
|
if (failed(fillOp))
|
|
return failure();
|
|
rewriter.replaceOp(packOp, fillOp.value().result());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Fold fill with copy.
|
|
struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
|
|
using OpRewritePattern<linalg::CopyOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
|
|
PatternRewriter &rewriter) const override {
|
|
if (auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
|
|
rewriter.replaceOpWithNewOp<FillOp>(copyOp, copyOp.getResultTypes(),
|
|
fillOp.getInputs(),
|
|
copyOp.getOutputs());
|
|
return success();
|
|
}
|
|
if (auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
|
|
rewriter.replaceOpWithNewOp<linalg::CopyOp>(copyOp, copyOp.getInputs(),
|
|
fillOp.getOutputs());
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
};
|
|
|
|
/// Fold fill with transpose.
|
|
struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
|
|
using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
|
|
PatternRewriter &rewriter) const override {
|
|
if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
|
|
rewriter.replaceOpWithNewOp<FillOp>(
|
|
transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
|
|
transposeOp.getDpsInitOperand(0)->get());
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
};
|
|
|
|
/// Fold a concat with all elements being fills of the same value
|
|
/// into a fill of the concat result shape.
|
|
struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto concatOperands = concatOp.getInputs();
|
|
if (concatOperands.empty()) {
|
|
return failure();
|
|
}
|
|
|
|
auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
|
|
if (!firstFillOp) {
|
|
return failure();
|
|
}
|
|
// Prefetch the fill value.
|
|
OpFoldResult firstFillVal =
|
|
getAsOpFoldResult(firstFillOp.getDpsInputOperand(0)->get());
|
|
// Collect all the outs values for the fill operations.
|
|
SmallVector<Value> allOuts;
|
|
allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
|
|
|
|
auto isDefinedByCompatibleFillOp = [&](Value v) -> bool {
|
|
auto fillOp = v.getDefiningOp<linalg::FillOp>();
|
|
if (!fillOp) {
|
|
return false;
|
|
}
|
|
|
|
OpFoldResult fillVal =
|
|
getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get());
|
|
if (fillVal != firstFillVal)
|
|
return false;
|
|
|
|
allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
|
|
return true;
|
|
};
|
|
if (!llvm::all_of(concatOperands.drop_front(),
|
|
isDefinedByCompatibleFillOp)) {
|
|
return rewriter.notifyMatchFailure(
|
|
concatOp, "not all operands are defined by a compatible fill op");
|
|
}
|
|
|
|
Value outsConcat = rewriter.create<tensor::ConcatOp>(
|
|
concatOp.getLoc(), concatOp.getDim(), allOuts);
|
|
rewriter.replaceOpWithNewOp<linalg::FillOp>(
|
|
concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
|
|
FoldFillWithPack, FoldFillWithPad,
|
|
FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
|
|
FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
|
|
FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// GenericOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static void buildGenericRegion(
|
|
OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs,
|
|
ValueRange outputs,
|
|
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
|
|
SmallVector<Type, 4> blockArgTypes;
|
|
SmallVector<Location, 4> blockArgLocs;
|
|
for (ValueRange container : {inputs, outputs}) {
|
|
for (Value v : container) {
|
|
Type t = v.getType();
|
|
blockArgTypes.push_back(
|
|
isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
|
|
blockArgLocs.push_back(v.getLoc());
|
|
}
|
|
}
|
|
|
|
OpBuilder::InsertionGuard guard(builder);
|
|
Block *bodyBlock =
|
|
builder.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs);
|
|
bodyBuild(builder, loc, bodyBlock->getArguments());
|
|
}
|
|
|
|
void GenericOp::getAsmBlockArgumentNames(Region ®ion,
|
|
OpAsmSetValueNameFn setNameFn) {
|
|
for (Value v : getRegionInputArgs())
|
|
setNameFn(v, "in");
|
|
for (Value v : getRegionOutputArgs())
|
|
setNameFn(v, "out");
|
|
}
|
|
|
|
void GenericOp::build(
|
|
OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
|
|
ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
|
|
ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
|
|
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
|
|
iteratorTypes, doc, libraryCall);
|
|
result.addAttributes(attributes);
|
|
if (bodyBuild)
|
|
buildGenericRegion(builder, result.location, *result.regions.front(),
|
|
inputs, outputs, bodyBuild);
|
|
}
|
|
|
|
void GenericOp::build(
|
|
OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
|
|
ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
|
|
ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
|
|
StringRef libraryCall,
|
|
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
build(builder, result, resultTensorTypes, inputs, outputs,
|
|
builder.getAffineMapArrayAttr(indexingMaps),
|
|
builder.getArrayAttr(llvm::to_vector(llvm::map_range(
|
|
iteratorTypes,
|
|
[&](utils::IteratorType iter) -> mlir::Attribute {
|
|
return IteratorTypeAttr::get(builder.getContext(), iter);
|
|
}))),
|
|
doc.empty() ? StringAttr() : builder.getStringAttr(doc),
|
|
libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
|
|
bodyBuild, attributes);
|
|
}
|
|
|
|
void GenericOp::build(
|
|
OpBuilder &builder, OperationState &result, ValueRange inputs,
|
|
ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
|
|
ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
|
|
StringRef libraryCall,
|
|
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
|
|
iteratorTypes, doc, libraryCall, bodyBuild, attributes);
|
|
}
|
|
|
|
void GenericOp::build(
|
|
OpBuilder &builder, OperationState &result, ValueRange inputs,
|
|
ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
|
|
ArrayRef<utils::IteratorType> iteratorTypes,
|
|
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
|
|
/*doc=*/"",
|
|
/*libraryCall=*/"", bodyBuild, attributes);
|
|
}
|
|
|
|
void GenericOp::build(
|
|
OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
|
|
ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
|
|
ArrayRef<utils::IteratorType> iteratorTypes,
|
|
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
|
|
iteratorTypes,
|
|
/*doc=*/"",
|
|
/*libraryCall=*/"", bodyBuild, attributes);
|
|
}
|
|
|
|
void GenericOp::print(OpAsmPrinter &p) {
|
|
p << " ";
|
|
|
|
// Print extra attributes.
|
|
auto genericAttrNames = linalgTraitAttrNames();
|
|
|
|
llvm::StringSet<> genericAttrNamesSet;
|
|
genericAttrNamesSet.insert_range(genericAttrNames);
|
|
SmallVector<NamedAttribute, 8> genericAttrs;
|
|
for (auto attr : (*this)->getAttrs()) {
|
|
if (attr.getName() == getIteratorTypesAttrName()) {
|
|
auto iteratorTypes =
|
|
llvm::cast<ArrayAttr>(attr.getValue())
|
|
.getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
|
|
// Convert IteratorType enums into the string representation. This is
|
|
// needed, because tests still use the old format when 'iterator_types'
|
|
// attribute is represented as an array of strings.
|
|
// TODO: Remove this conversion once tests are fixed.
|
|
SmallVector<Attribute> iteratorTypeNames =
|
|
llvm::to_vector(llvm::map_range(
|
|
iteratorTypes, [&](utils::IteratorType t) -> Attribute {
|
|
return StringAttr::get(getContext(), stringifyIteratorType(t));
|
|
}));
|
|
|
|
genericAttrs.emplace_back(
|
|
getIteratorTypesAttrName(),
|
|
ArrayAttr::get(getContext(), iteratorTypeNames));
|
|
} else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
|
|
genericAttrs.push_back(attr);
|
|
}
|
|
}
|
|
if (!genericAttrs.empty()) {
|
|
auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
|
|
p << genericDictAttr;
|
|
}
|
|
|
|
// Printing is shared with named ops, except for the region and attributes
|
|
printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
|
|
|
|
genericAttrNames.push_back("operandSegmentSizes");
|
|
genericAttrNamesSet.insert(genericAttrNames.back());
|
|
|
|
bool hasExtraAttrs = false;
|
|
for (NamedAttribute n : (*this)->getAttrs()) {
|
|
if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
|
|
break;
|
|
}
|
|
if (hasExtraAttrs) {
|
|
p << " attrs = ";
|
|
p.printOptionalAttrDict((*this)->getAttrs(),
|
|
/*elidedAttrs=*/genericAttrNames);
|
|
}
|
|
|
|
// Print region.
|
|
if (!getRegion().empty()) {
|
|
p << ' ';
|
|
p.printRegion(getRegion());
|
|
}
|
|
|
|
// Print results.
|
|
printNamedStructuredOpResults(p, getResultTensors().getTypes());
|
|
}
|
|
|
|
ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
DictionaryAttr dictAttr;
|
|
// Parse the core linalg traits that must check into a dictAttr.
|
|
// The name is unimportant as we will overwrite result.attributes.
|
|
// The core linalg traits must contain the information necessary to pass the
|
|
// verifier.
|
|
llvm::SMLoc attributeLocation = parser.getCurrentLocation();
|
|
if (parser.parseAttribute(dictAttr, "_", result.attributes))
|
|
return failure();
|
|
result.attributes.assign(dictAttr.getValue().begin(),
|
|
dictAttr.getValue().end());
|
|
|
|
// Convert array of string into an array of IteratorType enums. This is
|
|
// needed, because tests still use the old format when 'iterator_types'
|
|
// attribute is represented as an array of strings.
|
|
// TODO: Remove this conversion once tests are fixed.
|
|
auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
|
|
result.attributes.get(getIteratorTypesAttrName(result.name)));
|
|
if (!iteratorTypes) {
|
|
return parser.emitError(attributeLocation)
|
|
<< "expected " << getIteratorTypesAttrName(result.name)
|
|
<< " array attribute";
|
|
}
|
|
|
|
SmallVector<Attribute> iteratorTypeAttrs;
|
|
|
|
for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
|
|
auto maybeIteratorType = utils::symbolizeIteratorType(s);
|
|
if (!maybeIteratorType.has_value())
|
|
return parser.emitError(parser.getCurrentLocation())
|
|
<< "unexpected iterator_type (" << s << ")";
|
|
|
|
iteratorTypeAttrs.push_back(
|
|
IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
|
|
}
|
|
result.attributes.set(getIteratorTypesAttrName(result.name),
|
|
parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
|
|
|
|
// Parsing is shared with named ops, except for the region.
|
|
SmallVector<Type, 1> inputTypes, outputTypes;
|
|
if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
|
|
return failure();
|
|
|
|
// Optional attributes may be added.
|
|
if (succeeded(parser.parseOptionalKeyword("attrs")))
|
|
if (failed(parser.parseEqual()) ||
|
|
failed(parser.parseOptionalAttrDict(result.attributes)))
|
|
return failure();
|
|
|
|
std::unique_ptr<Region> region = std::make_unique<Region>();
|
|
if (parser.parseRegion(*region, {}))
|
|
return failure();
|
|
result.addRegion(std::move(region));
|
|
|
|
// Generic ops may specify that a subset of its outputs are tensors. Such
|
|
// outputs are specified in the result type.
|
|
// TODO: may need to move output parsing before region parsing.
|
|
// Need to wait for declarative assembly resolution to decide.
|
|
SmallVector<Type, 1> outputTensorsTypes;
|
|
if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
|
|
return failure();
|
|
result.addTypes(outputTensorsTypes);
|
|
|
|
return success();
|
|
}
|
|
|
|
static void getGenericEffectsImpl(
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
&effects,
|
|
LinalgOp linalgOp) {
|
|
for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
|
|
if (!llvm::isa<MemRefType>(operand.getType()))
|
|
continue;
|
|
effects.emplace_back(
|
|
MemoryEffects::Read::get(), &linalgOp->getOpOperand(index), /*stage=*/0,
|
|
/*effectOnFullRegion=*/true, SideEffects::DefaultResource::get());
|
|
}
|
|
|
|
for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
|
|
if (!llvm::isa<MemRefType>(operand.get().getType()))
|
|
continue;
|
|
if (linalgOp.payloadUsesValueFromOperand(&operand)) {
|
|
effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
|
|
/*effectOnFullRegion=*/true,
|
|
SideEffects::DefaultResource::get());
|
|
}
|
|
effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
|
|
/*effectOnFullRegion=*/true,
|
|
SideEffects::DefaultResource::get());
|
|
}
|
|
}
|
|
|
|
void GenericOp::getEffects(
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
&effects) {
|
|
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
static Speculation::Speculatability
|
|
getGenericSpeculatabilityImpl(LinalgOp linalgOp) {
|
|
// Operands with value semantics are speculatable, while operands with memory
|
|
// semantics are not.
|
|
if (!linalgOp.hasPureTensorSemantics())
|
|
return Speculation::NotSpeculatable;
|
|
// The body of the op can still have speculation in its region.
|
|
return Speculation::RecursivelySpeculatable;
|
|
}
|
|
|
|
Speculation::Speculatability GenericOp::getSpeculatability() {
|
|
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
LogicalResult GenericOp::verify() { return success(); }
|
|
|
|
namespace {
|
|
|
|
/// Remove any linalg operation (on tensors) that are just copying
|
|
/// the values from inputs to the results. Requirements are
|
|
/// 1) All iterator types are parallel
|
|
/// 2) The body contains just a yield operation with the yielded values being
|
|
/// the arguments corresponding to the operands.
|
|
template <typename OpTy>
|
|
struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
|
|
using OpRewritePattern<OpTy>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(OpTy linalgOp,
|
|
PatternRewriter &rewriter) const override {
|
|
// All indexing maps must be equal. It follows that they are permutations.
|
|
if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
|
|
return failure();
|
|
|
|
// Check that the body of the linalg operation is just a linalg.yield
|
|
// operation.
|
|
Block &body = linalgOp->getRegion(0).front();
|
|
if (!llvm::hasSingleElement(body))
|
|
return failure();
|
|
auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
|
|
if (!yieldOp)
|
|
return failure();
|
|
|
|
// In the buffer case, we need to check exact buffer equality.
|
|
if (linalgOp.hasPureBufferSemantics()) {
|
|
if (linalgOp.getNumDpsInputs() == 1 && linalgOp.getNumDpsInits() == 1 &&
|
|
linalgOp.getDpsInputOperand(0)->get() ==
|
|
linalgOp.getDpsInitOperand(0)->get()) {
|
|
rewriter.eraseOp(linalgOp);
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
// Mixed semantics is not supported yet.
|
|
if (!linalgOp.hasPureTensorSemantics())
|
|
return failure();
|
|
|
|
// Get the argument number of the returned values. That is the operand
|
|
// number to use for replacing uses of this operation.
|
|
SmallVector<Value> returnedArgs;
|
|
for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
|
|
auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
|
|
if (!yieldArg || yieldArg.getOwner() != &body)
|
|
return failure();
|
|
unsigned argumentNumber = yieldArg.getArgNumber();
|
|
Value returnedArg = linalgOp->getOperand(argumentNumber);
|
|
Type resultType = linalgOp->getResult(yieldVal.index()).getType();
|
|
// The input can have a different type than the result, e.g. a dynamic
|
|
// input dimension can be turned into a static output dimension.
|
|
Type returnType = returnedArg.getType();
|
|
if (returnType != resultType) {
|
|
// Distinguish between sparse conversion or dense tensor casting.
|
|
// TODO: unify the two ops?
|
|
if (sparse_tensor::getSparseTensorEncoding(returnType) ||
|
|
sparse_tensor::getSparseTensorEncoding(resultType))
|
|
returnedArg = rewriter.create<sparse_tensor::ConvertOp>(
|
|
linalgOp.getLoc(), resultType, returnedArg);
|
|
else {
|
|
if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
|
|
resultType))
|
|
return failure();
|
|
returnedArg = rewriter.create<tensor::CastOp>(
|
|
linalgOp.getLoc(), resultType, returnedArg);
|
|
}
|
|
}
|
|
returnedArgs.push_back(returnedArg);
|
|
}
|
|
|
|
if (returnedArgs.size() != linalgOp->getNumResults())
|
|
return failure();
|
|
rewriter.replaceOp(linalgOp, returnedArgs);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.add<EraseIdentityLinalgOp<GenericOp>>(context);
|
|
}
|
|
|
|
LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
|
|
return memref::foldMemRefCast(*this);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MapOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static ParseResult parseDstStyleOp(
|
|
OpAsmParser &parser, OperationState &result,
|
|
function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
|
|
nullptr) {
|
|
// Parse `ins` and `outs`.
|
|
SmallVector<Type, 4> inputTypes, outputTypes;
|
|
if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
|
|
/*addOperandSegmentSizes=*/false))
|
|
return failure();
|
|
|
|
// Add result types.
|
|
for (Type outputType : outputTypes) {
|
|
if (llvm::isa<RankedTensorType>(outputType))
|
|
result.addTypes(outputType);
|
|
}
|
|
|
|
// Parse required attributes.
|
|
if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
|
|
return failure();
|
|
|
|
// Parse optional attributes.
|
|
if (parser.parseOptionalAttrDict(result.attributes))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
void MapOp::getAsmBlockArgumentNames(Region ®ion,
|
|
OpAsmSetValueNameFn setNameFn) {
|
|
for (Value v : getRegionInputArgs())
|
|
setNameFn(v, "in");
|
|
}
|
|
|
|
void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
|
|
if (!getResults().empty())
|
|
setNameFn(getResults().front(), "mapped");
|
|
}
|
|
|
|
void MapOp::build(
|
|
OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
|
|
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
build(builder, result, TypeRange{}, inputs, init);
|
|
result.addAttributes(attributes);
|
|
|
|
// Add output types for `RankedTensorType` output arguments.
|
|
Type initType = init.getType();
|
|
if (llvm::isa<RankedTensorType>(initType))
|
|
result.addTypes(initType);
|
|
|
|
if (bodyBuild)
|
|
buildGenericRegion(builder, result.location, *result.regions.front(),
|
|
inputs, /*outputs=*/{}, bodyBuild);
|
|
}
|
|
|
|
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
|
|
const OperationName &payloadOpName,
|
|
const NamedAttrList &payloadOpAttrs,
|
|
ArrayRef<Value> operands,
|
|
bool initFirst = false) {
|
|
OpBuilder b(parser.getContext());
|
|
Region *body = result.addRegion();
|
|
Block &block = body->emplaceBlock();
|
|
b.setInsertionPointToStart(&block);
|
|
for (auto &operand : operands) {
|
|
block.addArgument(
|
|
llvm::cast<ShapedType>(operand.getType()).getElementType(),
|
|
b.getUnknownLoc());
|
|
}
|
|
SmallVector<Value> payloadOpOperands;
|
|
// If initFirst flag is enabled, we consider init as the first position of
|
|
// payload operands.
|
|
if (initFirst) {
|
|
payloadOpOperands.push_back(block.getArguments().back());
|
|
for (const auto &arg : block.getArguments().drop_back())
|
|
payloadOpOperands.push_back(arg);
|
|
} else {
|
|
payloadOpOperands = {block.getArguments().begin(),
|
|
block.getArguments().end()};
|
|
}
|
|
|
|
Operation *payloadOp = b.create(
|
|
result.location, b.getStringAttr(payloadOpName.getStringRef()),
|
|
payloadOpOperands,
|
|
TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
|
|
.getElementType()},
|
|
payloadOpAttrs);
|
|
b.create<YieldOp>(result.location, payloadOp->getResults());
|
|
}
|
|
|
|
ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
std::optional<OperationName> payloadOpName;
|
|
NamedAttrList payloadOpAttrs;
|
|
if (succeeded(parser.parseOptionalLBrace())) {
|
|
FailureOr<OperationName> operationName = parser.parseCustomOperationName();
|
|
if (failed(operationName))
|
|
return failure();
|
|
if (parser.parseOptionalAttrDict(payloadOpAttrs))
|
|
return failure();
|
|
payloadOpName = operationName.value();
|
|
if (parser.parseRBrace())
|
|
return failure();
|
|
}
|
|
|
|
if (parseDstStyleOp(parser, result))
|
|
return failure();
|
|
|
|
if (payloadOpName.has_value()) {
|
|
if (!result.operands.empty())
|
|
addBodyWithPayloadOp(parser, result, payloadOpName.value(),
|
|
payloadOpAttrs,
|
|
ArrayRef(result.operands).drop_back());
|
|
else
|
|
result.addRegion();
|
|
} else {
|
|
SmallVector<OpAsmParser::Argument> regionArgs;
|
|
if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
|
|
/*allowType=*/true, /*allowAttrs=*/true)) {
|
|
return failure();
|
|
}
|
|
Region *body = result.addRegion();
|
|
if (parser.parseRegion(*body, regionArgs))
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
// Retrieve the operation from the body, if it is the only one (except
|
|
// yield) and if it gets the same amount of arguments as the body does.
|
|
// If initFirst flag is enabled, we check that init takes the first position in
|
|
// operands of payload.
|
|
static Operation *findPayloadOp(Block *body, bool initFirst = false) {
|
|
if (body->getOperations().size() != 2)
|
|
return nullptr;
|
|
Operation &payload = body->getOperations().front();
|
|
assert(isa<YieldOp>(body->getOperations().back()));
|
|
|
|
if (payload.getNumOperands() == 0 ||
|
|
payload.getNumOperands() != body->getNumArguments())
|
|
return nullptr;
|
|
if (initFirst) {
|
|
// check init
|
|
if (payload.getOperands().back() != body->getArgument(0))
|
|
return nullptr;
|
|
// check rest
|
|
for (const auto &[operand, bbArg] :
|
|
llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
|
|
if (bbArg != operand)
|
|
return nullptr;
|
|
}
|
|
} else {
|
|
for (const auto &[operand, bbArg] :
|
|
llvm::zip(payload.getOperands(), body->getArguments())) {
|
|
if (bbArg != operand)
|
|
return nullptr;
|
|
}
|
|
}
|
|
return &payload;
|
|
}
|
|
|
|
void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
|
|
SmallVector<StringRef> elidedAttrs;
|
|
std::string attrToElide;
|
|
p << " { " << payloadOp->getName().getStringRef();
|
|
for (const auto &attr : payloadOp->getAttrs()) {
|
|
auto fastAttr =
|
|
llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
|
|
if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
|
|
attrToElide = attr.getName().str();
|
|
elidedAttrs.push_back(attrToElide);
|
|
break;
|
|
}
|
|
}
|
|
p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs);
|
|
p << " }";
|
|
}
|
|
|
|
void MapOp::print(OpAsmPrinter &p) {
|
|
Block *mapper = getBody();
|
|
Operation *payloadOp = findPayloadOp(mapper);
|
|
if (payloadOp) {
|
|
printShortForm(p, payloadOp);
|
|
}
|
|
|
|
printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
|
|
p.printOptionalAttrDict((*this)->getAttrs());
|
|
|
|
if (!payloadOp) {
|
|
// Print region if the payload op was not detected.
|
|
p.increaseIndent();
|
|
p.printNewline();
|
|
p << "(";
|
|
llvm::interleaveComma(mapper->getArguments(), p,
|
|
[&](auto arg) { p.printRegionArgument(arg); });
|
|
p << ") ";
|
|
|
|
p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
|
|
p.decreaseIndent();
|
|
}
|
|
}
|
|
|
|
LogicalResult MapOp::verify() {
|
|
auto *bodyBlock = getBody();
|
|
auto blockArgs = bodyBlock->getArguments();
|
|
|
|
// Checks if the number of `inputs` match the arity of the `mapper` region.
|
|
if (getInputs().size() != blockArgs.size())
|
|
return emitOpError() << "expects number of operands to match the arity of "
|
|
"mapper, but got: "
|
|
<< getInputs().size() << " and " << blockArgs.size();
|
|
|
|
// The parameters of mapper should all match the element type of inputs.
|
|
for (const auto &[bbArgType, inputArg] :
|
|
llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
|
|
auto inputElemType =
|
|
llvm::cast<ShapedType>(inputArg.getType()).getElementType();
|
|
if (bbArgType != inputElemType) {
|
|
return emitOpError() << "expected element type of input " << inputElemType
|
|
<< " to match bbArg type " << bbArgType;
|
|
}
|
|
}
|
|
|
|
// The shape of each input must match the shape of the output.
|
|
auto outputShape = getInit().getType().getShape();
|
|
for (Type inputArgType : TypeRange{getInputs()}) {
|
|
auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
|
|
if (inputElemShape != outputShape) {
|
|
return emitOpError() << "expected shape of input (" << inputElemShape
|
|
<< ") to match shape of output (" << outputShape
|
|
<< ")";
|
|
}
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
|
|
int64_t rank = getInit().getType().getRank();
|
|
return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
|
|
}
|
|
|
|
ArrayAttr MapOp::getIndexingMaps() {
|
|
Builder builder(getContext());
|
|
int64_t rank = getInit().getType().getRank();
|
|
int64_t numIndexingMaps = getOperands().size();
|
|
return builder.getAffineMapArrayAttr(SmallVector<AffineMap>(
|
|
numIndexingMaps, builder.getMultiDimIdentityMap(rank)));
|
|
}
|
|
|
|
void MapOp::getEffects(
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
&effects) {
|
|
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
Speculation::Speculatability MapOp::getSpeculatability() {
|
|
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ReduceOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void ReduceOp::getAsmBlockArgumentNames(Region ®ion,
|
|
OpAsmSetValueNameFn setNameFn) {
|
|
for (Value v : getRegionInputArgs())
|
|
setNameFn(v, "in");
|
|
for (Value v : getRegionOutputArgs())
|
|
setNameFn(v, "init");
|
|
}
|
|
|
|
void ReduceOp::getAsmResultNames(
|
|
function_ref<void(Value, StringRef)> setNameFn) {
|
|
if (!getResults().empty())
|
|
setNameFn(getResults().front(), "reduced");
|
|
}
|
|
|
|
void ReduceOp::build(
|
|
OpBuilder &builder, OperationState &result, ValueRange inputs,
|
|
ValueRange inits, ArrayRef<int64_t> dimensions,
|
|
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
build(builder, result, TypeRange{}, inputs, inits, dimensions);
|
|
result.addAttributes(attributes);
|
|
|
|
// Add output types for `RankedTensorType` output arguments.
|
|
for (Value init : inits) {
|
|
Type initType = init.getType();
|
|
if (llvm::isa<RankedTensorType>(initType))
|
|
result.addTypes(initType);
|
|
}
|
|
|
|
if (bodyBuild)
|
|
buildGenericRegion(builder, result.location, *result.regions.front(),
|
|
inputs, inits, bodyBuild);
|
|
}
|
|
|
|
SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
|
|
int64_t inputRank =
|
|
llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
|
|
SmallVector<utils::IteratorType> iteratorTypes(inputRank,
|
|
utils::IteratorType::parallel);
|
|
for (int64_t reductionDim : getDimensions())
|
|
iteratorTypes[reductionDim] = utils::IteratorType::reduction;
|
|
return iteratorTypes;
|
|
}
|
|
|
|
ArrayAttr ReduceOp::getIndexingMaps() {
|
|
int64_t inputRank =
|
|
llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
|
|
SmallVector<AffineMap> affineMaps(
|
|
getNumDpsInputs(),
|
|
AffineMap::getMultiDimIdentityMap(inputRank, getContext()));
|
|
AffineMap resultMap =
|
|
AffineMap::getMultiDimIdentityMap(inputRank, getContext())
|
|
.dropResults(getDimensions());
|
|
for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
|
|
affineMaps.push_back(resultMap);
|
|
return Builder(getContext()).getAffineMapArrayAttr(affineMaps);
|
|
}
|
|
|
|
void ReduceOp::getEffects(
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
&effects) {
|
|
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
Speculation::Speculatability ReduceOp::getSpeculatability() {
|
|
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
|
|
NamedAttrList &attributes,
|
|
StringRef attributeName) {
|
|
if (parser.parseKeyword(attributeName) || parser.parseEqual())
|
|
return failure();
|
|
|
|
attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{}));
|
|
return success();
|
|
}
|
|
|
|
ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
std::optional<OperationName> payloadOpName;
|
|
NamedAttrList payloadOpAttrs;
|
|
if (succeeded(parser.parseOptionalLBrace())) {
|
|
FailureOr<OperationName> operationName = parser.parseCustomOperationName();
|
|
if (failed(operationName))
|
|
return failure();
|
|
if (parser.parseOptionalAttrDict(payloadOpAttrs))
|
|
return failure();
|
|
payloadOpName = operationName.value();
|
|
if (parser.parseRBrace())
|
|
return failure();
|
|
}
|
|
|
|
if (parseDstStyleOp(
|
|
parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
|
|
return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
|
|
}))
|
|
return failure();
|
|
|
|
if (payloadOpName.has_value()) {
|
|
addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
|
|
ArrayRef(result.operands), /*initFirst=*/true);
|
|
} else {
|
|
SmallVector<OpAsmParser::Argument> regionArgs;
|
|
if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
|
|
/*allowType=*/true, /*allowAttrs=*/true)) {
|
|
return failure();
|
|
}
|
|
|
|
Region *body = result.addRegion();
|
|
if (parser.parseRegion(*body, regionArgs))
|
|
return failure();
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
|
|
ArrayRef<int64_t> attributeValue) {
|
|
p << ' ' << attributeName << " = [" << attributeValue << "] ";
|
|
}
|
|
|
|
void ReduceOp::print(OpAsmPrinter &p) {
|
|
Block *mapper = getBody();
|
|
Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true);
|
|
if (payloadOp) {
|
|
printShortForm(p, payloadOp);
|
|
}
|
|
|
|
printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
|
|
printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
|
|
p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
|
|
if (!payloadOp) {
|
|
// Print region if the payload op was not detected.
|
|
p.increaseIndent();
|
|
p.printNewline();
|
|
p << "(";
|
|
llvm::interleaveComma(mapper->getArguments(), p,
|
|
[&](auto arg) { p.printRegionArgument(arg); });
|
|
p << ") ";
|
|
|
|
p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
|
|
p.decreaseIndent();
|
|
}
|
|
}
|
|
|
|
LogicalResult ReduceOp::verify() {
|
|
ArrayRef<int64_t> dimensionsRef = getDimensions();
|
|
|
|
for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
|
|
if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
|
|
llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
|
|
return emitOpError() << "expects all inputs to have the same shapes. "
|
|
"Shape at input-index "
|
|
<< i
|
|
<< " is not equal to the shape at input-index 0.";
|
|
}
|
|
}
|
|
for (int64_t i = 1; i < getNumDpsInits(); ++i) {
|
|
if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() !=
|
|
llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) {
|
|
return emitOpError() << "expects all outputs to have the same shapes. "
|
|
"Shape at output-index "
|
|
<< i
|
|
<< " is not equal to the shape at output-index 0.";
|
|
}
|
|
}
|
|
auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType());
|
|
auto initType = llvm::cast<ShapedType>(getInits()[0].getType());
|
|
|
|
DenseSet<int64_t> dimensionsToReduce;
|
|
for (int64_t dimension : dimensionsRef) {
|
|
if (dimension < 0 || dimension >= inputType.getRank()) {
|
|
return emitOpError()
|
|
<< "dimensions for reduction should be in the range [0, "
|
|
<< inputType.getRank() - 1 << "].";
|
|
}
|
|
dimensionsToReduce.insert(dimension);
|
|
}
|
|
|
|
auto inputDims = inputType.getShape();
|
|
auto initDims = initType.getShape();
|
|
|
|
// Input dimensions that will be left after the reduction.
|
|
SmallVector<int64_t> reducedInputDims;
|
|
for (const auto &en : llvm::enumerate(inputDims)) {
|
|
if (!dimensionsToReduce.count(en.index()))
|
|
reducedInputDims.push_back(en.value());
|
|
}
|
|
|
|
if (reducedInputDims.size() != static_cast<size_t>(initType.getRank())) {
|
|
return emitOpError() << "number of dimensions after reduction "
|
|
<< reducedInputDims.size()
|
|
<< " doesn't match the init rank "
|
|
<< initType.getRank();
|
|
}
|
|
|
|
if (reducedInputDims != initDims)
|
|
return emitOpError() << "init dimensions [" << initDims
|
|
<< "] doesn't match input dimensions after reduction ["
|
|
<< reducedInputDims << "]";
|
|
|
|
Block *block = getBody();
|
|
if (block->getNumArguments() != this->getNumOperands())
|
|
return emitOpError()
|
|
<< "mismatching number of operands and block arguments";
|
|
|
|
// Check that the first block arguments match the element type of the inputs.
|
|
for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) {
|
|
Type inputElementType =
|
|
llvm::cast<ShapedType>(input.getType()).getElementType();
|
|
if (inputElementType != bbArg.getType())
|
|
return emitOpError()
|
|
<< "input element type " << inputElementType
|
|
<< " does not match corresponding block argument type "
|
|
<< bbArg.getType();
|
|
}
|
|
|
|
// Check that the last block arguments match the element type of the outputs.
|
|
for (auto [output, bbArg] : llvm::zip(
|
|
getDpsInits(), block->getArguments().take_back(getNumDpsInits()))) {
|
|
auto outputElementType =
|
|
llvm::cast<ShapedType>(output.getType()).getElementType();
|
|
if (outputElementType != bbArg.getType())
|
|
return emitOpError()
|
|
<< "output element type " << outputElementType
|
|
<< " does not match corresponding block argument type "
|
|
<< bbArg.getType();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TransposeOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static void buildIdentityRegion(OpBuilder &builder, Location loc,
|
|
Region ®ion, ValueRange inputs,
|
|
ValueRange outputs) {
|
|
buildGenericRegion(builder, loc, region, inputs, outputs,
|
|
[](OpBuilder &b, Location loc, ValueRange args) {
|
|
if (!args.empty())
|
|
b.create<linalg::YieldOp>(loc, args[0]);
|
|
});
|
|
}
|
|
|
|
void TransposeOp::build(::mlir::OpBuilder &builder,
|
|
::mlir::OperationState &result, Value input, Value init,
|
|
DenseI64ArrayAttr permutation,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
result.addOperands(input);
|
|
result.addOperands(init);
|
|
result.addAttribute(getPermutationAttrName(result.name), permutation);
|
|
result.addAttributes(attributes);
|
|
|
|
// Add output types for `RankedTensorType` output arguments.
|
|
Type initType = init.getType();
|
|
if (llvm::isa<RankedTensorType>(initType))
|
|
result.addTypes(initType);
|
|
|
|
buildIdentityRegion(builder, result.location, *result.addRegion(), input,
|
|
init);
|
|
}
|
|
|
|
void TransposeOp::build(::mlir::OpBuilder &builder,
|
|
::mlir::OperationState &result, Value input, Value init,
|
|
ArrayRef<int64_t> permutation,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
|
|
attributes);
|
|
}
|
|
|
|
ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
if (failed(parseDstStyleOp(
|
|
parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
|
|
return parseDenseI64ArrayAttr(parser, attributes, "permutation");
|
|
})))
|
|
return failure();
|
|
|
|
OpBuilder builder(parser.getContext());
|
|
buildIdentityRegion(builder, result.location, *result.addRegion(),
|
|
/*inputs=*/result.operands,
|
|
/*outputs=*/{});
|
|
return success();
|
|
}
|
|
|
|
void TransposeOp::getAsmResultNames(
|
|
function_ref<void(Value, StringRef)> setNameFn) {
|
|
if (!getResults().empty())
|
|
setNameFn(getResults().front(), "transposed");
|
|
}
|
|
|
|
void TransposeOp::print(OpAsmPrinter &p) {
|
|
printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
|
|
printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
|
|
p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
|
|
}
|
|
|
|
LogicalResult TransposeOp::verify() {
|
|
ArrayRef<int64_t> permutationRef = getPermutation();
|
|
|
|
if (!isPermutationVector(permutationRef))
|
|
return emitOpError("permutation is not valid");
|
|
|
|
auto inputType = getInput().getType();
|
|
auto initType = getInit().getType();
|
|
|
|
int64_t rank = inputType.getRank();
|
|
|
|
if (rank != initType.getRank())
|
|
return emitOpError() << "input rank " << rank
|
|
<< " does not match init rank " << initType.getRank();
|
|
|
|
if (rank != static_cast<int64_t>(permutationRef.size()))
|
|
return emitOpError() << "size of permutation " << permutationRef.size()
|
|
<< " does not match the argument rank " << rank;
|
|
|
|
auto inputDims = inputType.getShape();
|
|
auto initDims = initType.getShape();
|
|
|
|
for (int64_t i = 0; i < rank; ++i) {
|
|
int64_t inputDim = inputDims[permutationRef[i]];
|
|
int64_t initDim = initDims[i];
|
|
|
|
if (inputDim != initDim) {
|
|
return emitOpError() << "dim(result, " << i << ") = " << initDim
|
|
<< " doesn't match dim(input, permutation[" << i
|
|
<< "]) = " << inputDim;
|
|
}
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
|
|
int64_t rank = getInit().getType().getRank();
|
|
return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
|
|
}
|
|
|
|
ArrayAttr TransposeOp::getIndexingMaps() {
|
|
Builder builder(getContext());
|
|
int64_t rank = getInit().getType().getRank();
|
|
return builder.getAffineMapArrayAttr(
|
|
{inversePermutation(AffineMap::getPermutationMap(
|
|
llvm::to_vector_of<unsigned>(getPermutation()), getContext())),
|
|
builder.getMultiDimIdentityMap(rank)});
|
|
}
|
|
|
|
void TransposeOp::getEffects(
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
&effects) {
|
|
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
Speculation::Speculatability TransposeOp::getSpeculatability() {
|
|
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
|
|
SmallVectorImpl<OpFoldResult> &result) {
|
|
// Only the tensor type is supported.
|
|
if (!isa<TensorType>(getInput().getType()))
|
|
return failure();
|
|
|
|
// Single dimension transpose.
|
|
if (getPermutation().size() == 0) {
|
|
result.push_back(getInput());
|
|
return success();
|
|
}
|
|
// Identity permutation.
|
|
if (isIdentityPermutation(getPermutation())) {
|
|
result.push_back(getInput());
|
|
return success();
|
|
}
|
|
|
|
return failure();
|
|
}
|
|
|
|
/// Fold transpose with transpose.
|
|
struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
|
|
using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
|
|
if (!defTransposeOp)
|
|
return failure();
|
|
ArrayRef<int64_t> defPerms = defTransposeOp.getPermutation();
|
|
ArrayRef<int64_t> perms = transposeOp.getPermutation();
|
|
SmallVector<int64_t> foldedPerms;
|
|
foldedPerms.reserve(perms.size());
|
|
for (int64_t perm : perms)
|
|
foldedPerms.push_back(defPerms[perm]);
|
|
|
|
rewriter.replaceOpWithNewOp<TransposeOp>(
|
|
transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
|
|
foldedPerms);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// This pattern canonicalize transpose by swapping the order of
|
|
/// broadcast and transpose:
|
|
/// transpose(broadcast(input)) -> broadcast(transpose(input))
|
|
struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
|
|
using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
|
|
PatternRewriter &rewriter) const override {
|
|
Value input = transposeOp.getInput();
|
|
BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>();
|
|
if (!input.hasOneUse() || !broadcastOp)
|
|
return failure();
|
|
|
|
ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
|
|
ArrayRef<int64_t> perms = transposeOp.getPermutation();
|
|
|
|
// Get new perms and new dimensions.
|
|
SmallVector<int64_t> resultPerms = dropDims(perms, dimensions);
|
|
SmallVector<int64_t> invertPerm = invertPermutationVector(perms);
|
|
SmallVector<int64_t> resultDimensions;
|
|
unsigned dimensionSize = dimensions.size();
|
|
for (unsigned i = 0; i < dimensionSize; ++i)
|
|
resultDimensions.push_back(invertPerm[dimensions[i]]);
|
|
|
|
// Create transpose result.
|
|
Value broadcastInput = broadcastOp.getInput();
|
|
Location loc = transposeOp.getLoc();
|
|
MLIRContext *ctx = transposeOp.getContext();
|
|
SmallVector<OpFoldResult> dims;
|
|
auto broadcastInputTy =
|
|
mlir::cast<RankedTensorType>(broadcastInput.getType());
|
|
unsigned inputRank = broadcastInputTy.getRank();
|
|
for (unsigned i = 0; i < inputRank; ++i) {
|
|
if (broadcastInputTy.isDynamicDim(i)) {
|
|
dims.push_back(rewriter.create<tensor::DimOp>(loc, broadcastInput, i)
|
|
->getResult(0));
|
|
} else {
|
|
dims.push_back(IntegerAttr::get(IndexType::get(ctx),
|
|
broadcastInputTy.getDimSize(i)));
|
|
}
|
|
}
|
|
SmallVector<OpFoldResult> transposeResultShapes =
|
|
applyPermutation(dims, resultPerms);
|
|
Value transposeInit = rewriter.create<tensor::EmptyOp>(
|
|
transposeOp.getLoc(), transposeResultShapes,
|
|
broadcastInputTy.getElementType());
|
|
|
|
// Create broadcast(transpose(input)).
|
|
Value transposeResult =
|
|
rewriter
|
|
.create<TransposeOp>(loc, broadcastOp.getInput(), transposeInit,
|
|
resultPerms)
|
|
->getResult(0);
|
|
rewriter.replaceOpWithNewOp<BroadcastOp>(
|
|
transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.add<FoldTransposeWithTranspose, SwapTransposeWithBroadcast>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// BroadcastOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void BroadcastOp::build(::mlir::OpBuilder &builder,
|
|
::mlir::OperationState &result, Value input, Value init,
|
|
DenseI64ArrayAttr dimensions,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
result.addOperands(input);
|
|
result.addOperands(init);
|
|
result.addAttribute(getDimensionsAttrName(result.name), dimensions);
|
|
result.addAttributes(attributes);
|
|
|
|
// Add output types for `RankedTensorType` output arguments.
|
|
Type initType = init.getType();
|
|
if (llvm::isa<RankedTensorType>(initType))
|
|
result.addTypes(initType);
|
|
|
|
buildIdentityRegion(builder, result.location, *result.addRegion(), input,
|
|
init);
|
|
}
|
|
|
|
void BroadcastOp::build(::mlir::OpBuilder &builder,
|
|
::mlir::OperationState &result, Value input, Value init,
|
|
ArrayRef<int64_t> dimensions,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions),
|
|
attributes);
|
|
}
|
|
|
|
ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
if (failed(parseDstStyleOp(
|
|
parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
|
|
return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
|
|
})))
|
|
return failure();
|
|
|
|
OpBuilder builder(parser.getContext());
|
|
buildIdentityRegion(builder, result.location, *result.addRegion(),
|
|
/*inputs=*/result.operands,
|
|
/*outputs=*/{});
|
|
return success();
|
|
}
|
|
|
|
void BroadcastOp::getAsmResultNames(
|
|
function_ref<void(Value, StringRef)> setNameFn) {
|
|
if (!getResults().empty())
|
|
setNameFn(getResults().front(), "broadcasted");
|
|
}
|
|
|
|
void BroadcastOp::print(OpAsmPrinter &p) {
|
|
printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
|
|
printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
|
|
p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
|
|
}
|
|
|
|
LogicalResult BroadcastOp::verify() {
|
|
ArrayRef<int64_t> dimensionsRef = getDimensions();
|
|
|
|
auto inputType = getInput().getType();
|
|
auto initType = getInit().getType();
|
|
|
|
int64_t inputRank = inputType.getRank();
|
|
int64_t initRank = initType.getRank();
|
|
|
|
auto inputShape = inputType.getShape();
|
|
auto initShape = initType.getShape();
|
|
|
|
if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank)
|
|
return emitOpError() << "input rank plus added dimensions does not "
|
|
"match init rank. input rank: "
|
|
<< inputRank
|
|
<< ", dimensions size: " << dimensionsRef.size()
|
|
<< ", init rank: " << initRank;
|
|
|
|
for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
|
|
if (dim < 0 || dim >= initRank)
|
|
return emitOpError() << "dimension " << idx
|
|
<< " is out of range. expected range: [0, "
|
|
<< initRank - 1 << "], got: " << dim;
|
|
}
|
|
|
|
// Mapping from input dims to init dims.
|
|
SmallVector<int64_t> dimMap;
|
|
for (auto dim : llvm::seq<int64_t>(0, initRank)) {
|
|
if (!llvm::is_contained(dimensionsRef, dim))
|
|
dimMap.push_back(dim);
|
|
}
|
|
|
|
for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
|
|
// This dimensions is mapped from the input. Init and input dims should
|
|
// match.
|
|
if (inputShape[inputDimIdx] != initShape[initDimIdx])
|
|
return emitOpError() << "input dim " << inputDimIdx
|
|
<< " should match init dim " << initDimIdx
|
|
<< ". input: " << inputShape[inputDimIdx]
|
|
<< ", init: " << initShape[initDimIdx];
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
|
|
int64_t rank = getInit().getType().getRank();
|
|
return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
|
|
}
|
|
|
|
ArrayAttr BroadcastOp::getIndexingMaps() {
|
|
Builder builder(getContext());
|
|
int64_t rank = getInit().getType().getRank();
|
|
return builder.getAffineMapArrayAttr(
|
|
{builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()),
|
|
builder.getMultiDimIdentityMap(rank)});
|
|
}
|
|
|
|
void BroadcastOp::getEffects(
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
&effects) {
|
|
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
Speculation::Speculatability BroadcastOp::getSpeculatability() {
|
|
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.add<EraseIdentityLinalgOp<BroadcastOp>>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// YieldOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void linalg::YieldOp::print(OpAsmPrinter &p) {
|
|
if (getNumOperands() > 0)
|
|
p << ' ' << getOperands();
|
|
p.printOptionalAttrDict((*this)->getAttrs());
|
|
if (getNumOperands() > 0)
|
|
p << " : " << getOperandTypes();
|
|
}
|
|
|
|
ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 2> opInfo;
|
|
SmallVector<Type, 2> types;
|
|
SMLoc loc = parser.getCurrentLocation();
|
|
return failure(parser.parseOperandList(opInfo) ||
|
|
parser.parseOptionalAttrDict(result.attributes) ||
|
|
(!opInfo.empty() && parser.parseColonTypeList(types)) ||
|
|
parser.resolveOperands(opInfo, types, loc, result.operands));
|
|
}
|
|
|
|
// Check the operand number and types must match the element types of the
|
|
// LinalgOp interface's shaped operands.
|
|
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
|
|
if (op.getNumOperands() != linalgOp.getNumDpsInits())
|
|
return op.emitOpError("expected number of yield values (")
|
|
<< op.getNumOperands()
|
|
<< ") to match the number of inits / outs operands of the enclosing "
|
|
<< "LinalgOp (" << linalgOp.getNumDpsInits() << ")";
|
|
|
|
for (OpOperand &opOperand : op->getOpOperands()) {
|
|
OpOperand *outputOperand =
|
|
linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
|
|
Type elementType = outputOperand->get().getType();
|
|
if (isa<MemRefType, RankedTensorType>(elementType))
|
|
elementType = getElementTypeOrSelf(outputOperand->get().getType());
|
|
if (opOperand.get().getType() != elementType)
|
|
return op.emitOpError("type of yield operand ")
|
|
<< (opOperand.getOperandNumber() + 1) << " ("
|
|
<< opOperand.get().getType() << ") doesn't match "
|
|
<< "the element type of the enclosing linalg.generic op ("
|
|
<< elementType << ")";
|
|
}
|
|
return success();
|
|
}
|
|
|
|
LogicalResult linalg::YieldOp::verify() {
|
|
auto *parentOp = (*this)->getParentOp();
|
|
if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
|
|
return emitOpError("expected single non-empty parent region");
|
|
|
|
if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
|
|
return verifyYield(*this, linalgOp);
|
|
|
|
return emitOpError("expected parent op with LinalgOp interface");
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// IndexOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult IndexOp::verify() {
|
|
auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
|
|
if (!linalgOp)
|
|
return emitOpError("expected parent op with LinalgOp interface");
|
|
if (linalgOp.getNumLoops() <= getDim())
|
|
return emitOpError("expected dim (")
|
|
<< getDim() << ") to be lower than the number of loops ("
|
|
<< linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
|
|
return success();
|
|
}
|
|
|
|
OpFoldResult IndexOp::fold(FoldAdaptor adaptor) {
|
|
auto linalgOp = dyn_cast_or_null<LinalgOp>((*this)->getParentOp());
|
|
// Bail out if `linalg.index` does not have a proper parent yet at this
|
|
// point, e.g., when calling `createOrFold` during IR construction in
|
|
// `genericOp::build`.
|
|
if (!linalgOp)
|
|
return OpFoldResult{};
|
|
|
|
// Index of unit dims is always 0.
|
|
SmallVector<int64_t, 4> loopBounds = linalgOp.getStaticLoopRanges();
|
|
uint64_t dim = getDim();
|
|
assert(dim < loopBounds.size() && "Dim is out of bounds");
|
|
if (loopBounds[dim] == 1)
|
|
return IntegerAttr::get(IndexType::get(getContext()), 0);
|
|
|
|
return OpFoldResult{};
|
|
}
|
|
|
|
/////// Operations corresponding to library calls defined with Tablegen ////////
|
|
|
|
#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
|
|
#define GET_OP_CLASSES
|
|
#include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.cpp.inc"
|
|
|
|
AffineMap mlir::linalg::extractOrIdentityMap(std::optional<AffineMap> maybeMap,
|
|
unsigned rank,
|
|
MLIRContext *context) {
|
|
if (maybeMap)
|
|
return *maybeMap;
|
|
if (rank == 0)
|
|
return AffineMap::get(context);
|
|
return AffineMap::getMultiDimIdentityMap(rank, context);
|
|
}
|
|
|
|
SmallVector<AffineExpr, 4>
|
|
mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
|
|
MLIRContext *context) {
|
|
SmallVector<AffineExpr, 4> res;
|
|
res.reserve(num);
|
|
for (unsigned i = 0; i < num; ++i)
|
|
res.push_back(getAffineDimExpr(startIdx++, context));
|
|
return res;
|
|
}
|
|
|
|
SmallVector<AffineExpr, 4> mlir::linalg::concat(ArrayRef<AffineExpr> a,
|
|
ArrayRef<AffineExpr> b) {
|
|
auto rangeA = llvm::make_range(a.begin(), a.end());
|
|
auto rangeB = llvm::make_range(b.begin(), b.end());
|
|
auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
|
|
return llvm::to_vector<4>(concatRanges);
|
|
}
|
|
|
|
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) {
|
|
if (auto memref = llvm::dyn_cast<MemRefType>(t)) {
|
|
ss << "view";
|
|
for (auto size : memref.getShape())
|
|
if (size < 0)
|
|
ss << "sx";
|
|
else
|
|
ss << size << "x";
|
|
if (failed(appendMangledType(ss, memref.getElementType())))
|
|
return failure();
|
|
if (auto as = memref.getMemorySpace()) {
|
|
if (auto attr = llvm::dyn_cast<IntegerAttr>(as))
|
|
ss << "as" << attr.getInt();
|
|
else
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
if (auto vec = llvm::dyn_cast<VectorType>(t)) {
|
|
ss << "vector";
|
|
llvm::interleave(
|
|
vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
|
|
if (failed(appendMangledType(ss, vec.getElementType())))
|
|
return failure();
|
|
return success();
|
|
}
|
|
if (t.isSignlessIntOrIndexOrFloat()) {
|
|
ss << t;
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
std::string mlir::linalg::generateLibraryCallName(Operation *op) {
|
|
assert(isa<LinalgOp>(op));
|
|
std::string name(op->getName().getStringRef().str());
|
|
std::string fun = "";
|
|
for (NamedAttribute kv : op->getAttrs()) {
|
|
if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
|
|
fun = stringifyEnum(ufa.getValue()).str() + "_";
|
|
} else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
|
|
fun = stringifyEnum(bfa.getValue()).str() + "_";
|
|
}
|
|
}
|
|
name.reserve(128);
|
|
llvm::replace(name, '.', '_');
|
|
llvm::raw_string_ostream ss(name);
|
|
ss << "_" << fun;
|
|
for (Type t : op->getOperandTypes()) {
|
|
if (failed(appendMangledType(ss, t)))
|
|
return std::string();
|
|
ss << "_";
|
|
}
|
|
name.pop_back();
|
|
return name;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Canonicalizers and Folders.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
|
|
using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(LinalgOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
for (OpOperand &opOperand : op->getOpOperands()) {
|
|
// Linalg "inputs" may be either tensor or memref type.
|
|
// tensor<0xelt_type> is a convention that may not always mean
|
|
// "0 iterations". Only erase in cases we see memref<...x0x...>.
|
|
auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
|
|
if (!mt)
|
|
continue;
|
|
if (llvm::is_contained(op.getShape(&opOperand), 0)) {
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
}
|
|
return failure();
|
|
}
|
|
};
|
|
|
|
/// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
|
|
/// result that is more static than the linalg op.
|
|
struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
|
|
using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tensor::CastOp castOp,
|
|
PatternRewriter &rewriter) const override {
|
|
if (!tensor::canFoldIntoProducerOp(castOp))
|
|
return failure();
|
|
|
|
auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
|
|
if (!linalgOp)
|
|
return failure();
|
|
|
|
// Cast can be in conditionally reachable region, if which case folding will
|
|
// generate invalid code. Only conservatively fold ops in same block for
|
|
// now.
|
|
if (castOp->getBlock() != linalgOp->getBlock())
|
|
return failure();
|
|
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
rewriter.setInsertionPoint(linalgOp);
|
|
|
|
Location loc = linalgOp.getLoc();
|
|
OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
|
|
unsigned resultNumber = resultValue.getResultNumber();
|
|
auto resultType =
|
|
llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
|
|
// Replace the `outs` for the result with a `tensor.cast`. This cast is now
|
|
// going from a more dynamic shape to a less dynamic shape. If the producer
|
|
// for this cast, i.e. producer of the out operand, is also an operation
|
|
// that folds with tensor.cast consumer (like this pattern), the cast will
|
|
// continue to propagate as far up the stack as it can go.
|
|
OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
|
|
Value newOperand =
|
|
rewriter.create<tensor::CastOp>(loc, resultType, outOperand->get());
|
|
SmallVector<Value> newOperands = linalgOp.getDpsInputs();
|
|
SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
|
|
linalgOp.getDpsInits().end());
|
|
outputOperands[resultNumber] = newOperand;
|
|
newOperands.append(outputOperands.begin(), outputOperands.end());
|
|
|
|
SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
|
|
linalgOp->result_type_end());
|
|
resultTypes[resultNumber] = resultType;
|
|
Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
|
|
|
|
// Create a tensor.cast operation back to the original type.
|
|
Value castBack = rewriter.create<tensor::CastOp>(
|
|
loc, resultValue.getType(), newOp->getResult(resultNumber));
|
|
|
|
SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
|
|
results[resultNumber] = castBack;
|
|
rewriter.replaceOp(linalgOp, results);
|
|
rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// For each of the operand in `operands` this function maps the static sizes of
|
|
/// dimensions to their affine dim expressions.
|
|
static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
|
|
llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
|
|
for (OpOperand &opOperand : operands) {
|
|
if (linalgOp.isScalar(&opOperand))
|
|
continue;
|
|
Value src = opOperand.get();
|
|
auto sourceType = llvm::cast<RankedTensorType>(src.getType());
|
|
auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
|
|
|
|
// Get the `sourceShape` of the `sourceType`. If the operand is a result of
|
|
// `tensor.cast` operation and source of the cast operation has a static
|
|
// shape, then assign it to the `sourceShape`.
|
|
auto *parentOp = src.getDefiningOp();
|
|
ArrayRef<int64_t> sourceShape = sourceType.getShape();
|
|
if (parentOp) {
|
|
if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
|
|
Value castSource = castOp.getSource();
|
|
auto castSourceType =
|
|
llvm::dyn_cast<RankedTensorType>(castSource.getType());
|
|
if (castSourceType && castSourceType.hasStaticShape())
|
|
sourceShape = castSourceType.getShape();
|
|
}
|
|
}
|
|
|
|
// If the source shape's dimension has a static shape, map the affine dim
|
|
// expression to the known static size.
|
|
for (unsigned i = 0; i < sourceShape.size(); i++) {
|
|
if (sourceType.isDynamicDim(i))
|
|
continue;
|
|
if (auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
|
|
affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
|
|
/// mapped in `affineExprToSize`. New operands are created in `newOperands` and
|
|
/// their result types is stored in `resultTypes`. If `opOperand` requires no
|
|
/// change then `changeNeeded` is false and same operand is added in the
|
|
/// `newOperands` list.
|
|
static void createNewOperandWithStaticSizes(
|
|
Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
|
|
llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
|
|
SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
|
|
bool &changeNeeded) {
|
|
Value src = opOperand->get();
|
|
newOperands.push_back(src);
|
|
if (linalgOp.isScalar(opOperand))
|
|
return;
|
|
auto sourceType = llvm::cast<RankedTensorType>(src.getType());
|
|
Type resultType = sourceType;
|
|
if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
|
|
resultTypes.push_back(resultType);
|
|
return;
|
|
}
|
|
ArrayRef<int64_t> sourceShape = sourceType.getShape();
|
|
AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
|
|
SmallVector<int64_t> newShape;
|
|
// If operand is updated with new shape, `newOperandNeeded` will be
|
|
// true.
|
|
bool newOperandNeeded = false;
|
|
for (unsigned i = 0; i < sourceShape.size(); i++) {
|
|
int64_t dimShape = sourceShape[i];
|
|
AffineExpr dimExpr = sourceMap.getResult(i);
|
|
if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
|
|
newShape.push_back(dimShape);
|
|
continue;
|
|
}
|
|
// Dimension has a dynamic shape and corresponding affine dim
|
|
// expression is present in the map. So assign the size for the
|
|
// given affine dim expression to the dimension.
|
|
newShape.push_back(affineExprToSize[dimExpr]);
|
|
newOperandNeeded = true;
|
|
}
|
|
resultType = RankedTensorType::get(newShape, sourceType.getElementType(),
|
|
sourceType.getEncoding());
|
|
if (newOperandNeeded) {
|
|
changeNeeded = true;
|
|
// Get the new operand value given its size and element type by
|
|
// casting it.
|
|
Value newOperand = rewriter.create<tensor::CastOp>(loc, resultType, src);
|
|
unsigned index = opOperand->getOperandNumber();
|
|
newOperands[index] = newOperand;
|
|
}
|
|
if (linalgOp.isDpsInit(opOperand))
|
|
resultTypes.push_back(resultType);
|
|
}
|
|
|
|
/// Static shapes for the operands can be inferred if any one of the operands
|
|
/// have a static shape. This can be done by referring to the affine dim
|
|
/// expressions for the operand.
|
|
struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
|
|
using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(LinalgOp linalgOp,
|
|
PatternRewriter &rewriter) const override {
|
|
if (!linalgOp.hasPureTensorSemantics())
|
|
return failure();
|
|
|
|
// Maps must be projected permutations.
|
|
if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
|
|
return !map.isProjectedPermutation();
|
|
}))
|
|
return failure();
|
|
|
|
// Maps affine dim expressions to the static size of that dimension.
|
|
llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
|
|
Location loc = linalgOp.getLoc();
|
|
|
|
// For each of the affine dim expression, check if the size is known. If
|
|
// known add that in the map.
|
|
populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
|
|
|
|
SmallVector<Value> newOperands;
|
|
SmallVector<Type> resultTypes;
|
|
|
|
// `changeNeeded` is `false` if the operands of `linalgOp` require no
|
|
// change in their types.
|
|
bool changeNeeded = false;
|
|
newOperands.reserve(linalgOp->getNumOperands());
|
|
resultTypes.reserve(linalgOp.getNumDpsInits());
|
|
|
|
// Iterate over all the operands and update the static sizes.
|
|
for (OpOperand &opOperand : linalgOp->getOpOperands()) {
|
|
createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
|
|
affineExprToSize, linalgOp, newOperands,
|
|
resultTypes, changeNeeded);
|
|
}
|
|
|
|
// If the generic op has all the required static information, no
|
|
// canonicalization needed.
|
|
if (!changeNeeded)
|
|
return failure();
|
|
|
|
// Clone op.
|
|
Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
|
|
SmallVector<Value> replacements;
|
|
replacements.reserve(newOp->getNumResults());
|
|
for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
|
|
Value newResult = std::get<1>(it);
|
|
Value oldResult = std::get<0>(it);
|
|
Type newType = newResult.getType();
|
|
Type oldType = oldResult.getType();
|
|
replacements.push_back(
|
|
(newType != oldType)
|
|
? rewriter.create<tensor::CastOp>(loc, oldType, newResult)
|
|
: newResult);
|
|
}
|
|
rewriter.replaceOp(linalgOp, replacements);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
// All named ops canonicalizers and folders are auto-generated in the
|
|
// .cpp.inc.
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SoftmaxOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult SoftmaxOp::verify() {
|
|
ShapedType inputType = getInputOperandType();
|
|
ShapedType outputType = getOutputOperandType();
|
|
|
|
ArrayRef<int64_t> inputShape = inputType.getShape();
|
|
ArrayRef<int64_t> outputShape = outputType.getShape();
|
|
if (failed(verifyCompatibleShape(inputShape, outputShape)))
|
|
return emitOpError("incompatible output shape");
|
|
|
|
int64_t inputRank = getInputOperandRank();
|
|
int64_t dimension = getDimension();
|
|
if ((dimension < 0) || (dimension >= inputRank))
|
|
return emitOpError("incorrect dimension specified");
|
|
|
|
return success();
|
|
}
|
|
|
|
SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
|
|
int64_t operandRank = getInputOperandRank();
|
|
SmallVector<Range> loopBounds(operandRank);
|
|
Location loc = getLoc();
|
|
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
|
|
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
|
|
Value source = getInput();
|
|
for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
|
|
loopBounds[dim].offset = zero;
|
|
loopBounds[dim].size = getDimValue(builder, loc, source, dim);
|
|
loopBounds[dim].stride = one;
|
|
}
|
|
return loopBounds;
|
|
}
|
|
|
|
SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
|
|
SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
|
|
utils::IteratorType::parallel);
|
|
iteratorTypes[getDimension()] = utils::IteratorType::reduction;
|
|
return iteratorTypes;
|
|
}
|
|
|
|
FailureOr<TilingResult>
|
|
SoftmaxOp::getTiledImplementation(OpBuilder &builder,
|
|
ArrayRef<OpFoldResult> offsets,
|
|
ArrayRef<OpFoldResult> sizes) {
|
|
int64_t rank = getInputOperandRank();
|
|
auto oneAttr = builder.getI64IntegerAttr(1);
|
|
SmallVector<OpFoldResult> strides(rank, oneAttr);
|
|
SmallVector<Value> tiledOperands;
|
|
Operation *inputSlice =
|
|
getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
|
|
if (!inputSlice) {
|
|
return emitOpError("failed to compute input slice");
|
|
}
|
|
tiledOperands.emplace_back(inputSlice->getResult(0));
|
|
Operation *outputSlice =
|
|
getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
|
|
if (!outputSlice) {
|
|
return emitOpError("failed to compute output slice");
|
|
}
|
|
tiledOperands.emplace_back(outputSlice->getResult(0));
|
|
|
|
SmallVector<Type, 4> resultTypes;
|
|
if (hasPureTensorSemantics())
|
|
resultTypes.push_back(tiledOperands[1].getType());
|
|
Operation *tiledOp =
|
|
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
|
|
|
|
return TilingResult{
|
|
{tiledOp},
|
|
SmallVector<Value>(tiledOp->getResults()),
|
|
llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
|
|
}
|
|
|
|
LogicalResult SoftmaxOp::getResultTilePosition(
|
|
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
|
|
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
|
|
SmallVector<OpFoldResult> &resultSizes) {
|
|
if (resultNumber == 0) {
|
|
resultOffsets.assign(offsets.begin(), offsets.end());
|
|
resultSizes.assign(sizes.begin(), sizes.end());
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
// cast(dynamic) -> static.
|
|
LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
|
|
return memref::foldMemRefCast(*this);
|
|
}
|
|
|
|
LogicalResult
|
|
SoftmaxOp::reifyResultShapes(OpBuilder &b,
|
|
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
|
|
SmallVector<OpFoldResult> shapes;
|
|
Location loc = getOperation()->getLoc();
|
|
IRRewriter rewriter(b);
|
|
auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
|
|
auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
|
|
for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
|
|
if (!outputShapedType.isDynamicDim(dim)) {
|
|
// Static dim: Return IntegerAttr.
|
|
shapes.push_back(b.getIndexAttr(inputShapedType.getDimSize(dim)));
|
|
} else {
|
|
// Dynamic dim: Return Value.
|
|
OpFoldResult ofr = createOrFoldDimOp(b, loc, getInput(), dim);
|
|
shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
|
|
}
|
|
}
|
|
reifiedReturnShapes.emplace_back(std::move(shapes));
|
|
return success();
|
|
}
|
|
|
|
void SoftmaxOp::getEffects(
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
&effects) {
|
|
for (auto [index, operand] : llvm::enumerate(getDpsInputs())) {
|
|
if (!llvm::isa<MemRefType>(operand.getType()))
|
|
continue;
|
|
effects.emplace_back(MemoryEffects::Read::get(),
|
|
&getOperation()->getOpOperand(index), /*stage=*/0,
|
|
/*effectOnFullRegion=*/true,
|
|
SideEffects::DefaultResource::get());
|
|
}
|
|
|
|
for (OpOperand &operand : getDpsInitsMutable()) {
|
|
if (!llvm::isa<MemRefType>(operand.get().getType()))
|
|
continue;
|
|
effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
|
|
/*effectOnFullRegion=*/true,
|
|
SideEffects::DefaultResource::get());
|
|
effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
|
|
/*effectOnFullRegion=*/true,
|
|
SideEffects::DefaultResource::get());
|
|
}
|
|
}
|
|
|
|
// Helper functions for softmax decomposition.
|
|
// @{
|
|
|
|
// Helper function to produce the iterator types (reduction or parallel) and
|
|
// affine maps for the iterators used in the decomposition of softmax.
|
|
// This method creates:
|
|
// If allParallel == true:
|
|
// - iterator type: {parallel, ..., parallel}
|
|
// - affine maps:
|
|
// -- identity with inputRank dimensions.
|
|
// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
|
|
// where N == inputRank.
|
|
//
|
|
// If allParallel == false:
|
|
// - iterator type at dim(i) == parallel for i != \p dim and
|
|
// dim(dim) == reduction.
|
|
// - affine map:
|
|
// -- identity with inputRank dimensions.
|
|
// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
|
|
// where N == inputRank.
|
|
static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
|
|
computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank,
|
|
int64_t dim, bool allParallel = false) {
|
|
SmallVector<utils::IteratorType> iteratorTypes(inputRank,
|
|
utils::IteratorType::parallel);
|
|
if (!allParallel)
|
|
iteratorTypes[dim] = utils::IteratorType::reduction;
|
|
MLIRContext *ctxt = builder.getContext();
|
|
auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
|
|
SmallVector<AffineExpr, 2> affineExprs;
|
|
for (int i = 0; i < inputRank; i++) {
|
|
if (i != dim)
|
|
affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
|
|
}
|
|
auto reductionMap =
|
|
AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
|
|
SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
|
|
return std::make_tuple(iteratorTypes, indexingMaps);
|
|
}
|
|
|
|
// Helper function to produce a linalg.generic that computes a reduction on
|
|
// dimension \p dim with the operation type \p T.
|
|
template <typename T>
|
|
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
|
|
int64_t dim) {
|
|
auto inputType = cast<ShapedType>(input.getType());
|
|
ArrayRef<int64_t> inputShape = inputType.getShape();
|
|
int64_t inputRank = inputShape.size();
|
|
auto [iteratorTypes, indexingMaps] =
|
|
computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
|
|
assert(indexingMaps.size() == 2 &&
|
|
"We should have two maps: 1 for the input, 1 for the output");
|
|
assert(indexingMaps[0].isIdentity() && "input map should be identity");
|
|
|
|
auto genericOp = builder.create<linalg::GenericOp>(
|
|
loc, output.getType(), input, output, indexingMaps, iteratorTypes,
|
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
|
Value result = b.create<T>(loc, args[0], args[1]);
|
|
b.create<linalg::YieldOp>(loc, result);
|
|
});
|
|
return genericOp.getResult(0);
|
|
}
|
|
|
|
/// Produce a linalg generic that computes the second step of the softmax
|
|
/// decomposition: res = exp(input - max), where \p max is the max of \p input
|
|
/// on dimension \p dim.
|
|
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
|
|
Value max, Value output, int64_t dim) {
|
|
auto inputType = cast<ShapedType>(input.getType());
|
|
ArrayRef<int64_t> inputShape = inputType.getShape();
|
|
int64_t inputRank = inputShape.size();
|
|
auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
|
|
builder, inputRank, dim, /*allParallel=*/true);
|
|
assert(indexingMaps.size() == 2 && "We should have one map for each input");
|
|
assert(indexingMaps[0].isIdentity() && "input map should be identity");
|
|
// Add the affine map for the output argument.
|
|
indexingMaps.push_back(indexingMaps[0]);
|
|
auto genericOp = builder.create<linalg::GenericOp>(
|
|
loc, input.getType(), ValueRange{input, max}, output, indexingMaps,
|
|
iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
|
|
Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
|
|
Value result = b.create<math::ExpOp>(loc, diff);
|
|
b.create<linalg::YieldOp>(loc, result);
|
|
});
|
|
return genericOp.getResult(0);
|
|
}
|
|
|
|
/// Produce a linalg generic that computes the final step of the softmax
|
|
/// decomposition.
|
|
/// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
|
|
/// yield n / d
|
|
/// }
|
|
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
|
|
Value denominator, Value output, int64_t dim) {
|
|
auto inputType = cast<ShapedType>(numerator.getType());
|
|
ArrayRef<int64_t> inputShape = inputType.getShape();
|
|
int64_t inputRank = inputShape.size();
|
|
auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
|
|
builder, inputRank, dim, /*allParallel=*/true);
|
|
assert(indexingMaps.size() == 2 &&
|
|
"We should have one map for each input (2)");
|
|
assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
|
|
// Add the affine map for the output tensor.
|
|
indexingMaps.push_back(indexingMaps[0]);
|
|
auto genericOp = builder.create<linalg::GenericOp>(
|
|
loc, numerator.getType(), ValueRange{numerator, denominator}, output,
|
|
indexingMaps, iteratorTypes,
|
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
|
Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
|
|
b.create<linalg::YieldOp>(loc, result);
|
|
});
|
|
return genericOp.getResult(0);
|
|
}
|
|
// @} End helper functions for softmax decomposition.
|
|
|
|
/// Given an N-dimensional tensor x, this method converts
|
|
/// softmax(x) to the following sequence of operations:
|
|
///
|
|
/// 1. Compute the max of x along dimension d. This results
|
|
/// in a N-1 dimensional tensor m.
|
|
/// m = max(x, dim = d)
|
|
///
|
|
/// 2. Subtract a broadcasted m from x and exponentiate. This results in
|
|
/// a N dimensional tensor z.
|
|
/// z = exp(x - m)
|
|
///
|
|
/// 3. Compute the sum of z along dimension d. This results in
|
|
/// a N-1 dimensional tensor l.
|
|
/// l = sum(z, dim = d)
|
|
///
|
|
/// 4. Divide z and l. This gives the N-dimensional softmax.
|
|
/// softmax = z / l
|
|
///
|
|
FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
|
|
OpBuilder::InsertionGuard guard(b);
|
|
b.setInsertionPoint(*this);
|
|
Location loc = getLoc();
|
|
Value input = getInput();
|
|
ShapedType inputType = getInputOperandType();
|
|
Type elementType = inputType.getElementType();
|
|
int64_t reductionDim = getDimension();
|
|
SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
|
|
Value output = getOutput();
|
|
dims.erase(dims.begin() + reductionDim);
|
|
// Step 1: Compute max along dim.
|
|
Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
|
|
Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maxnumf,
|
|
elementType, b, loc,
|
|
/*useOnlyFiniteValue=*/true);
|
|
Value neutralForMaxFInit =
|
|
b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
|
|
.result();
|
|
Value max =
|
|
reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
|
|
|
|
// Step 2: Subtract max from input and exponentiate.
|
|
Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
|
|
|
|
// Step 3: Compute sum along dim.
|
|
Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
|
|
b, loc, /*useOnlyFiniteValue=*/true);
|
|
Value zeroInit =
|
|
b.create<linalg::FillOp>(loc, Value{zero}, outputReduce).result();
|
|
Value denominator =
|
|
reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
|
|
|
|
// Step 4: Compute softmax.
|
|
Value result =
|
|
buildDivOp(b, loc, numerator, denominator, output, reductionDim);
|
|
return SmallVector<Value>{result};
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// WinogradFilterTransformOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult WinogradFilterTransformOp::verify() {
|
|
auto filterType = cast<ShapedType>(getFilter().getType());
|
|
ArrayRef<int64_t> filterShape = filterType.getShape();
|
|
int64_t filterH = filterShape[getFilterHDim()];
|
|
int64_t filterW = filterShape[getFilterWDim()];
|
|
int64_t r = getR();
|
|
int64_t m = getM();
|
|
|
|
if (filterH != r && filterH != 1)
|
|
return emitOpError("expect filter height either equals to r or 1");
|
|
if (filterW != r && filterW != 1)
|
|
return emitOpError("expect filter width either equals to r or 1");
|
|
if (filterH == 1 && filterW == 1)
|
|
return emitOpError("expect either filter height or width equals to r");
|
|
|
|
SmallVector<int64_t> expectedOutputShape;
|
|
expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
|
|
expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
|
|
expectedOutputShape.push_back(filterShape[getFilterCDim()]);
|
|
expectedOutputShape.push_back(filterShape[getFilterFDim()]);
|
|
|
|
auto outputType = cast<ShapedType>(getOutput().getType());
|
|
ArrayRef<int64_t> outputShape = outputType.getShape();
|
|
if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
|
|
return emitOpError("the output shape is not expected");
|
|
}
|
|
return success();
|
|
}
|
|
|
|
SmallVector<Range>
|
|
WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
|
|
Location loc = getLoc();
|
|
IntegerAttr zeroAttr = builder.getIndexAttr(0);
|
|
IntegerAttr oneAttr = builder.getIndexAttr(1);
|
|
Value filter = getFilter();
|
|
int64_t filterRank = getFilterOperandRank();
|
|
SmallVector<Range> loopBounds(filterRank);
|
|
for (unsigned dim = 0; dim < filterRank; ++dim) {
|
|
loopBounds[dim].offset = zeroAttr;
|
|
loopBounds[dim].size = getDimValue(builder, loc, filter, dim);
|
|
loopBounds[dim].stride = oneAttr;
|
|
}
|
|
return loopBounds;
|
|
}
|
|
|
|
SmallVector<utils::IteratorType>
|
|
WinogradFilterTransformOp::getLoopIteratorTypes() {
|
|
int64_t filterRank = getFilterOperandRank();
|
|
SmallVector<utils::IteratorType> iteratorTypes(filterRank,
|
|
utils::IteratorType::parallel);
|
|
return iteratorTypes;
|
|
}
|
|
|
|
LogicalResult WinogradFilterTransformOp::getResultTilePosition(
|
|
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
|
|
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
|
|
SmallVector<OpFoldResult> &resultSizes) {
|
|
IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
|
|
ShapedType filterType = getFilterOperandType();
|
|
ArrayRef<int64_t> filterShape = filterType.getShape();
|
|
int64_t filterH = filterShape[getFilterHDim()];
|
|
int64_t filterW = filterShape[getFilterWDim()];
|
|
int64_t m = getM();
|
|
int64_t r = getR();
|
|
int64_t alpha = m + r - 1;
|
|
int64_t alphaH = filterH != 1 ? alpha : 1;
|
|
int64_t alphaW = filterW != 1 ? alpha : 1;
|
|
IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
|
|
IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
|
|
|
|
resultOffsets.append(
|
|
{zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
|
|
resultSizes.append(
|
|
{alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
|
|
|
|
return success();
|
|
}
|
|
|
|
/// Implement tiling for winograd_filter_transform
|
|
/// The input of winograd_filter_transform is (F, KH, KW, C).
|
|
/// The output of winograd_filter_transform is (alphaH, alphaW, C, F)
|
|
/// Users can specify the tile sizes of F and C.
|
|
/// `offsets` are the values for the offsets of F, KH, KW, C for one tile.
|
|
/// `sizes` are the values for the sizes of F, KH, KW, C for one tile.
|
|
FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
|
|
OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
|
|
ArrayRef<OpFoldResult> sizes) {
|
|
IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
|
|
IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
|
|
ShapedType filterType = getFilterOperandType();
|
|
ArrayRef<int64_t> filterShape = filterType.getShape();
|
|
int64_t filterH = filterShape[getFilterHDim()];
|
|
int64_t filterW = filterShape[getFilterWDim()];
|
|
IntegerAttr filterHAttr = builder.getI64IntegerAttr(filterH);
|
|
IntegerAttr filterWAttr = builder.getI64IntegerAttr(filterW);
|
|
SmallVector<Value> tiledOperands;
|
|
SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
|
|
|
|
sliceOffsets.append(
|
|
{offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
|
|
sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
|
|
sizes[getFilterCDim()]});
|
|
int64_t filterRank = getFilterOperandRank();
|
|
SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
|
|
Location loc = getLoc();
|
|
auto filterSlice = builder.create<tensor::ExtractSliceOp>(
|
|
loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
|
|
tiledOperands.emplace_back(filterSlice);
|
|
|
|
SmallVector<OpFoldResult> resultOffsets, resultSizes;
|
|
if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
|
|
resultSizes)))
|
|
return failure();
|
|
|
|
int64_t outputRank = getOutputOperandRank();
|
|
SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
|
|
auto outputSlice = builder.create<tensor::ExtractSliceOp>(
|
|
loc, getOutput(), resultOffsets, resultSizes, outputStrides);
|
|
tiledOperands.emplace_back(outputSlice);
|
|
|
|
SmallVector<Type> resultTypes;
|
|
resultTypes.push_back(tiledOperands[1].getType());
|
|
Operation *tiledOp =
|
|
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
|
|
|
|
return TilingResult{
|
|
{tiledOp},
|
|
SmallVector<Value>(tiledOp->getResults()),
|
|
llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// WinogradInputTransformOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult WinogradInputTransformOp::verify() {
|
|
auto inputType = cast<ShapedType>(getInput().getType());
|
|
ArrayRef<int64_t> inputShape = inputType.getShape();
|
|
int64_t inputH = inputShape[getInputHDim()];
|
|
int64_t inputW = inputShape[getInputWDim()];
|
|
int m = getM();
|
|
int r = getR();
|
|
int64_t tileSize = m + r - 1;
|
|
|
|
auto outputType = cast<ShapedType>(getOutput().getType());
|
|
ArrayRef<int64_t> outputShape = outputType.getShape();
|
|
bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
|
|
bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
|
|
|
|
SmallVector<int64_t> expectedOutputShape(6, inputH);
|
|
if (ShapedType::isDynamic(inputH)) {
|
|
expectedOutputShape[getOutputAlphaHDim()] = tileSize;
|
|
expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
|
|
} else {
|
|
expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
|
|
expectedOutputShape[getOutputTileHDim()] =
|
|
leftTransform ? (inputH - (r - 1)) / m : inputH;
|
|
}
|
|
if (ShapedType::isDynamic(inputW)) {
|
|
expectedOutputShape[getOutputAlphaWDim()] = tileSize;
|
|
expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
|
|
} else {
|
|
expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
|
|
expectedOutputShape[getOutputTileWDim()] =
|
|
rightTransform ? (inputW - (r - 1)) / m : inputW;
|
|
}
|
|
expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
|
|
expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
|
|
|
|
if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
|
|
return emitOpError("the output shape is not expected");
|
|
}
|
|
return success();
|
|
}
|
|
|
|
SmallVector<Range>
|
|
WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
|
|
Location loc = getLoc();
|
|
IntegerAttr zeroAttr = builder.getIndexAttr(0);
|
|
IntegerAttr oneAttr = builder.getIndexAttr(1);
|
|
Value output = getOutput();
|
|
int64_t outputRank = getOutputOperandRank();
|
|
SmallVector<Range> loopBounds(outputRank);
|
|
for (unsigned dim = 0; dim < outputRank; ++dim) {
|
|
loopBounds[dim].offset = zeroAttr;
|
|
// alphaH, alphaW, tileH, tileW, N, C
|
|
loopBounds[dim].size = getDimValue(builder, loc, output, dim);
|
|
loopBounds[dim].stride = oneAttr;
|
|
}
|
|
return loopBounds;
|
|
}
|
|
|
|
SmallVector<utils::IteratorType>
|
|
WinogradInputTransformOp::getLoopIteratorTypes() {
|
|
int64_t outputRank = getOutputOperandRank();
|
|
SmallVector<utils::IteratorType> iteratorTypes(outputRank,
|
|
utils::IteratorType::parallel);
|
|
return iteratorTypes;
|
|
}
|
|
|
|
LogicalResult WinogradInputTransformOp::getResultTilePosition(
|
|
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
|
|
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
|
|
SmallVector<OpFoldResult> &resultSizes) {
|
|
IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
|
|
ShapedType outputType = getOutputOperandType();
|
|
ArrayRef<int64_t> outputShape = outputType.getShape();
|
|
int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
|
|
int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
|
|
|
|
int64_t m = getM();
|
|
int64_t r = getR();
|
|
int64_t alpha = m + r - 1;
|
|
int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
|
|
int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
|
|
|
|
IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
|
|
IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
|
|
|
|
resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
|
|
offsets[getOutputTileWDim()], offsets[getOutputNDim()],
|
|
offsets[getOutputCDim()]});
|
|
resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
|
|
sizes[getOutputTileWDim()], sizes[getOutputNDim()],
|
|
sizes[getOutputCDim()]});
|
|
|
|
return success();
|
|
}
|
|
|
|
/// Implement tiling for winograd_input_transform
|
|
/// The input of winograd_input_transform is (N, H, W, C).
|
|
/// The output of winograd_input_transform is (alphaH, alphaW, tileH, tileW, N,
|
|
/// C) Users can specify the tile sizes of tileH, tileW, N, and C. `offsets` are
|
|
/// the values for the offsets of tileH, tileW, N, C for one tile. `sizes` are
|
|
/// the values for the sizes of tileH, tileW, N, C for one tile.
|
|
FailureOr<TilingResult>
|
|
WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
|
|
ArrayRef<OpFoldResult> offsets,
|
|
ArrayRef<OpFoldResult> sizes) {
|
|
IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
|
|
int64_t m = getM();
|
|
int64_t r = getR();
|
|
|
|
ShapedType outputType = getOutputOperandType();
|
|
ArrayRef<int64_t> outputShape = outputType.getShape();
|
|
int64_t alphaH = outputShape[getOutputAlphaHDim()];
|
|
int64_t alphaW = outputShape[getOutputAlphaWDim()];
|
|
|
|
Location loc = getLoc();
|
|
MLIRContext *context = builder.getContext();
|
|
auto identityAffineMap =
|
|
AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
|
|
auto offsetAffineMap =
|
|
AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
|
|
Value mappedOffsetH = affine::makeComposedAffineApply(
|
|
builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
|
|
offsets[getOutputTileHDim()]);
|
|
Value mappedOffsetW = affine::makeComposedAffineApply(
|
|
builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
|
|
offsets[getOutputTileWDim()]);
|
|
auto sizeAffineMap = AffineMap::get(
|
|
1, 0, {builder.getAffineDimExpr(0) * m + (r - 1)}, context);
|
|
Value mappedSizeH = affine::makeComposedAffineApply(
|
|
builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
|
|
Value mappedSizeW = affine::makeComposedAffineApply(
|
|
builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
|
|
|
|
SmallVector<Value> tiledOperands;
|
|
SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
|
|
|
|
OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
|
|
OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
|
|
sliceOffsets.append(
|
|
{offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
|
|
OpFoldResult sizeH =
|
|
alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
|
|
OpFoldResult sizeW =
|
|
alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
|
|
sliceSizes.append(
|
|
{sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
|
|
int64_t inputRank = getInputOperandRank();
|
|
SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
|
|
auto inputSlice = builder.create<tensor::ExtractSliceOp>(
|
|
loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
|
|
tiledOperands.emplace_back(inputSlice);
|
|
|
|
SmallVector<OpFoldResult> resultOffsets, resultSizes;
|
|
if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
|
|
resultSizes)))
|
|
return failure();
|
|
|
|
int64_t outputRank = getOutputOperandRank();
|
|
SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
|
|
auto outputSlice = builder.create<tensor::ExtractSliceOp>(
|
|
loc, getOutput(), resultOffsets, resultSizes, outputStrides);
|
|
tiledOperands.emplace_back(outputSlice);
|
|
|
|
SmallVector<Type> resultTypes;
|
|
resultTypes.push_back(tiledOperands[1].getType());
|
|
Operation *tiledOp =
|
|
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
|
|
|
|
return TilingResult{
|
|
{tiledOp},
|
|
SmallVector<Value>(tiledOp->getResults()),
|
|
llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// WinogradOutputTransformOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult WinogradOutputTransformOp::verify() {
|
|
auto valueType = cast<ShapedType>(getValue().getType());
|
|
ArrayRef<int64_t> valueShape = valueType.getShape();
|
|
int64_t valueH = valueShape[getValueAlphaHDim()];
|
|
int64_t valueW = valueShape[getValueAlphaWDim()];
|
|
int64_t valueTileH = valueShape[getValueTileHDim()];
|
|
int64_t valueTileW = valueShape[getValueTileWDim()];
|
|
int m = getM();
|
|
int r = getR();
|
|
bool leftTransform = valueH != 1;
|
|
bool rightTransform = valueW != 1;
|
|
|
|
int64_t outputRank = getOutputOperandRank();
|
|
SmallVector<int64_t> expectedOutputShape(outputRank, valueH);
|
|
if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
|
|
expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
|
|
} else {
|
|
if (valueH != (leftTransform ? m + r - 1 : 1))
|
|
return emitOpError("expect input height equals to input tile size");
|
|
expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
|
|
}
|
|
if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
|
|
expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
|
|
} else {
|
|
if (valueW != (rightTransform ? m + r - 1 : 1))
|
|
return emitOpError("expect input width equals to input tile size");
|
|
expectedOutputShape[getOutputWDim()] =
|
|
(rightTransform ? m : 1) * valueTileW;
|
|
}
|
|
expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
|
|
expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
|
|
|
|
auto outputType = cast<ShapedType>(getOutput().getType());
|
|
ArrayRef<int64_t> outputShape = outputType.getShape();
|
|
if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
|
|
return emitOpError("the output shape is not expected");
|
|
}
|
|
return success();
|
|
}
|
|
|
|
SmallVector<Range>
|
|
WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
|
|
Location loc = getLoc();
|
|
IntegerAttr zeroAttr = builder.getIndexAttr(0);
|
|
IntegerAttr oneAttr = builder.getIndexAttr(1);
|
|
Value value = getValue();
|
|
int64_t valueRank = getValueOperandRank();
|
|
SmallVector<Range> loopBounds(valueRank);
|
|
for (unsigned dim = 0; dim < valueRank; ++dim) {
|
|
loopBounds[dim].offset = zeroAttr;
|
|
// alphaH, alphaW, tileH, tileW, N, F
|
|
loopBounds[dim].size = getDimValue(builder, loc, value, dim);
|
|
loopBounds[dim].stride = oneAttr;
|
|
}
|
|
return loopBounds;
|
|
}
|
|
|
|
SmallVector<utils::IteratorType>
|
|
WinogradOutputTransformOp::getLoopIteratorTypes() {
|
|
int64_t valueRank = getValueOperandRank();
|
|
SmallVector<utils::IteratorType> iteratorTypes(valueRank,
|
|
utils::IteratorType::parallel);
|
|
return iteratorTypes;
|
|
}
|
|
|
|
LogicalResult WinogradOutputTransformOp::getResultTilePosition(
|
|
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
|
|
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
|
|
SmallVector<OpFoldResult> &resultSizes) {
|
|
int64_t m = getM();
|
|
|
|
Location loc = getLoc();
|
|
MLIRContext *context = builder.getContext();
|
|
auto identityAffineMap =
|
|
AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
|
|
auto affineMap =
|
|
AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
|
|
|
|
ShapedType valueType = getValueOperandType();
|
|
ArrayRef<int64_t> valueShape = valueType.getShape();
|
|
int64_t valueH = valueShape[0];
|
|
int64_t valueW = valueShape[1];
|
|
Value mappedOffsetH = affine::makeComposedAffineApply(
|
|
builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
|
|
offsets[getValueTileHDim()]);
|
|
Value mappedOffsetW = affine::makeComposedAffineApply(
|
|
builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
|
|
offsets[getValueTileWDim()]);
|
|
Value mappedSizeH = affine::makeComposedAffineApply(
|
|
builder, loc, affineMap, sizes[getValueTileHDim()]);
|
|
Value mappedSizeW = affine::makeComposedAffineApply(
|
|
builder, loc, affineMap, sizes[getValueTileWDim()]);
|
|
|
|
IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
|
|
OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
|
|
OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
|
|
OpFoldResult sizeH =
|
|
valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
|
|
OpFoldResult sizeW =
|
|
valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
|
|
|
|
resultOffsets.append(
|
|
{offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
|
|
resultSizes.append(
|
|
{sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
|
|
return success();
|
|
}
|
|
|
|
/// Implement tiling for winograd_output_transform
|
|
/// The input of winograd_output_transform is (alphaH, alphaW, tileH, tileW, N,
|
|
/// F). The output of winograd_output_transform is (N, H, W, F) Users can
|
|
/// specify the tile sizes of tileH, tileW, N, and F. `offsets` are the values
|
|
/// for the offsets of tileH, tileW, N, F for one tile. `sizes` are the values
|
|
/// for the sizes of tileH, tileW, N, F for one tile.
|
|
FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
|
|
OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
|
|
ArrayRef<OpFoldResult> sizes) {
|
|
IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
|
|
IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
|
|
Location loc = getLoc();
|
|
SmallVector<Value> tiledOperands;
|
|
SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
|
|
|
|
ShapedType valueType = getValueOperandType();
|
|
ArrayRef<int64_t> valueShape = valueType.getShape();
|
|
int64_t alphaH = valueShape[getValueAlphaHDim()];
|
|
int64_t alphaW = valueShape[getValueAlphaWDim()];
|
|
IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
|
|
IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
|
|
|
|
sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
|
|
offsets[getValueTileWDim()], offsets[getValueNDim()],
|
|
offsets[getValueFDim()]});
|
|
sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
|
|
sizes[getValueTileWDim()], sizes[getValueNDim()],
|
|
sizes[getValueFDim()]});
|
|
int64_t valueRank = getValueOperandRank();
|
|
SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
|
|
auto valueSlice = builder.create<tensor::ExtractSliceOp>(
|
|
loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
|
|
tiledOperands.emplace_back(valueSlice);
|
|
|
|
SmallVector<OpFoldResult> resultOffsets, resultSizes;
|
|
if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
|
|
resultSizes)))
|
|
return failure();
|
|
|
|
int64_t outputRank = getOutputOperandRank();
|
|
SmallVector<OpFoldResult> strides(outputRank, oneAttr);
|
|
auto outputSlice = builder.create<tensor::ExtractSliceOp>(
|
|
loc, getOutput(), resultOffsets, resultSizes, strides);
|
|
tiledOperands.emplace_back(outputSlice);
|
|
|
|
SmallVector<Type> resultTypes;
|
|
resultTypes.push_back(tiledOperands[1].getType());
|
|
Operation *tiledOp =
|
|
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
|
|
|
|
return TilingResult{
|
|
{tiledOp},
|
|
SmallVector<Value>(tiledOp->getResults()),
|
|
llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// LinalgDialect
|
|
// TODO: Merge with the LinalgDialect block at the bottom
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Returns true if the result expression of `subMap` are a subset of `fullMap`.
|
|
static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap) {
|
|
auto explicitRange = subMap.getResults();
|
|
auto defaultRange = fullMap.getResults();
|
|
DenseSet<AffineExpr> explicitSet(explicitRange.begin(), explicitRange.end());
|
|
DenseSet<AffineExpr> defaultSet(defaultRange.begin(), defaultRange.end());
|
|
llvm::set_union(explicitSet, defaultSet);
|
|
return explicitSet == defaultSet;
|
|
}
|
|
|
|
/// Check if the user defined map is valid broadcast map. Here broadcast
|
|
/// indexing maps are defined in context of corresponding default indexing maps
|
|
/// for the given Op. This way the check becomes very simple i.e just check the
|
|
/// number of result dims.
|
|
/// Returns true if the explictMap is broadcasted with respect to the
|
|
/// defaultMap.
|
|
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) {
|
|
return explictMap.getNumResults() < defaultMap.getNumResults();
|
|
}
|
|
|
|
/// Verifies the broadcast and transpose semantic sepecified by the explicit
|
|
/// indexing map for the MatmulOp \p op for each operand specified by \p
|
|
/// opIndex.
|
|
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
|
|
unsigned opIndex) {
|
|
SmallVector<AffineMap, 3> opIndexingMaps = matmulOp.getIndexingMapsArray();
|
|
SmallVector<AffineMap, 3> defaultIndexingMaps =
|
|
matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
|
|
|
|
auto opIndexingMap = opIndexingMaps[opIndex];
|
|
auto defaultIndexingMap = defaultIndexingMaps[opIndex];
|
|
// Check general validity of indexing map results.
|
|
if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
|
|
return matmulOp->emitOpError()
|
|
<< "Unexpected dim expression in map result.";
|
|
|
|
if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
|
|
if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
|
|
return matmulOp->emitOpError()
|
|
<< "Invalid broadcast requested, should be (d2).";
|
|
}
|
|
return success();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
// Check general validity of input indexing map of
|
|
// BatchMatmulOp/BatchReduceMatmulOp.
|
|
template <typename OpTy>
|
|
static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp,
|
|
AffineMap opIndexingMap,
|
|
AffineMap defaultIndexingMap, bool isLHS) {
|
|
assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
|
|
isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
|
|
"Expected BatchMatmulOp or BatchReduceMatmulOp");
|
|
// Check the result dims are valid.
|
|
if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
|
|
return batchVariantMatmulOp->emitOpError()
|
|
<< "Unexpected result dim expression (outside the set of default "
|
|
"result dims).";
|
|
|
|
// Check for valid number of result dims of input maps.
|
|
if (opIndexingMap.getNumResults() > 3)
|
|
return batchVariantMatmulOp->emitOpError()
|
|
<< "no. of result dim expressions exceeds 3.";
|
|
|
|
auto hasValidBatchDim = [](AffineMap map) {
|
|
AffineExpr batchDim = map.getResult(0);
|
|
return batchDim.isFunctionOfDim(0);
|
|
};
|
|
|
|
// Check if the requested broadcast is valid.
|
|
if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
|
|
if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
|
|
return batchVariantMatmulOp->emitOpError()
|
|
<< "Invalid broadcast requested.";
|
|
} else if (!hasValidBatchDim(opIndexingMap)) {
|
|
return batchVariantMatmulOp->emitOpError()
|
|
<< "Invalid batch dimension expression.";
|
|
}
|
|
return success();
|
|
}
|
|
|
|
/// This function checks if the given AffineMap for the output of a
|
|
/// BatchMatmulOp/BatchReduceMatmulOp has exactly the desired number of result
|
|
/// dimensions and if the output map result dimensions are valid.
|
|
template <typename OpTy>
|
|
static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp,
|
|
AffineMap opIndexingMap) {
|
|
assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
|
|
isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
|
|
"Expected BatchMatmulOp or BatchReduceMatmulOp");
|
|
if (isa<BatchMatmulOp>(batchVariantMatmulOp) &&
|
|
opIndexingMap.getNumResults() != 3) {
|
|
|
|
return batchVariantMatmulOp->emitOpError()
|
|
<< "expects 3 dims, but got (" << opIndexingMap.getNumResults()
|
|
<< ").";
|
|
}
|
|
if (isa<BatchReduceMatmulOp>(batchVariantMatmulOp) &&
|
|
opIndexingMap.getNumResults() != 2) {
|
|
return batchVariantMatmulOp->emitOpError()
|
|
<< "expects 2 dims, but got (" << opIndexingMap.getNumResults()
|
|
<< ").";
|
|
}
|
|
|
|
auto areValidOutputResultDim = [&](AffineMap outputMap) {
|
|
return isa<BatchMatmulOp>(batchVariantMatmulOp)
|
|
? outputMap.getResult(0).isFunctionOfDim(0) &&
|
|
outputMap.getResult(1).isFunctionOfDim(1) &&
|
|
outputMap.getResult(2).isFunctionOfDim(2)
|
|
: outputMap.getResult(0).isFunctionOfDim(1) &&
|
|
outputMap.getResult(1).isFunctionOfDim(2);
|
|
};
|
|
|
|
if (!areValidOutputResultDim(opIndexingMap)) {
|
|
return batchVariantMatmulOp->emitOpError()
|
|
<< "Invalid output map result dimension.";
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
/// Verifies the broadcast and transpose semantic specified by the explicit
|
|
/// indexing map for the BatchMatmulOp/BatchReduceMatmulOp op for each operand
|
|
/// specified by opIndex.
|
|
template <typename OpTy>
|
|
static LogicalResult
|
|
verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp,
|
|
unsigned opIndex) {
|
|
SmallVector<AffineMap, 3> opIndexingMaps =
|
|
batchVariantMatmulOp.getIndexingMapsArray();
|
|
SmallVector<AffineMap, 3> defaultIndexingMaps =
|
|
batchVariantMatmulOp.getDefaultIndexingMaps(
|
|
batchVariantMatmulOp->getContext());
|
|
|
|
if (opIndexingMaps.size() != 3)
|
|
return batchVariantMatmulOp->emitOpError()
|
|
<< "Indexing_map attribute must have 3 affine maps.";
|
|
|
|
auto opIndexingMap = opIndexingMaps[opIndex];
|
|
auto defaultIndexingMap = defaultIndexingMaps[opIndex];
|
|
|
|
if (opIndex == 2 &&
|
|
failed(verifyOutputMap(batchVariantMatmulOp, opIndexingMap)))
|
|
return failure();
|
|
|
|
if (opIndex != 2 &&
|
|
failed(verifyInputMaps(batchVariantMatmulOp, opIndexingMap,
|
|
defaultIndexingMap, opIndex == 0)))
|
|
return failure();
|
|
|
|
return success();
|
|
}
|
|
|
|
namespace mlir {
|
|
namespace linalg {
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MatMulOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Returns a list of AffineMap with the typical matmul indexing charactristic.
|
|
SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
|
|
AffineExpr d0, d1, d2;
|
|
SmallVector<AffineMap> indexingMaps;
|
|
bindDims(context, d0, d1, d2);
|
|
indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
|
|
indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
|
|
indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
|
|
return indexingMaps;
|
|
}
|
|
|
|
SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
|
|
return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
|
|
utils::IteratorType::parallel,
|
|
utils::IteratorType::reduction};
|
|
}
|
|
|
|
unsigned MatmulOp::getNumRegionArgs() { return 3; }
|
|
|
|
std::string MatmulOp::getLibraryCallName() {
|
|
return generateLibraryCallName(getOperation());
|
|
}
|
|
|
|
bool MatmulOp::hasDynamicIndexingMaps() { return true; }
|
|
|
|
/// Check if the op has broadcast and/or transpose semantic. Returns true if
|
|
/// the user defined indexing maps are not equal to default map.
|
|
bool MatmulOp::hasUserDefinedMaps() {
|
|
SmallVector<AffineMap, 3> defaultMaps =
|
|
getDefaultIndexingMaps(this->getContext());
|
|
SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
|
|
return defaultMaps != explicitMaps;
|
|
}
|
|
|
|
/// Implements the block region builder for the MatmulOp. This is called by
|
|
/// 'fillStructuredOpRegion'.
|
|
void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
|
|
ArrayRef<NamedAttribute> attrs) {
|
|
assert(3 > 0 && block.getNumArguments() == 3 &&
|
|
"MatmulOp regionBuilder expects 3 (>=0) args");
|
|
RegionBuilderHelper helper(b, block);
|
|
SmallVector<Value> yields;
|
|
|
|
TypeFn castVal = TypeFn::cast_signed;
|
|
const auto *castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
|
|
return attr.getName() == "cast";
|
|
});
|
|
if (castIter != attrs.end()) {
|
|
if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
|
|
castVal = attr.getValue();
|
|
}
|
|
|
|
Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
|
|
block.getArgument(0));
|
|
Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
|
|
block.getArgument(1));
|
|
Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
|
|
Value value4 =
|
|
helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
|
|
yields.push_back(value4);
|
|
helper.yieldOutputs(yields);
|
|
}
|
|
|
|
/// Returns true if the given bcastMap map is a valid broadcast map. A valid
|
|
/// broadcast map must include K dimension.
|
|
/// TODO: Strict inclusion of K dimension in the broadcast map is not
|
|
/// necessary for both input matrices simultaneously. We can relax this
|
|
/// condition to have K dimension for one input matrix map and infer the K
|
|
/// dimension for other input matrix map from the one already having K
|
|
/// dimension.
|
|
bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
|
|
assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");
|
|
AffineExpr expr = bcastMap.getResult(0);
|
|
// Invalid map if the common dimension of matmul not found.
|
|
return expr.isFunctionOfDim(bcastMap.getNumDims() - 1);
|
|
}
|
|
|
|
FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) {
|
|
if (parser.parseOptionalKeyword("indexing_maps"))
|
|
return ArrayAttr{
|
|
nullptr}; // Success in case indexing_maps was not provided.
|
|
|
|
ArrayAttr arrayAttr;
|
|
if (parser.parseEqual() || parser.parseAttribute(arrayAttr))
|
|
return failure();
|
|
|
|
if (llvm::any_of(arrayAttr,
|
|
[](auto elt) { return !dyn_cast<AffineMapAttr>(elt); }))
|
|
return parser.emitError(parser.getCurrentLocation())
|
|
<< "element of indexing_maps array is not an affine_map";
|
|
|
|
return arrayAttr;
|
|
}
|
|
|
|
ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
|
|
if (failed(indexingMapsAttr))
|
|
return failure();
|
|
|
|
if (*indexingMapsAttr == nullptr) {
|
|
auto indexingMapAttrs = llvm::map_to_vector(
|
|
MatmulOp::getDefaultIndexingMaps(parser.getContext()),
|
|
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
|
|
indexingMapsAttr = parser.getBuilder().getArrayAttr(indexingMapAttrs);
|
|
}
|
|
|
|
result.addAttribute("indexing_maps", *indexingMapsAttr);
|
|
return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
|
|
MatmulOp::getRegionBuilder());
|
|
}
|
|
|
|
void MatmulOp::print(OpAsmPrinter &p) {
|
|
SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
|
|
MatmulOp::getDefaultIndexingMaps(getContext()),
|
|
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
|
|
if (!llvm::equal(getIndexingMaps(), indexingMaps))
|
|
p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
|
|
|
|
std::array<StringRef, 3> elidedAttrs = {
|
|
"operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
|
|
printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
|
|
elidedAttrs);
|
|
}
|
|
|
|
/// Verify the user defined indexing maps.
|
|
LogicalResult MatmulOp::verify() {
|
|
// Verification of pure matmul is handled by verifyStructuredOpInterface().
|
|
if (!hasUserDefinedMaps())
|
|
return success();
|
|
|
|
for (unsigned opIndex = 0; opIndex < 2; opIndex++) {
|
|
if (failed(verifyExtendedMatmulSemantic(*this, opIndex)))
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
|
|
return memref::foldMemRefCast(*this);
|
|
}
|
|
|
|
void MatmulOp::getEffects(
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
&effects) {
|
|
if (hasPureTensorSemantics())
|
|
return;
|
|
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
Speculation::Speculatability MatmulOp::getSpeculatability() {
|
|
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ContractOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
SmallVector<utils::IteratorType> ContractOp::getIteratorTypesArray() {
|
|
AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
|
|
// On well-formed IR, indexing_maps is non-empty, contained affine_maps'
|
|
// domains are all the same, and each implements a projected permutation.
|
|
// Each iteration space dim must occur for at least one operand and either
|
|
// takes part in a contraction/reduction or else has parallel iteration type.
|
|
// We have that a dim is a contraction/reduction dim if and only if the dim
|
|
// occurs for the output operand. We use this fact for fast inference:
|
|
// NB: In case we allow dims to occur solely for one input, the above still
|
|
// holds: per the einsum semantics, these are reduction dims as well.
|
|
SmallVector<bool> dimsInOutput(outAffineMap.getNumDims(), false);
|
|
for (auto result : outAffineMap.getResults()) {
|
|
auto dimExpr = dyn_cast<AffineDimExpr>(result);
|
|
assert(dimExpr && "affine_map is a projected permutation");
|
|
dimsInOutput[dimExpr.getPosition()] = true;
|
|
}
|
|
|
|
SmallVector<utils::IteratorType> iteratorTypes;
|
|
for (auto dimOccursInOutput : dimsInOutput)
|
|
iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
|
|
: utils::IteratorType::reduction);
|
|
|
|
return iteratorTypes;
|
|
}
|
|
|
|
unsigned ContractOp::getNumRegionArgs() { return 3; }
|
|
|
|
/// Implement block region builder, which is called by 'fillStructuredOpRegion'.
|
|
void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
|
|
ArrayRef<NamedAttribute> attrs) {
|
|
assert(block.getNumArguments() == 3 &&
|
|
"ContractOp regionBuilder expects 3 args");
|
|
RegionBuilderHelper helper(b, block);
|
|
|
|
TypeFn castSignedness = TypeFn::cast_signed;
|
|
auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
|
|
return attr.getName() == "cast";
|
|
});
|
|
if (castIter != attrs.end()) {
|
|
if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
|
|
castSignedness = attr.getValue();
|
|
}
|
|
|
|
// TODO: Support fields with operators besides mult & add.
|
|
Type outType = block.getArgument(2).getType();
|
|
Value lhsAtOutType =
|
|
helper.buildTypeFn(castSignedness, outType, block.getArgument(0));
|
|
Value rhsAtOutType =
|
|
helper.buildTypeFn(castSignedness, outType, block.getArgument(1));
|
|
Value productAtOutType =
|
|
helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType, rhsAtOutType);
|
|
Value result = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
|
|
productAtOutType);
|
|
helper.yieldOutputs({result});
|
|
}
|
|
|
|
ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
|
|
if (failed(indexingMapsAttr) || *indexingMapsAttr == nullptr)
|
|
return parser.emitError(parser.getCurrentLocation(),
|
|
"expected 'indexing_maps' attribute");
|
|
result.addAttribute("indexing_maps", *indexingMapsAttr);
|
|
|
|
return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
|
|
regionBuilder);
|
|
}
|
|
|
|
void ContractOp::print(OpAsmPrinter &p) {
|
|
p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
|
|
printNamedStructuredOp(
|
|
p, getOperation(), getInputs(), getOutputs(),
|
|
/*elidedAttrs=*/{"indexing_maps", "operandSegmentSizes"});
|
|
}
|
|
|
|
LogicalResult ContractOp::verify() {
|
|
int iterationSpaceDims = -1;
|
|
// Map iter space dims to #occurrences in inputs' and output's affine_maps:
|
|
// e.g., inOccurrences[0] will hold #times that dim (with index) 0 is used to
|
|
// access an input operand (so occurrence count can be at most 2) and
|
|
// outOccurrences[1] will indicate whether dim 1 occurred in the output, etc.
|
|
SmallVector<size_t> inOccurrences;
|
|
SmallVector<size_t> outOccurrences;
|
|
|
|
// A helper so that for each operand's affine_map and type we check that ...
|
|
auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
|
|
bool isInput) -> LogicalResult {
|
|
// ... the affine_map is a projected permutation;
|
|
if (!affineMap.isProjectedPermutation())
|
|
return emitError("provided affine_map is not a projected permutation");
|
|
|
|
// ... the rank of the affine_map's results and corresponding type match;
|
|
if (auto shapedType = dyn_cast<ShapedType>(operandType)) {
|
|
if (affineMap.getNumResults() != shapedType.getRank())
|
|
return emitError("ranks of shaped operand and results of corresponding "
|
|
"affine_map differ");
|
|
} else if (affineMap.getNumResults() != 0) {
|
|
return emitError("affine_map specifies shaped access while operand has "
|
|
"non-shaped type");
|
|
}
|
|
|
|
// ... the rank of the affine_map's domain is the same as those seen prior;
|
|
if (iterationSpaceDims == -1) {
|
|
iterationSpaceDims = affineMap.getNumDims();
|
|
inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
|
|
outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
|
|
} else if (iterationSpaceDims != (int)affineMap.getNumDims()) {
|
|
return emitError("iteration spaces of provided affine_maps differ");
|
|
}
|
|
|
|
// ... update counts of dims used to access either an input or the output.
|
|
for (AffineExpr affineExpr : affineMap.getResults()) {
|
|
auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
|
|
if (!affineDimExpr)
|
|
llvm_unreachable("affine_map is a projected permutation");
|
|
|
|
if (isInput)
|
|
inOccurrences[affineDimExpr.getPosition()] += 1;
|
|
else
|
|
outOccurrences[affineDimExpr.getPosition()] += 1;
|
|
}
|
|
|
|
return success();
|
|
};
|
|
|
|
for (auto &&[affineMap, operandType, isInput] :
|
|
llvm::zip(getIndexingMapsArray(), getOperandTypes(),
|
|
SmallVector<bool>{true, true, false})) {
|
|
if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
|
|
return failure(); // NB: checkAffineMapAndType will emit relevant error.
|
|
}
|
|
|
|
bool hasContractingDim = false;
|
|
for (size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
|
|
size_t inOccCount = inOccurrences[dimIndex];
|
|
size_t outOccCount = outOccurrences[dimIndex];
|
|
|
|
// We have a contracting dim if and only if ...
|
|
hasContractingDim |= inOccCount == 2 && outOccCount == 0;
|
|
|
|
if (inOccCount == 0 && outOccCount == 0)
|
|
return emitError() << "iteration space dim at index " << dimIndex
|
|
<< " not used to access any operand";
|
|
|
|
// NB: We disallow a dim which occurs for only one input operand and not
|
|
// for the output. In terms of einsum semantics such dims have a
|
|
// sensible meaning - namely an additional reduction per each such dim.
|
|
// By contrast, the ContractionOpInterface does not know about this
|
|
// iter type - cf. inferContractionDims' supported dim kinds. Similarly,
|
|
// while vector.contract's verifier accepts dims of this kind many of
|
|
// its lowerings give up on encountering these dims.
|
|
// TODO: Remove following once we have comprehensive support for input-only
|
|
// reduction dims, at both the linalg- and vector-dialect levels.
|
|
if (inOccCount == 1 && outOccCount != 1)
|
|
return emitError()
|
|
<< "iteration space dim at index " << dimIndex
|
|
<< " is neither a contracting dim nor of parallel iteration type";
|
|
}
|
|
|
|
if (!hasContractingDim)
|
|
return emitError("'indexing_maps' do not specify a contracting dimension");
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
|
|
return memref::foldMemRefCast(*this);
|
|
}
|
|
|
|
void ContractOp::getEffects(
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
&effects) {
|
|
if (hasPureTensorSemantics())
|
|
return;
|
|
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
Speculation::Speculatability ContractOp::getSpeculatability() {
|
|
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Implementation of BatchMatmulOp
|
|
//===----------------------------------------------------------------------===//
|
|
SmallVector<AffineMap>
|
|
BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
|
|
AffineExpr d0, d1, d2, d3;
|
|
SmallVector<AffineMap> indexingMaps;
|
|
bindDims(context, d0, d1, d2, d3);
|
|
indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
|
|
indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
|
|
indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d2}, context));
|
|
return indexingMaps;
|
|
}
|
|
|
|
SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
|
|
return SmallVector<utils::IteratorType>{
|
|
utils::IteratorType::parallel, utils::IteratorType::parallel,
|
|
utils::IteratorType::parallel, utils::IteratorType::reduction};
|
|
}
|
|
|
|
unsigned BatchMatmulOp::getNumRegionArgs() { return 3; }
|
|
|
|
std::string BatchMatmulOp::getLibraryCallName() {
|
|
return generateLibraryCallName(getOperation());
|
|
}
|
|
|
|
/// Check if the op has broadcast and/or transpose semantic. Returns true if
|
|
/// the user defined indexing maps are not equal to default map.
|
|
bool BatchMatmulOp::hasUserDefinedMaps() {
|
|
SmallVector<AffineMap, 3> defaultMaps =
|
|
getDefaultIndexingMaps(this->getContext());
|
|
SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
|
|
return defaultMaps != explicitMaps;
|
|
}
|
|
|
|
/// Returns true if the given bcastMap map is a valid broadcast map. A valid
|
|
/// broadcast map must include K dimension.
|
|
/// TODO: Strict inclusion of K dimension in the broadcast map is not
|
|
/// necessary for both input matrices simultaneously. We can relax this
|
|
/// condition to have K dimension for one input matrix map and infer the K
|
|
/// dimension for other input matrix map from the one already having K
|
|
/// dimension.
|
|
bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
|
|
assert(bcastMap.getNumResults() < 3 &&
|
|
"Expected less than 3 result dim expr.");
|
|
bool isValid = false;
|
|
enum Indices { batchPos, mPos, nPos, kPos };
|
|
if (bcastMap.getNumResults() == 1) {
|
|
AffineExpr expr = bcastMap.getResult(0);
|
|
isValid = expr.isFunctionOfDim(kPos);
|
|
} else if (bcastMap.getNumResults() == 2) {
|
|
AffineExpr expr0 = bcastMap.getResult(0);
|
|
AffineExpr expr1 = bcastMap.getResult(1);
|
|
isValid =
|
|
isLHS ? ((expr0.isFunctionOfDim(batchPos) ||
|
|
expr0.isFunctionOfDim(mPos)) &&
|
|
expr1.isFunctionOfDim(kPos))
|
|
: ((expr0.isFunctionOfDim(batchPos) &&
|
|
expr1.isFunctionOfDim(kPos)) ||
|
|
(expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
|
|
}
|
|
return isValid;
|
|
}
|
|
|
|
void BatchMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
|
|
ArrayRef<NamedAttribute> attrs) {
|
|
assert(block.getNumArguments() == 3 &&
|
|
"BatchMatmulOp regionBuilder expects 3 (>=0) args");
|
|
RegionBuilderHelper helper(b, block);
|
|
SmallVector<Value> yields;
|
|
|
|
TypeFn castVal = TypeFn::cast_signed;
|
|
auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
|
|
return attr.getName() == "cast";
|
|
});
|
|
if (castIter != attrs.end()) {
|
|
if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
|
|
castVal = attr.getValue();
|
|
}
|
|
|
|
auto toType = block.getArgument(2).getType();
|
|
Value castValA = helper.buildTypeFn(castVal, toType, block.getArgument(0));
|
|
Value castValB = helper.buildTypeFn(castVal, toType, block.getArgument(1));
|
|
Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
|
|
Value addVal =
|
|
helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
|
|
yields.push_back(addVal);
|
|
helper.yieldOutputs(yields);
|
|
}
|
|
|
|
ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
SmallVector<Attribute, 3> indexingMapsAttr;
|
|
Attribute mapAttr;
|
|
if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
|
|
if (parser.parseEqual())
|
|
return failure();
|
|
|
|
if (parser.parseLSquare())
|
|
return failure();
|
|
|
|
do {
|
|
if (parser.parseAttribute(mapAttr))
|
|
return failure();
|
|
if (!isa<AffineMapAttr>(mapAttr)) {
|
|
return parser.emitError(parser.getCurrentLocation(),
|
|
"expected affine map attribute");
|
|
}
|
|
indexingMapsAttr.push_back(mapAttr);
|
|
|
|
if (parser.parseOptionalComma())
|
|
break;
|
|
} while (true);
|
|
|
|
if (parser.parseRSquare())
|
|
return failure();
|
|
}
|
|
// Initialize indexingMaps, if not supplied explicitly.
|
|
if (indexingMapsAttr.empty()) {
|
|
indexingMapsAttr = llvm::map_to_vector(
|
|
BatchMatmulOp::getDefaultIndexingMaps(parser.getContext()),
|
|
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
|
|
}
|
|
result.addAttribute("indexing_maps",
|
|
parser.getBuilder().getArrayAttr(indexingMapsAttr));
|
|
|
|
return ::parseNamedStructuredOp(parser, result,
|
|
BatchMatmulOp::getNumRegionArgs(),
|
|
BatchMatmulOp::getRegionBuilder());
|
|
}
|
|
|
|
void BatchMatmulOp::print(OpAsmPrinter &p) {
|
|
SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
|
|
BatchMatmulOp::getDefaultIndexingMaps(getContext()),
|
|
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
|
|
if (!llvm::equal(getIndexingMaps(), indexingMaps))
|
|
p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
|
|
|
|
std::array<StringRef, 3> elidedAttrs = {
|
|
"operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
|
|
::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
|
|
elidedAttrs);
|
|
}
|
|
|
|
/// Verify the user defined indexing maps.
|
|
LogicalResult BatchMatmulOp::verify() {
|
|
// Verification of pure batch_matmul is handled by
|
|
// verifyStructuredOpInterface().
|
|
if (!hasUserDefinedMaps())
|
|
return success();
|
|
|
|
for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
|
|
if (failed(verifyExtendedBatchVariantMatmulSemantic(*this, opIndex)))
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
LogicalResult BatchMatmulOp::fold(FoldAdaptor,
|
|
SmallVectorImpl<OpFoldResult> &) {
|
|
return memref::foldMemRefCast(*this);
|
|
}
|
|
|
|
void BatchMatmulOp::getEffects(
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
&effects) {
|
|
if (hasPureTensorSemantics())
|
|
return;
|
|
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
Speculation::Speculatability BatchMatmulOp::getSpeculatability() {
|
|
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ElementwiseOp
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
namespace {
|
|
struct ArityGroupAndKind {
|
|
// The enum class {Unary, Binary, Ternary, ..}
|
|
ElementwiseArityGroup arityGroup;
|
|
|
|
// The kind (e.g. `exp` or `add`) belonging to the arity group.
|
|
union Kind {
|
|
UnaryFn unaryFn;
|
|
BinaryFn binaryFn;
|
|
TernaryFn ternaryFn;
|
|
} kind;
|
|
};
|
|
|
|
unsigned getArityGroupAsUInt(ElementwiseArityGroup arityGroup) {
|
|
return static_cast<unsigned>(arityGroup);
|
|
}
|
|
} // namespace
|
|
|
|
static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind) {
|
|
constexpr int lastUnary = static_cast<int>(ElementwiseCaseLimits::LastUnary);
|
|
constexpr int lastBinary =
|
|
static_cast<int>(ElementwiseCaseLimits::LastBinary);
|
|
constexpr int lastTernary =
|
|
static_cast<int>(ElementwiseCaseLimits::LastTernary);
|
|
|
|
int val = static_cast<int>(kind);
|
|
ArityGroupAndKind result;
|
|
|
|
if (val < lastUnary) {
|
|
result.arityGroup = ElementwiseArityGroup::Unary;
|
|
result.kind.unaryFn = static_cast<UnaryFn>(val);
|
|
return result;
|
|
}
|
|
if (val < lastBinary) {
|
|
result.arityGroup = ElementwiseArityGroup::Binary;
|
|
result.kind.binaryFn = static_cast<BinaryFn>(val - lastUnary);
|
|
return result;
|
|
}
|
|
if (val >= lastTernary) {
|
|
llvm_unreachable("unhandled ElementwiseFn");
|
|
}
|
|
result.arityGroup = ElementwiseArityGroup::Ternary;
|
|
result.kind.ternaryFn = static_cast<TernaryFn>(val - lastBinary);
|
|
return result;
|
|
}
|
|
|
|
SmallVector<utils::IteratorType> ElementwiseOp::getIteratorTypesArray() {
|
|
auto rank = getResultRank();
|
|
return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
|
|
}
|
|
|
|
SmallVector<AffineMap>
|
|
ElementwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims,
|
|
MLIRContext *context) {
|
|
auto map = AffineMap::getMultiDimIdentityMap(numDims, context);
|
|
return SmallVector<AffineMap>(numMaps, map);
|
|
}
|
|
|
|
ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
// Expect e.g. `kind = #linalg.elemwise_kind<add>`
|
|
Attribute attr;
|
|
mlir::linalg::ElementwiseKind elemwiseKindVal;
|
|
if (parser.parseKeyword("kind") || parser.parseEqual())
|
|
return failure();
|
|
|
|
if (succeeded(parser.parseAttribute(attr))) {
|
|
auto elemwiseKindAttr = dyn_cast<ElementwiseKindAttr>(attr);
|
|
if (!elemwiseKindAttr)
|
|
return parser.emitError(parser.getCurrentLocation(),
|
|
"expected ElementwiseKind attribute");
|
|
elemwiseKindVal = elemwiseKindAttr.getValue();
|
|
} else {
|
|
return parser.emitError(parser.getCurrentLocation(),
|
|
"expected operation 'kind' attribute");
|
|
}
|
|
result.addAttribute(
|
|
"kind", ElementwiseKindAttr::get(parser.getContext(), elemwiseKindVal));
|
|
|
|
// Parse optional `indexing_maps`
|
|
SmallVector<Attribute, 3> indexingMapsAttr;
|
|
Attribute mapAttr;
|
|
if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
|
|
if (parser.parseEqual())
|
|
return failure();
|
|
if (parser.parseLSquare())
|
|
return failure();
|
|
do {
|
|
if (parser.parseAttribute(mapAttr))
|
|
return failure();
|
|
if (!isa<AffineMapAttr>(mapAttr))
|
|
return parser.emitError(parser.getCurrentLocation(),
|
|
"expected affine map attribute");
|
|
indexingMapsAttr.push_back(mapAttr);
|
|
if (parser.parseOptionalComma())
|
|
break;
|
|
} while (true);
|
|
if (parser.parseRSquare())
|
|
return failure();
|
|
}
|
|
// At this stage of parsing the only way to infer number of region
|
|
// args is through op kind, as input output tensors are not parsed yet.
|
|
auto arityGroupAndKind = getArityGroupAndKind(elemwiseKindVal);
|
|
int numRegionArgs =
|
|
getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 /*output*/;
|
|
if (parseNamedStructuredOp(parser, result, numRegionArgs,
|
|
ElementwiseOp::getRegionBuilder())) {
|
|
return parser.emitError(parser.getCurrentLocation(),
|
|
"unable to parse elemwise op");
|
|
}
|
|
|
|
// Initialize indexingMaps, if not supplied explicitly.
|
|
if (indexingMapsAttr.empty()) {
|
|
// We need to infer the numDims of the indexing maps from the output
|
|
// type which is already parsed by now.
|
|
auto resultType = result.operands[result.operands.size() - 1].getType();
|
|
auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
|
|
if (!shapedType)
|
|
return parser.emitError(parser.getCurrentLocation(),
|
|
"return type needs to be shaped type");
|
|
auto numDims = shapedType.getRank();
|
|
indexingMapsAttr = llvm::map_to_vector(
|
|
ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,
|
|
parser.getContext()),
|
|
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
|
|
}
|
|
|
|
result.addAttribute("indexing_maps",
|
|
parser.getBuilder().getArrayAttr(indexingMapsAttr));
|
|
return success();
|
|
}
|
|
|
|
void ElementwiseOp::print(OpAsmPrinter &p) {
|
|
p << " kind=";
|
|
p.printAttribute(getKindAttr());
|
|
SmallVector<StringRef, 3> elidedAttrs = {"operandSegmentSizes", "kind",
|
|
"indexing_maps"};
|
|
unsigned arity =
|
|
getArityGroupAsUInt(getArityGroupAndKind(getKind()).arityGroup);
|
|
unsigned numDims = getResultRank();
|
|
|
|
SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
|
|
ElementwiseOp::getDefaultIndexingMaps(arity + 1 /*output*/, numDims,
|
|
getContext()),
|
|
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
|
|
|
|
if (!llvm::equal(getIndexingMaps(), indexingMaps))
|
|
p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
|
|
|
|
printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
|
|
elidedAttrs);
|
|
}
|
|
|
|
LogicalResult ElementwiseOp::verify() {
|
|
// All necessary checks are done either by
|
|
// - EnumAttr (e.g. unknown operation kind)
|
|
// - verifyStructuredOpInterface (incorrect map, sizes).
|
|
return success();
|
|
}
|
|
|
|
/// Implements the block region builder for the ElementwiseOp. This is called by
|
|
/// 'fillStructuredOpRegion'.
|
|
void ElementwiseOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
|
|
ArrayRef<NamedAttribute> attrs) {
|
|
ElementwiseKind elemwiseKind;
|
|
for (auto attr : attrs) {
|
|
if (attr.getName() == b.getStringAttr("kind")) {
|
|
auto kindAttr = dyn_cast<ElementwiseKindAttr>(attr.getValue());
|
|
assert(kindAttr && "op kind attribute incorrectly set");
|
|
elemwiseKind = kindAttr.getValue();
|
|
break;
|
|
}
|
|
}
|
|
|
|
ArityGroupAndKind groupAndKind = getArityGroupAndKind(elemwiseKind);
|
|
auto arityGroup = groupAndKind.arityGroup;
|
|
auto kind = groupAndKind.kind;
|
|
assert(block.getNumArguments() ==
|
|
getArityGroupAsUInt(arityGroup) + 1 /*output*/
|
|
&& "Elementwise regionBuilder number of block args mismatch");
|
|
|
|
RegionBuilderHelper helper(b, block);
|
|
SmallVector<Value> yields;
|
|
Value result;
|
|
|
|
if (arityGroup == ElementwiseArityGroup::Unary) {
|
|
result = helper.buildUnaryFn(kind.unaryFn, block.getArgument(0));
|
|
|
|
} else if (arityGroup == ElementwiseArityGroup::Binary) {
|
|
result = helper.buildBinaryFn(kind.binaryFn, block.getArgument(0),
|
|
block.getArgument(1));
|
|
|
|
} else if (arityGroup == ElementwiseArityGroup::Ternary) {
|
|
result = helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0),
|
|
block.getArgument(1), block.getArgument(2));
|
|
|
|
} else {
|
|
assert(false && "found unhandled category in elemwise");
|
|
}
|
|
|
|
yields.push_back(result);
|
|
helper.yieldOutputs(yields);
|
|
}
|
|
|
|
LogicalResult ElementwiseOp::fold(FoldAdaptor,
|
|
SmallVectorImpl<OpFoldResult> &) {
|
|
return memref::foldMemRefCast(*this);
|
|
}
|
|
|
|
void ElementwiseOp::getEffects(
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
&effects) {
|
|
if (hasPureTensorSemantics())
|
|
return;
|
|
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
Speculation::Speculatability ElementwiseOp::getSpeculatability() {
|
|
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// PackOp/UnPackOp Common
|
|
//===----------------------------------------------------------------------===//
|
|
// Given the (potentially) updated packed type, `newPackedTy`, generates an
|
|
// updated mixed-tile-sizes attribute. A tile size is updated only
|
|
// when:
|
|
// * a dim from newPackedTy is static, and
|
|
// * the corresponding size from mixedTiles is still dynamic.
|
|
// Otherwise, the original tile size is preserved.
|
|
// Note - packed-type-dim and mixed-tile-size should always match!
|
|
static SmallVector<OpFoldResult>
|
|
getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy,
|
|
SmallVector<OpFoldResult> mixedTiles) {
|
|
SmallVector<OpFoldResult> newMixedTileSizes;
|
|
for (auto it : llvm::zip(cast<ShapedType>(newPackedTy)
|
|
.getShape()
|
|
.take_back(mixedTiles.size()),
|
|
mixedTiles)) {
|
|
int64_t shape = std::get<0>(it);
|
|
if (shape == ShapedType::kDynamic) {
|
|
newMixedTileSizes.push_back(std::get<1>(it));
|
|
continue;
|
|
}
|
|
|
|
// If the current result dim is static, update the dynamic mixed-size
|
|
// (provided the original value is dynamic).
|
|
OpFoldResult tile = std::get<1>(it);
|
|
if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tile)) {
|
|
// Already a constant
|
|
newMixedTileSizes.push_back(tile);
|
|
} else {
|
|
assert(getConstantIntValue(tile).value() == shape &&
|
|
"tile size and dim size don't match!");
|
|
newMixedTileSizes.push_back(
|
|
(rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
|
|
}
|
|
}
|
|
|
|
return newMixedTileSizes;
|
|
}
|
|
|
|
template <typename OpTy>
|
|
static LogicalResult
|
|
reifyResultShapesImpl(OpTy op, OpBuilder &builder,
|
|
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
|
|
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
|
|
"applies to only pack or unpack operations");
|
|
int64_t destRank = op.getDestRank();
|
|
reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(destRank));
|
|
reifiedReturnShapes[0] =
|
|
tensor::getMixedSizes(builder, op.getLoc(), op.getDest());
|
|
return success();
|
|
}
|
|
|
|
template <typename OpTy>
|
|
static DenseMap<int64_t, OpFoldResult> getDimAndTileMappingImpl(OpTy op) {
|
|
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
|
|
"applies to only pack or unpack operations");
|
|
DenseMap<int64_t, OpFoldResult> dimAndTileMapping;
|
|
ArrayRef<int64_t> dimsToTile = op.getInnerDimsPos();
|
|
SmallVector<OpFoldResult> tiles = op.getMixedTiles();
|
|
assert(tiles.size() == dimsToTile.size() &&
|
|
"tiles must match indices of dimension to block");
|
|
// bind the dimension `i` with the tile factor.
|
|
for (auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
|
|
dimAndTileMapping[dimsToTile[i]] = tiles[i];
|
|
return dimAndTileMapping;
|
|
}
|
|
|
|
template <typename OpTy>
|
|
static SmallVector<OpFoldResult> getMixedTilesImpl(OpTy op) {
|
|
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
|
|
"applies to only pack or unpack operations");
|
|
Builder builder(op);
|
|
SmallVector<OpFoldResult> mixedInnerTiles;
|
|
unsigned dynamicValIndex = 0;
|
|
for (int64_t staticTile : op.getStaticInnerTiles()) {
|
|
if (!ShapedType::isDynamic(staticTile))
|
|
mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile));
|
|
else
|
|
mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
|
|
}
|
|
return mixedInnerTiles;
|
|
}
|
|
|
|
template <typename OpTy>
|
|
static SmallVector<int64_t> getStaticTilesImpl(OpTy op) {
|
|
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
|
|
"applies to only pack or unpack operations");
|
|
SmallVector<Value> dynamicTiles;
|
|
SmallVector<int64_t> staticTiles;
|
|
dispatchIndexOpFoldResults(op.getMixedTiles(), dynamicTiles, staticTiles);
|
|
return staticTiles;
|
|
}
|
|
|
|
/// Returns true if `dimsPos` is invalid. It is invalid when:
|
|
/// a) It contains duplicate.
|
|
/// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank).
|
|
/// c) The number of elements in `dimsPos` is > than `rank`.
|
|
static bool isInvalidPackingPosSpecification(ArrayRef<int64_t> dimsPos,
|
|
size_t rank) {
|
|
size_t dimsPosSize = dimsPos.size();
|
|
if (dimsPosSize > rank)
|
|
return true;
|
|
DenseSet<int64_t> uniqued(llvm::from_range, dimsPos);
|
|
if (dimsPosSize != uniqued.size())
|
|
return true;
|
|
return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
|
|
return dimPos < 0 || dimPos >= static_cast<int64_t>(rank);
|
|
});
|
|
}
|
|
|
|
/// Returns true if the dimension of `sourceShape` is smaller than the dimension
|
|
/// of the `limitShape`.
|
|
static bool areAllInBound(ArrayRef<int64_t> sourceShape,
|
|
ArrayRef<int64_t> limitShape) {
|
|
assert(
|
|
sourceShape.size() == limitShape.size() &&
|
|
"expected source shape rank, and limit of the shape to have same rank");
|
|
return llvm::all_of(
|
|
llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
|
|
int64_t sourceExtent = std::get<0>(it);
|
|
int64_t limit = std::get<1>(it);
|
|
return ShapedType::isDynamic(sourceExtent) ||
|
|
ShapedType::isDynamic(limit) || sourceExtent <= limit;
|
|
});
|
|
}
|
|
|
|
template <typename OpTy>
|
|
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
|
|
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
|
|
"applies to only pack or unpack operations");
|
|
Operation *op = packOrUnPack.getOperation();
|
|
|
|
// Return true if we have a zero-value tile.
|
|
auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
|
|
return llvm::any_of(tiles, isZeroInteger);
|
|
};
|
|
|
|
// Verify tiles. Do not allow zero tiles.
|
|
SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
|
|
if (hasZeros(mixedTiles))
|
|
return op->emitError("invalid zero tile factor");
|
|
|
|
// Verify inner_dims_pos and outer_dims_perm.
|
|
RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
|
|
? packOrUnPack.getSourceType()
|
|
: packOrUnPack.getDestType();
|
|
size_t unpackedRank = unpackedType.getRank();
|
|
ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
|
|
ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
|
|
if (isInvalidPackingPosSpecification(innerDimsPos, unpackedRank))
|
|
return op->emitError("invalid inner_dims_pos vector");
|
|
if (isInvalidPackingPosSpecification(outerDimPerm, unpackedRank))
|
|
return op->emitError("invalid outer_dims_perm vector");
|
|
if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
|
|
return op->emitError("outer_dims_perm must be a permutation or empty");
|
|
|
|
// Tiling factors must be less than or equal to the input rank for pack (or
|
|
// output rank for unpack), and must match the number of `inner_dims_pos`.
|
|
if (mixedTiles.size() > unpackedRank) {
|
|
return op->emitError("tiling factors must be less than or equal to the "
|
|
"input rank for pack or output rank for unpack");
|
|
}
|
|
if (mixedTiles.size() != innerDimsPos.size()) {
|
|
return op->emitError(
|
|
"tiling factors must equal the number of dimensions to tile");
|
|
}
|
|
|
|
ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
|
|
? packOrUnPack.getDestType()
|
|
: packOrUnPack.getSourceType();
|
|
size_t packedRank = packedType.getRank();
|
|
// Require output rank to match input rank + number of blocking factors.
|
|
size_t expectedPackedRank = unpackedRank + mixedTiles.size();
|
|
if (expectedPackedRank != packedRank) {
|
|
return op->emitError(
|
|
"packed rank != (unpacked rank + num tiling factors), got ")
|
|
<< packedRank << " != " << expectedPackedRank;
|
|
}
|
|
|
|
// Verify result shape is greater than the minimum expected
|
|
// by the pack operation, and that the output shape
|
|
// represents full tiles.
|
|
RankedTensorType expectedPackedType = PackOp::inferPackedType(
|
|
unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
|
|
if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
|
|
return op->emitError("the shape of output is not large enough to hold the "
|
|
"packed data. Expected at least ")
|
|
<< expectedPackedType << ", got " << packedType;
|
|
}
|
|
if (!llvm::all_of(
|
|
llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
|
|
mixedTiles),
|
|
[](std::tuple<int64_t, OpFoldResult> it) {
|
|
int64_t shape = std::get<0>(it);
|
|
if (Attribute attr =
|
|
llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
|
|
IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
|
|
int64_t staticTileSize = intAttr.getValue().getSExtValue();
|
|
return shape == staticTileSize;
|
|
}
|
|
return ShapedType::isDynamic(shape);
|
|
})) {
|
|
return op->emitError("mismatch in inner tile sizes specified and shaped of "
|
|
"tiled dimension in the packed type");
|
|
}
|
|
return success();
|
|
}
|
|
|
|
namespace {
|
|
/// Subset of PackOp/UnPackOp fields used to compute the result of applying
|
|
/// various permutations to the op.
|
|
// TODO: Add linalg.transpose + pack/unpack folding patterns that just reuse
|
|
// these. These may or may not become true foldings / canonicalizations
|
|
// depending on how aggressive we want to be in automatically folding
|
|
// transposes.
|
|
struct PackOrUnPackTransposeResult {
|
|
SmallVector<int64_t> innerDimsPos;
|
|
SmallVector<OpFoldResult> innerTiles;
|
|
SmallVector<int64_t> outerDimsPerm;
|
|
};
|
|
} // namespace
|
|
|
|
template <typename OpTy>
|
|
static PackOrUnPackTransposeResult
|
|
commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp,
|
|
ArrayRef<int64_t> innerPermutation,
|
|
ArrayRef<int64_t> outerPermutation) {
|
|
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
|
|
"applies to only pack or unpack operations");
|
|
assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
|
|
"some permutation must be non-empty");
|
|
PackOrUnPackTransposeResult metadata;
|
|
metadata.innerDimsPos =
|
|
SmallVector<int64_t>(packOrUnPackOp.getInnerDimsPos());
|
|
metadata.innerTiles =
|
|
SmallVector<OpFoldResult>(packOrUnPackOp.getMixedTiles());
|
|
int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
|
|
? packOrUnPackOp.getSourceRank()
|
|
: packOrUnPackOp.getDestRank();
|
|
metadata.outerDimsPerm =
|
|
packOrUnPackOp.getOuterDimsPerm().empty()
|
|
? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
|
|
: SmallVector<int64_t>(packOrUnPackOp.getOuterDimsPerm());
|
|
if (!innerPermutation.empty()) {
|
|
assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
|
|
isPermutationVector(innerPermutation) &&
|
|
"invalid inner permutation");
|
|
applyPermutationToVector(metadata.innerDimsPos, innerPermutation);
|
|
applyPermutationToVector(metadata.innerTiles, innerPermutation);
|
|
}
|
|
if (!outerPermutation.empty()) {
|
|
assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
|
|
isPermutationVector(outerPermutation) &&
|
|
"invalid outer permutation");
|
|
applyPermutationToVector(metadata.outerDimsPerm, outerPermutation);
|
|
}
|
|
return metadata;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// PackOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void PackOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
|
|
setNameFn(getResult(), "pack");
|
|
}
|
|
|
|
void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
|
|
Value dest, ArrayRef<int64_t> innerDimsPos,
|
|
ArrayRef<OpFoldResult> innerTiles,
|
|
std::optional<Value> paddingValue,
|
|
ArrayRef<int64_t> outerDimsPerm) {
|
|
assert(innerDimsPos.size() == innerTiles.size() &&
|
|
"number of tile sizes specified must match the specified number of "
|
|
"original dimensions to be tiled");
|
|
SmallVector<int64_t> staticTileSizes;
|
|
SmallVector<Value> dynamicTileSizes;
|
|
dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
|
|
build(builder, state, dest.getType(), source, dest,
|
|
paddingValue ? *paddingValue : nullptr,
|
|
outerDimsPerm.empty() ? nullptr
|
|
: builder.getDenseI64ArrayAttr(outerDimsPerm),
|
|
builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
|
|
builder.getDenseI64ArrayAttr(staticTileSizes));
|
|
}
|
|
|
|
LogicalResult
|
|
PackOp::reifyResultShapes(OpBuilder &builder,
|
|
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
|
|
return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
|
|
}
|
|
|
|
DenseMap<int64_t, OpFoldResult> PackOp::getDimAndTileMapping() {
|
|
return getDimAndTileMappingImpl(*this);
|
|
}
|
|
|
|
SmallVector<OpFoldResult> PackOp::getMixedTiles() {
|
|
return getMixedTilesImpl(*this);
|
|
}
|
|
|
|
SmallVector<int64_t> PackOp::getStaticTiles() {
|
|
return getStaticTilesImpl(*this);
|
|
}
|
|
|
|
ArrayRef<int64_t> PackOp::getAllOuterDims() {
|
|
ShapedType inputType = getSourceType();
|
|
int64_t inputRank = inputType.getRank();
|
|
return getDestType().getShape().take_front(inputRank);
|
|
}
|
|
|
|
SmallVector<int64_t> PackOp::getTiledOuterDims() {
|
|
auto innerDimsPos = getInnerDimsPos();
|
|
auto packedShape = getDestType().getShape();
|
|
SmallVector<int64_t> res;
|
|
|
|
for (auto index : innerDimsPos)
|
|
res.push_back(packedShape[index]);
|
|
|
|
return res;
|
|
}
|
|
|
|
bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
|
|
ArrayRef<int64_t> innerDimsPos,
|
|
ArrayRef<int64_t> outputShape,
|
|
ArrayRef<int64_t> outerDimsPerm,
|
|
ArrayRef<OpFoldResult> innerTiles) {
|
|
SmallVector<int64_t> outputTileSizes(
|
|
outputShape.take_front(inputShape.size()));
|
|
if (!outerDimsPerm.empty()) {
|
|
assert(outerDimsPerm.size() == outputTileSizes.size() &&
|
|
"expected output and outer_dims_perm to have same size");
|
|
applyPermutationToVector(outputTileSizes,
|
|
invertPermutationVector(outerDimsPerm));
|
|
}
|
|
for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
|
|
if (ShapedType::isDynamic(inputShape[pos]))
|
|
continue;
|
|
std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
|
|
|
|
if (!constantTile) {
|
|
if (!ShapedType::isDynamic(outputTileSizes[pos]) &&
|
|
(inputShape[pos] % outputTileSizes[pos] != 0))
|
|
return true;
|
|
} else if (inputShape[pos] % (*constantTile) != 0) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
LogicalResult PackOp::verify() {
|
|
if (failed(commonVerifierPackAndUnPackOp(*this)))
|
|
return failure();
|
|
|
|
// Verify padding value, and bail out if the tile does not divide the
|
|
// dimension fully. In the case of dynamic tile factors or dimensions, having
|
|
// a partial tile is undefined behavior.
|
|
auto paddingValue = getPaddingValue();
|
|
if (paddingValue &&
|
|
paddingValue.getType() != getSourceType().getElementType()) {
|
|
return emitOpError("expected padding_value has ")
|
|
<< getSourceType().getElementType()
|
|
<< " but got: " << paddingValue.getType();
|
|
}
|
|
|
|
if (!paddingValue &&
|
|
requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
|
|
getDestType().getShape(), getOuterDimsPerm(),
|
|
getMixedTiles())) {
|
|
return emitOpError(
|
|
"invalid tile factor or output size provided. Only full tiles are "
|
|
"supported when padding_value is not set");
|
|
}
|
|
return success();
|
|
}
|
|
|
|
/// Converts OpFoldResults to int64_t shape entries, unconditionally mapping all
|
|
/// Value's to kDynamic, even if they are arith.constant values.
|
|
static SmallVector<int64_t>
|
|
asShapeWithAnyValueAsDynamic(ArrayRef<OpFoldResult> ofrs) {
|
|
SmallVector<int64_t> result;
|
|
for (auto o : ofrs) {
|
|
// Have to do this first, as getConstantIntValue special-cases constants.
|
|
if (llvm::dyn_cast_if_present<Value>(o))
|
|
result.push_back(ShapedType::kDynamic);
|
|
else
|
|
result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
|
|
}
|
|
return result;
|
|
}
|
|
|
|
/// Helper for PackOp::{getResultShape,inferPackedType}. Returns the shape of
|
|
/// the packed type. Having a shared helper helps implement these two methods in
|
|
/// a way that ensures that they agree on which dimensions are dynamic.
|
|
static SmallVector<int64_t> getPackOpResultTypeShape(
|
|
ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
|
|
ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
|
|
SmallVector<int64_t> resultShape = llvm::to_vector(sourceShape);
|
|
for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
|
|
if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
|
|
continue;
|
|
if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
|
|
resultShape[tiledDim.value()] = ShapedType::kDynamic;
|
|
continue;
|
|
}
|
|
resultShape[tiledDim.value()] = llvm::divideCeilSigned(
|
|
resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
|
|
}
|
|
|
|
// Swap tile loops if outer_dims_perm is available.
|
|
if (!outerDimsPerm.empty())
|
|
applyPermutationToVector(resultShape, outerDimsPerm);
|
|
|
|
// Append the inner tile dimensions.
|
|
resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
|
|
return resultShape;
|
|
}
|
|
|
|
SmallVector<OpFoldResult> PackOp::getResultShape(
|
|
OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
|
|
ArrayRef<OpFoldResult> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
|
|
ArrayRef<int64_t> outerDimsPerm) {
|
|
SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
|
|
|
|
AffineExpr s0, s1;
|
|
bindSymbols(builder.getContext(), s0, s1);
|
|
AffineExpr ceilDivExpr = s0.ceilDiv(s1);
|
|
for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
|
|
resultDims[tiledDim.value()] = affine::makeComposedFoldedAffineApply(
|
|
builder, loc, ceilDivExpr,
|
|
{resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
|
|
}
|
|
if (!outerDimsPerm.empty())
|
|
applyPermutationToVector(resultDims, outerDimsPerm);
|
|
resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
|
|
|
|
SmallVector<int64_t> resultTypeShape =
|
|
getPackOpResultTypeShape(asShapeWithAnyValueAsDynamic(sourceDims),
|
|
asShapeWithAnyValueAsDynamic(innerTileSizes),
|
|
innerDimsPos, outerDimsPerm);
|
|
|
|
// Fix-up `resultDims` to ensure that they are Value's if and only if the
|
|
// result type shape says it's a dynamic dim. This is needed as callers may
|
|
// use dispatchIndexOpFoldResults on the result, and rely on exact number of
|
|
// dynamic dims returned by that.
|
|
for (unsigned i = 0; i < resultDims.size(); ++i) {
|
|
if (!ShapedType::isDynamic(resultTypeShape[i]))
|
|
continue;
|
|
resultDims[i] =
|
|
getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]);
|
|
}
|
|
|
|
return resultDims;
|
|
}
|
|
|
|
/// Get the expected packed type based on source type, tile factors, position of
|
|
/// the inner tiles and permutation of the outer tiled loop.
|
|
RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
|
|
ArrayRef<int64_t> innerTileSizes,
|
|
ArrayRef<int64_t> innerDimsPos,
|
|
ArrayRef<int64_t> outerDimsPerm) {
|
|
SmallVector<int64_t> resultShape = getPackOpResultTypeShape(
|
|
sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
|
|
return RankedTensorType::get(resultShape, sourceType.getElementType());
|
|
}
|
|
|
|
Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
|
|
ArrayRef<OpFoldResult> innerTileSizes,
|
|
ArrayRef<int64_t> innerDimsPos,
|
|
ArrayRef<int64_t> outerDimsPerm) {
|
|
AffineExpr dim0, dim1;
|
|
bindDims(b.getContext(), dim0, dim1);
|
|
auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
|
|
return affine::makeComposedFoldedAffineApply(b, loc, dim0.ceilDiv(dim1),
|
|
{v1, v2});
|
|
};
|
|
|
|
SmallVector<OpFoldResult> mixedSizes;
|
|
for (auto [index, value] : llvm::enumerate(
|
|
llvm::cast<RankedTensorType>(source.getType()).getShape())) {
|
|
if (ShapedType::isDynamic(value))
|
|
mixedSizes.push_back(
|
|
b.create<tensor::DimOp>(loc, source, index).getResult());
|
|
else
|
|
mixedSizes.push_back(b.getIndexAttr(value));
|
|
}
|
|
for (auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
|
|
int64_t dimPos = std::get<0>(it);
|
|
OpFoldResult tileSize = std::get<1>(it);
|
|
mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
|
|
}
|
|
if (!outerDimsPerm.empty())
|
|
applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
|
|
|
|
mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
|
|
auto elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
|
|
return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
|
|
}
|
|
|
|
PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
|
|
ArrayRef<int64_t> innerPermutation,
|
|
ArrayRef<int64_t> outerPermutation) {
|
|
PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
|
|
*this, innerPermutation, outerPermutation);
|
|
Value transposedDest =
|
|
createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
|
|
metadata.innerDimsPos, metadata.outerDimsPerm);
|
|
return b.create<PackOp>(loc, getSource(), transposedDest,
|
|
metadata.innerDimsPos, metadata.innerTiles,
|
|
getPaddingValue(), metadata.outerDimsPerm);
|
|
}
|
|
|
|
/// Returns true if the tiles and the tiled dims are constant.
|
|
template <typename OpTy>
|
|
bool areTilesAndTiledDimsAllConstant(OpTy op) {
|
|
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
|
|
"applies to only pack or unpack operations");
|
|
ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
|
|
? op.getDestType()
|
|
: op.getSourceType();
|
|
SmallVector<OpFoldResult> mixedTiles = op.getMixedTiles();
|
|
for (auto [dimDest, tile] : llvm::zip(
|
|
packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
|
|
std::optional<int64_t> constTileSize = getConstantIntValue(tile);
|
|
if (!constTileSize || ShapedType::isDynamic(dimDest))
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
Speculation::Speculatability PackOp::getSpeculatability() {
|
|
if (getPaddingValue())
|
|
return Speculation::Speculatable;
|
|
|
|
// The verifier rejects already operations if we can statically prove that the
|
|
// sizes of the tiles do not divide perfectly the dimension; thus, check only
|
|
// to have constant tiles and tiled inner dimensions.
|
|
if (!areTilesAndTiledDimsAllConstant(*this))
|
|
return Speculation::NotSpeculatable;
|
|
|
|
return Speculation::Speculatable;
|
|
}
|
|
|
|
// Return true if `inner_dims_pos` and `outer_dims_perm` target the same
|
|
// dimensions for pack and unpack.
|
|
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
|
|
if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
|
|
return false;
|
|
if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
|
|
return true;
|
|
// Outer dims permutation is optional.
|
|
// To compare unbalanced pack-unpack pair, treat no permutation as equal to
|
|
// identity permutation.
|
|
return isIdentityPermutation(packOp.getOuterDimsPerm()) &&
|
|
isIdentityPermutation(unPackOp.getOuterDimsPerm());
|
|
}
|
|
|
|
// Return true if pack and unpack have the same tiles.
|
|
// Same SSA values or same integer constants.
|
|
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
|
|
auto packTiles = packOp.getMixedTiles();
|
|
auto unPackTiles = unPackOp.getMixedTiles();
|
|
if (packTiles.size() != unPackTiles.size())
|
|
return false;
|
|
for (size_t i = 0, e = packTiles.size(); i < e; i++) {
|
|
if (!isEqualConstantIntOrValue(packTiles[i], unPackTiles[i]))
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
/// Returns true if the pack op does not need a padding value.
|
|
static bool paddingIsNotNeeded(PackOp op) {
|
|
auto srcType = op.getSourceType();
|
|
if (llvm::any_of(op.getInnerDimsPos(),
|
|
[&](int64_t pos) { return srcType.isDynamicDim(pos); }))
|
|
return false;
|
|
if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
|
|
return false;
|
|
return !PackOp::requirePaddingValue(
|
|
srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
|
|
op.getOuterDimsPerm(), op.getMixedTiles());
|
|
}
|
|
|
|
/// Returns true if the `srcShape` or `destShape` is different from the one in
|
|
/// `packOp` and populates each with the inferred static shape.
|
|
static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
|
|
SmallVectorImpl<int64_t> &destShape) {
|
|
bool changeNeeded = false;
|
|
srcShape.assign(packOp.getSourceType().getShape().begin(),
|
|
packOp.getSourceType().getShape().end());
|
|
destShape.assign(packOp.getDestType().getShape().begin(),
|
|
packOp.getDestType().getShape().end());
|
|
llvm::SmallSetVector<int64_t, 4> innerDims;
|
|
innerDims.insert_range(packOp.getInnerDimsPos());
|
|
SmallVector<int64_t> inverseOuterDimsPerm;
|
|
if (!packOp.getOuterDimsPerm().empty())
|
|
inverseOuterDimsPerm = invertPermutationVector(packOp.getOuterDimsPerm());
|
|
int srcRank = packOp.getSourceRank();
|
|
for (auto i : llvm::seq<int64_t>(0, srcRank)) {
|
|
if (innerDims.contains(i))
|
|
continue;
|
|
int64_t srcPos = i;
|
|
int64_t destPos = i;
|
|
if (!inverseOuterDimsPerm.empty())
|
|
destPos = inverseOuterDimsPerm[srcPos];
|
|
if (ShapedType::isDynamic(srcShape[srcPos]) ==
|
|
ShapedType::isDynamic(destShape[destPos])) {
|
|
continue;
|
|
}
|
|
int64_t size = srcShape[srcPos];
|
|
if (ShapedType::isDynamic(size))
|
|
size = destShape[destPos];
|
|
srcShape[srcPos] = size;
|
|
destShape[destPos] = size;
|
|
changeNeeded = true;
|
|
}
|
|
return changeNeeded;
|
|
}
|
|
|
|
LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
|
|
// Fold an pack(unpack(x)) to x.
|
|
if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
|
|
if (unPackOp.getSourceType() != packOp.getDestType())
|
|
return failure();
|
|
if (packOp.getPaddingValue() ||
|
|
!hasSameInnerOuterAttribute(packOp, unPackOp) ||
|
|
!haveSameTiles(packOp, unPackOp))
|
|
return failure();
|
|
rewriter.replaceOp(packOp, unPackOp.getSource());
|
|
return success();
|
|
}
|
|
|
|
// Fold optional PaddingValue operand away if padding is not needed.
|
|
if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) {
|
|
rewriter.startOpModification(packOp);
|
|
packOp.getPaddingValueMutable().clear();
|
|
rewriter.finalizeOpModification(packOp);
|
|
return success();
|
|
}
|
|
|
|
// Insert tensor.cast ops if static shape inference is available..
|
|
SmallVector<int64_t> srcShape, destShape;
|
|
if (inferStaticShape(packOp, srcShape, destShape)) {
|
|
Location loc = packOp.getLoc();
|
|
Value source = packOp.getSource();
|
|
if (srcShape != packOp.getSourceType().getShape()) {
|
|
auto newSrcType = packOp.getSourceType().clone(srcShape);
|
|
source =
|
|
rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
|
|
}
|
|
Value dest = packOp.getDest();
|
|
RankedTensorType originalResultType = packOp.getDestType();
|
|
bool needUpdateDestType = (destShape != originalResultType.getShape());
|
|
if (needUpdateDestType) {
|
|
auto newDestType = packOp.getDestType().clone(destShape);
|
|
dest =
|
|
rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
|
|
}
|
|
rewriter.modifyOpInPlace(packOp, [&] {
|
|
packOp.getSourceMutable().assign(source);
|
|
packOp.getDestMutable().assign(dest);
|
|
packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
|
|
});
|
|
// Insert a cast if needed
|
|
if (needUpdateDestType) {
|
|
rewriter.setInsertionPointAfter(packOp);
|
|
auto castOp =
|
|
rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
|
|
rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
|
|
}
|
|
return success();
|
|
}
|
|
|
|
return failure();
|
|
}
|
|
|
|
template <typename PackOrUnpackOp>
|
|
static bool isLikePadUnPad(PackOrUnpackOp packOp,
|
|
RankedTensorType packedTensorType) {
|
|
static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
|
|
std::is_same<PackOrUnpackOp, UnPackOp>::value,
|
|
"Function meant for pack/unpack");
|
|
// This is a pad if packing only adds ones and we don't transpose dimensions.
|
|
|
|
// Check that we are not transposing any dimensions.
|
|
ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
|
|
int64_t numPackedDims = innerDimsPos.size();
|
|
auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
|
|
if (orderedDims != innerDimsPos) {
|
|
// Dimensions don't happen in order.
|
|
return false;
|
|
}
|
|
|
|
ArrayRef<int64_t> packedShape = packedTensorType.getShape();
|
|
int64_t packedRank = packedTensorType.getRank();
|
|
// At this point we know that we are taking numPackedDims outer
|
|
// dimensions and pushing them all the way as the inner most dimensions.
|
|
// What's left on the outer most dimensions is, in this order:
|
|
// - the factor of the packed dimensions, then
|
|
// - the untouched dimensions
|
|
// This shifting inward of dimensions is a no-op (as opposed to a transpose)
|
|
// if all the dimensions that bubble outerward are ones.
|
|
// Therefore check that all the dimensions but the numPackedDims inner most
|
|
// ones are ones.
|
|
return llvm::all_of(
|
|
llvm::seq<int64_t>(0, packedRank - numPackedDims),
|
|
[&packedShape](int64_t i) { return packedShape[i] == 1; });
|
|
}
|
|
|
|
bool PackOp::isLikePad() {
|
|
auto packedTensorType =
|
|
llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
|
|
return isLikePadUnPad(*this, packedTensorType);
|
|
}
|
|
|
|
OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
|
|
std::optional<Attribute> paddingValue;
|
|
if (auto pad = adaptor.getPaddingValue())
|
|
paddingValue = pad;
|
|
if (OpFoldResult reshapedSource = reshapeConstantSource(
|
|
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
|
|
getDestType(), paddingValue))
|
|
return reshapedSource;
|
|
return {};
|
|
}
|
|
|
|
/// Folds a tensor.cast op into a consuming PackOp op if the
|
|
/// `tensor.cast` has source that is more static than the consuming op.
|
|
///
|
|
/// Example:
|
|
/// ```mlir
|
|
/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
|
|
/// %2 = tensor.pack %1 ... : tensor<?x?xf32> ...
|
|
/// ```
|
|
///
|
|
/// folds into:
|
|
///
|
|
/// ```mlir
|
|
/// %2 = tensor.pack %0 ... : tensor<8x16xf32> ...
|
|
/// ```
|
|
struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
|
|
using OpRewritePattern<PackOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(PackOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
if (!tensor::hasFoldableTensorCastOperand(op))
|
|
return failure();
|
|
|
|
SmallVector<Type> newResultTypes(op->getResultTypes());
|
|
SmallVector<Value> newOperands =
|
|
tensor::getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
|
|
|
|
// Get the updated mixed-tile-sizes attribute.
|
|
SmallVector<OpFoldResult> newMixedTileSizes =
|
|
getNewMixedTileSizes(rewriter, newResultTypes[0], op.getMixedTiles());
|
|
|
|
// Clone op.
|
|
// TODO: Strictly speaking, discardable attributes should be _discarded_ at
|
|
// this point. However, in practice, we use them for things that we'd like
|
|
// to preserve. Implement a better abstraction.
|
|
PackOp newOp = rewriter.create<PackOp>(
|
|
op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(),
|
|
newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm());
|
|
newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
|
|
|
|
// Replace op.
|
|
Value oldResult = op.getResult();
|
|
Value newResult = newOp.getResult();
|
|
Value replacement = (newResult.getType() != oldResult.getType())
|
|
? rewriter.create<tensor::CastOp>(
|
|
op->getLoc(), oldResult.getType(), newResult)
|
|
: newResult;
|
|
|
|
rewriter.replaceOp(op, {replacement});
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// UnPackOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void UnPackOp::getAsmResultNames(
|
|
function_ref<void(Value, StringRef)> setNameFn) {
|
|
setNameFn(getResult(), "unpack");
|
|
}
|
|
|
|
LogicalResult
|
|
UnPackOp::reifyResultShapes(OpBuilder &builder,
|
|
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
|
|
return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
|
|
}
|
|
|
|
DenseMap<int64_t, OpFoldResult> UnPackOp::getDimAndTileMapping() {
|
|
return getDimAndTileMappingImpl(*this);
|
|
}
|
|
|
|
SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
|
|
return getMixedTilesImpl(*this);
|
|
}
|
|
|
|
SmallVector<int64_t> UnPackOp::getStaticTiles() {
|
|
return getStaticTilesImpl(*this);
|
|
}
|
|
|
|
ArrayRef<int64_t> UnPackOp::getAllOuterDims() {
|
|
ShapedType destType = getDestType();
|
|
int64_t destRank = destType.getRank();
|
|
return getSourceType().getShape().take_front(destRank);
|
|
}
|
|
|
|
SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
|
|
auto innerDimsPos = getInnerDimsPos();
|
|
auto packedShape = getSourceType().getShape();
|
|
SmallVector<int64_t> res;
|
|
|
|
for (auto index : innerDimsPos)
|
|
res.push_back(packedShape[index]);
|
|
|
|
return res;
|
|
}
|
|
|
|
LogicalResult UnPackOp::verify() {
|
|
return commonVerifierPackAndUnPackOp(*this);
|
|
}
|
|
|
|
Speculation::Speculatability UnPackOp::getSpeculatability() {
|
|
// See PackOp::getSpeculatability.
|
|
if (!areTilesAndTiledDimsAllConstant(*this))
|
|
return Speculation::NotSpeculatable;
|
|
|
|
return Speculation::Speculatable;
|
|
}
|
|
|
|
void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
|
|
Value dest, ArrayRef<int64_t> innerDimsPos,
|
|
ArrayRef<OpFoldResult> innerTiles,
|
|
ArrayRef<int64_t> outerDimsPerm) {
|
|
assert(innerDimsPos.size() == innerTiles.size() &&
|
|
"number of tile sizes specified must match the specified number of "
|
|
"original dimensions to be tiled");
|
|
SmallVector<int64_t> staticTileSizes;
|
|
SmallVector<Value> dynamicTileSizes;
|
|
dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
|
|
build(builder, state, dest.getType(), source, dest,
|
|
outerDimsPerm.empty() ? nullptr
|
|
: builder.getDenseI64ArrayAttr(outerDimsPerm),
|
|
builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
|
|
builder.getDenseI64ArrayAttr(staticTileSizes));
|
|
}
|
|
|
|
Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
|
|
Value source,
|
|
ArrayRef<OpFoldResult> innerTileSizes,
|
|
ArrayRef<int64_t> innerDimsPos,
|
|
ArrayRef<int64_t> outerDimsPerm) {
|
|
AffineExpr sym0, sym1;
|
|
bindSymbols(b.getContext(), sym0, sym1);
|
|
auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
|
|
return affine::makeComposedFoldedAffineApply(b, loc, sym0 * sym1, {v1, v2});
|
|
};
|
|
|
|
SmallVector<OpFoldResult> mixedSizes;
|
|
auto srcType = llvm::cast<RankedTensorType>(source.getType());
|
|
for (auto i :
|
|
llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
|
|
if (srcType.isDynamicDim(i))
|
|
mixedSizes.push_back(b.create<tensor::DimOp>(loc, source, i).getResult());
|
|
else
|
|
mixedSizes.push_back(b.getIndexAttr(srcType.getDimSize(i)));
|
|
}
|
|
if (!outerDimsPerm.empty()) {
|
|
applyPermutationToVector<OpFoldResult>(
|
|
mixedSizes, invertPermutationVector(outerDimsPerm));
|
|
}
|
|
|
|
for (auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
|
|
mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
|
|
|
|
auto elemType = srcType.getElementType();
|
|
return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
|
|
}
|
|
|
|
UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
|
|
Value transposedSource,
|
|
ArrayRef<int64_t> innerPermutation,
|
|
ArrayRef<int64_t> outerPermutation) {
|
|
PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
|
|
*this, innerPermutation, outerPermutation);
|
|
return b.create<UnPackOp>(loc, transposedSource, getDest(),
|
|
metadata.innerDimsPos, metadata.innerTiles,
|
|
metadata.outerDimsPerm);
|
|
}
|
|
|
|
/// Returns true if the `srcShape` or `destShape` is different from the one in
|
|
/// `op` and populates each with the inferred static shape.
|
|
static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
|
|
SmallVectorImpl<int64_t> &destShape) {
|
|
bool changeNeeded = false;
|
|
srcShape.assign(op.getSourceType().getShape().begin(),
|
|
op.getSourceType().getShape().end());
|
|
destShape.assign(op.getDestType().getShape().begin(),
|
|
op.getDestType().getShape().end());
|
|
llvm::SmallSetVector<int64_t, 4> innerDims;
|
|
innerDims.insert_range(op.getInnerDimsPos());
|
|
SmallVector<int64_t> inverseOuterDimsPerm;
|
|
if (!op.getOuterDimsPerm().empty())
|
|
inverseOuterDimsPerm = invertPermutationVector(op.getOuterDimsPerm());
|
|
int destRank = op.getDestRank();
|
|
for (auto i : llvm::seq<int64_t>(0, destRank)) {
|
|
if (innerDims.contains(i))
|
|
continue;
|
|
int64_t srcPos = i;
|
|
int64_t destPos = i;
|
|
if (!inverseOuterDimsPerm.empty())
|
|
srcPos = inverseOuterDimsPerm[destPos];
|
|
if (ShapedType::isDynamic(srcShape[srcPos]) ==
|
|
ShapedType::isDynamic(destShape[destPos])) {
|
|
continue;
|
|
}
|
|
int64_t size = srcShape[srcPos];
|
|
if (ShapedType::isDynamic(size))
|
|
size = destShape[destPos];
|
|
srcShape[srcPos] = size;
|
|
destShape[destPos] = size;
|
|
changeNeeded = true;
|
|
}
|
|
return changeNeeded;
|
|
}
|
|
|
|
LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
|
|
PatternRewriter &rewriter) {
|
|
/// unpack(pack(x)) -> x
|
|
if (PackOp packOp = unPackOp.getSource().getDefiningOp<PackOp>()) {
|
|
if (packOp.getSourceType() != unPackOp.getDestType())
|
|
return failure();
|
|
if (packOp.getPaddingValue() ||
|
|
!hasSameInnerOuterAttribute(packOp, unPackOp) ||
|
|
!haveSameTiles(packOp, unPackOp))
|
|
return failure();
|
|
rewriter.replaceOp(unPackOp, packOp.getSource());
|
|
return success();
|
|
}
|
|
/// unpack(destinationStyleOp(x)) -> unpack(x)
|
|
if (auto dstStyleOp =
|
|
unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
|
|
auto destValue = cast<OpResult>(unPackOp.getDest());
|
|
Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
|
|
rewriter.modifyOpInPlace(unPackOp,
|
|
[&]() { unPackOp.setDpsInitOperand(0, newDest); });
|
|
return success();
|
|
}
|
|
/// extract_slice(unpack(x into y)) -> unpack(x into extract_slice(y))
|
|
if (unPackOp->hasOneUse()) {
|
|
auto extractSliceUser =
|
|
dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
|
|
if (extractSliceUser &&
|
|
areAllConstantIntValue(extractSliceUser.getMixedOffsets(), 0) &&
|
|
areAllConstantIntValue(extractSliceUser.getMixedStrides(), 1) &&
|
|
extractSliceUser.getSourceType().getRank() ==
|
|
extractSliceUser.getResultType().getRank()) {
|
|
OpBuilder::InsertionGuard g(rewriter);
|
|
rewriter.setInsertionPoint(unPackOp);
|
|
auto newDest = rewriter.create<tensor::ExtractSliceOp>(
|
|
unPackOp->getLoc(), unPackOp.getDest(),
|
|
extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
|
|
extractSliceUser.getMixedStrides());
|
|
rewriter.modifyOpInPlace(unPackOp, [&]() {
|
|
unPackOp.setDpsInitOperand(0, newDest);
|
|
unPackOp.getResult().setType(newDest.getType());
|
|
});
|
|
rewriter.replaceOp(extractSliceUser, unPackOp);
|
|
return success();
|
|
}
|
|
}
|
|
|
|
// Insert tensor.cast ops if static shape inference is available..
|
|
SmallVector<int64_t> srcShape, destShape;
|
|
if (inferStaticShape(unPackOp, srcShape, destShape)) {
|
|
Location loc = unPackOp.getLoc();
|
|
Value source = unPackOp.getSource();
|
|
if (srcShape != unPackOp.getSourceType().getShape()) {
|
|
auto newSrcType = unPackOp.getSourceType().clone(srcShape);
|
|
source = rewriter.create<tensor::CastOp>(loc, newSrcType,
|
|
unPackOp.getSource());
|
|
}
|
|
Value dest = unPackOp.getDest();
|
|
if (destShape != unPackOp.getDestType().getShape()) {
|
|
auto newDestType = unPackOp.getDestType().clone(destShape);
|
|
dest =
|
|
rewriter.create<tensor::CastOp>(loc, newDestType, unPackOp.getDest());
|
|
}
|
|
Value newOp = rewriter.create<UnPackOp>(
|
|
loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
|
|
unPackOp.getOuterDimsPerm());
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
|
unPackOp, unPackOp.getResult().getType(), newOp);
|
|
return success();
|
|
}
|
|
|
|
return failure();
|
|
}
|
|
|
|
bool UnPackOp::isLikeUnPad() {
|
|
RankedTensorType packedTensorType = getSourceType();
|
|
return isLikePadUnPad(*this, packedTensorType);
|
|
}
|
|
|
|
OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
|
|
if (OpFoldResult reshapedSource = reshapeConstantSource(
|
|
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
|
|
getResult().getType()))
|
|
return reshapedSource;
|
|
return {};
|
|
}
|
|
|
|
/// Folds a tensor.cast op into a consuming UnPackOp op if the
|
|
/// `tensor.cast` has source that is more static than the consuming op.
|
|
///
|
|
/// Example:
|
|
/// ```mlir
|
|
/// %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
|
|
/// %2 = tensor.unpack %1 ... : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
|
|
/// ```
|
|
///
|
|
/// folds into:
|
|
///
|
|
/// ```mlir
|
|
/// %2 = tensor.unpack %0 ... tensor<1x1x8x1xi32> -> tensor<7x?xi32>
|
|
/// ```
|
|
struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
|
|
using OpRewritePattern<UnPackOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(UnPackOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
if (!tensor::hasFoldableTensorCastOperand(op))
|
|
return failure();
|
|
|
|
SmallVector<Type> newResultTypes(op->getResultTypes());
|
|
SmallVector<Value> newOperands =
|
|
tensor::getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
|
|
Value sourceTensor = newOperands[0];
|
|
|
|
// Get the updated mixed-tile-sizes attribute.
|
|
SmallVector<OpFoldResult> newMixedTileSizes = getNewMixedTileSizes(
|
|
rewriter, sourceTensor.getType(), op.getMixedTiles());
|
|
|
|
// Clone op.
|
|
// TODO: Strictly speaking, discardable attributes should be _discarded_ at
|
|
// this point. However, in practice, we use them for things that we'd like
|
|
// to preserve. Implement a better abstraction.
|
|
UnPackOp newOp = rewriter.create<UnPackOp>(
|
|
op.getLoc(), sourceTensor, newOperands[1], op.getInnerDimsPos(),
|
|
newMixedTileSizes, op.getOuterDimsPerm());
|
|
newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
|
|
|
|
// Replace op.
|
|
Value oldResult = op.getResult();
|
|
Value newResult = newOp.getResult();
|
|
Value replacement = (newResult.getType() != oldResult.getType())
|
|
? rewriter.create<tensor::CastOp>(
|
|
op->getLoc(), oldResult.getType(), newResult)
|
|
: newResult;
|
|
|
|
rewriter.replaceOp(op, {replacement});
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// BatchReduceMatmulOp
|
|
//===----------------------------------------------------------------------===//
|
|
SmallVector<utils::IteratorType> BatchReduceMatmulOp::getIteratorTypesArray() {
|
|
return SmallVector<utils::IteratorType>{
|
|
utils::IteratorType::reduction, utils::IteratorType::parallel,
|
|
utils::IteratorType::parallel, utils::IteratorType::reduction};
|
|
}
|
|
|
|
SmallVector<AffineMap>
|
|
BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
|
|
AffineExpr d0, d1, d2, d3;
|
|
SmallVector<AffineMap> indexingMaps;
|
|
bindDims(context, d0, d1, d2, d3);
|
|
indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
|
|
indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
|
|
indexingMaps.push_back(AffineMap::get(4, 0, {d1, d2}, context));
|
|
return indexingMaps;
|
|
}
|
|
|
|
unsigned BatchReduceMatmulOp::getNumRegionArgs() { return 3; }
|
|
|
|
std::string BatchReduceMatmulOp::getLibraryCallName() {
|
|
return generateLibraryCallName(getOperation());
|
|
}
|
|
|
|
/// Check if the op has broadcast and/or transpose semantic. Returns true if
|
|
/// the user defined indexing maps are not equal to default map.
|
|
bool BatchReduceMatmulOp::hasUserDefinedMaps() {
|
|
SmallVector<AffineMap, 3> defaultMaps =
|
|
getDefaultIndexingMaps(this->getContext());
|
|
SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
|
|
return defaultMaps != explicitMaps;
|
|
}
|
|
|
|
/// Returns true if the given bcastMap map is a valid broadcast map. A valid
|
|
/// broadcast map must include K dimension.
|
|
/// TODO: Strict inclusion of K dimension in the broadcast map is not
|
|
/// necessary for both input matrices simultaneously. We can relax this
|
|
/// condition to have K dimension for one input matrix map and infer the K
|
|
/// dimension for other input matrix map from the one already having K
|
|
/// dimension.
|
|
bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
|
|
bool isLHS) {
|
|
assert(bcastMap.getNumResults() < 3 &&
|
|
"Expected less than 3 result dim expr.");
|
|
bool isValid = false;
|
|
enum Indices { batchPos, mPos, nPos, kPos };
|
|
if (bcastMap.getNumResults() == 1) {
|
|
AffineExpr expr = bcastMap.getResult(0);
|
|
isValid = expr.isFunctionOfDim(kPos);
|
|
} else if (bcastMap.getNumResults() == 2) {
|
|
AffineExpr expr0 = bcastMap.getResult(0);
|
|
AffineExpr expr1 = bcastMap.getResult(1);
|
|
isValid =
|
|
isLHS ? ((expr0.isFunctionOfDim(batchPos) ||
|
|
expr0.isFunctionOfDim(mPos)) &&
|
|
expr1.isFunctionOfDim(kPos))
|
|
: ((expr0.isFunctionOfDim(batchPos) &&
|
|
expr1.isFunctionOfDim(kPos)) ||
|
|
(expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
|
|
}
|
|
return isValid;
|
|
}
|
|
|
|
void BatchReduceMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
|
|
ArrayRef<NamedAttribute> attrs) {
|
|
assert(block.getNumArguments() == 3 &&
|
|
"BatchReduceMatmulOp regionBuilder expects 3 (>=0) args");
|
|
RegionBuilderHelper helper(b, block);
|
|
SmallVector<Value> yields;
|
|
|
|
auto toType = block.getArgument(2).getType();
|
|
Value castValA =
|
|
helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
|
|
Value castValB =
|
|
helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
|
|
Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
|
|
Value addVal =
|
|
helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
|
|
yields.push_back(addVal);
|
|
helper.yieldOutputs(yields);
|
|
}
|
|
|
|
ParseResult BatchReduceMatmulOp::parse(OpAsmParser &parser,
|
|
OperationState &result) {
|
|
SmallVector<Attribute, 3> indexingMapsAttr;
|
|
Attribute mapAttr;
|
|
if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
|
|
if (parser.parseEqual())
|
|
return failure();
|
|
if (parser.parseLSquare())
|
|
return failure();
|
|
|
|
do {
|
|
if (parser.parseAttribute(mapAttr))
|
|
return failure();
|
|
if (!isa<AffineMapAttr>(mapAttr)) {
|
|
return parser.emitError(parser.getCurrentLocation(),
|
|
"expected affine map attribute");
|
|
}
|
|
indexingMapsAttr.push_back(mapAttr);
|
|
|
|
if (parser.parseOptionalComma())
|
|
break;
|
|
} while (true);
|
|
|
|
if (parser.parseRSquare())
|
|
return failure();
|
|
}
|
|
// Initialize indexingMaps, if not supplied explicitly.
|
|
if (indexingMapsAttr.empty()) {
|
|
indexingMapsAttr = llvm::map_to_vector(
|
|
BatchReduceMatmulOp::getDefaultIndexingMaps(parser.getContext()),
|
|
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
|
|
}
|
|
result.addAttribute("indexing_maps",
|
|
parser.getBuilder().getArrayAttr(indexingMapsAttr));
|
|
return ::parseNamedStructuredOp(parser, result,
|
|
BatchReduceMatmulOp::getNumRegionArgs(),
|
|
BatchReduceMatmulOp::getRegionBuilder());
|
|
}
|
|
|
|
void BatchReduceMatmulOp::print(OpAsmPrinter &p) {
|
|
SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
|
|
BatchReduceMatmulOp::getDefaultIndexingMaps(getContext()),
|
|
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
|
|
|
|
if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
|
|
p << " indexing_maps = [";
|
|
llvm::interleaveComma(getIndexingMaps(), p,
|
|
[&](Attribute attr) { p.printAttribute(attr); });
|
|
p << "]";
|
|
}
|
|
|
|
SmallVector<StringRef, 3> elidedAttrs = {
|
|
"operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
|
|
::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
|
|
elidedAttrs);
|
|
}
|
|
|
|
/// Verify the user defined indexing maps.
|
|
LogicalResult BatchReduceMatmulOp::verify() {
|
|
// Verification of pure batch_reduce_matmul is handled by
|
|
// verifyStructuredOpInterface().
|
|
if (!hasUserDefinedMaps())
|
|
return success();
|
|
|
|
for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
|
|
if (failed(verifyExtendedBatchVariantMatmulSemantic(*this, opIndex)))
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
LogicalResult BatchReduceMatmulOp::fold(FoldAdaptor,
|
|
SmallVectorImpl<OpFoldResult> &) {
|
|
return memref::foldMemRefCast(*this);
|
|
}
|
|
void BatchReduceMatmulOp::getEffects(
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
&effects) {
|
|
if (hasPureTensorSemantics())
|
|
return;
|
|
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
Speculation::Speculatability BatchReduceMatmulOp::getSpeculatability() {
|
|
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
|
|
}
|
|
|
|
} // namespace linalg
|
|
} // namespace mlir
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// LinalgDialect
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void LinalgDialect::getCanonicalizationPatterns(
|
|
RewritePatternSet &results) const {
|
|
results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp, FoldTensorCastPackOp,
|
|
FoldTensorCastUnPackOp, InferStaticShapeOfOperands>(getContext());
|
|
}
|
|
|
|
Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
|
|
Attribute value, Type type,
|
|
Location loc) {
|
|
return arith::ConstantOp::materialize(builder, value, type, loc);
|
|
}
|