[mlir][tosa] Enhance TosaInferShapes pass for simple shape inference (#178418)
This commit enhances the TosaInferShapes pass with two new options: - fold-shape-expressions - convert-function-boundaries The "fold-shape-expressions" option enables greedily folding the newly added TOSA shape operations when possible. Folding these operations directly within TosaInferShapes is useful since it allows shapes of later operations to be inferred in a single pass. The "convert-function-boundaries" updates the return types of a function to the newly inferred output shapes. This avoids the need for additional tensor.cast operations at function boundaries. This option is particularly useful when wanting to resolve a dynamic function to fully static. When both of these options are used in conjunction with the "tosa-input-shapes" pass option, it's possible to resolve a dynamic function to static in a single pass. Note: This PR is split into two commits. [68888db](68888dbb1e) is a simple refactor and consists of no logic changes. [db9a6a3](db9a6a323a) includes the changes for shape inference.
This commit is contained in:
parent
0ac8e41ee1
commit
6874e945fe
@ -42,6 +42,16 @@ def TosaInferShapesPass : Pass<"tosa-infer-shapes", "func::FuncOp"> {
|
||||
"tensor::TensorDialect",
|
||||
"tosa::TosaDialect",
|
||||
];
|
||||
|
||||
let options = [
|
||||
Option<"foldShapeExpressions", "fold-shape-expressions", "bool",
|
||||
/*default=*/"false",
|
||||
"Fold TOSA shape operations when they have known input values">,
|
||||
Option<"convertFunctionBoundaries", "convert-function-boundaries", "bool",
|
||||
/*default=*/"false",
|
||||
"If enabled, the pass will convert function I/O types as well. Otherwise casts will"
|
||||
"be inserted at the I/O boundaries.">
|
||||
];
|
||||
}
|
||||
|
||||
def TosaMakeBroadcastablePass
|
||||
|
||||
@ -18,8 +18,10 @@
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Iterators.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace tosa {
|
||||
@ -128,179 +130,6 @@ private:
|
||||
llvm::SmallVector<std::pair<Value, Type>> oldTypes;
|
||||
};
|
||||
|
||||
void propagateShapesInRegion(Region ®ion, TypeModificationState &state);
|
||||
|
||||
void propagateShapesToTosaIf(Operation &op, TypeModificationState &state) {
|
||||
IfOp ifOp = dyn_cast<IfOp>(op);
|
||||
if (!ifOp)
|
||||
return;
|
||||
|
||||
for (auto ®ion : op.getRegions()) {
|
||||
Block &frontBlock = region.front();
|
||||
if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands())
|
||||
return;
|
||||
|
||||
for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) {
|
||||
auto inferredTy = cast<ShapedType>(op.getOperand(i).getType());
|
||||
auto blockArg = frontBlock.getArgument(i - 1);
|
||||
auto oldType = cast<ShapedType>(blockArg.getType());
|
||||
|
||||
if (inferredTy.hasRank()) {
|
||||
Type newType = oldType.clone(inferredTy.getShape());
|
||||
state.setType(blockArg, newType);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) {
|
||||
ValueKnowledge operandKnowledge = ValueKnowledge::getKnowledgeFromType(
|
||||
ifOp.getOperand(i + 1).getType());
|
||||
ValueKnowledge blockKnowledge = ValueKnowledge::getKnowledgeFromType(
|
||||
frontBlock.getArgument(i).getType());
|
||||
ValueKnowledge joinedKnowledge =
|
||||
ValueKnowledge::join(operandKnowledge, blockKnowledge);
|
||||
if (!joinedKnowledge)
|
||||
continue;
|
||||
state.setType(frontBlock.getArgument(i), joinedKnowledge.getType());
|
||||
}
|
||||
|
||||
propagateShapesInRegion(region, state);
|
||||
}
|
||||
}
|
||||
|
||||
void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
|
||||
WhileOp whileOp = dyn_cast<WhileOp>(op);
|
||||
if (!whileOp)
|
||||
return;
|
||||
|
||||
// Determine what the expected argument types are to the cond/body blocks.
|
||||
// The expected arguments should be compatible with ever iteration of the
|
||||
// loop body / condition for tosa.while.
|
||||
SmallVector<Type> argTypes = llvm::to_vector(op.getOperandTypes());
|
||||
|
||||
bool hasNewTypes = true;
|
||||
while (hasNewTypes) {
|
||||
TypeModificationState localState;
|
||||
|
||||
// Set types on the block args.
|
||||
Region &bodyRegion = op.getRegion(1);
|
||||
Block &block = bodyRegion.front();
|
||||
for (int i = 0, s = argTypes.size(); i < s; i++) {
|
||||
localState.setType(block.getArgument(i), argTypes[i]);
|
||||
}
|
||||
|
||||
// Propagate to the end.
|
||||
propagateShapesInRegion(bodyRegion, localState);
|
||||
|
||||
// Find all the tosa yield types and verify there is a single one.
|
||||
llvm::SmallVector<YieldOp> yieldOps;
|
||||
for (auto &block : bodyRegion)
|
||||
if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator()))
|
||||
yieldOps.push_back(yieldOp);
|
||||
|
||||
assert(yieldOps.size() == 1 && "missing or non-unique yield op");
|
||||
// Using the new tosa.yield operand types, infer the new subtypes.
|
||||
llvm::SmallVector<ValueKnowledge> yieldTypeInfo;
|
||||
for (auto ty : argTypes) {
|
||||
yieldTypeInfo.push_back(ValueKnowledge::getKnowledgeFromType(ty));
|
||||
}
|
||||
|
||||
for (auto yieldOp : yieldOps) {
|
||||
for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
|
||||
auto newKnowledge =
|
||||
ValueKnowledge::getKnowledgeFromType(it.value().getType());
|
||||
yieldTypeInfo[it.index()] =
|
||||
ValueKnowledge::meet(yieldTypeInfo[it.index()], newKnowledge);
|
||||
}
|
||||
}
|
||||
|
||||
// This should never happen.
|
||||
if (yieldTypeInfo.size() != argTypes.size()) {
|
||||
op.emitWarning("has a tosa.yield with the incorrect number of operands");
|
||||
return;
|
||||
}
|
||||
|
||||
// Determine the new block args and see if any changed.
|
||||
hasNewTypes = false;
|
||||
for (int i = 0, s = yieldTypeInfo.size(); i < s; i++) {
|
||||
Type newType = yieldTypeInfo[i].getType();
|
||||
hasNewTypes |= (newType != argTypes[i]);
|
||||
argTypes[i] = newType;
|
||||
}
|
||||
|
||||
// Roll back all changes made during the speculative part of the algorithm.
|
||||
localState.rollBack();
|
||||
}
|
||||
|
||||
// We now set the block arguments according to the most recent shape
|
||||
// inference results. This gives us the block arg types for the next
|
||||
// iteration.
|
||||
for (auto ®ion : op.getRegions()) {
|
||||
for (unsigned int i = 0, s = argTypes.size(); i < s; i++) {
|
||||
state.setType(region.front().getArgument(i), argTypes[i]);
|
||||
}
|
||||
|
||||
propagateShapesInRegion(region, state);
|
||||
}
|
||||
}
|
||||
|
||||
void propagateShapesInRegion(Region ®ion, TypeModificationState &state) {
|
||||
Dialect *tosaDialect = region.getContext()->getLoadedDialect<TosaDialect>();
|
||||
|
||||
for (auto &block : region) {
|
||||
for (Operation &op : block) {
|
||||
if (op.getDialect() != tosaDialect)
|
||||
continue;
|
||||
|
||||
propagateShapesToTosaIf(op, state);
|
||||
propagateShapesToTosaWhile(op, state);
|
||||
|
||||
InferShapedTypeOpInterface shapeInterface =
|
||||
dyn_cast<InferShapedTypeOpInterface>(op);
|
||||
if (!shapeInterface)
|
||||
continue;
|
||||
|
||||
SmallVector<ShapedTypeComponents> returnedShapes;
|
||||
|
||||
if (shapeInterface
|
||||
.inferReturnTypeComponents(
|
||||
op.getContext(), op.getLoc(), op.getOperands(),
|
||||
op.getDiscardableAttrDictionary(), op.getPropertiesStorage(),
|
||||
op.getRegions(), returnedShapes)
|
||||
.succeeded()) {
|
||||
for (auto it : llvm::zip(op.getResults(), returnedShapes)) {
|
||||
Value result = std::get<0>(it);
|
||||
ShapedTypeComponents predictedShape = std::get<1>(it);
|
||||
|
||||
// Determine the knowledge based on the output type.
|
||||
// TODO: should also query WIP type probably
|
||||
Type resultTy = result.getType();
|
||||
auto currentKnowledge =
|
||||
ValueKnowledge::getKnowledgeFromType(resultTy);
|
||||
|
||||
// Compute the knowledge based on the inferred type.
|
||||
auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
|
||||
inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType();
|
||||
inferredKnowledge.hasRank = predictedShape.hasRank();
|
||||
if (predictedShape.hasRank()) {
|
||||
for (auto dim : predictedShape.getDims()) {
|
||||
inferredKnowledge.sizes.push_back(dim);
|
||||
}
|
||||
}
|
||||
|
||||
// Compute the new type based on the joined version.
|
||||
auto newKnowledge =
|
||||
ValueKnowledge::join(currentKnowledge, inferredKnowledge);
|
||||
if (!newKnowledge)
|
||||
continue;
|
||||
|
||||
// Set new type
|
||||
state.setType(result, newKnowledge.getType());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Recursively validate tosa ops with SameOperandsAndResultRank trait in region
|
||||
/// and all nested regions
|
||||
void validateSameOperandsAndResultRankTrait(Region ®ion) {
|
||||
@ -333,13 +162,261 @@ void validateSameOperandsAndResultRankTrait(Region ®ion) {
|
||||
struct TosaInferShapes
|
||||
: public tosa::impl::TosaInferShapesPassBase<TosaInferShapes> {
|
||||
public:
|
||||
explicit TosaInferShapes() = default;
|
||||
explicit TosaInferShapes(const TosaInferShapesPassOptions &options)
|
||||
: TosaInferShapes() {
|
||||
this->foldShapeExpressions = options.foldShapeExpressions;
|
||||
this->convertFunctionBoundaries = options.convertFunctionBoundaries;
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
func::FuncOp func = getOperation();
|
||||
TypeModificationState state;
|
||||
propagateShapesInRegion(func.getBody(), state);
|
||||
state.commit();
|
||||
|
||||
if (foldShapeExpressions) {
|
||||
// Folding shape expressions may leave dead tosa.const_shape operations
|
||||
func.walk<WalkOrder::PostOrder, ReverseIterator>(
|
||||
[](tosa::ConstShapeOp op) {
|
||||
if (isOpTriviallyDead(op))
|
||||
op->erase();
|
||||
});
|
||||
}
|
||||
|
||||
validateSameOperandsAndResultRankTrait(func.getBody());
|
||||
|
||||
if (convertFunctionBoundaries)
|
||||
convertFunctionReturnTypes(func);
|
||||
}
|
||||
|
||||
private:
|
||||
void propagateShapesToTosaIf(Operation &op, TypeModificationState &state) {
|
||||
IfOp ifOp = dyn_cast<IfOp>(op);
|
||||
if (!ifOp)
|
||||
return;
|
||||
|
||||
for (auto ®ion : op.getRegions()) {
|
||||
Block &frontBlock = region.front();
|
||||
if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands())
|
||||
return;
|
||||
|
||||
for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) {
|
||||
auto inferredTy = cast<ShapedType>(op.getOperand(i).getType());
|
||||
auto blockArg = frontBlock.getArgument(i - 1);
|
||||
auto oldType = cast<ShapedType>(blockArg.getType());
|
||||
|
||||
if (inferredTy.hasRank()) {
|
||||
Type newType = oldType.clone(inferredTy.getShape());
|
||||
state.setType(blockArg, newType);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) {
|
||||
ValueKnowledge operandKnowledge = ValueKnowledge::getKnowledgeFromType(
|
||||
ifOp.getOperand(i + 1).getType());
|
||||
ValueKnowledge blockKnowledge = ValueKnowledge::getKnowledgeFromType(
|
||||
frontBlock.getArgument(i).getType());
|
||||
ValueKnowledge joinedKnowledge =
|
||||
ValueKnowledge::join(operandKnowledge, blockKnowledge);
|
||||
if (!joinedKnowledge)
|
||||
continue;
|
||||
state.setType(frontBlock.getArgument(i), joinedKnowledge.getType());
|
||||
}
|
||||
|
||||
propagateShapesInRegion(region, state);
|
||||
}
|
||||
}
|
||||
|
||||
void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
|
||||
WhileOp whileOp = dyn_cast<WhileOp>(op);
|
||||
if (!whileOp)
|
||||
return;
|
||||
|
||||
// Determine what the expected argument types are to the cond/body blocks.
|
||||
// The expected arguments should be compatible with ever iteration of the
|
||||
// loop body / condition for tosa.while.
|
||||
SmallVector<Type> argTypes = llvm::to_vector(op.getOperandTypes());
|
||||
|
||||
bool hasNewTypes = true;
|
||||
while (hasNewTypes) {
|
||||
TypeModificationState localState;
|
||||
|
||||
// Set types on the block args.
|
||||
Region &bodyRegion = op.getRegion(1);
|
||||
Block &block = bodyRegion.front();
|
||||
for (int i = 0, s = argTypes.size(); i < s; i++) {
|
||||
localState.setType(block.getArgument(i), argTypes[i]);
|
||||
}
|
||||
|
||||
// Propagate to the end.
|
||||
propagateShapesInRegion(bodyRegion, localState);
|
||||
|
||||
// Find all the tosa yield types and verify there is a single one.
|
||||
llvm::SmallVector<YieldOp> yieldOps;
|
||||
for (auto &block : bodyRegion)
|
||||
if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator()))
|
||||
yieldOps.push_back(yieldOp);
|
||||
|
||||
assert(yieldOps.size() == 1 && "missing or non-unique yield op");
|
||||
// Using the new tosa.yield operand types, infer the new subtypes.
|
||||
llvm::SmallVector<ValueKnowledge> yieldTypeInfo;
|
||||
for (auto ty : argTypes) {
|
||||
yieldTypeInfo.push_back(ValueKnowledge::getKnowledgeFromType(ty));
|
||||
}
|
||||
|
||||
for (auto yieldOp : yieldOps) {
|
||||
for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
|
||||
auto newKnowledge =
|
||||
ValueKnowledge::getKnowledgeFromType(it.value().getType());
|
||||
yieldTypeInfo[it.index()] =
|
||||
ValueKnowledge::meet(yieldTypeInfo[it.index()], newKnowledge);
|
||||
}
|
||||
}
|
||||
|
||||
// This should never happen.
|
||||
if (yieldTypeInfo.size() != argTypes.size()) {
|
||||
op.emitWarning(
|
||||
"has a tosa.yield with the incorrect number of operands");
|
||||
return;
|
||||
}
|
||||
|
||||
// Determine the new block args and see if any changed.
|
||||
hasNewTypes = false;
|
||||
for (int i = 0, s = yieldTypeInfo.size(); i < s; i++) {
|
||||
Type newType = yieldTypeInfo[i].getType();
|
||||
hasNewTypes |= (newType != argTypes[i]);
|
||||
argTypes[i] = newType;
|
||||
}
|
||||
|
||||
// Roll back all changes made during the speculative part of the
|
||||
// algorithm.
|
||||
localState.rollBack();
|
||||
}
|
||||
|
||||
// We now set the block arguments according to the most recent shape
|
||||
// inference results. This gives us the block arg types for the next
|
||||
// iteration.
|
||||
for (auto ®ion : op.getRegions()) {
|
||||
for (unsigned int i = 0, s = argTypes.size(); i < s; i++) {
|
||||
state.setType(region.front().getArgument(i), argTypes[i]);
|
||||
}
|
||||
|
||||
propagateShapesInRegion(region, state);
|
||||
}
|
||||
}
|
||||
|
||||
void propagateShapesInRegion(Region ®ion, TypeModificationState &state) {
|
||||
MLIRContext *ctx = region.getContext();
|
||||
Dialect *tosaDialect = ctx->getLoadedDialect<TosaDialect>();
|
||||
OperationFolder folder(ctx);
|
||||
|
||||
for (auto &block : region) {
|
||||
// The loop body may erase operations, so we need to be careful
|
||||
// when iterating. Fetch the next operation before the current
|
||||
// operation is modified.
|
||||
for (auto it = block.begin(); it != block.end();) {
|
||||
Operation &op = *it++;
|
||||
if (op.getDialect() != tosaDialect)
|
||||
continue;
|
||||
|
||||
propagateShapesToTosaIf(op, state);
|
||||
propagateShapesToTosaWhile(op, state);
|
||||
|
||||
if (foldShapeExpressions &&
|
||||
op.hasTrait<OpTrait::tosa::TosaShapeOperator>()) {
|
||||
(void)folder.tryToFold(&op);
|
||||
continue;
|
||||
}
|
||||
|
||||
InferShapedTypeOpInterface shapeInterface =
|
||||
dyn_cast<InferShapedTypeOpInterface>(op);
|
||||
if (!shapeInterface)
|
||||
continue;
|
||||
|
||||
SmallVector<ShapedTypeComponents> returnedShapes;
|
||||
|
||||
if (shapeInterface
|
||||
.inferReturnTypeComponents(
|
||||
op.getContext(), op.getLoc(), op.getOperands(),
|
||||
op.getDiscardableAttrDictionary(),
|
||||
op.getPropertiesStorage(), op.getRegions(), returnedShapes)
|
||||
.succeeded()) {
|
||||
for (auto it : llvm::zip(op.getResults(), returnedShapes)) {
|
||||
Value result = std::get<0>(it);
|
||||
ShapedTypeComponents predictedShape = std::get<1>(it);
|
||||
|
||||
// Determine the knowledge based on the output type.
|
||||
// TODO: should also query WIP type probably
|
||||
Type resultTy = result.getType();
|
||||
auto currentKnowledge =
|
||||
ValueKnowledge::getKnowledgeFromType(resultTy);
|
||||
|
||||
// Compute the knowledge based on the inferred type.
|
||||
auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
|
||||
inferredKnowledge.dtype =
|
||||
cast<ShapedType>(resultTy).getElementType();
|
||||
inferredKnowledge.hasRank = predictedShape.hasRank();
|
||||
if (predictedShape.hasRank()) {
|
||||
for (auto dim : predictedShape.getDims()) {
|
||||
inferredKnowledge.sizes.push_back(dim);
|
||||
}
|
||||
}
|
||||
|
||||
// Compute the new type based on the joined version.
|
||||
auto newKnowledge =
|
||||
ValueKnowledge::join(currentKnowledge, inferredKnowledge);
|
||||
if (!newKnowledge)
|
||||
continue;
|
||||
|
||||
// Set new type
|
||||
state.setType(result, newKnowledge.getType());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void convertFunctionReturnTypes(func::FuncOp func) {
|
||||
IRRewriter rewriter(func.getContext());
|
||||
SmallVector<Type> newReturnTypes;
|
||||
|
||||
// Rewrite func.return ops, removing dead tensor.cast ops if possible
|
||||
func.walk([&rewriter, &newReturnTypes](func::ReturnOp ret) {
|
||||
SmallVector<Value> newReturnValues;
|
||||
SmallVector<Value> maybeDeadCasts;
|
||||
OperandRange returnOperands = ret.getOperands();
|
||||
newReturnValues.reserve(returnOperands.size());
|
||||
maybeDeadCasts.reserve(returnOperands.size());
|
||||
newReturnTypes.reserve(newReturnTypes.size() + returnOperands.size());
|
||||
|
||||
for (const Value &v : returnOperands) {
|
||||
Value newReturnValue = v;
|
||||
if (auto castOp = v.getDefiningOp<tensor::CastOp>()) {
|
||||
newReturnValue = castOp.getSource();
|
||||
maybeDeadCasts.push_back(castOp);
|
||||
}
|
||||
newReturnValues.push_back(newReturnValue);
|
||||
newReturnTypes.push_back(newReturnValue.getType());
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(ret);
|
||||
rewriter.replaceOpWithNewOp<func::ReturnOp>(ret, newReturnValues);
|
||||
|
||||
if (!maybeDeadCasts.empty()) {
|
||||
llvm::for_each(maybeDeadCasts, [&](Value castVal) {
|
||||
if (castVal.use_empty()) {
|
||||
rewriter.eraseOp(castVal.getDefiningOp());
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
// Update function return types with newly inferred types
|
||||
const FunctionType oldType = func.getFunctionType();
|
||||
const FunctionType newType = FunctionType::get(
|
||||
func.getContext(), oldType.getInputs(), newReturnTypes);
|
||||
func.setType(newType);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
@ -0,0 +1,62 @@
|
||||
// RUN: mlir-opt --split-input-file --verify-diagnostics --tosa-infer-shapes="fold-shape-expressions" %s | FileCheck %s --check-prefixes=CHECK,DEFAULT
|
||||
// RUN: mlir-opt --split-input-file --verify-diagnostics --tosa-infer-shapes="convert-function-boundaries fold-shape-expressions" %s | FileCheck %s --check-prefixes=CHECK,FUNCBOUND
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: test_simple_shape_expression
|
||||
func.func @test_simple_shape_expression(%arg0: tensor<7x12xi32>, %arg1: tensor<80xi32>, %arg2: tensor<4xi32>) -> tensor<?xi32> {
|
||||
// CHECK-NOT: tosa.dim
|
||||
// CHECK-NOT: tosa.add_shape
|
||||
// CHECK: %[[SHAPE:.+]] = tosa.const_shape {values = dense<84> : tensor<1xindex>} : () -> !tosa.shape<1>
|
||||
// CHECK-NOT: tosa.const_shape {values = dense<4> : tensor<1xindex>} : () -> !tosa.shape<1>
|
||||
// CHECK-NOT: tosa.const_shape {values = dense<80> : tensor<1xindex>} : () -> !tosa.shape<1>
|
||||
// CHECK: %[[RESHAPE:.+]] = tosa.reshape %arg0, %[[SHAPE]] : (tensor<7x12xi32>, !tosa.shape<1>) -> tensor<84xi32>
|
||||
// CHECK: %[[TILE:.+]] = tosa.tile %[[RESHAPE]], %[[SHAPE]] : (tensor<84xi32>, !tosa.shape<1>) -> tensor<7056xi32>
|
||||
// DEFAULT: %[[CAST:.+]] = tensor.cast %[[TILE]] : tensor<7056xi32> to tensor<?xi32>
|
||||
// DEFAULT: return %[[CAST]] : tensor<?xi32>
|
||||
// FUNCBOUND-NOT: tensor.cast
|
||||
// FUNCBOUND: return %[[TILE]] : tensor<7056xi32>
|
||||
%a = tosa.dim %arg1 {axis = 0: i32} : (tensor<80xi32>) -> !tosa.shape<1>
|
||||
%b = tosa.dim %arg2 {axis = 0: i32} : (tensor<4xi32>) -> !tosa.shape<1>
|
||||
%c = tosa.add_shape %a, %b : (!tosa.shape<1>, !tosa.shape<1>) -> !tosa.shape<1>
|
||||
%d = tosa.reshape %arg0, %c : (tensor<7x12xi32>, !tosa.shape<1>) -> tensor<?xi32>
|
||||
%e = tosa.dim %d {axis = 0: i32} : (tensor<?xi32>) -> !tosa.shape<1>
|
||||
%f = tosa.tile %d, %e : (tensor<?xi32>, !tosa.shape<1>) -> tensor<?xi32>
|
||||
return %f : tensor<?xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: test_cond_if_with_shape_expressions
|
||||
func.func @test_cond_if_with_shape_expressions(%arg0 : tensor<3xf32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () {
|
||||
// CHECK: %[[CONST_SHAPE:.*]] = tosa.const_shape {values = dense<3> : tensor<1xindex>} : () -> !tosa.shape<1>
|
||||
// CHECK: tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> {
|
||||
%0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<3xf32>, tensor<3xf32>) -> tensor<?xf32> {
|
||||
// CHECK: ^bb0(%arg3: tensor<3xf32>, %arg4: tensor<3xf32>)
|
||||
^bb0(%arg3 : tensor<?xf32>, %arg4 : tensor<?xf32>):
|
||||
// CHECK-NOT: tosa.dim
|
||||
%0 = tosa.dim %arg3 {axis = 0 : i32} : (tensor<?xf32>) -> !tosa.shape<1>
|
||||
// CHECK: %[[RESHAPE:.*]] = tosa.reshape %arg3, %[[CONST_SHAPE]] : (tensor<3xf32>, !tosa.shape<1>) -> tensor<3xf32>
|
||||
%1 = tosa.reshape %arg3, %0 : (tensor<?xf32>, !tosa.shape<1>) -> tensor<?xf32>
|
||||
// CHECK: tosa.yield %[[RESHAPE]] : tensor<3xf32>
|
||||
tosa.yield %1 : tensor<?xf32>
|
||||
} else {
|
||||
// CHECK: ^bb0(%arg3: tensor<3xf32>, %arg4: tensor<3xf32>)
|
||||
^bb0(%arg3 : tensor<?xf32>, %arg4 : tensor<?xf32>):
|
||||
// CHECK: tosa.yield %arg4 : tensor<3xf32>
|
||||
tosa.yield %arg4 : tensor<?xf32>
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: test_no_fold_shape_expression
|
||||
func.func @test_no_fold_shape_expression(%arg0: tensor<1x?x3xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
||||
// CHECK: tosa.dim
|
||||
%0 = tosa.dim %arg0 {axis = 1: i32} : (tensor<1x?x3xf32>) -> !tosa.shape<1>
|
||||
// CHECK: tosa.tile
|
||||
%1 = tosa.tile %arg1, %0 : (tensor<?xf32>, !tosa.shape<1>) -> tensor<?xf32>
|
||||
// CHECK: return %{{.*}} : tensor<?xf32>
|
||||
return %1 : tensor<?xf32>
|
||||
}
|
||||
@ -1,9 +1,12 @@
|
||||
// RUN: mlir-opt --split-input-file --tosa-infer-shapes --allow-unregistered-dialect %s | FileCheck %s
|
||||
// RUN: mlir-opt --split-input-file --verify-diagnostics --tosa-infer-shapes --allow-unregistered-dialect %s | FileCheck %s --allow-unused-prefixes --check-prefixes=CHECK,DEFAULT
|
||||
// RUN: mlir-opt --split-input-file --verify-diagnostics --tosa-infer-shapes="convert-function-boundaries" --allow-unregistered-dialect %s | FileCheck %s --allow-unused-prefixes --check-prefixes=CHECK,FUNCBOUND
|
||||
|
||||
// CHECK-LABEL: @test_return
|
||||
func.func @test_return(%arg0 : tensor<4xf32>) -> tensor<*xf32> {
|
||||
// CHECK: [[LOG:%.+]] = tosa.log %arg0 : (tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK: tensor.cast [[LOG]] : tensor<4xf32> to tensor<*xf32>
|
||||
// CHECK: %[[LOG:.+]] = tosa.log %arg0 : (tensor<4xf32>) -> tensor<4xf32>
|
||||
// DEFAULT: %[[CAST:.+]] = tensor.cast %[[LOG]] : tensor<4xf32> to tensor<*xf32>
|
||||
// DEFAULT: return %[[CAST]] : tensor<*xf32>
|
||||
// FUNCBOUND: return %[[LOG]] : tensor<4xf32>
|
||||
%0 = tosa.log %arg0 : (tensor<4xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
@ -12,13 +15,13 @@ func.func @test_return(%arg0 : tensor<4xf32>) -> tensor<*xf32> {
|
||||
|
||||
// CHECK-LABEL: @test_multiple
|
||||
func.func @test_multiple(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32>, %arg2 : tensor<1xf32>) -> tensor<*xf32> {
|
||||
// CHECK: [[ADD:%.+]] = tosa.add %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
|
||||
// CHECK: %[[ADD:.+]] = tosa.add %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
|
||||
%0 = tosa.add %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
|
||||
|
||||
// CHECK: [[LOG:%.+]] = tosa.log %0 : (tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK: %[[LOG:.+]] = tosa.log %[[ADD]] : (tensor<4xf32>) -> tensor<4xf32>
|
||||
%1 = tosa.log %0 : (tensor<*xf32>) -> tensor<*xf32>
|
||||
|
||||
// CHECK: [[SUB:%.+]] = tosa.sub %0, %arg2 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
|
||||
// CHECK: %[[SUB:.+]] = tosa.sub %[[ADD]], %arg2 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
|
||||
%2 = tosa.sub %0, %arg2 : (tensor<*xf32>, tensor<1xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
@ -364,12 +367,15 @@ func.func @test_unranked_zero_points_matmul(%arg0: tensor<1x2x3xf32>, %arg1: ten
|
||||
|
||||
// CHECK-LABEL: @test_accepts_unranked_scalar_tensor
|
||||
func.func @test_accepts_unranked_scalar_tensor(%arg0: tensor<1x2x2xf32>, %arg1: tensor<1xf32>) -> tensor<*xf32> {
|
||||
// CHECK: %[[ZP:.*]] = tosa.cast %arg1 : (tensor<1xf32>) -> tensor<1xf32>
|
||||
// CHECK-DAG: %[[SHAPE:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
|
||||
// CHECK-DAG: %[[ZP:.*]] = tosa.cast %arg1 : (tensor<1xf32>) -> tensor<1xf32>
|
||||
%0 = tosa.cast %arg1 : (tensor<1xf32>) -> tensor<*xf32>
|
||||
// CHECK: %[[SHAPE:.*]] = tosa.const_shape
|
||||
%1 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
|
||||
// CHECK: tosa.pad %arg0, %[[SHAPE]], %[[ZP]] : (tensor<1x2x2xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<1x3x3xf32>
|
||||
// CHECK: %[[PAD:.*]] = tosa.pad %arg0, %[[SHAPE]], %[[ZP]] : (tensor<1x2x2xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<1x3x3xf32>
|
||||
%2 = tosa.pad %arg0, %1, %0 : (tensor<1x2x2xf32>, !tosa.shape<6>, tensor<*xf32>) -> tensor<*xf32>
|
||||
// DEFAULT: %[[CAST:.*]] = tensor.cast %[[PAD]] : tensor<1x3x3xf32> to tensor<*xf32>
|
||||
// DEFAULT: return %[[CAST]] : tensor<*xf32>
|
||||
// FUNCBOUND: return %[[PAD]] : tensor<1x3x3xf32>
|
||||
return %2 : tensor<*xf32>
|
||||
}
|
||||
|
||||
@ -406,18 +412,16 @@ func.func @test_table_dynamic(%arg0 : tensor<4x?xi16>, %arg1 : tensor<513xi16>)
|
||||
|
||||
// CHECK-LABEL: @test_static_reshape
|
||||
func.func @test_static_reshape(%arg0 : tensor<4x4xi32>) -> () {
|
||||
// CHECK: %[[CONST3:.+]] = tosa.const_shape {values = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1>
|
||||
// CHECK-DAG: %[[CONSTSHAPE1:.+]] = tosa.const_shape {values = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1>
|
||||
// CHECK-DAG: %[[CONSTSHAPE2:.+]] = tosa.const_shape {values = dense<-1> : tensor<1xindex>} : () -> !tosa.shape<1>
|
||||
// CHECK-DAG: %[[CONSTSHAPE3:.+]] = tosa.const_shape {values = dense<[2, -1]> : tensor<2xindex>} : () -> !tosa.shape<2>
|
||||
// CHECK-DAG: tosa.reshape %arg0, %[[CONSTSHAPE1]] : (tensor<4x4xi32>, !tosa.shape<1>) -> tensor<16xi32>
|
||||
// CHECK-DAG: tosa.reshape %arg0, %[[CONSTSHAPE2]] : (tensor<4x4xi32>, !tosa.shape<1>) -> tensor<16xi32>
|
||||
// CHECK-DAG: tosa.reshape %arg0, %[[CONSTSHAPE3]] : (tensor<4x4xi32>, !tosa.shape<2>) -> tensor<2x8xi32>
|
||||
%3 = tosa.const_shape {values = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1>
|
||||
// CHECK: tosa.reshape %arg0, %[[CONST3]] : (tensor<4x4xi32>, !tosa.shape<1>) -> tensor<16xi32>
|
||||
%0 = tosa.reshape %arg0, %3 : (tensor<4x4xi32>, !tosa.shape<1>) -> tensor<16xi32>
|
||||
|
||||
// CHECK: %[[CONST4:.+]] = tosa.const_shape {values = dense<-1> : tensor<1xindex>} : () -> !tosa.shape<1>
|
||||
// CHECK: tosa.reshape %arg0, %[[CONST4]] : (tensor<4x4xi32>, !tosa.shape<1>) -> tensor<16xi32>
|
||||
%4 = tosa.const_shape {values = dense<-1> : tensor<1xindex>} : () -> !tosa.shape<1>
|
||||
%1 = tosa.reshape %arg0, %4 : (tensor<4x4xi32>, !tosa.shape<1>) -> tensor<16xi32>
|
||||
|
||||
// CHECK: %[[CONST5:.+]] = tosa.const_shape {values = dense<[2, -1]> : tensor<2xindex>} : () -> !tosa.shape<2>
|
||||
// CHECK: tosa.reshape %arg0, %[[CONST5]] : (tensor<4x4xi32>, !tosa.shape<2>) -> tensor<2x8xi32>
|
||||
%5 = tosa.const_shape {values = dense<[2, -1]> : tensor<2xindex>} : () -> !tosa.shape<2>
|
||||
%2 = tosa.reshape %arg0, %5 : (tensor<4x4xi32>, !tosa.shape<2>) -> tensor<2x8xi32>
|
||||
|
||||
@ -428,19 +432,17 @@ func.func @test_static_reshape(%arg0 : tensor<4x4xi32>) -> () {
|
||||
|
||||
// CHECK-LABEL: @test_dynamic_reshape
|
||||
func.func @test_dynamic_reshape(%arg0 : tensor<4x?xi32>) -> () {
|
||||
// CHECK: %0 = tosa.const_shape {values = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1>
|
||||
// CHECK-DAG: %[[CONSTSHAPE1:.+]] = tosa.const_shape {values = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1>
|
||||
// CHECK-DAG: %[[CONSTSHAPE2:.+]] = tosa.const_shape {values = dense<-1> : tensor<1xindex>} : () -> !tosa.shape<1>
|
||||
// CHECK-DAG: %[[CONSTSHAPE3:.+]] = tosa.const_shape {values = dense<[2, -1]> : tensor<2xindex>} : () -> !tosa.shape<2>
|
||||
// CHECK-DAG: tosa.reshape %arg0, %[[CONSTSHAPE1]] : (tensor<4x?xi32>, !tosa.shape<1>) -> tensor<16xi32>
|
||||
// CHECK-DAG: tosa.reshape %arg0, %[[CONSTSHAPE2]] : (tensor<4x?xi32>, !tosa.shape<1>) -> tensor<?xi32>
|
||||
// CHECK-DAG: tosa.reshape %arg0, %[[CONSTSHAPE3]] : (tensor<4x?xi32>, !tosa.shape<2>) -> tensor<2x?xi32>
|
||||
%0 = tosa.const_shape {values = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1>
|
||||
// CHECK: %1 = tosa.reshape %arg0, %0 : (tensor<4x?xi32>, !tosa.shape<1>) -> tensor<16xi32>
|
||||
%1 = tosa.reshape %arg0, %0 : (tensor<4x?xi32>, !tosa.shape<1>) -> tensor<?xi32>
|
||||
|
||||
// CHECK: %2 = tosa.const_shape {values = dense<-1> : tensor<1xindex>} : () -> !tosa.shape<1>
|
||||
%2 = tosa.const_shape {values = dense<-1> : tensor<1xindex>} : () -> !tosa.shape<1>
|
||||
// CHECK: %3 = tosa.reshape %arg0, %2 : (tensor<4x?xi32>, !tosa.shape<1>) -> tensor<?xi32>
|
||||
%3 = tosa.reshape %arg0, %2 : (tensor<4x?xi32>, !tosa.shape<1>) -> tensor<?xi32>
|
||||
|
||||
// CHECK: %4 = tosa.const_shape {values = dense<[2, -1]> : tensor<2xindex>} : () -> !tosa.shape<2>
|
||||
%4 = tosa.const_shape {values = dense<[2, -1]> : tensor<2xindex>} : () -> !tosa.shape<2>
|
||||
// CHECK: %5 = tosa.reshape %arg0, %4 : (tensor<4x?xi32>, !tosa.shape<2>) -> tensor<2x?xi32>
|
||||
%5 = tosa.reshape %arg0, %4 : (tensor<4x?xi32>, !tosa.shape<2>) -> tensor<?x?xi32>
|
||||
|
||||
return
|
||||
@ -563,9 +565,9 @@ func.func @test_padding_simple(%arg0 : tensor<1x2xf32>) -> () {
|
||||
|
||||
// CHECK-LABEL: @test_slice
|
||||
func.func @test_slice(%arg0 : tensor<?xi32>) -> () {
|
||||
// CHECK: %0 = tosa.const_shape {values = dense<1> : tensor<1xindex>}
|
||||
// CHECK: %1 = tosa.const_shape {values = dense<2> : tensor<1xindex>}
|
||||
// CHECK: %2 = tosa.slice %arg0, %0, %1 : (tensor<?xi32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<2xi32>
|
||||
// CHECK-DAG: %[[SIZE:.+]] = tosa.const_shape {values = dense<2> : tensor<1xindex>} : () -> !tosa.shape<1>
|
||||
// CHECK-DAG: %[[START:.+]] = tosa.const_shape {values = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1>
|
||||
// CHECK: %[[SLICE:.+]] = tosa.slice %arg0, %[[START]], %[[SIZE]] : (tensor<?xi32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<2xi32>
|
||||
%0 = tosa.const_shape {values = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1>
|
||||
%1 = tosa.const_shape {values = dense<2> : tensor<1xindex>} : () -> !tosa.shape<1>
|
||||
%2= tosa.slice %arg0, %0, %1 : (tensor<?xi32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<?xi32>
|
||||
@ -576,8 +578,8 @@ func.func @test_slice(%arg0 : tensor<?xi32>) -> () {
|
||||
|
||||
// CHECK-LABEL: @test_slice_size_minus_one
|
||||
func.func @test_slice_size_minus_one(%arg0 : tensor<?x8x8x8xi32>) -> () {
|
||||
// CHECK: %[[START:.+]] = tosa.const_shape
|
||||
// CHECK: %[[SIZE:.+]] = tosa.const_shape
|
||||
// CHECK-DAG: %[[SIZE:.+]] = tosa.const_shape {values = dense<-1> : tensor<4xindex>} : () -> !tosa.shape<4>
|
||||
// CHECK-DAG: %[[START:.+]] = tosa.const_shape {values = dense<[0, 1, -1, 8]> : tensor<4xindex>} : () -> !tosa.shape<4>
|
||||
// CHECK: %[[VAL:.+]] = tosa.slice %arg0, %[[START]], %[[SIZE]] : (tensor<?x8x8x8xi32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<?x7x?x?xi32>
|
||||
// this checks following
|
||||
// dim 0: size=-1, input dim=? => inferred output dim is ?
|
||||
@ -594,9 +596,9 @@ func.func @test_slice_size_minus_one(%arg0 : tensor<?x8x8x8xi32>) -> () {
|
||||
|
||||
// CHECK-LABEL: @test_slice_dynamic
|
||||
func.func @test_slice_dynamic(%arg0 : tensor<10x?x2xf32>) -> () {
|
||||
// CHECK: %0 = tosa.const_shape {values = dense<[1, 0, 0]> : tensor<3xindex>}
|
||||
// CHECK: %1 = tosa.const_shape {values = dense<[7, -1, 1]> : tensor<3xindex>}
|
||||
// CHECK: %2 = tosa.slice %arg0, %0, %1 : (tensor<10x?x2xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<7x?x1xf32>
|
||||
// CHECK-DAG: %[[SIZE:.+]] = tosa.const_shape {values = dense<[7, -1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
|
||||
// CHECK-DAG: %[[START:.+]] = tosa.const_shape {values = dense<[1, 0, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
|
||||
// CHECK: %[[SLICE:.+]] = tosa.slice %arg0, %[[START]], %[[SIZE]] : (tensor<10x?x2xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<7x?x1xf32>
|
||||
%0 = tosa.const_shape {values = dense<[1, 0, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
|
||||
%1 = tosa.const_shape {values = dense<[7, -1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
|
||||
%2= tosa.slice %arg0, %0, %1 : (tensor<10x?x2xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<?x?x?xf32>
|
||||
@ -1146,7 +1148,7 @@ func.func @resize_negative_output_dim(%arg0: tensor<1x3x1x1xi8>) {
|
||||
%scale = tosa.const_shape { values = dense<[1, 3, 1, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
|
||||
%offset = tosa.const_shape { values = dense<[6, 1]> : tensor<2xindex> } : () -> !tosa.shape<2>
|
||||
%border = tosa.const_shape { values = dense<[-15, 0]> : tensor<2xindex> } : () -> !tosa.shape<2>
|
||||
// expected-error@+1 {{calculated output height and width must be non-negative, got height = -5, width = 0}}
|
||||
// expected-error@below {{calculated output height and width must be non-negative, got height = -5, width = 0}}
|
||||
%0 = tosa.resize %arg0, %scale, %offset, %border {mode = NEAREST_NEIGHBOR} : (tensor<1x3x1x1xi8>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xi8>
|
||||
return
|
||||
}
|
||||
@ -1214,6 +1216,25 @@ func.func @if_test_propagate(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : t
|
||||
|
||||
// -----
|
||||
|
||||
func.func @if_test_propagate_dynamic(%arg0 : tensor<3xf32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () {
|
||||
// CHECK: tosa.cond_if
|
||||
// CHECK: -> tensor<3xf32>
|
||||
%0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<3xf32>, tensor<3xf32>) -> tensor<?xf32> {
|
||||
// CHECK: ^bb0(%arg3: tensor<3xf32>, %arg4: tensor<3xf32>)
|
||||
^bb0(%arg3 : tensor<?xf32>, %arg4 : tensor<?xf32>):
|
||||
// CHECK: tosa.yield %arg3 : tensor<3xf32>
|
||||
tosa.yield %arg3 : tensor<?xf32>
|
||||
} else {
|
||||
// CHECK: ^bb0(%arg3: tensor<3xf32>, %arg4: tensor<3xf32>)
|
||||
^bb0(%arg3 : tensor<?xf32>, %arg4 : tensor<?xf32>):
|
||||
// CHECK: tosa.yield %arg4 : tensor<3xf32>
|
||||
tosa.yield %arg4 : tensor<?xf32>
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @while_test
|
||||
func.func @while_test(%arg0 : tensor<i32>) -> (tensor<*xi32>) {
|
||||
// CHECK: tosa.add
|
||||
@ -1249,7 +1270,9 @@ func.func @while_test(%arg0 : tensor<i32>) -> (tensor<*xi32>) {
|
||||
tosa.yield %3 : tensor<*xi32>
|
||||
}
|
||||
|
||||
// CHECK: tensor.cast
|
||||
// DEFAULT: %[[CAST:.+]] = tensor.cast %{{.*}} : tensor<i32> to tensor<*xi32>
|
||||
// DEFAULT: return %[[CAST]] : tensor<*xi32>
|
||||
// FUNCBOUND: return %{{.*}} : tensor<i32>
|
||||
return %1 : tensor<*xi32>
|
||||
}
|
||||
|
||||
@ -1323,7 +1346,9 @@ func.func @while_dont_crash(%arg0 : tensor<i32>) -> (tensor<*xi32>) {
|
||||
"use"(%3) : (tensor<*xi32>) -> ()
|
||||
tosa.yield %3 : tensor<*xi32>
|
||||
}
|
||||
// CHECK: tensor.cast
|
||||
// DEFAULT: %[[CAST:.+]] = tensor.cast
|
||||
// DEFAULT: return %[[CAST]] : tensor<*xi32>
|
||||
// FUNCBOUND: return %{{.*}} : tensor<i32>
|
||||
return %1 : tensor<*xi32>
|
||||
}
|
||||
|
||||
@ -1379,7 +1404,9 @@ func.func @while_dont_crash_nested(%arg0 : tensor<i32>) -> (tensor<*xi32>) {
|
||||
tosa.yield %1 : tensor<*xi32>
|
||||
}
|
||||
|
||||
// CHECK: tensor.cast
|
||||
// DEFAULT: %[[CAST:.+]] = tensor.cast
|
||||
// DEFAULT: return %[[CAST]] : tensor<*xi32>
|
||||
// FUNCBOUND: return %{{.*}} : tensor<i32>
|
||||
return %1 : tensor<*xi32>
|
||||
}
|
||||
|
||||
@ -1752,3 +1779,102 @@ func.func @test_avg_pool2d_unranked_input(%input: tensor<*xi32>, %zp: tensor<1xi
|
||||
%0 = tosa.avg_pool2d %input, %zp, %zp { acc_type = i32, kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1> } : (tensor<*xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: test_simple_shape_expression
|
||||
func.func @test_simple_shape_expression(%arg0: tensor<7x12xi32>, %arg1: tensor<80xi32>, %arg2: tensor<4xi32>) -> tensor<?xi32> {
|
||||
// CHECK: %[[DIM1:.+]] = tosa.dim
|
||||
// CHECK: %[[DIM2:.+]] = tosa.dim
|
||||
// CHECK: %[[ADD_SHAPE:.+]] = tosa.add_shape
|
||||
// CHECK: %[[RESHAPE:.+]] = tosa.reshape %arg0, %[[ADD_SHAPE]] : (tensor<7x12xi32>, !tosa.shape<1>) -> tensor<?xi32>
|
||||
// CHECK: %[[DIM3:.+]] = tosa.dim %[[RESHAPE]]
|
||||
// CHECK: %[[TILE:.+]] = tosa.tile %[[RESHAPE]], %[[DIM3]] : (tensor<?xi32>, !tosa.shape<1>) -> tensor<?xi32>
|
||||
// CHECK: return %[[TILE]] : tensor<?xi32>
|
||||
%a = tosa.dim %arg1 {axis = 0: i32} : (tensor<80xi32>) -> !tosa.shape<1>
|
||||
%b = tosa.dim %arg2 {axis = 0: i32} : (tensor<4xi32>) -> !tosa.shape<1>
|
||||
%c = tosa.add_shape %a, %b : (!tosa.shape<1>, !tosa.shape<1>) -> !tosa.shape<1>
|
||||
%d = tosa.reshape %arg0, %c : (tensor<7x12xi32>, !tosa.shape<1>) -> tensor<?xi32>
|
||||
%e = tosa.dim %d {axis = 0: i32} : (tensor<?xi32>) -> !tosa.shape<1>
|
||||
%f = tosa.tile %d, %e : (tensor<?xi32>, !tosa.shape<1>) -> tensor<?xi32>
|
||||
return %f : tensor<?xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: test_conv2d_block_scaled_static
|
||||
func.func @test_conv2d_block_scaled_static(%arg0: tensor<1x4x4x64xf4E2M1FN>, %arg1: tensor<1x4x4x2xf8E8M0FNU>, %arg2: tensor<8x1x1x64xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<*xf32> {
|
||||
%pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
|
||||
%stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
|
||||
%dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
|
||||
// CHECK: -> tensor<1x4x4x8xf32>
|
||||
%0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<1x4x4x64xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<8x1x1x64xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: test_conv2d_block_scaled_dynamic_scales
|
||||
func.func @test_conv2d_block_scaled_dynamic_scales(%arg0: tensor<?x4x4x64xf4E2M1FN>, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor<?x1x1x64xf4E2M1FN>, %arg3: tensor<*xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<*xf32> {
|
||||
%pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
|
||||
%stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
|
||||
%dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
|
||||
// CHECK: -> tensor<?x4x4x?xf32>
|
||||
%0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<?x4x4x64xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<?x1x1x64xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: test_conv2d_block_scaled_dynamic_data
|
||||
func.func @test_conv2d_block_scaled_dynamic_data(%arg0: tensor<*xf4E2M1FN>, %arg1: tensor<1x4x4x2xf8E8M0FNU>, %arg2: tensor<*xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<*xf32> {
|
||||
%pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
|
||||
%stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
|
||||
%dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
|
||||
// CHECK: -> tensor<1x4x4x8xf32>
|
||||
%0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<*xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: test_conv2d_block_scaled_dynamic_unranked
|
||||
func.func @test_conv2d_block_scaled_dynamic_unranked(%arg0: tensor<*xf4E2M1FN>, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor<*xf4E2M1FN>, %arg3: tensor<*xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<*xf32> {
|
||||
%pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
|
||||
%stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
|
||||
%dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
|
||||
// CHECK: -> tensor<?x?x?x?xf32>
|
||||
%0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: test_dwconv2d_bias_broadcast
|
||||
func.func @test_dwconv2d_bias_broadcast(%input: tensor<2x8x9x?xf32>, %weight: tensor<3x3x?x?xf32>, %bias: tensor<1xf32>, %input_zp: tensor<1xf32>, %weight_zp: tensor<1xf32>) {
|
||||
// CHECK: -> tensor<2x6x7x?xf32>
|
||||
%0 = tosa.depthwise_conv2d %input, %weight, %bias, %input_zp, %weight_zp
|
||||
{ acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> }
|
||||
: (tensor<2x8x9x?xf32>, tensor<3x3x?x?xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x?x?x?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: test_tconv2d_bias_broadcast
|
||||
func.func @test_tconv2d_bias_broadcast(%input: tensor<2x6x7x3xf32>, %weight: tensor<?x3x3x3xf32>, %bias: tensor<1xf32>, %input_zp: tensor<1xf32>, %weight_zp: tensor<1xf32>) {
|
||||
// CHECK: -> tensor<2x8x9x?xf32>
|
||||
%0 = tosa.transpose_conv2d %input, %weight, %bias, %input_zp, %weight_zp
|
||||
{ acc_type = f32, pad = array<i64: 0, 0, 0, 0>, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> }
|
||||
: (tensor<2x6x7x3xf32>, tensor<?x3x3x3xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x?x?x?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: test_avg_pool2d_unranked_input
|
||||
func.func @test_avg_pool2d_unranked_input(%input: tensor<*xi32>, %zp: tensor<1xi32>) {
|
||||
// CHECK: -> tensor<?x?x?x?xi32>
|
||||
%0 = tosa.avg_pool2d %input, %zp, %zp { acc_type = i32, kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1> } : (tensor<*xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32>
|
||||
return
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user