[mlir][tosa] Enhance TosaInferShapes pass for simple shape inference (#178418)

This commit enhances the TosaInferShapes pass with two new options:
- fold-shape-expressions
- convert-function-boundaries

The "fold-shape-expressions" option enables greedily folding the newly
added TOSA shape operations when possible. Folding these operations
directly within TosaInferShapes is useful since it allows shapes of
later operations to be inferred in a single pass.

The "convert-function-boundaries" updates the return types of a function
to the newly inferred output shapes. This avoids the need for additional
tensor.cast operations at function boundaries. This option is
particularly useful when wanting to resolve a dynamic function to fully
static.

When both of these options are used in conjunction with the
"tosa-input-shapes" pass option, it's possible to resolve a dynamic
function to static in a single pass.

Note: This PR is split into two commits.
[68888db](68888dbb1e)
is a simple refactor and consists of no logic changes.
[db9a6a3](db9a6a323a)
includes the changes for shape inference.
This commit is contained in:
Luke Hutton 2026-02-25 17:38:37 +00:00 committed by GitHub
parent 0ac8e41ee1
commit 6874e945fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 485 additions and 210 deletions

View File

@ -42,6 +42,16 @@ def TosaInferShapesPass : Pass<"tosa-infer-shapes", "func::FuncOp"> {
"tensor::TensorDialect",
"tosa::TosaDialect",
];
let options = [
Option<"foldShapeExpressions", "fold-shape-expressions", "bool",
/*default=*/"false",
"Fold TOSA shape operations when they have known input values">,
Option<"convertFunctionBoundaries", "convert-function-boundaries", "bool",
/*default=*/"false",
"If enabled, the pass will convert function I/O types as well. Otherwise casts will"
"be inserted at the I/O boundaries.">
];
}
def TosaMakeBroadcastablePass

View File

@ -18,8 +18,10 @@
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Iterators.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/FoldUtils.h"
namespace mlir {
namespace tosa {
@ -128,179 +130,6 @@ private:
llvm::SmallVector<std::pair<Value, Type>> oldTypes;
};
void propagateShapesInRegion(Region &region, TypeModificationState &state);
void propagateShapesToTosaIf(Operation &op, TypeModificationState &state) {
IfOp ifOp = dyn_cast<IfOp>(op);
if (!ifOp)
return;
for (auto &region : op.getRegions()) {
Block &frontBlock = region.front();
if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands())
return;
for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) {
auto inferredTy = cast<ShapedType>(op.getOperand(i).getType());
auto blockArg = frontBlock.getArgument(i - 1);
auto oldType = cast<ShapedType>(blockArg.getType());
if (inferredTy.hasRank()) {
Type newType = oldType.clone(inferredTy.getShape());
state.setType(blockArg, newType);
}
}
for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) {
ValueKnowledge operandKnowledge = ValueKnowledge::getKnowledgeFromType(
ifOp.getOperand(i + 1).getType());
ValueKnowledge blockKnowledge = ValueKnowledge::getKnowledgeFromType(
frontBlock.getArgument(i).getType());
ValueKnowledge joinedKnowledge =
ValueKnowledge::join(operandKnowledge, blockKnowledge);
if (!joinedKnowledge)
continue;
state.setType(frontBlock.getArgument(i), joinedKnowledge.getType());
}
propagateShapesInRegion(region, state);
}
}
void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
WhileOp whileOp = dyn_cast<WhileOp>(op);
if (!whileOp)
return;
// Determine what the expected argument types are to the cond/body blocks.
// The expected arguments should be compatible with ever iteration of the
// loop body / condition for tosa.while.
SmallVector<Type> argTypes = llvm::to_vector(op.getOperandTypes());
bool hasNewTypes = true;
while (hasNewTypes) {
TypeModificationState localState;
// Set types on the block args.
Region &bodyRegion = op.getRegion(1);
Block &block = bodyRegion.front();
for (int i = 0, s = argTypes.size(); i < s; i++) {
localState.setType(block.getArgument(i), argTypes[i]);
}
// Propagate to the end.
propagateShapesInRegion(bodyRegion, localState);
// Find all the tosa yield types and verify there is a single one.
llvm::SmallVector<YieldOp> yieldOps;
for (auto &block : bodyRegion)
if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator()))
yieldOps.push_back(yieldOp);
assert(yieldOps.size() == 1 && "missing or non-unique yield op");
// Using the new tosa.yield operand types, infer the new subtypes.
llvm::SmallVector<ValueKnowledge> yieldTypeInfo;
for (auto ty : argTypes) {
yieldTypeInfo.push_back(ValueKnowledge::getKnowledgeFromType(ty));
}
for (auto yieldOp : yieldOps) {
for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
auto newKnowledge =
ValueKnowledge::getKnowledgeFromType(it.value().getType());
yieldTypeInfo[it.index()] =
ValueKnowledge::meet(yieldTypeInfo[it.index()], newKnowledge);
}
}
// This should never happen.
if (yieldTypeInfo.size() != argTypes.size()) {
op.emitWarning("has a tosa.yield with the incorrect number of operands");
return;
}
// Determine the new block args and see if any changed.
hasNewTypes = false;
for (int i = 0, s = yieldTypeInfo.size(); i < s; i++) {
Type newType = yieldTypeInfo[i].getType();
hasNewTypes |= (newType != argTypes[i]);
argTypes[i] = newType;
}
// Roll back all changes made during the speculative part of the algorithm.
localState.rollBack();
}
// We now set the block arguments according to the most recent shape
// inference results. This gives us the block arg types for the next
// iteration.
for (auto &region : op.getRegions()) {
for (unsigned int i = 0, s = argTypes.size(); i < s; i++) {
state.setType(region.front().getArgument(i), argTypes[i]);
}
propagateShapesInRegion(region, state);
}
}
void propagateShapesInRegion(Region &region, TypeModificationState &state) {
Dialect *tosaDialect = region.getContext()->getLoadedDialect<TosaDialect>();
for (auto &block : region) {
for (Operation &op : block) {
if (op.getDialect() != tosaDialect)
continue;
propagateShapesToTosaIf(op, state);
propagateShapesToTosaWhile(op, state);
InferShapedTypeOpInterface shapeInterface =
dyn_cast<InferShapedTypeOpInterface>(op);
if (!shapeInterface)
continue;
SmallVector<ShapedTypeComponents> returnedShapes;
if (shapeInterface
.inferReturnTypeComponents(
op.getContext(), op.getLoc(), op.getOperands(),
op.getDiscardableAttrDictionary(), op.getPropertiesStorage(),
op.getRegions(), returnedShapes)
.succeeded()) {
for (auto it : llvm::zip(op.getResults(), returnedShapes)) {
Value result = std::get<0>(it);
ShapedTypeComponents predictedShape = std::get<1>(it);
// Determine the knowledge based on the output type.
// TODO: should also query WIP type probably
Type resultTy = result.getType();
auto currentKnowledge =
ValueKnowledge::getKnowledgeFromType(resultTy);
// Compute the knowledge based on the inferred type.
auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType();
inferredKnowledge.hasRank = predictedShape.hasRank();
if (predictedShape.hasRank()) {
for (auto dim : predictedShape.getDims()) {
inferredKnowledge.sizes.push_back(dim);
}
}
// Compute the new type based on the joined version.
auto newKnowledge =
ValueKnowledge::join(currentKnowledge, inferredKnowledge);
if (!newKnowledge)
continue;
// Set new type
state.setType(result, newKnowledge.getType());
}
}
}
}
}
/// Recursively validate tosa ops with SameOperandsAndResultRank trait in region
/// and all nested regions
void validateSameOperandsAndResultRankTrait(Region &region) {
@ -333,13 +162,261 @@ void validateSameOperandsAndResultRankTrait(Region &region) {
struct TosaInferShapes
: public tosa::impl::TosaInferShapesPassBase<TosaInferShapes> {
public:
explicit TosaInferShapes() = default;
explicit TosaInferShapes(const TosaInferShapesPassOptions &options)
: TosaInferShapes() {
this->foldShapeExpressions = options.foldShapeExpressions;
this->convertFunctionBoundaries = options.convertFunctionBoundaries;
}
void runOnOperation() override {
func::FuncOp func = getOperation();
TypeModificationState state;
propagateShapesInRegion(func.getBody(), state);
state.commit();
if (foldShapeExpressions) {
// Folding shape expressions may leave dead tosa.const_shape operations
func.walk<WalkOrder::PostOrder, ReverseIterator>(
[](tosa::ConstShapeOp op) {
if (isOpTriviallyDead(op))
op->erase();
});
}
validateSameOperandsAndResultRankTrait(func.getBody());
if (convertFunctionBoundaries)
convertFunctionReturnTypes(func);
}
private:
void propagateShapesToTosaIf(Operation &op, TypeModificationState &state) {
IfOp ifOp = dyn_cast<IfOp>(op);
if (!ifOp)
return;
for (auto &region : op.getRegions()) {
Block &frontBlock = region.front();
if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands())
return;
for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) {
auto inferredTy = cast<ShapedType>(op.getOperand(i).getType());
auto blockArg = frontBlock.getArgument(i - 1);
auto oldType = cast<ShapedType>(blockArg.getType());
if (inferredTy.hasRank()) {
Type newType = oldType.clone(inferredTy.getShape());
state.setType(blockArg, newType);
}
}
for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) {
ValueKnowledge operandKnowledge = ValueKnowledge::getKnowledgeFromType(
ifOp.getOperand(i + 1).getType());
ValueKnowledge blockKnowledge = ValueKnowledge::getKnowledgeFromType(
frontBlock.getArgument(i).getType());
ValueKnowledge joinedKnowledge =
ValueKnowledge::join(operandKnowledge, blockKnowledge);
if (!joinedKnowledge)
continue;
state.setType(frontBlock.getArgument(i), joinedKnowledge.getType());
}
propagateShapesInRegion(region, state);
}
}
void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
WhileOp whileOp = dyn_cast<WhileOp>(op);
if (!whileOp)
return;
// Determine what the expected argument types are to the cond/body blocks.
// The expected arguments should be compatible with ever iteration of the
// loop body / condition for tosa.while.
SmallVector<Type> argTypes = llvm::to_vector(op.getOperandTypes());
bool hasNewTypes = true;
while (hasNewTypes) {
TypeModificationState localState;
// Set types on the block args.
Region &bodyRegion = op.getRegion(1);
Block &block = bodyRegion.front();
for (int i = 0, s = argTypes.size(); i < s; i++) {
localState.setType(block.getArgument(i), argTypes[i]);
}
// Propagate to the end.
propagateShapesInRegion(bodyRegion, localState);
// Find all the tosa yield types and verify there is a single one.
llvm::SmallVector<YieldOp> yieldOps;
for (auto &block : bodyRegion)
if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator()))
yieldOps.push_back(yieldOp);
assert(yieldOps.size() == 1 && "missing or non-unique yield op");
// Using the new tosa.yield operand types, infer the new subtypes.
llvm::SmallVector<ValueKnowledge> yieldTypeInfo;
for (auto ty : argTypes) {
yieldTypeInfo.push_back(ValueKnowledge::getKnowledgeFromType(ty));
}
for (auto yieldOp : yieldOps) {
for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
auto newKnowledge =
ValueKnowledge::getKnowledgeFromType(it.value().getType());
yieldTypeInfo[it.index()] =
ValueKnowledge::meet(yieldTypeInfo[it.index()], newKnowledge);
}
}
// This should never happen.
if (yieldTypeInfo.size() != argTypes.size()) {
op.emitWarning(
"has a tosa.yield with the incorrect number of operands");
return;
}
// Determine the new block args and see if any changed.
hasNewTypes = false;
for (int i = 0, s = yieldTypeInfo.size(); i < s; i++) {
Type newType = yieldTypeInfo[i].getType();
hasNewTypes |= (newType != argTypes[i]);
argTypes[i] = newType;
}
// Roll back all changes made during the speculative part of the
// algorithm.
localState.rollBack();
}
// We now set the block arguments according to the most recent shape
// inference results. This gives us the block arg types for the next
// iteration.
for (auto &region : op.getRegions()) {
for (unsigned int i = 0, s = argTypes.size(); i < s; i++) {
state.setType(region.front().getArgument(i), argTypes[i]);
}
propagateShapesInRegion(region, state);
}
}
void propagateShapesInRegion(Region &region, TypeModificationState &state) {
MLIRContext *ctx = region.getContext();
Dialect *tosaDialect = ctx->getLoadedDialect<TosaDialect>();
OperationFolder folder(ctx);
for (auto &block : region) {
// The loop body may erase operations, so we need to be careful
// when iterating. Fetch the next operation before the current
// operation is modified.
for (auto it = block.begin(); it != block.end();) {
Operation &op = *it++;
if (op.getDialect() != tosaDialect)
continue;
propagateShapesToTosaIf(op, state);
propagateShapesToTosaWhile(op, state);
if (foldShapeExpressions &&
op.hasTrait<OpTrait::tosa::TosaShapeOperator>()) {
(void)folder.tryToFold(&op);
continue;
}
InferShapedTypeOpInterface shapeInterface =
dyn_cast<InferShapedTypeOpInterface>(op);
if (!shapeInterface)
continue;
SmallVector<ShapedTypeComponents> returnedShapes;
if (shapeInterface
.inferReturnTypeComponents(
op.getContext(), op.getLoc(), op.getOperands(),
op.getDiscardableAttrDictionary(),
op.getPropertiesStorage(), op.getRegions(), returnedShapes)
.succeeded()) {
for (auto it : llvm::zip(op.getResults(), returnedShapes)) {
Value result = std::get<0>(it);
ShapedTypeComponents predictedShape = std::get<1>(it);
// Determine the knowledge based on the output type.
// TODO: should also query WIP type probably
Type resultTy = result.getType();
auto currentKnowledge =
ValueKnowledge::getKnowledgeFromType(resultTy);
// Compute the knowledge based on the inferred type.
auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
inferredKnowledge.dtype =
cast<ShapedType>(resultTy).getElementType();
inferredKnowledge.hasRank = predictedShape.hasRank();
if (predictedShape.hasRank()) {
for (auto dim : predictedShape.getDims()) {
inferredKnowledge.sizes.push_back(dim);
}
}
// Compute the new type based on the joined version.
auto newKnowledge =
ValueKnowledge::join(currentKnowledge, inferredKnowledge);
if (!newKnowledge)
continue;
// Set new type
state.setType(result, newKnowledge.getType());
}
}
}
}
}
void convertFunctionReturnTypes(func::FuncOp func) {
IRRewriter rewriter(func.getContext());
SmallVector<Type> newReturnTypes;
// Rewrite func.return ops, removing dead tensor.cast ops if possible
func.walk([&rewriter, &newReturnTypes](func::ReturnOp ret) {
SmallVector<Value> newReturnValues;
SmallVector<Value> maybeDeadCasts;
OperandRange returnOperands = ret.getOperands();
newReturnValues.reserve(returnOperands.size());
maybeDeadCasts.reserve(returnOperands.size());
newReturnTypes.reserve(newReturnTypes.size() + returnOperands.size());
for (const Value &v : returnOperands) {
Value newReturnValue = v;
if (auto castOp = v.getDefiningOp<tensor::CastOp>()) {
newReturnValue = castOp.getSource();
maybeDeadCasts.push_back(castOp);
}
newReturnValues.push_back(newReturnValue);
newReturnTypes.push_back(newReturnValue.getType());
}
rewriter.setInsertionPoint(ret);
rewriter.replaceOpWithNewOp<func::ReturnOp>(ret, newReturnValues);
if (!maybeDeadCasts.empty()) {
llvm::for_each(maybeDeadCasts, [&](Value castVal) {
if (castVal.use_empty()) {
rewriter.eraseOp(castVal.getDefiningOp());
}
});
}
});
// Update function return types with newly inferred types
const FunctionType oldType = func.getFunctionType();
const FunctionType newType = FunctionType::get(
func.getContext(), oldType.getInputs(), newReturnTypes);
func.setType(newType);
}
};
} // namespace

