[mlir][tosa] Use math.ctlz intrinsic for tosa.clz

We were custom counting per bit for the clz instruction. Math dialect
now has an intrinsic to do this in one instruction. Migrated to this
instruction and fixed a minor bug math-to-llvm for the intrinsic.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D125592
This commit is contained in:
Robert Suderman 2022-05-16 11:08:49 -07:00 committed by Rob Suderman
parent 52c615553c
commit cb4a5eae1e
3 changed files with 4 additions and 56 deletions

View File

@ -74,8 +74,8 @@ struct CountOpLowering : public ConvertOpToLLVMPattern<MathOp> {
[&](Type llvm1DVectorTy, ValueRange operands) {
LLVM::ConstantOp zero =
rewriter.create<LLVM::ConstantOp>(loc, boolType, boolZero);
return rewriter.replaceOpWithNewOp<LLVMOp>(op, llvm1DVectorTy,
operands[0], zero);
return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0],
zero);
},
rewriter);
}

View File

@ -259,54 +259,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
// tosa::ClzOp
if (isa<tosa::ClzOp>(op) && elementTy.isa<IntegerType>()) {
int bitWidth = elementTy.getIntOrFloatBitWidth();
auto zero =
rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
auto leadingZeros = rewriter.create<arith::ConstantOp>(
loc, IntegerAttr::get(elementTy, bitWidth));
SmallVector<Value> operands = {args[0], leadingZeros, zero};
SmallVector<Type> types = {elementTy, elementTy, elementTy};
SmallVector<Location> locations = {loc, loc, loc};
auto whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
Block *before =
rewriter.createBlock(&whileOp.getBefore(), {}, types, locations);
Block *after =
rewriter.createBlock(&whileOp.getAfter(), {}, types, locations);
// The conditional block of the while loop.
{
rewriter.setInsertionPointToStart(&whileOp.getBefore().front());
Value input = before->getArgument(0);
Value zero = before->getArgument(2);
Value inputLargerThanZero = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ne, input, zero);
rewriter.create<scf::ConditionOp>(loc, inputLargerThanZero,
before->getArguments());
}
// The body of the while loop: shift right until reaching a value of 0.
{
rewriter.setInsertionPointToStart(&whileOp.getAfter().front());
Value input = after->getArgument(0);
Value leadingZeros = after->getArgument(1);
auto one = rewriter.create<arith::ConstantOp>(
loc, IntegerAttr::get(elementTy, 1));
auto shifted =
rewriter.create<arith::ShRUIOp>(loc, resultTypes, input, one);
auto leadingZerosMinusOne =
rewriter.create<arith::SubIOp>(loc, resultTypes, leadingZeros, one);
rewriter.create<scf::YieldOp>(
loc,
ValueRange({shifted, leadingZerosMinusOne, after->getArgument(2)}));
}
rewriter.setInsertionPointAfter(whileOp);
return whileOp->getResult(1);
return rewriter.create<math::CountLeadingZerosOp>(loc, elementTy, args[0]);
}
// tosa::LogicalAnd

View File

@ -366,12 +366,7 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
// CHECK: arith.addi
%12 = "tosa.arithmetic_right_shift"(%arg0, %arg0) {round = 1 : i1} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: scf.while
// CHECK: arith.cmpi ne
// CHECK: scf.condition
// CHECK: arith.shrui
// CHECK: arith.subi
// CHECK: scf.yield
// CHECK: math.ctlz
%13 = "tosa.clz"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic