[MLIR][Linalg] Harden parsing Linalg named ops (#145337)
This thread through proper error handling / reporting capabilities to avoid hitting llvm_unreachable while parsing linalg ops. Fixes #132755 Fixes #132740 Fixes #129185
This commit is contained in:
parent
ac29858e2d
commit
ff0dcc4614
@ -16,6 +16,7 @@
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/BuiltinDialect.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
@ -26,6 +27,9 @@
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "mlir/Interfaces/TilingInterface.h"
|
||||
#include "mlir/Interfaces/ViewLikeInterface.h"
|
||||
|
||||
#include "llvm/ADT/STLFunctionalExtras.h"
|
||||
|
||||
#include <optional>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
@ -52,7 +52,8 @@ def Linalg_Dialect : Dialect {
|
||||
kMemoizedIndexingMapsAttrName = "linalg.memoized_indexing_maps";
|
||||
|
||||
using RegionBuilderFunType = llvm::function_ref<
|
||||
void(ImplicitLocOpBuilder &b, Block &, ArrayRef<NamedAttribute>)>;
|
||||
void(ImplicitLocOpBuilder &b, Block &, ArrayRef<NamedAttribute>,
|
||||
function_ref<InFlightDiagnostic()>)>;
|
||||
RegionBuilderFunType getRegionBuilder(StringRef name) {
|
||||
return namedStructuredOpRegionBuilders.lookup(name);
|
||||
}
|
||||
|
||||
@ -720,7 +720,7 @@ def LinalgStructuredInterface
|
||||
Returns a null function if this named op does not define a region
|
||||
builder.
|
||||
}],
|
||||
/*retTy=*/"std::function<void(ImplicitLocOpBuilder &, Block &, ArrayRef<NamedAttribute>)>",
|
||||
/*retTy=*/"std::function<void(ImplicitLocOpBuilder &, Block &, ArrayRef<NamedAttribute>, function_ref<InFlightDiagnostic()>)>",
|
||||
/*methodName=*/"getRegionBuilder",
|
||||
(ins),
|
||||
[{ return ConcreteOp::getRegionBuilder(); }]
|
||||
|
||||
@ -192,7 +192,8 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [
|
||||
}
|
||||
|
||||
static std::function<void(ImplicitLocOpBuilder &,
|
||||
Block &, ArrayRef<NamedAttribute>)>
|
||||
Block &, ArrayRef<NamedAttribute>,
|
||||
function_ref<InFlightDiagnostic()>)>
|
||||
getRegionBuilder() {
|
||||
return nullptr;
|
||||
}
|
||||
@ -300,7 +301,8 @@ def MapOp : LinalgStructuredBase_Op<"map", [
|
||||
}
|
||||
|
||||
static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
|
||||
mlir::ArrayRef<mlir::NamedAttribute>)>
|
||||
mlir::ArrayRef<mlir::NamedAttribute>,
|
||||
function_ref<InFlightDiagnostic()>)>
|
||||
getRegionBuilder() {
|
||||
return nullptr;
|
||||
}
|
||||
@ -380,7 +382,8 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
|
||||
|
||||
// Implement functions necessary for DestinationStyleOpInterface.
|
||||
static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
|
||||
mlir::ArrayRef<mlir::NamedAttribute>)>
|
||||
mlir::ArrayRef<mlir::NamedAttribute>,
|
||||
function_ref<InFlightDiagnostic()>)>
|
||||
getRegionBuilder() {
|
||||
return nullptr;
|
||||
}
|
||||
@ -449,13 +452,14 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [
|
||||
MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
|
||||
|
||||
static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
|
||||
mlir::ArrayRef<mlir::NamedAttribute>) {
|
||||
mlir::ArrayRef<mlir::NamedAttribute>, function_ref<InFlightDiagnostic()> emitError) {
|
||||
OpBuilder::InsertionGuard guard(b);
|
||||
b.create<linalg::YieldOp>(b.getLoc(), block.getArgument(0));
|
||||
}
|
||||
|
||||
static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
|
||||
mlir::ArrayRef<mlir::NamedAttribute>)>
|
||||
mlir::ArrayRef<mlir::NamedAttribute>,
|
||||
function_ref<InFlightDiagnostic()>)>
|
||||
getRegionBuilder() {
|
||||
return regionBuilder;
|
||||
}
|
||||
@ -521,13 +525,15 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
|
||||
MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
|
||||
|
||||
static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
|
||||
mlir::ArrayRef<mlir::NamedAttribute>) {
|
||||
mlir::ArrayRef<mlir::NamedAttribute>,
|
||||
function_ref<InFlightDiagnostic()> emitError) {
|
||||
OpBuilder::InsertionGuard guard(b);
|
||||
b.create<linalg::YieldOp>(b.getLoc(), block.getArgument(0));
|
||||
}
|
||||
|
||||
static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
|
||||
mlir::ArrayRef<mlir::NamedAttribute>)>
|
||||
mlir::ArrayRef<mlir::NamedAttribute>,
|
||||
function_ref<InFlightDiagnostic()>)>
|
||||
getRegionBuilder() {
|
||||
return regionBuilder;
|
||||
}
|
||||
@ -631,10 +637,12 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
|
||||
/// Implements the block region builder for the elementwiseOp. This is
|
||||
/// called by the 'fillStructuredOpRegion'.
|
||||
static void regionBuilder(ImplicitLocOpBuilder &b,
|
||||
Block &block, ArrayRef<NamedAttribute> attrs);
|
||||
Block &block, ArrayRef<NamedAttribute> attrs,
|
||||
function_ref<InFlightDiagnostic()> emitError);
|
||||
|
||||
static std::function<void(ImplicitLocOpBuilder &,
|
||||
Block &, ArrayRef<NamedAttribute>)>
|
||||
Block &, ArrayRef<NamedAttribute>,
|
||||
function_ref<InFlightDiagnostic()>)>
|
||||
getRegionBuilder() {
|
||||
return regionBuilder;
|
||||
}
|
||||
@ -771,7 +779,8 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
|
||||
|
||||
/// Implements the block region builder.
|
||||
static void regionBuilder(ImplicitLocOpBuilder &b,
|
||||
Block &block, ArrayRef<NamedAttribute> attrs);
|
||||
Block &block, ArrayRef<NamedAttribute> attrs,
|
||||
function_ref<InFlightDiagnostic()> emitError);
|
||||
|
||||
/// Returns a list of AffineMap with the default matmul indexing charactristic.
|
||||
static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
|
||||
@ -780,7 +789,8 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
|
||||
bool isValidLhsRhsBroadcastMap(AffineMap bcastMap);
|
||||
|
||||
static std::function<void(ImplicitLocOpBuilder &,
|
||||
Block &, ArrayRef<NamedAttribute>)>
|
||||
Block &, ArrayRef<NamedAttribute>,
|
||||
function_ref<InFlightDiagnostic()>)>
|
||||
getRegionBuilder() {
|
||||
return regionBuilder;
|
||||
}
|
||||
@ -916,10 +926,12 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
|
||||
static unsigned getNumRegionArgs();
|
||||
|
||||
static void regionBuilder(ImplicitLocOpBuilder &b,
|
||||
Block &block, ArrayRef<NamedAttribute> attrs);
|
||||
Block &block, ArrayRef<NamedAttribute> attrs,
|
||||
function_ref<InFlightDiagnostic()> emitError);
|
||||
|
||||
static std::function<void(ImplicitLocOpBuilder &,
|
||||
Block &, ArrayRef<NamedAttribute>)>
|
||||
Block &, ArrayRef<NamedAttribute>,
|
||||
function_ref<InFlightDiagnostic()>)>
|
||||
getRegionBuilder() {
|
||||
return regionBuilder;
|
||||
}
|
||||
@ -1033,9 +1045,11 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
|
||||
|
||||
SmallVector<utils::IteratorType> getIteratorTypesArray();
|
||||
static void regionBuilder(ImplicitLocOpBuilder &b,
|
||||
Block &block, ArrayRef<NamedAttribute> attrs);
|
||||
Block &block, ArrayRef<NamedAttribute> attrs,
|
||||
function_ref<InFlightDiagnostic()> emitError);
|
||||
static std::function<void(ImplicitLocOpBuilder &,
|
||||
Block &, ArrayRef<NamedAttribute>)>
|
||||
Block &, ArrayRef<NamedAttribute>,
|
||||
function_ref<InFlightDiagnostic()>)>
|
||||
getRegionBuilder() {
|
||||
return regionBuilder;
|
||||
}
|
||||
@ -1161,7 +1175,8 @@ def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
|
||||
|
||||
/// Implements the block region builder.
|
||||
static void regionBuilder(ImplicitLocOpBuilder &b,
|
||||
Block &block, ArrayRef<NamedAttribute> attrs);
|
||||
Block &block, ArrayRef<NamedAttribute> attrs,
|
||||
function_ref<InFlightDiagnostic()> emitError);
|
||||
|
||||
/// Returns a list of AffineMap with the default batch_reduce_matmul indexing charactristic.
|
||||
static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
|
||||
@ -1170,7 +1185,8 @@ def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
|
||||
bool isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS = true);
|
||||
|
||||
static std::function<void(ImplicitLocOpBuilder &,
|
||||
Block &, ArrayRef<NamedAttribute>)>
|
||||
Block &, ArrayRef<NamedAttribute>,
|
||||
function_ref<InFlightDiagnostic()>)>
|
||||
getRegionBuilder() {
|
||||
return regionBuilder;
|
||||
}
|
||||
|
||||
@ -38,7 +38,7 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) {
|
||||
Region ®ion = op->getRegion(0);
|
||||
Block *body = b.createBlock(®ion, /*insertPt=*/{}, argTypes, argLocs);
|
||||
b.setInsertionPointToStart(body);
|
||||
fun(b, *body, op->getAttrs());
|
||||
fun(b, *body, op->getAttrs(), /*emitError=*/{});
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED bool mlirLinalgIsAContractionOp(MlirOperation op) {
|
||||
|
||||
@ -117,8 +117,9 @@ OpFoldResult linalg::createFoldedDimOp(OpBuilder &b, Location loc, Value source,
|
||||
// Support for named Linalg ops defined in ods-gen.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
using RegionBuilderFn = llvm::function_ref<void(ImplicitLocOpBuilder &, Block &,
|
||||
ArrayRef<NamedAttribute>)>;
|
||||
using RegionBuilderFn = llvm::function_ref<void(
|
||||
ImplicitLocOpBuilder &, Block &, ArrayRef<NamedAttribute>,
|
||||
function_ref<InFlightDiagnostic()>)>;
|
||||
|
||||
/// Fills the region of a structured operation using the provided
|
||||
/// `regionBuilder`. The method is used by both named structured ops created by
|
||||
@ -128,6 +129,7 @@ using RegionBuilderFn = llvm::function_ref<void(ImplicitLocOpBuilder &, Block &,
|
||||
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion,
|
||||
TypeRange inputTypes, TypeRange outputTypes,
|
||||
ArrayRef<NamedAttribute> attrs,
|
||||
function_ref<InFlightDiagnostic()> emitError,
|
||||
RegionBuilderFn regionBuilder) {
|
||||
SmallVector<Type, 8> argTypes;
|
||||
SmallVector<Location, 8> argLocs;
|
||||
@ -148,7 +150,7 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion,
|
||||
|
||||
opBuilder.setInsertionPointToStart(body);
|
||||
ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
|
||||
regionBuilder(b, *body, attrs);
|
||||
regionBuilder(b, *body, attrs, emitError);
|
||||
|
||||
// indexing_maps is an auto-generated method.
|
||||
|
||||
@ -184,7 +186,8 @@ static void buildStructuredOp(OpBuilder &b, OperationState &state,
|
||||
// Create and fill the region of the structured operation.
|
||||
Region ®ion = *state.addRegion();
|
||||
fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
|
||||
state.attributes.getAttrs(), regionBuilder);
|
||||
state.attributes.getAttrs(), /*emitError=*/{},
|
||||
regionBuilder);
|
||||
}
|
||||
|
||||
static void buildMatmulOp(OpBuilder &b, OperationState &state,
|
||||
@ -329,7 +332,7 @@ static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs,
|
||||
static ParseResult parseNamedStructuredOpRegion(
|
||||
OpAsmParser &parser, Region ®ion, unsigned numRegionArgs,
|
||||
TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
|
||||
RegionBuilderFn regionBuilder) {
|
||||
RegionBuilderFn regionBuilder, SMLoc loc) {
|
||||
if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
|
||||
return parser.emitError(
|
||||
parser.getCurrentLocation(),
|
||||
@ -339,9 +342,15 @@ static ParseResult parseNamedStructuredOpRegion(
|
||||
}
|
||||
|
||||
OpBuilder opBuilder(parser.getContext());
|
||||
fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, attrs,
|
||||
regionBuilder);
|
||||
return success();
|
||||
ParseResult result = success();
|
||||
fillStructuredOpRegion(
|
||||
opBuilder, region, inputTypes, outputTypes, attrs,
|
||||
[&]() {
|
||||
result = failure();
|
||||
return parser.emitError(loc);
|
||||
},
|
||||
regionBuilder);
|
||||
return result;
|
||||
}
|
||||
|
||||
static ParseResult
|
||||
@ -358,6 +367,7 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
|
||||
RegionBuilderFn regionBuilder) {
|
||||
// TODO: Enable when ods-gen supports captures.
|
||||
SmallVector<Type, 1> inputTypes, outputTypes;
|
||||
SMLoc loc = parser.getCurrentLocation();
|
||||
if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
|
||||
return failure();
|
||||
|
||||
@ -375,7 +385,7 @@ static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
|
||||
std::unique_ptr<Region> region = std::make_unique<Region>();
|
||||
if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
|
||||
outputTypes, result.attributes.getAttrs(),
|
||||
regionBuilder))
|
||||
regionBuilder, loc))
|
||||
return failure();
|
||||
result.addRegion(std::move(region));
|
||||
|
||||
@ -435,9 +445,15 @@ public:
|
||||
: builder(builder), block(block) {}
|
||||
|
||||
// Build the unary functions defined by OpDSL.
|
||||
Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
|
||||
if (!isFloatingPoint(arg))
|
||||
Value buildUnaryFn(UnaryFn unaryFn, Value arg,
|
||||
function_ref<InFlightDiagnostic()> emitError = {}) {
|
||||
if (!isFloatingPoint(arg)) {
|
||||
if (emitError) {
|
||||
emitError() << "unsupported non numeric type";
|
||||
return nullptr;
|
||||
}
|
||||
llvm_unreachable("unsupported non numeric type");
|
||||
}
|
||||
OpBuilder::InsertionGuard g(builder);
|
||||
builder.setInsertionPointToEnd(&block);
|
||||
switch (unaryFn) {
|
||||
@ -472,18 +488,34 @@ public:
|
||||
case UnaryFn::erf:
|
||||
return builder.create<math::ErfOp>(arg.getLoc(), arg);
|
||||
}
|
||||
if (emitError) {
|
||||
emitError() << "unsupported unary function";
|
||||
return nullptr;
|
||||
}
|
||||
llvm_unreachable("unsupported unary function");
|
||||
}
|
||||
|
||||
// Build the binary functions defined by OpDSL.
|
||||
Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) {
|
||||
// If emitError is provided, an error will be emitted if the operation is not
|
||||
// supported and a nullptr will be returned, otherwise an assertion will be
|
||||
// raised.
|
||||
Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1,
|
||||
function_ref<InFlightDiagnostic()> emitError = {}) {
|
||||
bool allComplex = isComplex(arg0) && isComplex(arg1);
|
||||
bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
|
||||
bool allInteger = isInteger(arg0) && isInteger(arg1);
|
||||
bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
|
||||
arg1.getType().getIntOrFloatBitWidth() == 1;
|
||||
if (!allComplex && !allFloatingPoint && !allInteger)
|
||||
if (!allComplex && !allFloatingPoint && !allInteger) {
|
||||
if (emitError) {
|
||||
emitError()
|
||||
<< "Cannot build binary Linalg operation: expects allComplex, "
|
||||
"allFloatingPoint, or allInteger, got "
|
||||
<< arg0.getType() << " and " << arg1.getType();
|
||||
return nullptr;
|
||||
}
|
||||
llvm_unreachable("unsupported non numeric type");
|
||||
}
|
||||
OpBuilder::InsertionGuard g(builder);
|
||||
builder.setInsertionPointToEnd(&block);
|
||||
switch (binaryFn) {
|
||||
@ -500,8 +532,13 @@ public:
|
||||
return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
|
||||
if (allFloatingPoint)
|
||||
return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
|
||||
if (allBool)
|
||||
if (allBool) {
|
||||
if (emitError) {
|
||||
emitError() << "unsupported operation: sub with bools";
|
||||
return nullptr;
|
||||
}
|
||||
llvm_unreachable("unsupported operation: sub with bools");
|
||||
}
|
||||
return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
|
||||
case BinaryFn::mul:
|
||||
if (allComplex)
|
||||
@ -516,12 +553,22 @@ public:
|
||||
return builder.create<complex::DivOp>(arg0.getLoc(), arg0, arg1);
|
||||
if (allFloatingPoint)
|
||||
return builder.create<arith::DivFOp>(arg0.getLoc(), arg0, arg1);
|
||||
if (allBool)
|
||||
if (allBool) {
|
||||
if (emitError) {
|
||||
emitError() << "unsupported operation: div with bools";
|
||||
return nullptr;
|
||||
}
|
||||
llvm_unreachable("unsupported operation: div with bools");
|
||||
}
|
||||
return builder.create<arith::DivSIOp>(arg0.getLoc(), arg0, arg1);
|
||||
case BinaryFn::div_unsigned:
|
||||
if (!allInteger || allBool)
|
||||
if (!allInteger || allBool) {
|
||||
if (emitError) {
|
||||
emitError() << "unsupported operation: unsigned div not on uint";
|
||||
return nullptr;
|
||||
}
|
||||
llvm_unreachable("unsupported operation: unsigned div not on uint");
|
||||
}
|
||||
return builder.create<arith::DivUIOp>(arg0.getLoc(), arg0, arg1);
|
||||
case BinaryFn::max_signed:
|
||||
assert(!allComplex);
|
||||
@ -547,12 +594,16 @@ public:
|
||||
assert(allFloatingPoint);
|
||||
return builder.create<math::PowFOp>(arg0.getLoc(), arg0, arg1);
|
||||
}
|
||||
if (emitError) {
|
||||
emitError() << "unsupported binary function";
|
||||
return nullptr;
|
||||
}
|
||||
llvm_unreachable("unsupported binary function");
|
||||
}
|
||||
|
||||
// Build the ternary functions defined by OpDSL.
|
||||
Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1,
|
||||
Value arg2) {
|
||||
Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2,
|
||||
function_ref<InFlightDiagnostic()> emitError = {}) {
|
||||
bool headBool =
|
||||
isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
|
||||
bool tailFloatingPoint =
|
||||
@ -566,17 +617,26 @@ public:
|
||||
llvm_unreachable("unsupported non numeric type");
|
||||
return builder.create<arith::SelectOp>(arg0.getLoc(), arg0, arg1, arg2);
|
||||
}
|
||||
if (emitError) {
|
||||
emitError() << "unsupported ternary function";
|
||||
return nullptr;
|
||||
}
|
||||
llvm_unreachable("unsupported ternary function");
|
||||
}
|
||||
|
||||
// Build the type functions defined by OpDSL.
|
||||
Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
|
||||
Value buildTypeFn(TypeFn typeFn, Type toType, Value operand,
|
||||
function_ref<InFlightDiagnostic()> emitError = {}) {
|
||||
switch (typeFn) {
|
||||
case TypeFn::cast_signed:
|
||||
return cast(toType, operand, false);
|
||||
case TypeFn::cast_unsigned:
|
||||
return cast(toType, operand, true);
|
||||
}
|
||||
if (emitError) {
|
||||
emitError() << "unsupported type conversion function";
|
||||
return nullptr;
|
||||
}
|
||||
llvm_unreachable("unsupported type conversion function");
|
||||
}
|
||||
|
||||
@ -617,6 +677,13 @@ private:
|
||||
OpBuilder::InsertionGuard g(builder);
|
||||
builder.setInsertionPointToEnd(&block);
|
||||
auto loc = operand.getLoc();
|
||||
if (isa<UnknownLoc>(loc)) {
|
||||
if (operand.getDefiningOp())
|
||||
loc = operand.getDefiningOp()->getLoc();
|
||||
else if (operand.getParentBlock() &&
|
||||
operand.getParentBlock()->getParentOp())
|
||||
loc = operand.getParentBlock()->getParentOp()->getLoc();
|
||||
}
|
||||
return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
|
||||
}
|
||||
|
||||
@ -3693,9 +3760,15 @@ bool MatmulOp::hasUserDefinedMaps() {
|
||||
/// Implements the block region builder for the MatmulOp. This is called by
|
||||
/// 'fillStructuredOpRegion'.
|
||||
void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
|
||||
ArrayRef<NamedAttribute> attrs) {
|
||||
assert(3 > 0 && block.getNumArguments() == 3 &&
|
||||
"MatmulOp regionBuilder expects 3 (>=0) args");
|
||||
ArrayRef<NamedAttribute> attrs,
|
||||
function_ref<InFlightDiagnostic()> emitError) {
|
||||
if (emitError && block.getNumArguments() != 3) {
|
||||
emitError() << "MatmulOp regionBuilder expects 3 args, got "
|
||||
<< block.getNumArguments();
|
||||
return;
|
||||
}
|
||||
assert(block.getNumArguments() == 3 &&
|
||||
"MatmulOp regionBuilder expects 3 args");
|
||||
RegionBuilderHelper helper(b, block);
|
||||
SmallVector<Value> yields;
|
||||
|
||||
@ -3712,9 +3785,13 @@ void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
|
||||
block.getArgument(0));
|
||||
Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
|
||||
block.getArgument(1));
|
||||
Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
|
||||
Value value4 =
|
||||
helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
|
||||
Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2, emitError);
|
||||
if (!value3)
|
||||
return;
|
||||
Value value4 = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
|
||||
value3, emitError);
|
||||
if (!value4)
|
||||
return;
|
||||
yields.push_back(value4);
|
||||
helper.yieldOutputs(yields);
|
||||
}
|
||||
@ -3842,7 +3919,13 @@ unsigned ContractOp::getNumRegionArgs() { return 3; }
|
||||
|
||||
/// Implement block region builder, which is called by 'fillStructuredOpRegion'.
|
||||
void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
|
||||
ArrayRef<NamedAttribute> attrs) {
|
||||
ArrayRef<NamedAttribute> attrs,
|
||||
function_ref<InFlightDiagnostic()> emitError) {
|
||||
if (emitError && block.getNumArguments() != 3) {
|
||||
emitError() << "ContractOp regionBuilder expects 3 args, got "
|
||||
<< block.getNumArguments();
|
||||
return;
|
||||
}
|
||||
assert(block.getNumArguments() == 3 &&
|
||||
"ContractOp regionBuilder expects 3 args");
|
||||
RegionBuilderHelper helper(b, block);
|
||||
@ -3862,10 +3945,14 @@ void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
|
||||
helper.buildTypeFn(castSignedness, outType, block.getArgument(0));
|
||||
Value rhsAtOutType =
|
||||
helper.buildTypeFn(castSignedness, outType, block.getArgument(1));
|
||||
Value productAtOutType =
|
||||
helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType, rhsAtOutType);
|
||||
Value productAtOutType = helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType,
|
||||
rhsAtOutType, emitError);
|
||||
if (!productAtOutType)
|
||||
return;
|
||||
Value result = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
|
||||
productAtOutType);
|
||||
productAtOutType, emitError);
|
||||
if (!result)
|
||||
return;
|
||||
helper.yieldOutputs({result});
|
||||
}
|
||||
|
||||
@ -4057,10 +4144,16 @@ bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
|
||||
return isValid;
|
||||
}
|
||||
|
||||
void BatchMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
|
||||
ArrayRef<NamedAttribute> attrs) {
|
||||
void BatchMatmulOp::regionBuilder(
|
||||
ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs,
|
||||
function_ref<InFlightDiagnostic()> emitError) {
|
||||
if (emitError && block.getNumArguments() != 3) {
|
||||
emitError() << "BatchMatmulOp regionBuilder expects 3 args, got "
|
||||
<< block.getNumArguments();
|
||||
return;
|
||||
}
|
||||
assert(block.getNumArguments() == 3 &&
|
||||
"BatchMatmulOp regionBuilder expects 3 (>=0) args");
|
||||
"BatchMatmulOp regionBuilder expects 3 args");
|
||||
RegionBuilderHelper helper(b, block);
|
||||
SmallVector<Value> yields;
|
||||
|
||||
@ -4332,8 +4425,9 @@ LogicalResult ElementwiseOp::verify() {
|
||||
|
||||
/// Implements the block region builder for the ElementwiseOp. This is called by
|
||||
/// 'fillStructuredOpRegion'.
|
||||
void ElementwiseOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
|
||||
ArrayRef<NamedAttribute> attrs) {
|
||||
void ElementwiseOp::regionBuilder(
|
||||
ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs,
|
||||
function_ref<InFlightDiagnostic()> emitError) {
|
||||
ElementwiseKind elemwiseKind;
|
||||
for (auto attr : attrs) {
|
||||
if (attr.getName() == b.getStringAttr("kind")) {
|
||||
@ -4347,6 +4441,13 @@ void ElementwiseOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
|
||||
ArityGroupAndKind groupAndKind = getArityGroupAndKind(elemwiseKind);
|
||||
auto arityGroup = groupAndKind.arityGroup;
|
||||
auto kind = groupAndKind.kind;
|
||||
if (emitError && block.getNumArguments() !=
|
||||
getArityGroupAsUInt(arityGroup) + 1 /*output*/) {
|
||||
emitError() << "Elementwise regionBuilder expects "
|
||||
<< (getArityGroupAsUInt(arityGroup) + 1) << " args, got "
|
||||
<< block.getNumArguments();
|
||||
return;
|
||||
}
|
||||
assert(block.getNumArguments() ==
|
||||
getArityGroupAsUInt(arityGroup) + 1 /*output*/
|
||||
&& "Elementwise regionBuilder number of block args mismatch");
|
||||
@ -5530,10 +5631,16 @@ bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
|
||||
return isValid;
|
||||
}
|
||||
|
||||
void BatchReduceMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
|
||||
ArrayRef<NamedAttribute> attrs) {
|
||||
void BatchReduceMatmulOp::regionBuilder(
|
||||
ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs,
|
||||
function_ref<InFlightDiagnostic()> emitError) {
|
||||
if (emitError && block.getNumArguments() != 3) {
|
||||
emitError() << "BatchReduceMatmulOp regionBuilder expects 3 args, got "
|
||||
<< block.getNumArguments();
|
||||
return;
|
||||
}
|
||||
assert(block.getNumArguments() == 3 &&
|
||||
"BatchReduceMatmulOp regionBuilder expects 3 (>=0) args");
|
||||
"BatchReduceMatmulOp regionBuilder expects 3 args");
|
||||
RegionBuilderHelper helper(b, block);
|
||||
SmallVector<Value> yields;
|
||||
|
||||
|
||||
@ -1868,9 +1868,51 @@ func.func @unpack_static_inner_tile_size_and_dynamic_output_shape(
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// linalg.reduce
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
func.func @reduce_non_operation_name(%arg0: tensor<4xf32>, %arg1: tensor<f32>) -> tensor<f32> {
|
||||
// expected-error @below {{expected bare identifier or keyword}}
|
||||
%0 = linalg.reduce {@reduce_fusion_elementwise} ins(
|
||||
%arg0: tensor<4xf32>) outs(%arg1: tensor<f32>) dimensions = [0]
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Tests for generic infrastructure for named Ops. The actual Ops used are
|
||||
// secondary - we merely want to ensure that the diagnostic infra triggers
|
||||
// correctly.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
module {
|
||||
func.func @add_invalid_mixed_types(%in_f32: memref<3xf32>, %in_i32 : memref< 3xi32>, %out_f32: memref<3xf32>, %arg3: memref<3xf32>) {
|
||||
// expected-error @below {{Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got 'f32' and 'i32'}}
|
||||
linalg.add ins(%in_f32, %in_i32 : memref<3xf32>, memref< 3xi32>) outs(%out_f32 : memref<3xf32>)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @elemwise_unary_invalid_mixed_types(%arg0 : tensor<?xi32>) -> tensor<?xi32> {
|
||||
// expected-error @below {{unsupported non numeric type}}
|
||||
%0 = linalg.elemwise_unary ins(%arg0 : tensor<?xi32>) outs(%arg0 : tensor<?xi32>) -> tensor<?xi32>
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @matmul_invalid_mixed_types(%t: tensor<?xf16>, %f: vector<4xf16>)
|
||||
-> (tensor<?xf16>, vector<4xf16>)
|
||||
{
|
||||
// expected-warning @unknown {{could not cast operand of type 'f16' to 'vector<4xf16>'}}
|
||||
// expected-error @below {{Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got 'vector<4xf16>' and 'f16'}}
|
||||
%0 = linalg.matmul ins(%t, %t : tensor<?xf16>, tensor<?xf16>)
|
||||
outs(%f : vector<4xf16>) -> tensor<?xf16>
|
||||
func.return %0, %f : tensor<?xf16>, vector<4xf16>
|
||||
}
|
||||
|
||||
@ -2737,12 +2737,14 @@ def TestLinalgConvOp :
|
||||
bool hasIndexSemantics() { return false; }
|
||||
|
||||
static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
|
||||
mlir::ArrayRef<mlir::NamedAttribute> attrs) {
|
||||
mlir::ArrayRef<mlir::NamedAttribute> attrs,
|
||||
llvm::function_ref<mlir::InFlightDiagnostic()> emitError) {
|
||||
b.create<mlir::linalg::YieldOp>(block.getArguments().back());
|
||||
}
|
||||
|
||||
static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
|
||||
mlir::ArrayRef<mlir::NamedAttribute>)>
|
||||
mlir::ArrayRef<mlir::NamedAttribute>,
|
||||
llvm::function_ref<mlir::InFlightDiagnostic()>)>
|
||||
getRegionBuilder() {
|
||||
return ®ionBuilder;
|
||||
}
|
||||
@ -2798,12 +2800,14 @@ def TestLinalgFillOp :
|
||||
bool hasIndexSemantics() { return false; }
|
||||
|
||||
static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
|
||||
mlir::ArrayRef<mlir::NamedAttribute> attrs) {
|
||||
mlir::ArrayRef<mlir::NamedAttribute> attrs,
|
||||
llvm::function_ref<mlir::InFlightDiagnostic()> emitError) {
|
||||
b.create<mlir::linalg::YieldOp>(block.getArguments().back());
|
||||
}
|
||||
|
||||
static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
|
||||
mlir::ArrayRef<mlir::NamedAttribute>)>
|
||||
mlir::ArrayRef<mlir::NamedAttribute>,
|
||||
llvm::function_ref<mlir::InFlightDiagnostic()>)>
|
||||
getRegionBuilder() {
|
||||
return ®ionBuilder;
|
||||
}
|
||||
|
||||
@ -87,7 +87,8 @@ structured_op: !LinalgStructuredOpConfig
|
||||
# ODS-NEXT: }
|
||||
|
||||
# IMPL-LABEL: void Test1Op::regionBuilder(ImplicitLocOpBuilder &b,
|
||||
# IMPL-NEXT: Block &block, ArrayRef<NamedAttribute> attrs)
|
||||
# IMPL-NEXT: Block &block, ArrayRef<NamedAttribute> attrs,
|
||||
# IMPL-NEXT: function_ref<InFlightDiagnostic()> emitError)
|
||||
# IMPL: TypeFn castVal = TypeFn::cast_signed;
|
||||
# IMPL-NEXT: auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
|
||||
# IMPL-NEXT: return attr.getName() == "cast"; });
|
||||
@ -97,10 +98,10 @@ structured_op: !LinalgStructuredOpConfig
|
||||
# IMPL-NEXT: }
|
||||
|
||||
# IMPL: Value [[VAL0:[a-z0-9]+]] = helper.constant("42 : i64");
|
||||
# IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL0]]);
|
||||
# IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL0]], emitError);
|
||||
# IMPL-DAG: Value [[VAL2:[a-z0-9]+]] = helper.index(1);
|
||||
# IMPL-DAG: Value [[VAL3:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL2]]);
|
||||
# IMPL-DAG: Value [[VAL4:[a-z0-9]+]] = helper.buildBinaryFn(BinaryFn::add, [[VAL1]], [[VAL3]]);
|
||||
# IMPL-DAG: Value [[VAL3:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL2]], emitError);
|
||||
# IMPL-DAG: Value [[VAL4:[a-z0-9]+]] = helper.buildBinaryFn(BinaryFn::add, [[VAL1]], [[VAL3]], emitError);
|
||||
|
||||
|
||||
# @linalg_structured_op
|
||||
@ -186,7 +187,8 @@ structured_op: !LinalgStructuredOpConfig
|
||||
# IMPL: "incorrect element type for index attribute 'strides'"
|
||||
# IMPL: "incorrect shape for index attribute 'strides'"
|
||||
# IMPL: void Test2Op::regionBuilder(ImplicitLocOpBuilder &b,
|
||||
# IMPL-NEXT: Block &block, ArrayRef<NamedAttribute> attrs)
|
||||
# IMPL-NEXT: Block &block, ArrayRef<NamedAttribute> attrs,
|
||||
# IMPL-NEXT: function_ref<InFlightDiagnostic()> emitError)
|
||||
# IMPL-NEXT: assert(2 > 0 && block.getNumArguments() == 2 &&
|
||||
|
||||
# IMPL: yields.push_back(block.getArgument(0));
|
||||
@ -315,13 +317,18 @@ structured_op: !LinalgStructuredOpConfig
|
||||
# ODS-NEXT: $_state.addAttribute("binary_fun", binary_fun)
|
||||
|
||||
# IMPL-LABEL: void Test4Op::regionBuilder(ImplicitLocOpBuilder &b,
|
||||
# IMPL-NEXT: Block &block, ArrayRef<NamedAttribute> attrs)
|
||||
# IMPL-NEXT: Block &block, ArrayRef<NamedAttribute> attrs,
|
||||
# IMPL-NEXT: function_ref<InFlightDiagnostic()> emitError)
|
||||
# IMPL: UnaryFn unary_funVal = UnaryFn::exp
|
||||
# IMPL: BinaryFn binary_funVal = BinaryFn::add
|
||||
|
||||
# IMPL: Value [[VAL0:[a-z0-9]+]] = helper.buildUnaryFn(unary_funVal, block.getArgument(0))
|
||||
# IMPL-NEXT: Value [[VAL1:[a-z0-9]+]] = helper.buildBinaryFn(binary_funVal, [[VAL0]], block.getArgument(0))
|
||||
# IMPL-NEXT: yields.push_back([[VAL1]])
|
||||
# IMPL: Value [[VAL0:[a-z0-9]+]] = helper.buildUnaryFn(unary_funVal, block.getArgument(0), emitError);
|
||||
# IMPL-NEXT: if (![[VAL0]])
|
||||
# IMPL-NEXT: return;
|
||||
# IMPL: Value [[VAL1:[a-z0-9]+]] = helper.buildBinaryFn(binary_funVal, [[VAL0]], block.getArgument(0), emitError);
|
||||
# IMPL-NEXT: if (![[VAL1]])
|
||||
# IMPL-NEXT: return;
|
||||
# IMPL: yields.push_back([[VAL1]])
|
||||
|
||||
# @linalg_structured_op
|
||||
# def test5(value=ScalarDef(T1), O=TensorDef(U, output=True)):
|
||||
|
||||
@ -559,9 +559,10 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments],
|
||||
SmallVector<utils::IteratorType> getIteratorTypesArray();
|
||||
ArrayAttr getIndexingMaps();
|
||||
static void regionBuilder(ImplicitLocOpBuilder &b,
|
||||
Block &block, ArrayRef<NamedAttribute> attrs);
|
||||
Block &block, ArrayRef<NamedAttribute> attrs,
|
||||
function_ref<InFlightDiagnostic()> emitError);
|
||||
static std::function<void(ImplicitLocOpBuilder &,
|
||||
Block &, ArrayRef<NamedAttribute>)>
|
||||
Block &, ArrayRef<NamedAttribute>, function_ref<InFlightDiagnostic()> emitError)>
|
||||
getRegionBuilder() {{
|
||||
return regionBuilder;
|
||||
}
|
||||
@ -1010,7 +1011,8 @@ LogicalResult {0}::verifyIndexingMapRequiredAttributes() {{
|
||||
// {3}: Statements
|
||||
static const char structuredOpRegionBuilderFormat[] = R"FMT(
|
||||
void {0}::regionBuilder(ImplicitLocOpBuilder &b,
|
||||
Block &block, ArrayRef<NamedAttribute> attrs) {{
|
||||
Block &block, ArrayRef<NamedAttribute> attrs,
|
||||
function_ref<InFlightDiagnostic()> emitError) {{
|
||||
assert({1} > 0 && block.getNumArguments() == {1} &&
|
||||
"{0} regionBuilder expects {1} (>=0) args");
|
||||
RegionBuilderHelper helper(b, block);
|
||||
@ -1137,8 +1139,13 @@ void {0}::regionBuilder(ImplicitLocOpBuilder &b,
|
||||
// Call the function builder.
|
||||
std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
|
||||
stmts.push_back(llvm::formatv(
|
||||
"Value {0} = helper.build{1}({2}, {3});", cppIdent, enumName,
|
||||
funcType, interleaveToString(operandCppValues, ", ")));
|
||||
R"mlir(
|
||||
Value {0} = helper.build{1}({2}, {3}, emitError);
|
||||
if (!{0})
|
||||
return;
|
||||
)mlir",
|
||||
cppIdent, enumName, funcType,
|
||||
interleaveToString(operandCppValues, ", ")));
|
||||
return cppIdent;
|
||||
}
|
||||
emitError(genContext.getLoc()) << "unknown ScalarExpression type";
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user