View File

@ -0,0 +1,62 @@
// RUN: mlir-opt --split-input-file --verify-diagnostics --tosa-infer-shapes="fold-shape-expressions" %s | FileCheck %s --check-prefixes=CHECK,DEFAULT
// RUN: mlir-opt --split-input-file --verify-diagnostics --tosa-infer-shapes="convert-function-boundaries fold-shape-expressions" %s | FileCheck %s --check-prefixes=CHECK,FUNCBOUND
// -----
// CHECK-LABEL: test_simple_shape_expression
func.func @test_simple_shape_expression(%arg0: tensor<7x12xi32>, %arg1: tensor<80xi32>, %arg2: tensor<4xi32>) -> tensor<?xi32> {
// CHECK-NOT: tosa.dim
// CHECK-NOT: tosa.add_shape
// CHECK: %[[SHAPE:.+]] = tosa.const_shape {values = dense<84> : tensor<1xindex>} : () -> !tosa.shape<1>
// CHECK-NOT: tosa.const_shape {values = dense<4> : tensor<1xindex>} : () -> !tosa.shape<1>
// CHECK-NOT: tosa.const_shape {values = dense<80> : tensor<1xindex>} : () -> !tosa.shape<1>
// CHECK: %[[RESHAPE:.+]] = tosa.reshape %arg0, %[[SHAPE]] : (tensor<7x12xi32>, !tosa.shape<1>) -> tensor<84xi32>
// CHECK: %[[TILE:.+]] = tosa.tile %[[RESHAPE]], %[[SHAPE]] : (tensor<84xi32>, !tosa.shape<1>) -> tensor<7056xi32>
// DEFAULT: %[[CAST:.+]] = tensor.cast %[[TILE]] : tensor<7056xi32> to tensor<?xi32>
// DEFAULT: return %[[CAST]] : tensor<?xi32>
// FUNCBOUND-NOT: tensor.cast
// FUNCBOUND: return %[[TILE]] : tensor<7056xi32>
%a = tosa.dim %arg1 {axis = 0: i32} : (tensor<80xi32>) -> !tosa.shape<1>
%b = tosa.dim %arg2 {axis = 0: i32} : (tensor<4xi32>) -> !tosa.shape<1>
%c = tosa.add_shape %a, %b : (!tosa.shape<1>, !tosa.shape<1>) -> !tosa.shape<1>
%d = tosa.reshape %arg0, %c : (tensor<7x12xi32>, !tosa.shape<1>) -> tensor<?xi32>
%e = tosa.dim %d {axis = 0: i32} : (tensor<?xi32>) -> !tosa.shape<1>
%f = tosa.tile %d, %e : (tensor<?xi32>, !tosa.shape<1>) -> tensor<?xi32>
return %f : tensor<?xi32>
}
// -----
// CHECK-LABEL: test_cond_if_with_shape_expressions
func.func @test_cond_if_with_shape_expressions(%arg0 : tensor<3xf32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () {
// CHECK: %[[CONST_SHAPE:.*]] = tosa.const_shape {values = dense<3> : tensor<1xindex>} : () -> !tosa.shape<1>
// CHECK: tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> {
%0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<3xf32>, tensor<3xf32>) -> tensor<?xf32> {
// CHECK: ^bb0(%arg3: tensor<3xf32>, %arg4: tensor<3xf32>)
^bb0(%arg3 : tensor<?xf32>, %arg4 : tensor<?xf32>):
// CHECK-NOT: tosa.dim
%0 = tosa.dim %arg3 {axis = 0 : i32} : (tensor<?xf32>) -> !tosa.shape<1>
// CHECK: %[[RESHAPE:.*]] = tosa.reshape %arg3, %[[CONST_SHAPE]] : (tensor<3xf32>, !tosa.shape<1>) -> tensor<3xf32>
%1 = tosa.reshape %arg3, %0 : (tensor<?xf32>, !tosa.shape<1>) -> tensor<?xf32>
// CHECK: tosa.yield %[[RESHAPE]] : tensor<3xf32>
tosa.yield %1 : tensor<?xf32>
} else {
// CHECK: ^bb0(%arg3: tensor<3xf32>, %arg4: tensor<3xf32>)
^bb0(%arg3 : tensor<?xf32>, %arg4 : tensor<?xf32>):
// CHECK: tosa.yield %arg4 : tensor<3xf32>
tosa.yield %arg4 : tensor<?xf32>
}
return
}
// -----
// CHECK-LABEL: test_no_fold_shape_expression
func.func @test_no_fold_shape_expression(%arg0: tensor<1x?x3xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
// CHECK: tosa.dim
%0 = tosa.dim %arg0 {axis = 1: i32} : (tensor<1x?x3xf32>) -> !tosa.shape<1>
// CHECK: tosa.tile
%1 = tosa.tile %arg1, %0 : (tensor<?xf32>, !tosa.shape<1>) -> tensor<?xf32>
// CHECK: return %{{.*}} : tensor<?xf32>
return %1 : tensor<?xf32>
}

