[mlir][tosa] Use math.ctlz intrinsic for tosa.clz
We were custom counting per bit for the clz instruction. Math dialect now has an intrinsic to do this in one instruction. Migrated to this instruction and fixed a minor bug math-to-llvm for the intrinsic. Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D125592
This commit is contained in:
parent
52c615553c
commit
cb4a5eae1e
@ -74,8 +74,8 @@ struct CountOpLowering : public ConvertOpToLLVMPattern<MathOp> {
|
||||
[&](Type llvm1DVectorTy, ValueRange operands) {
|
||||
LLVM::ConstantOp zero =
|
||||
rewriter.create<LLVM::ConstantOp>(loc, boolType, boolZero);
|
||||
return rewriter.replaceOpWithNewOp<LLVMOp>(op, llvm1DVectorTy,
|
||||
operands[0], zero);
|
||||
return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0],
|
||||
zero);
|
||||
},
|
||||
rewriter);
|
||||
}
|
||||
|
@ -259,54 +259,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
|
||||
|
||||
// tosa::ClzOp
|
||||
if (isa<tosa::ClzOp>(op) && elementTy.isa<IntegerType>()) {
|
||||
int bitWidth = elementTy.getIntOrFloatBitWidth();
|
||||
auto zero =
|
||||
rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
|
||||
auto leadingZeros = rewriter.create<arith::ConstantOp>(
|
||||
loc, IntegerAttr::get(elementTy, bitWidth));
|
||||
|
||||
SmallVector<Value> operands = {args[0], leadingZeros, zero};
|
||||
SmallVector<Type> types = {elementTy, elementTy, elementTy};
|
||||
SmallVector<Location> locations = {loc, loc, loc};
|
||||
|
||||
auto whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
|
||||
Block *before =
|
||||
rewriter.createBlock(&whileOp.getBefore(), {}, types, locations);
|
||||
Block *after =
|
||||
rewriter.createBlock(&whileOp.getAfter(), {}, types, locations);
|
||||
|
||||
// The conditional block of the while loop.
|
||||
{
|
||||
rewriter.setInsertionPointToStart(&whileOp.getBefore().front());
|
||||
Value input = before->getArgument(0);
|
||||
Value zero = before->getArgument(2);
|
||||
|
||||
Value inputLargerThanZero = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::ne, input, zero);
|
||||
rewriter.create<scf::ConditionOp>(loc, inputLargerThanZero,
|
||||
before->getArguments());
|
||||
}
|
||||
|
||||
// The body of the while loop: shift right until reaching a value of 0.
|
||||
{
|
||||
rewriter.setInsertionPointToStart(&whileOp.getAfter().front());
|
||||
Value input = after->getArgument(0);
|
||||
Value leadingZeros = after->getArgument(1);
|
||||
|
||||
auto one = rewriter.create<arith::ConstantOp>(
|
||||
loc, IntegerAttr::get(elementTy, 1));
|
||||
auto shifted =
|
||||
rewriter.create<arith::ShRUIOp>(loc, resultTypes, input, one);
|
||||
auto leadingZerosMinusOne =
|
||||
rewriter.create<arith::SubIOp>(loc, resultTypes, leadingZeros, one);
|
||||
|
||||
rewriter.create<scf::YieldOp>(
|
||||
loc,
|
||||
ValueRange({shifted, leadingZerosMinusOne, after->getArgument(2)}));
|
||||
}
|
||||
|
||||
rewriter.setInsertionPointAfter(whileOp);
|
||||
return whileOp->getResult(1);
|
||||
return rewriter.create<math::CountLeadingZerosOp>(loc, elementTy, args[0]);
|
||||
}
|
||||
|
||||
// tosa::LogicalAnd
|
||||
|
@ -366,12 +366,7 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
|
||||
// CHECK: arith.addi
|
||||
%12 = "tosa.arithmetic_right_shift"(%arg0, %arg0) {round = 1 : i1} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
|
||||
// CHECK: scf.while
|
||||
// CHECK: arith.cmpi ne
|
||||
// CHECK: scf.condition
|
||||
// CHECK: arith.shrui
|
||||
// CHECK: arith.subi
|
||||
// CHECK: scf.yield
|
||||
// CHECK: math.ctlz
|
||||
%13 = "tosa.clz"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
|
||||
|
||||
// CHECK: linalg.generic
|
||||
|
Loading…
x
Reference in New Issue
Block a user