View File

@ -1,9 +1,12 @@
// RUN: mlir-opt --split-input-file --tosa-infer-shapes --allow-unregistered-dialect %s | FileCheck %s
// RUN: mlir-opt --split-input-file --verify-diagnostics --tosa-infer-shapes --allow-unregistered-dialect %s | FileCheck %s --allow-unused-prefixes --check-prefixes=CHECK,DEFAULT
// RUN: mlir-opt --split-input-file --verify-diagnostics --tosa-infer-shapes="convert-function-boundaries" --allow-unregistered-dialect %s | FileCheck %s --allow-unused-prefixes --check-prefixes=CHECK,FUNCBOUND
// CHECK-LABEL: @test_return
func.func @test_return(%arg0 : tensor<4xf32>) -> tensor<*xf32> {
// CHECK: [[LOG:%.+]] = tosa.log %arg0 : (tensor<4xf32>) -> tensor<4xf32>
// CHECK: tensor.cast [[LOG]] : tensor<4xf32> to tensor<*xf32>
// CHECK: %[[LOG:.+]] = tosa.log %arg0 : (tensor<4xf32>) -> tensor<4xf32>
// DEFAULT: %[[CAST:.+]] = tensor.cast %[[LOG]] : tensor<4xf32> to tensor<*xf32>
// DEFAULT: return %[[CAST]] : tensor<*xf32>
// FUNCBOUND: return %[[LOG]] : tensor<4xf32>
%0 = tosa.log %arg0 : (tensor<4xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
@ -12,13 +15,13 @@ func.func @test_return(%arg0 : tensor<4xf32>) -> tensor<*xf32> {
// CHECK-LABEL: @test_multiple
func.func @test_multiple(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32>, %arg2 : tensor<1xf32>) -> tensor<*xf32> {
// CHECK: [[ADD:%.+]] = tosa.add %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
// CHECK: %[[ADD:.+]] = tosa.add %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
%0 = tosa.add %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32>
// CHECK: [[LOG:%.+]] = tosa.log %0 : (tensor<4xf32>) -> tensor<4xf32>
// CHECK: %[[LOG:.+]] = tosa.log %[[ADD]] : (tensor<4xf32>) -> tensor<4xf32>
%1 = tosa.log %0 : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: [[SUB:%.+]] = tosa.sub %0, %arg2 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
// CHECK: %[[SUB:.+]] = tosa.sub %[[ADD]], %arg2 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32>
%2 = tosa.sub %0, %arg2 : (tensor<*xf32>, tensor<1xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
@ -364,12 +367,15 @@ func.func @test_unranked_zero_points_matmul(%arg0: tensor<1x2x3xf32>, %arg1: ten
// CHECK-LABEL: @test_accepts_unranked_scalar_tensor
func.func @test_accepts_unranked_scalar_tensor(%arg0: tensor<1x2x2xf32>, %arg1: tensor<1xf32>) -> tensor<*xf32> {
// CHECK: %[[ZP:.*]] = tosa.cast %arg1 : (tensor<1xf32>) -> tensor<1xf32>
// CHECK-DAG: %[[SHAPE:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
// CHECK-DAG: %[[ZP:.*]] = tosa.cast %arg1 : (tensor<1xf32>) -> tensor<1xf32>
%0 = tosa.cast %arg1 : (tensor<1xf32>) -> tensor<*xf32>
// CHECK: %[[SHAPE:.*]] = tosa.const_shape
%1 = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 1]> : tensor<6xindex>} : () -> !tosa.shape<6>
// CHECK: tosa.pad %arg0, %[[SHAPE]], %[[ZP]] : (tensor<1x2x2xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<1x3x3xf32>
// CHECK: %[[PAD:.*]] = tosa.pad %arg0, %[[SHAPE]], %[[ZP]] : (tensor<1x2x2xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<1x3x3xf32>
%2 = tosa.pad %arg0, %1, %0 : (tensor<1x2x2xf32>, !tosa.shape<6>, tensor<*xf32>) -> tensor<*xf32>
// DEFAULT: %[[CAST:.*]] = tensor.cast %[[PAD]] : tensor<1x3x3xf32> to tensor<*xf32>
// DEFAULT: return %[[CAST]] : tensor<*xf32>
// FUNCBOUND: return %[[PAD]] : tensor<1x3x3xf32>
return %2 : tensor<*xf32>
}
@ -406,18 +412,16 @@ func.func @test_table_dynamic(%arg0 : tensor<4x?xi16>, %arg1 : tensor<513xi16>)
// CHECK-LABEL: @test_static_reshape
func.func @test_static_reshape(%arg0 : tensor<4x4xi32>) -> () {
// CHECK: %[[CONST3:.+]] = tosa.const_shape {values = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1>
// CHECK-DAG: %[[CONSTSHAPE1:.+]] = tosa.const_shape {values = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1>
// CHECK-DAG: %[[CONSTSHAPE2:.+]] = tosa.const_shape {values = dense<-1> : tensor<1xindex>} : () -> !tosa.shape<1>
// CHECK-DAG: %[[CONSTSHAPE3:.+]] = tosa.const_shape {values = dense<[2, -1]> : tensor<2xindex>} : () -> !tosa.shape<2>
// CHECK-DAG: tosa.reshape %arg0, %[[CONSTSHAPE1]] : (tensor<4x4xi32>, !tosa.shape<1>) -> tensor<16xi32>
// CHECK-DAG: tosa.reshape %arg0, %[[CONSTSHAPE2]] : (tensor<4x4xi32>, !tosa.shape<1>) -> tensor<16xi32>
// CHECK-DAG: tosa.reshape %arg0, %[[CONSTSHAPE3]] : (tensor<4x4xi32>, !tosa.shape<2>) -> tensor<2x8xi32>
%3 = tosa.const_shape {values = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1>
// CHECK: tosa.reshape %arg0, %[[CONST3]] : (tensor<4x4xi32>, !tosa.shape<1>) -> tensor<16xi32>
%0 = tosa.reshape %arg0, %3 : (tensor<4x4xi32>, !tosa.shape<1>) -> tensor<16xi32>
// CHECK: %[[CONST4:.+]] = tosa.const_shape {values = dense<-1> : tensor<1xindex>} : () -> !tosa.shape<1>
// CHECK: tosa.reshape %arg0, %[[CONST4]] : (tensor<4x4xi32>, !tosa.shape<1>) -> tensor<16xi32>
%4 = tosa.const_shape {values = dense<-1> : tensor<1xindex>} : () -> !tosa.shape<1>
%1 = tosa.reshape %arg0, %4 : (tensor<4x4xi32>, !tosa.shape<1>) -> tensor<16xi32>
// CHECK: %[[CONST5:.+]] = tosa.const_shape {values = dense<[2, -1]> : tensor<2xindex>} : () -> !tosa.shape<2>
// CHECK: tosa.reshape %arg0, %[[CONST5]] : (tensor<4x4xi32>, !tosa.shape<2>) -> tensor<2x8xi32>
%5 = tosa.const_shape {values = dense<[2, -1]> : tensor<2xindex>} : () -> !tosa.shape<2>
%2 = tosa.reshape %arg0, %5 : (tensor<4x4xi32>, !tosa.shape<2>) -> tensor<2x8xi32>
@ -428,19 +432,17 @@ func.func @test_static_reshape(%arg0 : tensor<4x4xi32>) -> () {
// CHECK-LABEL: @test_dynamic_reshape
func.func @test_dynamic_reshape(%arg0 : tensor<4x?xi32>) -> () {
// CHECK: %0 = tosa.const_shape {values = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1>
// CHECK-DAG: %[[CONSTSHAPE1:.+]] = tosa.const_shape {values = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1>
// CHECK-DAG: %[[CONSTSHAPE2:.+]] = tosa.const_shape {values = dense<-1> : tensor<1xindex>} : () -> !tosa.shape<1>
// CHECK-DAG: %[[CONSTSHAPE3:.+]] = tosa.const_shape {values = dense<[2, -1]> : tensor<2xindex>} : () -> !tosa.shape<2>
// CHECK-DAG: tosa.reshape %arg0, %[[CONSTSHAPE1]] : (tensor<4x?xi32>, !tosa.shape<1>) -> tensor<16xi32>
// CHECK-DAG: tosa.reshape %arg0, %[[CONSTSHAPE2]] : (tensor<4x?xi32>, !tosa.shape<1>) -> tensor<?xi32>
// CHECK-DAG: tosa.reshape %arg0, %[[CONSTSHAPE3]] : (tensor<4x?xi32>, !tosa.shape<2>) -> tensor<2x?xi32>
%0 = tosa.const_shape {values = dense<16> : tensor<1xindex>} : () -> !tosa.shape<1>
// CHECK: %1 = tosa.reshape %arg0, %0 : (tensor<4x?xi32>, !tosa.shape<1>) -> tensor<16xi32>
%1 = tosa.reshape %arg0, %0 : (tensor<4x?xi32>, !tosa.shape<1>) -> tensor<?xi32>
// CHECK: %2 = tosa.const_shape {values = dense<-1> : tensor<1xindex>} : () -> !tosa.shape<1>
%2 = tosa.const_shape {values = dense<-1> : tensor<1xindex>} : () -> !tosa.shape<1>
// CHECK: %3 = tosa.reshape %arg0, %2 : (tensor<4x?xi32>, !tosa.shape<1>) -> tensor<?xi32>
%3 = tosa.reshape %arg0, %2 : (tensor<4x?xi32>, !tosa.shape<1>) -> tensor<?xi32>
// CHECK: %4 = tosa.const_shape {values = dense<[2, -1]> : tensor<2xindex>} : () -> !tosa.shape<2>
%4 = tosa.const_shape {values = dense<[2, -1]> : tensor<2xindex>} : () -> !tosa.shape<2>
// CHECK: %5 = tosa.reshape %arg0, %4 : (tensor<4x?xi32>, !tosa.shape<2>) -> tensor<2x?xi32>
%5 = tosa.reshape %arg0, %4 : (tensor<4x?xi32>, !tosa.shape<2>) -> tensor<?x?xi32>
return
@ -563,9 +565,9 @@ func.func @test_padding_simple(%arg0 : tensor<1x2xf32>) -> () {
// CHECK-LABEL: @test_slice
func.func @test_slice(%arg0 : tensor<?xi32>) -> () {
// CHECK: %0 = tosa.const_shape {values = dense<1> : tensor<1xindex>}
// CHECK: %1 = tosa.const_shape {values = dense<2> : tensor<1xindex>}
// CHECK: %2 = tosa.slice %arg0, %0, %1 : (tensor<?xi32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<2xi32>
// CHECK-DAG: %[[SIZE:.+]] = tosa.const_shape {values = dense<2> : tensor<1xindex>} : () -> !tosa.shape<1>
// CHECK-DAG: %[[START:.+]] = tosa.const_shape {values = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1>
// CHECK: %[[SLICE:.+]] = tosa.slice %arg0, %[[START]], %[[SIZE]] : (tensor<?xi32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<2xi32>
%0 = tosa.const_shape {values = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1>
%1 = tosa.const_shape {values = dense<2> : tensor<1xindex>} : () -> !tosa.shape<1>
%2= tosa.slice %arg0, %0, %1 : (tensor<?xi32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<?xi32>
@ -576,8 +578,8 @@ func.func @test_slice(%arg0 : tensor<?xi32>) -> () {
// CHECK-LABEL: @test_slice_size_minus_one
func.func @test_slice_size_minus_one(%arg0 : tensor<?x8x8x8xi32>) -> () {
// CHECK: %[[START:.+]] = tosa.const_shape
// CHECK: %[[SIZE:.+]] = tosa.const_shape
// CHECK-DAG: %[[SIZE:.+]] = tosa.const_shape {values = dense<-1> : tensor<4xindex>} : () -> !tosa.shape<4>
// CHECK-DAG: %[[START:.+]] = tosa.const_shape {values = dense<[0, 1, -1, 8]> : tensor<4xindex>} : () -> !tosa.shape<4>
// CHECK: %[[VAL:.+]] = tosa.slice %arg0, %[[START]], %[[SIZE]] : (tensor<?x8x8x8xi32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<?x7x?x?xi32>
// this checks following
// dim 0: size=-1, input dim=? => inferred output dim is ?
@ -594,9 +596,9 @@ func.func @test_slice_size_minus_one(%arg0 : tensor<?x8x8x8xi32>) -> () {
// CHECK-LABEL: @test_slice_dynamic
func.func @test_slice_dynamic(%arg0 : tensor<10x?x2xf32>) -> () {
// CHECK: %0 = tosa.const_shape {values = dense<[1, 0, 0]> : tensor<3xindex>}
// CHECK: %1 = tosa.const_shape {values = dense<[7, -1, 1]> : tensor<3xindex>}
// CHECK: %2 = tosa.slice %arg0, %0, %1 : (tensor<10x?x2xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<7x?x1xf32>
// CHECK-DAG: %[[SIZE:.+]] = tosa.const_shape {values = dense<[7, -1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
// CHECK-DAG: %[[START:.+]] = tosa.const_shape {values = dense<[1, 0, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
// CHECK: %[[SLICE:.+]] = tosa.slice %arg0, %[[START]], %[[SIZE]] : (tensor<10x?x2xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<7x?x1xf32>
%0 = tosa.const_shape {values = dense<[1, 0, 0]> : tensor<3xindex>} : () -> !tosa.shape<3>
%1 = tosa.const_shape {values = dense<[7, -1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
%2= tosa.slice %arg0, %0, %1 : (tensor<10x?x2xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<?x?x?xf32>
@ -1146,7 +1148,7 @@ func.func @resize_negative_output_dim(%arg0: tensor<1x3x1x1xi8>) {
%scale = tosa.const_shape { values = dense<[1, 3, 1, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
%offset = tosa.const_shape { values = dense<[6, 1]> : tensor<2xindex> } : () -> !tosa.shape<2>
%border = tosa.const_shape { values = dense<[-15, 0]> : tensor<2xindex> } : () -> !tosa.shape<2>
// expected-error@+1 {{calculated output height and width must be non-negative, got height = -5, width = 0}}
// expected-error@below {{calculated output height and width must be non-negative, got height = -5, width = 0}}
%0 = tosa.resize %arg0, %scale, %offset, %border {mode = NEAREST_NEIGHBOR} : (tensor<1x3x1x1xi8>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xi8>
return
}
@ -1214,6 +1216,25 @@ func.func @if_test_propagate(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : t
// -----
func.func @if_test_propagate_dynamic(%arg0 : tensor<3xf32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () {
// CHECK: tosa.cond_if
// CHECK: -> tensor<3xf32>
%0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<3xf32>, tensor<3xf32>) -> tensor<?xf32> {
// CHECK: ^bb0(%arg3: tensor<3xf32>, %arg4: tensor<3xf32>)
^bb0(%arg3 : tensor<?xf32>, %arg4 : tensor<?xf32>):
// CHECK: tosa.yield %arg3 : tensor<3xf32>
tosa.yield %arg3 : tensor<?xf32>
} else {
// CHECK: ^bb0(%arg3: tensor<3xf32>, %arg4: tensor<3xf32>)
^bb0(%arg3 : tensor<?xf32>, %arg4 : tensor<?xf32>):
// CHECK: tosa.yield %arg4 : tensor<3xf32>
tosa.yield %arg4 : tensor<?xf32>
}
return
}
// -----
// CHECK-LABEL: @while_test
func.func @while_test(%arg0 : tensor<i32>) -> (tensor<*xi32>) {
// CHECK: tosa.add
@ -1249,7 +1270,9 @@ func.func @while_test(%arg0 : tensor<i32>) -> (tensor<*xi32>) {
tosa.yield %3 : tensor<*xi32>
}
// CHECK: tensor.cast
// DEFAULT: %[[CAST:.+]] = tensor.cast %{{.*}} : tensor<i32> to tensor<*xi32>
// DEFAULT: return %[[CAST]] : tensor<*xi32>
// FUNCBOUND: return %{{.*}} : tensor<i32>
return %1 : tensor<*xi32>
}
@ -1323,7 +1346,9 @@ func.func @while_dont_crash(%arg0 : tensor<i32>) -> (tensor<*xi32>) {
"use"(%3) : (tensor<*xi32>) -> ()
tosa.yield %3 : tensor<*xi32>
}
// CHECK: tensor.cast
// DEFAULT: %[[CAST:.+]] = tensor.cast
// DEFAULT: return %[[CAST]] : tensor<*xi32>
// FUNCBOUND: return %{{.*}} : tensor<i32>
return %1 : tensor<*xi32>
}
@ -1379,7 +1404,9 @@ func.func @while_dont_crash_nested(%arg0 : tensor<i32>) -> (tensor<*xi32>) {
tosa.yield %1 : tensor<*xi32>
}
// CHECK: tensor.cast
// DEFAULT: %[[CAST:.+]] = tensor.cast
// DEFAULT: return %[[CAST]] : tensor<*xi32>
// FUNCBOUND: return %{{.*}} : tensor<i32>
return %1 : tensor<*xi32>
}
@ -1752,3 +1779,102 @@ func.func @test_avg_pool2d_unranked_input(%input: tensor<*xi32>, %zp: tensor<1xi
%0 = tosa.avg_pool2d %input, %zp, %zp { acc_type = i32, kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1> } : (tensor<*xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32>
return
}
// -----
// CHECK-LABEL: test_simple_shape_expression
func.func @test_simple_shape_expression(%arg0: tensor<7x12xi32>, %arg1: tensor<80xi32>, %arg2: tensor<4xi32>) -> tensor<?xi32> {
// CHECK: %[[DIM1:.+]] = tosa.dim
// CHECK: %[[DIM2:.+]] = tosa.dim
// CHECK: %[[ADD_SHAPE:.+]] = tosa.add_shape
// CHECK: %[[RESHAPE:.+]] = tosa.reshape %arg0, %[[ADD_SHAPE]] : (tensor<7x12xi32>, !tosa.shape<1>) -> tensor<?xi32>
// CHECK: %[[DIM3:.+]] = tosa.dim %[[RESHAPE]]
// CHECK: %[[TILE:.+]] = tosa.tile %[[RESHAPE]], %[[DIM3]] : (tensor<?xi32>, !tosa.shape<1>) -> tensor<?xi32>
// CHECK: return %[[TILE]] : tensor<?xi32>
%a = tosa.dim %arg1 {axis = 0: i32} : (tensor<80xi32>) -> !tosa.shape<1>
%b = tosa.dim %arg2 {axis = 0: i32} : (tensor<4xi32>) -> !tosa.shape<1>
%c = tosa.add_shape %a, %b : (!tosa.shape<1>, !tosa.shape<1>) -> !tosa.shape<1>
%d = tosa.reshape %arg0, %c : (tensor<7x12xi32>, !tosa.shape<1>) -> tensor<?xi32>
%e = tosa.dim %d {axis = 0: i32} : (tensor<?xi32>) -> !tosa.shape<1>
%f = tosa.tile %d, %e : (tensor<?xi32>, !tosa.shape<1>) -> tensor<?xi32>
return %f : tensor<?xi32>
}
// -----
// CHECK-LABEL: test_conv2d_block_scaled_static
func.func @test_conv2d_block_scaled_static(%arg0: tensor<1x4x4x64xf4E2M1FN>, %arg1: tensor<1x4x4x2xf8E8M0FNU>, %arg2: tensor<8x1x1x64xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<*xf32> {
%pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
%stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
%dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
// CHECK: -> tensor<1x4x4x8xf32>
%0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<1x4x4x64xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<8x1x1x64xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// -----
// CHECK-LABEL: test_conv2d_block_scaled_dynamic_scales
func.func @test_conv2d_block_scaled_dynamic_scales(%arg0: tensor<?x4x4x64xf4E2M1FN>, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor<?x1x1x64xf4E2M1FN>, %arg3: tensor<*xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<*xf32> {
%pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
%stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
%dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
// CHECK: -> tensor<?x4x4x?xf32>
%0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<?x4x4x64xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<?x1x1x64xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// -----
// CHECK-LABEL: test_conv2d_block_scaled_dynamic_data
func.func @test_conv2d_block_scaled_dynamic_data(%arg0: tensor<*xf4E2M1FN>, %arg1: tensor<1x4x4x2xf8E8M0FNU>, %arg2: tensor<*xf4E2M1FN>, %arg3: tensor<8x1x1x2xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<*xf32> {
%pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
%stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
%dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
// CHECK: -> tensor<1x4x4x8xf32>
%0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf4E2M1FN>, tensor<1x4x4x2xf8E8M0FNU>, tensor<*xf4E2M1FN>, tensor<8x1x1x2xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// -----
// CHECK-LABEL: test_conv2d_block_scaled_dynamic_unranked
func.func @test_conv2d_block_scaled_dynamic_unranked(%arg0: tensor<*xf4E2M1FN>, %arg1: tensor<*xf8E8M0FNU>, %arg2: tensor<*xf4E2M1FN>, %arg3: tensor<*xf8E8M0FNU>, %arg4: tensor<1xf32>, %arg5: tensor<4xi32>, %arg6: tensor<2xi32>, %arg7: tensor<2xi32>) -> tensor<*xf32> {
%pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
%stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
%dilation = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
// CHECK: -> tensor<?x?x?x?xf32>
%0 = tosa.conv2d_block_scaled %arg0, %arg1, %arg2, %arg3, %arg4, %pad, %stride, %dilation {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>, tensor<1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// -----
// CHECK-LABEL: test_dwconv2d_bias_broadcast
func.func @test_dwconv2d_bias_broadcast(%input: tensor<2x8x9x?xf32>, %weight: tensor<3x3x?x?xf32>, %bias: tensor<1xf32>, %input_zp: tensor<1xf32>, %weight_zp: tensor<1xf32>) {
// CHECK: -> tensor<2x6x7x?xf32>
%0 = tosa.depthwise_conv2d %input, %weight, %bias, %input_zp, %weight_zp
{ acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> }
: (tensor<2x8x9x?xf32>, tensor<3x3x?x?xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x?x?x?xf32>
return
}
// -----
// CHECK-LABEL: test_tconv2d_bias_broadcast
func.func @test_tconv2d_bias_broadcast(%input: tensor<2x6x7x3xf32>, %weight: tensor<?x3x3x3xf32>, %bias: tensor<1xf32>, %input_zp: tensor<1xf32>, %weight_zp: tensor<1xf32>) {
// CHECK: -> tensor<2x8x9x?xf32>
%0 = tosa.transpose_conv2d %input, %weight, %bias, %input_zp, %weight_zp
{ acc_type = f32, pad = array<i64: 0, 0, 0, 0>, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> }
: (tensor<2x6x7x3xf32>, tensor<?x3x3x3xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x?x?x?xf32>
return
}
// -----
// CHECK-LABEL: test_avg_pool2d_unranked_input
func.func @test_avg_pool2d_unranked_input(%input: tensor<*xi32>, %zp: tensor<1xi32>) {
// CHECK: -> tensor<?x?x?x?xi32>
%0 = tosa.avg_pool2d %input, %zp, %zp { acc_type = i32, kernel = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1> } : (tensor<*xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32>
return
}