I am not sure about the meaning of Type in the name (was it meant be interpreted as Kind?), and given the importance and meaning of Type in the context of MLIR, its probably better to rename it. Given the comment in the source code, the suggestion in the GitHub issue and the final discussions in the review, this patch renames the OperandType to UnresolvedOperand. Fixes https://github.com/llvm/llvm-project/issues/54446 Differential Revision: https://reviews.llvm.org/D122142
2011 lines
72 KiB
C++
2011 lines
72 KiB
C++
//===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include <utility>
|
|
|
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
|
#include "mlir/Dialect/CommonFolders.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/IR/OpImplementation.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/IR/TypeUtilities.h"
|
|
#include "llvm/ADT/SmallString.h"
|
|
|
|
#include "llvm/ADT/APSInt.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::arith;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Pattern helpers
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
|
|
Attribute lhs, Attribute rhs) {
|
|
return builder.getIntegerAttr(res.getType(),
|
|
lhs.cast<IntegerAttr>().getInt() +
|
|
rhs.cast<IntegerAttr>().getInt());
|
|
}
|
|
|
|
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
|
|
Attribute lhs, Attribute rhs) {
|
|
return builder.getIntegerAttr(res.getType(),
|
|
lhs.cast<IntegerAttr>().getInt() -
|
|
rhs.cast<IntegerAttr>().getInt());
|
|
}
|
|
|
|
/// Invert an integer comparison predicate.
|
|
arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
|
|
switch (pred) {
|
|
case arith::CmpIPredicate::eq:
|
|
return arith::CmpIPredicate::ne;
|
|
case arith::CmpIPredicate::ne:
|
|
return arith::CmpIPredicate::eq;
|
|
case arith::CmpIPredicate::slt:
|
|
return arith::CmpIPredicate::sge;
|
|
case arith::CmpIPredicate::sle:
|
|
return arith::CmpIPredicate::sgt;
|
|
case arith::CmpIPredicate::sgt:
|
|
return arith::CmpIPredicate::sle;
|
|
case arith::CmpIPredicate::sge:
|
|
return arith::CmpIPredicate::slt;
|
|
case arith::CmpIPredicate::ult:
|
|
return arith::CmpIPredicate::uge;
|
|
case arith::CmpIPredicate::ule:
|
|
return arith::CmpIPredicate::ugt;
|
|
case arith::CmpIPredicate::ugt:
|
|
return arith::CmpIPredicate::ule;
|
|
case arith::CmpIPredicate::uge:
|
|
return arith::CmpIPredicate::ult;
|
|
}
|
|
llvm_unreachable("unknown cmpi predicate kind");
|
|
}
|
|
|
|
static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
|
|
return arith::CmpIPredicateAttr::get(pred.getContext(),
|
|
invertPredicate(pred.getValue()));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TableGen'd canonicalization patterns
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
#include "ArithmeticCanonicalization.inc"
|
|
} // namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ConstantOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void arith::ConstantOp::getAsmResultNames(
|
|
function_ref<void(Value, StringRef)> setNameFn) {
|
|
auto type = getType();
|
|
if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
|
|
auto intType = type.dyn_cast<IntegerType>();
|
|
|
|
// Sugar i1 constants with 'true' and 'false'.
|
|
if (intType && intType.getWidth() == 1)
|
|
return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
|
|
|
|
// Otherwise, build a compex name with the value and type.
|
|
SmallString<32> specialNameBuffer;
|
|
llvm::raw_svector_ostream specialName(specialNameBuffer);
|
|
specialName << 'c' << intCst.getInt();
|
|
if (intType)
|
|
specialName << '_' << type;
|
|
setNameFn(getResult(), specialName.str());
|
|
} else {
|
|
setNameFn(getResult(), "cst");
|
|
}
|
|
}
|
|
|
|
/// TODO: disallow arith.constant to return anything other than signless integer
|
|
/// or float like.
|
|
LogicalResult arith::ConstantOp::verify() {
|
|
auto type = getType();
|
|
// The value's type must match the return type.
|
|
if (getValue().getType() != type) {
|
|
return emitOpError() << "value type " << getValue().getType()
|
|
<< " must match return type: " << type;
|
|
}
|
|
// Integer values must be signless.
|
|
if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
|
|
return emitOpError("integer return type must be signless");
|
|
// Any float or elements attribute are acceptable.
|
|
if (!getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) {
|
|
return emitOpError(
|
|
"value must be an integer, float, or elements attribute");
|
|
}
|
|
return success();
|
|
}
|
|
|
|
bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
|
|
// The value's type must be the same as the provided type.
|
|
if (value.getType() != type)
|
|
return false;
|
|
// Integer values must be signless.
|
|
if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
|
|
return false;
|
|
// Integer, float, and element attributes are buildable.
|
|
return value.isa<IntegerAttr, FloatAttr, ElementsAttr>();
|
|
}
|
|
|
|
OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) {
|
|
return getValue();
|
|
}
|
|
|
|
void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
|
|
int64_t value, unsigned width) {
|
|
auto type = builder.getIntegerType(width);
|
|
arith::ConstantOp::build(builder, result, type,
|
|
builder.getIntegerAttr(type, value));
|
|
}
|
|
|
|
void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
|
|
int64_t value, Type type) {
|
|
assert(type.isSignlessInteger() &&
|
|
"ConstantIntOp can only have signless integer type values");
|
|
arith::ConstantOp::build(builder, result, type,
|
|
builder.getIntegerAttr(type, value));
|
|
}
|
|
|
|
bool arith::ConstantIntOp::classof(Operation *op) {
|
|
if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
|
|
return constOp.getType().isSignlessInteger();
|
|
return false;
|
|
}
|
|
|
|
void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
|
|
const APFloat &value, FloatType type) {
|
|
arith::ConstantOp::build(builder, result, type,
|
|
builder.getFloatAttr(type, value));
|
|
}
|
|
|
|
bool arith::ConstantFloatOp::classof(Operation *op) {
|
|
if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
|
|
return constOp.getType().isa<FloatType>();
|
|
return false;
|
|
}
|
|
|
|
void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
|
|
int64_t value) {
|
|
arith::ConstantOp::build(builder, result, builder.getIndexType(),
|
|
builder.getIndexAttr(value));
|
|
}
|
|
|
|
bool arith::ConstantIndexOp::classof(Operation *op) {
|
|
if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
|
|
return constOp.getType().isIndex();
|
|
return false;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// AddIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) {
|
|
// addi(x, 0) -> x
|
|
if (matchPattern(getRhs(), m_Zero()))
|
|
return getLhs();
|
|
|
|
// addi(subi(a, b), b) -> a
|
|
if (auto sub = getLhs().getDefiningOp<SubIOp>())
|
|
if (getRhs() == sub.getRhs())
|
|
return sub.getLhs();
|
|
|
|
// addi(b, subi(a, b)) -> a
|
|
if (auto sub = getRhs().getDefiningOp<SubIOp>())
|
|
if (getLhs() == sub.getRhs())
|
|
return sub.getLhs();
|
|
|
|
return constFoldBinaryOp<IntegerAttr>(
|
|
operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
|
|
}
|
|
|
|
void arith::AddIOp::getCanonicalizationPatterns(
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
|
patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>(
|
|
context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SubIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) {
|
|
// subi(x,x) -> 0
|
|
if (getOperand(0) == getOperand(1))
|
|
return Builder(getContext()).getZeroAttr(getType());
|
|
// subi(x,0) -> x
|
|
if (matchPattern(getRhs(), m_Zero()))
|
|
return getLhs();
|
|
|
|
return constFoldBinaryOp<IntegerAttr>(
|
|
operands, [](APInt a, const APInt &b) { return std::move(a) - b; });
|
|
}
|
|
|
|
void arith::SubIOp::getCanonicalizationPatterns(
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
|
patterns
|
|
.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
|
|
SubIRHSSubConstantLHS, SubILHSSubConstantRHS, SubILHSSubConstantLHS>(
|
|
context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MulIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) {
|
|
// muli(x, 0) -> 0
|
|
if (matchPattern(getRhs(), m_Zero()))
|
|
return getRhs();
|
|
// muli(x, 1) -> x
|
|
if (matchPattern(getRhs(), m_One()))
|
|
return getOperand(0);
|
|
// TODO: Handle the overflow case.
|
|
|
|
// default folder
|
|
return constFoldBinaryOp<IntegerAttr>(
|
|
operands, [](const APInt &a, const APInt &b) { return a * b; });
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// DivUIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
|
|
// Don't fold if it would require a division by zero.
|
|
bool div0 = false;
|
|
auto result =
|
|
constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
|
|
if (div0 || !b) {
|
|
div0 = true;
|
|
return a;
|
|
}
|
|
return a.udiv(b);
|
|
});
|
|
|
|
// Fold out division by one. Assumes all tensors of all ones are splats.
|
|
if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
|
|
if (rhs.getValue() == 1)
|
|
return getLhs();
|
|
} else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
|
|
if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
|
|
return getLhs();
|
|
}
|
|
|
|
return div0 ? Attribute() : result;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// DivSIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) {
|
|
// Don't fold if it would overflow or if it requires a division by zero.
|
|
bool overflowOrDiv0 = false;
|
|
auto result =
|
|
constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
|
|
if (overflowOrDiv0 || !b) {
|
|
overflowOrDiv0 = true;
|
|
return a;
|
|
}
|
|
return a.sdiv_ov(b, overflowOrDiv0);
|
|
});
|
|
|
|
// Fold out division by one. Assumes all tensors of all ones are splats.
|
|
if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
|
|
if (rhs.getValue() == 1)
|
|
return getLhs();
|
|
} else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
|
|
if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
|
|
return getLhs();
|
|
}
|
|
|
|
return overflowOrDiv0 ? Attribute() : result;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Ceil and floor division folding helpers
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
|
|
bool &overflow) {
|
|
// Returns (a-1)/b + 1
|
|
APInt one(a.getBitWidth(), 1, true); // Signed value 1.
|
|
APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
|
|
return val.sadd_ov(one, overflow);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CeilDivUIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
|
|
bool overflowOrDiv0 = false;
|
|
auto result =
|
|
constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
|
|
if (overflowOrDiv0 || !b) {
|
|
overflowOrDiv0 = true;
|
|
return a;
|
|
}
|
|
APInt quotient = a.udiv(b);
|
|
if (!a.urem(b))
|
|
return quotient;
|
|
APInt one(a.getBitWidth(), 1, true);
|
|
return quotient.uadd_ov(one, overflowOrDiv0);
|
|
});
|
|
// Fold out ceil division by one. Assumes all tensors of all ones are
|
|
// splats.
|
|
if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
|
|
if (rhs.getValue() == 1)
|
|
return getLhs();
|
|
} else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
|
|
if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
|
|
return getLhs();
|
|
}
|
|
|
|
return overflowOrDiv0 ? Attribute() : result;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CeilDivSIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
|
|
// Don't fold if it would overflow or if it requires a division by zero.
|
|
bool overflowOrDiv0 = false;
|
|
auto result =
|
|
constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
|
|
if (overflowOrDiv0 || !b) {
|
|
overflowOrDiv0 = true;
|
|
return a;
|
|
}
|
|
if (!a)
|
|
return a;
|
|
// After this point we know that neither a or b are zero.
|
|
unsigned bits = a.getBitWidth();
|
|
APInt zero = APInt::getZero(bits);
|
|
bool aGtZero = a.sgt(zero);
|
|
bool bGtZero = b.sgt(zero);
|
|
if (aGtZero && bGtZero) {
|
|
// Both positive, return ceil(a, b).
|
|
return signedCeilNonnegInputs(a, b, overflowOrDiv0);
|
|
}
|
|
if (!aGtZero && !bGtZero) {
|
|
// Both negative, return ceil(-a, -b).
|
|
APInt posA = zero.ssub_ov(a, overflowOrDiv0);
|
|
APInt posB = zero.ssub_ov(b, overflowOrDiv0);
|
|
return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
|
|
}
|
|
if (!aGtZero && bGtZero) {
|
|
// A is negative, b is positive, return - ( -a / b).
|
|
APInt posA = zero.ssub_ov(a, overflowOrDiv0);
|
|
APInt div = posA.sdiv_ov(b, overflowOrDiv0);
|
|
return zero.ssub_ov(div, overflowOrDiv0);
|
|
}
|
|
// A is positive, b is negative, return - (a / -b).
|
|
APInt posB = zero.ssub_ov(b, overflowOrDiv0);
|
|
APInt div = a.sdiv_ov(posB, overflowOrDiv0);
|
|
return zero.ssub_ov(div, overflowOrDiv0);
|
|
});
|
|
|
|
// Fold out ceil division by one. Assumes all tensors of all ones are
|
|
// splats.
|
|
if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
|
|
if (rhs.getValue() == 1)
|
|
return getLhs();
|
|
} else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
|
|
if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
|
|
return getLhs();
|
|
}
|
|
|
|
return overflowOrDiv0 ? Attribute() : result;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// FloorDivSIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
|
|
// Don't fold if it would overflow or if it requires a division by zero.
|
|
bool overflowOrDiv0 = false;
|
|
auto result =
|
|
constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
|
|
if (overflowOrDiv0 || !b) {
|
|
overflowOrDiv0 = true;
|
|
return a;
|
|
}
|
|
if (!a)
|
|
return a;
|
|
// After this point we know that neither a or b are zero.
|
|
unsigned bits = a.getBitWidth();
|
|
APInt zero = APInt::getZero(bits);
|
|
bool aGtZero = a.sgt(zero);
|
|
bool bGtZero = b.sgt(zero);
|
|
if (aGtZero && bGtZero) {
|
|
// Both positive, return a / b.
|
|
return a.sdiv_ov(b, overflowOrDiv0);
|
|
}
|
|
if (!aGtZero && !bGtZero) {
|
|
// Both negative, return -a / -b.
|
|
APInt posA = zero.ssub_ov(a, overflowOrDiv0);
|
|
APInt posB = zero.ssub_ov(b, overflowOrDiv0);
|
|
return posA.sdiv_ov(posB, overflowOrDiv0);
|
|
}
|
|
if (!aGtZero && bGtZero) {
|
|
// A is negative, b is positive, return - ceil(-a, b).
|
|
APInt posA = zero.ssub_ov(a, overflowOrDiv0);
|
|
APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
|
|
return zero.ssub_ov(ceil, overflowOrDiv0);
|
|
}
|
|
// A is positive, b is negative, return - ceil(a, -b).
|
|
APInt posB = zero.ssub_ov(b, overflowOrDiv0);
|
|
APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
|
|
return zero.ssub_ov(ceil, overflowOrDiv0);
|
|
});
|
|
|
|
// Fold out floor division by one. Assumes all tensors of all ones are
|
|
// splats.
|
|
if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
|
|
if (rhs.getValue() == 1)
|
|
return getLhs();
|
|
} else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
|
|
if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
|
|
return getLhs();
|
|
}
|
|
|
|
return overflowOrDiv0 ? Attribute() : result;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// RemUIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) {
|
|
auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
|
|
if (!rhs)
|
|
return {};
|
|
auto rhsValue = rhs.getValue();
|
|
|
|
// x % 1 = 0
|
|
if (rhsValue.isOneValue())
|
|
return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
|
|
|
|
// Don't fold if it requires division by zero.
|
|
if (rhsValue.isNullValue())
|
|
return {};
|
|
|
|
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
|
|
if (!lhs)
|
|
return {};
|
|
return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// RemSIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) {
|
|
auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
|
|
if (!rhs)
|
|
return {};
|
|
auto rhsValue = rhs.getValue();
|
|
|
|
// x % 1 = 0
|
|
if (rhsValue.isOneValue())
|
|
return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
|
|
|
|
// Don't fold if it requires division by zero.
|
|
if (rhsValue.isNullValue())
|
|
return {};
|
|
|
|
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
|
|
if (!lhs)
|
|
return {};
|
|
return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// AndIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) {
|
|
/// and(x, 0) -> 0
|
|
if (matchPattern(getRhs(), m_Zero()))
|
|
return getRhs();
|
|
/// and(x, allOnes) -> x
|
|
APInt intValue;
|
|
if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes())
|
|
return getLhs();
|
|
|
|
return constFoldBinaryOp<IntegerAttr>(
|
|
operands, [](APInt a, const APInt &b) { return std::move(a) & b; });
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// OrIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) {
|
|
/// or(x, 0) -> x
|
|
if (matchPattern(getRhs(), m_Zero()))
|
|
return getLhs();
|
|
/// or(x, <all ones>) -> <all ones>
|
|
if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>())
|
|
if (rhsAttr.getValue().isAllOnes())
|
|
return rhsAttr;
|
|
|
|
return constFoldBinaryOp<IntegerAttr>(
|
|
operands, [](APInt a, const APInt &b) { return std::move(a) | b; });
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// XOrIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) {
|
|
/// xor(x, 0) -> x
|
|
if (matchPattern(getRhs(), m_Zero()))
|
|
return getLhs();
|
|
/// xor(x, x) -> 0
|
|
if (getLhs() == getRhs())
|
|
return Builder(getContext()).getZeroAttr(getType());
|
|
/// xor(xor(x, a), a) -> x
|
|
if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>())
|
|
if (prev.getRhs() == getRhs())
|
|
return prev.getLhs();
|
|
|
|
return constFoldBinaryOp<IntegerAttr>(
|
|
operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; });
|
|
}
|
|
|
|
void arith::XOrIOp::getCanonicalizationPatterns(
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
|
patterns.add<XOrINotCmpI>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// AddFOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
|
|
// addf(x, -0) -> x
|
|
if (matchPattern(getRhs(), m_NegZeroFloat()))
|
|
return getLhs();
|
|
|
|
return constFoldBinaryOp<FloatAttr>(
|
|
operands, [](const APFloat &a, const APFloat &b) { return a + b; });
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SubFOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) {
|
|
// subf(x, +0) -> x
|
|
if (matchPattern(getRhs(), m_PosZeroFloat()))
|
|
return getLhs();
|
|
|
|
return constFoldBinaryOp<FloatAttr>(
|
|
operands, [](const APFloat &a, const APFloat &b) { return a - b; });
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MaxFOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::MaxFOp::fold(ArrayRef<Attribute> operands) {
|
|
assert(operands.size() == 2 && "maxf takes two operands");
|
|
|
|
// maxf(x,x) -> x
|
|
if (getLhs() == getRhs())
|
|
return getRhs();
|
|
|
|
// maxf(x, -inf) -> x
|
|
if (matchPattern(getRhs(), m_NegInfFloat()))
|
|
return getLhs();
|
|
|
|
return constFoldBinaryOp<FloatAttr>(
|
|
operands,
|
|
[](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MaxSIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) {
|
|
assert(operands.size() == 2 && "binary operation takes two operands");
|
|
|
|
// maxsi(x,x) -> x
|
|
if (getLhs() == getRhs())
|
|
return getRhs();
|
|
|
|
APInt intValue;
|
|
// maxsi(x,MAX_INT) -> MAX_INT
|
|
if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
|
|
intValue.isMaxSignedValue())
|
|
return getRhs();
|
|
|
|
// maxsi(x, MIN_INT) -> x
|
|
if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
|
|
intValue.isMinSignedValue())
|
|
return getLhs();
|
|
|
|
return constFoldBinaryOp<IntegerAttr>(operands,
|
|
[](const APInt &a, const APInt &b) {
|
|
return llvm::APIntOps::smax(a, b);
|
|
});
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MaxUIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) {
|
|
assert(operands.size() == 2 && "binary operation takes two operands");
|
|
|
|
// maxui(x,x) -> x
|
|
if (getLhs() == getRhs())
|
|
return getRhs();
|
|
|
|
APInt intValue;
|
|
// maxui(x,MAX_INT) -> MAX_INT
|
|
if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
|
|
return getRhs();
|
|
|
|
// maxui(x, MIN_INT) -> x
|
|
if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
|
|
return getLhs();
|
|
|
|
return constFoldBinaryOp<IntegerAttr>(operands,
|
|
[](const APInt &a, const APInt &b) {
|
|
return llvm::APIntOps::umax(a, b);
|
|
});
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MinFOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::MinFOp::fold(ArrayRef<Attribute> operands) {
|
|
assert(operands.size() == 2 && "minf takes two operands");
|
|
|
|
// minf(x,x) -> x
|
|
if (getLhs() == getRhs())
|
|
return getRhs();
|
|
|
|
// minf(x, +inf) -> x
|
|
if (matchPattern(getRhs(), m_PosInfFloat()))
|
|
return getLhs();
|
|
|
|
return constFoldBinaryOp<FloatAttr>(
|
|
operands,
|
|
[](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MinSIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) {
|
|
assert(operands.size() == 2 && "binary operation takes two operands");
|
|
|
|
// minsi(x,x) -> x
|
|
if (getLhs() == getRhs())
|
|
return getRhs();
|
|
|
|
APInt intValue;
|
|
// minsi(x,MIN_INT) -> MIN_INT
|
|
if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
|
|
intValue.isMinSignedValue())
|
|
return getRhs();
|
|
|
|
// minsi(x, MAX_INT) -> x
|
|
if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
|
|
intValue.isMaxSignedValue())
|
|
return getLhs();
|
|
|
|
return constFoldBinaryOp<IntegerAttr>(operands,
|
|
[](const APInt &a, const APInt &b) {
|
|
return llvm::APIntOps::smin(a, b);
|
|
});
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MinUIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) {
|
|
assert(operands.size() == 2 && "binary operation takes two operands");
|
|
|
|
// minui(x,x) -> x
|
|
if (getLhs() == getRhs())
|
|
return getRhs();
|
|
|
|
APInt intValue;
|
|
// minui(x,MIN_INT) -> MIN_INT
|
|
if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
|
|
return getRhs();
|
|
|
|
// minui(x, MAX_INT) -> x
|
|
if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
|
|
return getLhs();
|
|
|
|
return constFoldBinaryOp<IntegerAttr>(operands,
|
|
[](const APInt &a, const APInt &b) {
|
|
return llvm::APIntOps::umin(a, b);
|
|
});
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MulFOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) {
|
|
APFloat floatValue(0.0f), inverseValue(0.0f);
|
|
// mulf(x, 1) -> x
|
|
if (matchPattern(getRhs(), m_OneFloat()))
|
|
return getLhs();
|
|
|
|
// mulf(1, x) -> x
|
|
if (matchPattern(getLhs(), m_OneFloat()))
|
|
return getRhs();
|
|
|
|
return constFoldBinaryOp<FloatAttr>(
|
|
operands, [](const APFloat &a, const APFloat &b) { return a * b; });
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// DivFOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) {
|
|
APFloat floatValue(0.0f), inverseValue(0.0f);
|
|
// divf(x, 1) -> x
|
|
if (matchPattern(getRhs(), m_OneFloat()))
|
|
return getLhs();
|
|
|
|
return constFoldBinaryOp<FloatAttr>(
|
|
operands, [](const APFloat &a, const APFloat &b) { return a / b; });
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Utility functions for verifying cast ops
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
template <typename... Types>
|
|
using type_list = std::tuple<Types...> *;
|
|
|
|
/// Returns a non-null type only if the provided type is one of the allowed
|
|
/// types or one of the allowed shaped types of the allowed types. Returns the
|
|
/// element type if a valid shaped type is provided.
|
|
template <typename... ShapedTypes, typename... ElementTypes>
|
|
static Type getUnderlyingType(Type type, type_list<ShapedTypes...>,
|
|
type_list<ElementTypes...>) {
|
|
if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>())
|
|
return {};
|
|
|
|
auto underlyingType = getElementTypeOrSelf(type);
|
|
if (!underlyingType.isa<ElementTypes...>())
|
|
return {};
|
|
|
|
return underlyingType;
|
|
}
|
|
|
|
/// Get allowed underlying types for vectors and tensors.
|
|
template <typename... ElementTypes>
|
|
static Type getTypeIfLike(Type type) {
|
|
return getUnderlyingType(type, type_list<VectorType, TensorType>(),
|
|
type_list<ElementTypes...>());
|
|
}
|
|
|
|
/// Get allowed underlying types for vectors, tensors, and memrefs.
|
|
template <typename... ElementTypes>
|
|
static Type getTypeIfLikeOrMemRef(Type type) {
|
|
return getUnderlyingType(type,
|
|
type_list<VectorType, TensorType, MemRefType>(),
|
|
type_list<ElementTypes...>());
|
|
}
|
|
|
|
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
|
|
return inputs.size() == 1 && outputs.size() == 1 &&
|
|
succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Verifiers for integer and floating point extension/truncation ops
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Extend ops can only extend to a wider type.
|
|
template <typename ValType, typename Op>
|
|
static LogicalResult verifyExtOp(Op op) {
|
|
Type srcType = getElementTypeOrSelf(op.getIn().getType());
|
|
Type dstType = getElementTypeOrSelf(op.getType());
|
|
|
|
if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth())
|
|
return op.emitError("result type ")
|
|
<< dstType << " must be wider than operand type " << srcType;
|
|
|
|
return success();
|
|
}
|
|
|
|
// Truncate ops can only truncate to a shorter type.
|
|
template <typename ValType, typename Op>
|
|
static LogicalResult verifyTruncateOp(Op op) {
|
|
Type srcType = getElementTypeOrSelf(op.getIn().getType());
|
|
Type dstType = getElementTypeOrSelf(op.getType());
|
|
|
|
if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth())
|
|
return op.emitError("result type ")
|
|
<< dstType << " must be shorter than operand type " << srcType;
|
|
|
|
return success();
|
|
}
|
|
|
|
/// Validate a cast that changes the width of a type.
|
|
template <template <typename> class WidthComparator, typename... ElementTypes>
|
|
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
|
|
if (!areValidCastInputsAndOutputs(inputs, outputs))
|
|
return false;
|
|
|
|
auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
|
|
auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
|
|
if (!srcType || !dstType)
|
|
return false;
|
|
|
|
return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
|
|
srcType.getIntOrFloatBitWidth());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ExtUIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
|
|
if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
|
|
return IntegerAttr::get(
|
|
getType(), lhs.getValue().zext(getType().getIntOrFloatBitWidth()));
|
|
|
|
if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
|
|
getInMutable().assign(lhs.getIn());
|
|
return getResult();
|
|
}
|
|
|
|
return {};
|
|
}
|
|
|
|
bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
|
return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
|
|
}
|
|
|
|
LogicalResult arith::ExtUIOp::verify() {
|
|
return verifyExtOp<IntegerType>(*this);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ExtSIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
|
|
if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
|
|
return IntegerAttr::get(
|
|
getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth()));
|
|
|
|
if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
|
|
getInMutable().assign(lhs.getIn());
|
|
return getResult();
|
|
}
|
|
|
|
return {};
|
|
}
|
|
|
|
bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
|
return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
|
|
}
|
|
|
|
void arith::ExtSIOp::getCanonicalizationPatterns(
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
|
patterns.add<ExtSIOfExtUI>(context);
|
|
}
|
|
|
|
LogicalResult arith::ExtSIOp::verify() {
|
|
return verifyExtOp<IntegerType>(*this);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ExtFOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
|
return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
|
|
}
|
|
|
|
LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TruncIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
|
|
assert(operands.size() == 1 && "unary operation takes one operand");
|
|
|
|
// trunci(zexti(a)) -> a
|
|
// trunci(sexti(a)) -> a
|
|
if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
|
|
matchPattern(getOperand(), m_Op<arith::ExtSIOp>()))
|
|
return getOperand().getDefiningOp()->getOperand(0);
|
|
|
|
// trunci(trunci(a)) -> trunci(a))
|
|
if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
|
|
setOperand(getOperand().getDefiningOp()->getOperand(0));
|
|
return getResult();
|
|
}
|
|
|
|
if (!operands[0])
|
|
return {};
|
|
|
|
if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) {
|
|
return IntegerAttr::get(
|
|
getType(), lhs.getValue().trunc(getType().getIntOrFloatBitWidth()));
|
|
}
|
|
|
|
return {};
|
|
}
|
|
|
|
bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
|
return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
|
|
}
|
|
|
|
LogicalResult arith::TruncIOp::verify() {
|
|
return verifyTruncateOp<IntegerType>(*this);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TruncFOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Perform safe const propagation for truncf, i.e. only propagate if FP value
|
|
/// can be represented without precision loss or rounding.
|
|
OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) {
|
|
assert(operands.size() == 1 && "unary operation takes one operand");
|
|
|
|
auto constOperand = operands.front();
|
|
if (!constOperand || !constOperand.isa<FloatAttr>())
|
|
return {};
|
|
|
|
// Convert to target type via 'double'.
|
|
double sourceValue =
|
|
constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble();
|
|
auto targetAttr = FloatAttr::get(getType(), sourceValue);
|
|
|
|
// Propagate if constant's value does not change after truncation.
|
|
if (sourceValue == targetAttr.getValue().convertToDouble())
|
|
return targetAttr;
|
|
|
|
return {};
|
|
}
|
|
|
|
bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
|
return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
|
|
}
|
|
|
|
LogicalResult arith::TruncFOp::verify() {
|
|
return verifyTruncateOp<FloatType>(*this);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// AndIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void arith::AndIOp::getCanonicalizationPatterns(
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
|
patterns.add<AndOfExtUI, AndOfExtSI>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// OrIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void arith::OrIOp::getCanonicalizationPatterns(
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
|
patterns.add<OrOfExtUI, OrOfExtSI>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Verifiers for casts between integers and floats.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
template <typename From, typename To>
|
|
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
|
|
if (!areValidCastInputsAndOutputs(inputs, outputs))
|
|
return false;
|
|
|
|
auto srcType = getTypeIfLike<From>(inputs.front());
|
|
auto dstType = getTypeIfLike<To>(outputs.back());
|
|
|
|
return srcType && dstType;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// UIToFPOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
|
return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
|
|
}
|
|
|
|
OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) {
|
|
if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
|
|
const APInt &api = lhs.getValue();
|
|
FloatType floatTy = getType().cast<FloatType>();
|
|
APFloat apf(floatTy.getFloatSemantics(),
|
|
APInt::getZero(floatTy.getWidth()));
|
|
apf.convertFromAPInt(api, /*IsSigned=*/false, APFloat::rmNearestTiesToEven);
|
|
return FloatAttr::get(floatTy, apf);
|
|
}
|
|
return {};
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SIToFPOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
|
return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
|
|
}
|
|
|
|
OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) {
|
|
if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
|
|
const APInt &api = lhs.getValue();
|
|
FloatType floatTy = getType().cast<FloatType>();
|
|
APFloat apf(floatTy.getFloatSemantics(),
|
|
APInt::getZero(floatTy.getWidth()));
|
|
apf.convertFromAPInt(api, /*IsSigned=*/true, APFloat::rmNearestTiesToEven);
|
|
return FloatAttr::get(floatTy, apf);
|
|
}
|
|
return {};
|
|
}
|
|
//===----------------------------------------------------------------------===//
|
|
// FPToUIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
|
return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
|
|
}
|
|
|
|
OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) {
|
|
if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
|
|
const APFloat &apf = lhs.getValue();
|
|
IntegerType intTy = getType().cast<IntegerType>();
|
|
bool ignored;
|
|
APSInt api(intTy.getWidth(), /*isUnsigned=*/true);
|
|
if (APFloat::opInvalidOp ==
|
|
apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) {
|
|
// Undefined behavior invoked - the destination type can't represent
|
|
// the input constant.
|
|
return {};
|
|
}
|
|
return IntegerAttr::get(getType(), api);
|
|
}
|
|
|
|
return {};
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// FPToSIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
|
return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
|
|
}
|
|
|
|
OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) {
|
|
if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
|
|
const APFloat &apf = lhs.getValue();
|
|
IntegerType intTy = getType().cast<IntegerType>();
|
|
bool ignored;
|
|
APSInt api(intTy.getWidth(), /*isUnsigned=*/false);
|
|
if (APFloat::opInvalidOp ==
|
|
apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) {
|
|
// Undefined behavior invoked - the destination type can't represent
|
|
// the input constant.
|
|
return {};
|
|
}
|
|
return IntegerAttr::get(getType(), api);
|
|
}
|
|
|
|
return {};
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// IndexCastOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
|
|
TypeRange outputs) {
|
|
if (!areValidCastInputsAndOutputs(inputs, outputs))
|
|
return false;
|
|
|
|
auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
|
|
auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
|
|
if (!srcType || !dstType)
|
|
return false;
|
|
|
|
return (srcType.isIndex() && dstType.isSignlessInteger()) ||
|
|
(srcType.isSignlessInteger() && dstType.isIndex());
|
|
}
|
|
|
|
OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) {
|
|
// index_cast(constant) -> constant
|
|
// A little hack because we go through int. Otherwise, the size of the
|
|
// constant might need to change.
|
|
if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>())
|
|
return IntegerAttr::get(getType(), value.getInt());
|
|
|
|
return {};
|
|
}
|
|
|
|
void arith::IndexCastOp::getCanonicalizationPatterns(
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
|
patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// BitcastOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
|
if (!areValidCastInputsAndOutputs(inputs, outputs))
|
|
return false;
|
|
|
|
auto srcType =
|
|
getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
|
|
auto dstType =
|
|
getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
|
|
if (!srcType || !dstType)
|
|
return false;
|
|
|
|
return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
|
|
}
|
|
|
|
OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) {
|
|
assert(operands.size() == 1 && "bitcast op expects 1 operand");
|
|
|
|
auto resType = getType();
|
|
auto operand = operands[0];
|
|
if (!operand)
|
|
return {};
|
|
|
|
/// Bitcast dense elements.
|
|
if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>())
|
|
return denseAttr.bitcast(resType.cast<ShapedType>().getElementType());
|
|
/// Other shaped types unhandled.
|
|
if (resType.isa<ShapedType>())
|
|
return {};
|
|
|
|
/// Bitcast integer or float to integer or float.
|
|
APInt bits = operand.isa<FloatAttr>()
|
|
? operand.cast<FloatAttr>().getValue().bitcastToAPInt()
|
|
: operand.cast<IntegerAttr>().getValue();
|
|
|
|
if (auto resFloatType = resType.dyn_cast<FloatType>())
|
|
return FloatAttr::get(resType,
|
|
APFloat(resFloatType.getFloatSemantics(), bits));
|
|
return IntegerAttr::get(resType, bits);
|
|
}
|
|
|
|
void arith::BitcastOp::getCanonicalizationPatterns(
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
|
patterns.add<BitcastOfBitcast>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Helpers for compare ops
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Return the type of the same shape (scalar, vector or tensor) containing i1.
|
|
static Type getI1SameShape(Type type) {
|
|
auto i1Type = IntegerType::get(type.getContext(), 1);
|
|
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
|
return RankedTensorType::get(tensorType.getShape(), i1Type);
|
|
if (type.isa<UnrankedTensorType>())
|
|
return UnrankedTensorType::get(i1Type);
|
|
if (auto vectorType = type.dyn_cast<VectorType>())
|
|
return VectorType::get(vectorType.getShape(), i1Type,
|
|
vectorType.getNumScalableDims());
|
|
return i1Type;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CmpIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
|
|
/// comparison predicates.
|
|
bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
|
|
const APInt &lhs, const APInt &rhs) {
|
|
switch (predicate) {
|
|
case arith::CmpIPredicate::eq:
|
|
return lhs.eq(rhs);
|
|
case arith::CmpIPredicate::ne:
|
|
return lhs.ne(rhs);
|
|
case arith::CmpIPredicate::slt:
|
|
return lhs.slt(rhs);
|
|
case arith::CmpIPredicate::sle:
|
|
return lhs.sle(rhs);
|
|
case arith::CmpIPredicate::sgt:
|
|
return lhs.sgt(rhs);
|
|
case arith::CmpIPredicate::sge:
|
|
return lhs.sge(rhs);
|
|
case arith::CmpIPredicate::ult:
|
|
return lhs.ult(rhs);
|
|
case arith::CmpIPredicate::ule:
|
|
return lhs.ule(rhs);
|
|
case arith::CmpIPredicate::ugt:
|
|
return lhs.ugt(rhs);
|
|
case arith::CmpIPredicate::uge:
|
|
return lhs.uge(rhs);
|
|
}
|
|
llvm_unreachable("unknown cmpi predicate kind");
|
|
}
|
|
|
|
/// Returns true if the predicate is true for two equal operands.
|
|
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
|
|
switch (predicate) {
|
|
case arith::CmpIPredicate::eq:
|
|
case arith::CmpIPredicate::sle:
|
|
case arith::CmpIPredicate::sge:
|
|
case arith::CmpIPredicate::ule:
|
|
case arith::CmpIPredicate::uge:
|
|
return true;
|
|
case arith::CmpIPredicate::ne:
|
|
case arith::CmpIPredicate::slt:
|
|
case arith::CmpIPredicate::sgt:
|
|
case arith::CmpIPredicate::ult:
|
|
case arith::CmpIPredicate::ugt:
|
|
return false;
|
|
}
|
|
llvm_unreachable("unknown cmpi predicate kind");
|
|
}
|
|
|
|
static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
|
|
auto boolAttr = BoolAttr::get(ctx, value);
|
|
ShapedType shapedType = type.dyn_cast_or_null<ShapedType>();
|
|
if (!shapedType)
|
|
return boolAttr;
|
|
return DenseElementsAttr::get(shapedType, boolAttr);
|
|
}
|
|
|
|
OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
|
|
assert(operands.size() == 2 && "cmpi takes two operands");
|
|
|
|
// cmpi(pred, x, x)
|
|
if (getLhs() == getRhs()) {
|
|
auto val = applyCmpPredicateToEqualOperands(getPredicate());
|
|
return getBoolAttribute(getType(), getContext(), val);
|
|
}
|
|
|
|
if (matchPattern(getRhs(), m_Zero())) {
|
|
if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
|
|
if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) {
|
|
// extsi(%x : i1 -> iN) != 0 -> %x
|
|
if (getPredicate() == arith::CmpIPredicate::ne) {
|
|
return extOp.getOperand();
|
|
}
|
|
}
|
|
}
|
|
if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
|
|
if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1) {
|
|
// extui(%x : i1 -> iN) != 0 -> %x
|
|
if (getPredicate() == arith::CmpIPredicate::ne) {
|
|
return extOp.getOperand();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
|
|
auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
|
|
if (!lhs || !rhs)
|
|
return {};
|
|
|
|
auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
|
|
return BoolAttr::get(getContext(), val);
|
|
}
|
|
|
|
void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
MLIRContext *context) {
|
|
patterns.insert<CmpIExtSI, CmpIExtUI>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CmpFOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
|
|
/// comparison predicates.
|
|
bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
|
|
const APFloat &lhs, const APFloat &rhs) {
|
|
auto cmpResult = lhs.compare(rhs);
|
|
switch (predicate) {
|
|
case arith::CmpFPredicate::AlwaysFalse:
|
|
return false;
|
|
case arith::CmpFPredicate::OEQ:
|
|
return cmpResult == APFloat::cmpEqual;
|
|
case arith::CmpFPredicate::OGT:
|
|
return cmpResult == APFloat::cmpGreaterThan;
|
|
case arith::CmpFPredicate::OGE:
|
|
return cmpResult == APFloat::cmpGreaterThan ||
|
|
cmpResult == APFloat::cmpEqual;
|
|
case arith::CmpFPredicate::OLT:
|
|
return cmpResult == APFloat::cmpLessThan;
|
|
case arith::CmpFPredicate::OLE:
|
|
return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
|
|
case arith::CmpFPredicate::ONE:
|
|
return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
|
|
case arith::CmpFPredicate::ORD:
|
|
return cmpResult != APFloat::cmpUnordered;
|
|
case arith::CmpFPredicate::UEQ:
|
|
return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
|
|
case arith::CmpFPredicate::UGT:
|
|
return cmpResult == APFloat::cmpUnordered ||
|
|
cmpResult == APFloat::cmpGreaterThan;
|
|
case arith::CmpFPredicate::UGE:
|
|
return cmpResult == APFloat::cmpUnordered ||
|
|
cmpResult == APFloat::cmpGreaterThan ||
|
|
cmpResult == APFloat::cmpEqual;
|
|
case arith::CmpFPredicate::ULT:
|
|
return cmpResult == APFloat::cmpUnordered ||
|
|
cmpResult == APFloat::cmpLessThan;
|
|
case arith::CmpFPredicate::ULE:
|
|
return cmpResult == APFloat::cmpUnordered ||
|
|
cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
|
|
case arith::CmpFPredicate::UNE:
|
|
return cmpResult != APFloat::cmpEqual;
|
|
case arith::CmpFPredicate::UNO:
|
|
return cmpResult == APFloat::cmpUnordered;
|
|
case arith::CmpFPredicate::AlwaysTrue:
|
|
return true;
|
|
}
|
|
llvm_unreachable("unknown cmpf predicate kind");
|
|
}
|
|
|
|
OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) {
|
|
assert(operands.size() == 2 && "cmpf takes two operands");
|
|
|
|
auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
|
|
auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
|
|
|
|
// If one operand is NaN, making them both NaN does not change the result.
|
|
if (lhs && lhs.getValue().isNaN())
|
|
rhs = lhs;
|
|
if (rhs && rhs.getValue().isNaN())
|
|
lhs = rhs;
|
|
|
|
if (!lhs || !rhs)
|
|
return {};
|
|
|
|
auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
|
|
return BoolAttr::get(getContext(), val);
|
|
}
|
|
|
|
class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
|
|
public:
|
|
using OpRewritePattern<CmpFOp>::OpRewritePattern;
|
|
|
|
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
|
|
bool isUnsigned) {
|
|
using namespace arith;
|
|
switch (pred) {
|
|
case CmpFPredicate::UEQ:
|
|
case CmpFPredicate::OEQ:
|
|
return CmpIPredicate::eq;
|
|
case CmpFPredicate::UGT:
|
|
case CmpFPredicate::OGT:
|
|
return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
|
|
case CmpFPredicate::UGE:
|
|
case CmpFPredicate::OGE:
|
|
return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
|
|
case CmpFPredicate::ULT:
|
|
case CmpFPredicate::OLT:
|
|
return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
|
|
case CmpFPredicate::ULE:
|
|
case CmpFPredicate::OLE:
|
|
return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
|
|
case CmpFPredicate::UNE:
|
|
case CmpFPredicate::ONE:
|
|
return CmpIPredicate::ne;
|
|
default:
|
|
llvm_unreachable("Unexpected predicate!");
|
|
}
|
|
}
|
|
|
|
LogicalResult matchAndRewrite(CmpFOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
FloatAttr flt;
|
|
if (!matchPattern(op.getRhs(), m_Constant(&flt)))
|
|
return failure();
|
|
|
|
const APFloat &rhs = flt.getValue();
|
|
|
|
// Don't attempt to fold a nan.
|
|
if (rhs.isNaN())
|
|
return failure();
|
|
|
|
// Get the width of the mantissa. We don't want to hack on conversions that
|
|
// might lose information from the integer, e.g. "i64 -> float"
|
|
FloatType floatTy = op.getRhs().getType().cast<FloatType>();
|
|
int mantissaWidth = floatTy.getFPMantissaWidth();
|
|
if (mantissaWidth <= 0)
|
|
return failure();
|
|
|
|
bool isUnsigned;
|
|
Value intVal;
|
|
|
|
if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
|
|
isUnsigned = false;
|
|
intVal = si.getIn();
|
|
} else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
|
|
isUnsigned = true;
|
|
intVal = ui.getIn();
|
|
} else {
|
|
return failure();
|
|
}
|
|
|
|
// Check to see that the input is converted from an integer type that is
|
|
// small enough that preserves all bits.
|
|
auto intTy = intVal.getType().cast<IntegerType>();
|
|
auto intWidth = intTy.getWidth();
|
|
|
|
// Number of bits representing values, as opposed to the sign
|
|
auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
|
|
|
|
// Following test does NOT adjust intWidth downwards for signed inputs,
|
|
// because the most negative value still requires all the mantissa bits
|
|
// to distinguish it from one less than that value.
|
|
if ((int)intWidth > mantissaWidth) {
|
|
// Conversion would lose accuracy. Check if loss can impact comparison.
|
|
int exponent = ilogb(rhs);
|
|
if (exponent == APFloat::IEK_Inf) {
|
|
int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
|
|
if (maxExponent < (int)valueBits) {
|
|
// Conversion could create infinity.
|
|
return failure();
|
|
}
|
|
} else {
|
|
// Note that if rhs is zero or NaN, then Exp is negative
|
|
// and first condition is trivially false.
|
|
if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
|
|
// Conversion could affect comparison.
|
|
return failure();
|
|
}
|
|
}
|
|
}
|
|
|
|
// Convert to equivalent cmpi predicate
|
|
CmpIPredicate pred;
|
|
switch (op.getPredicate()) {
|
|
case CmpFPredicate::ORD:
|
|
// Int to fp conversion doesn't create a nan (ord checks neither is a nan)
|
|
rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
|
|
/*width=*/1);
|
|
return success();
|
|
case CmpFPredicate::UNO:
|
|
// Int to fp conversion doesn't create a nan (uno checks either is a nan)
|
|
rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
|
|
/*width=*/1);
|
|
return success();
|
|
default:
|
|
pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
|
|
break;
|
|
}
|
|
|
|
if (!isUnsigned) {
|
|
// If the rhs value is > SignedMax, fold the comparison. This handles
|
|
// +INF and large values.
|
|
APFloat signedMax(rhs.getSemantics());
|
|
signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
|
|
APFloat::rmNearestTiesToEven);
|
|
if (signedMax < rhs) { // smax < 13123.0
|
|
if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
|
|
pred == CmpIPredicate::sle)
|
|
rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
|
|
/*width=*/1);
|
|
else
|
|
rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
|
|
/*width=*/1);
|
|
return success();
|
|
}
|
|
} else {
|
|
// If the rhs value is > UnsignedMax, fold the comparison. This handles
|
|
// +INF and large values.
|
|
APFloat unsignedMax(rhs.getSemantics());
|
|
unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
|
|
APFloat::rmNearestTiesToEven);
|
|
if (unsignedMax < rhs) { // umax < 13123.0
|
|
if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
|
|
pred == CmpIPredicate::ule)
|
|
rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
|
|
/*width=*/1);
|
|
else
|
|
rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
|
|
/*width=*/1);
|
|
return success();
|
|
}
|
|
}
|
|
|
|
if (!isUnsigned) {
|
|
// See if the rhs value is < SignedMin.
|
|
APFloat signedMin(rhs.getSemantics());
|
|
signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
|
|
APFloat::rmNearestTiesToEven);
|
|
if (signedMin > rhs) { // smin > 12312.0
|
|
if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
|
|
pred == CmpIPredicate::sge)
|
|
rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
|
|
/*width=*/1);
|
|
else
|
|
rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
|
|
/*width=*/1);
|
|
return success();
|
|
}
|
|
} else {
|
|
// See if the rhs value is < UnsignedMin.
|
|
APFloat unsignedMin(rhs.getSemantics());
|
|
unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
|
|
APFloat::rmNearestTiesToEven);
|
|
if (unsignedMin > rhs) { // umin > 12312.0
|
|
if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
|
|
pred == CmpIPredicate::uge)
|
|
rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
|
|
/*width=*/1);
|
|
else
|
|
rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
|
|
/*width=*/1);
|
|
return success();
|
|
}
|
|
}
|
|
|
|
// Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
|
|
// [0, UMAX], but it may still be fractional. See if it is fractional by
|
|
// casting the FP value to the integer value and back, checking for
|
|
// equality. Don't do this for zero, because -0.0 is not fractional.
|
|
bool ignored;
|
|
APSInt rhsInt(intWidth, isUnsigned);
|
|
if (APFloat::opInvalidOp ==
|
|
rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
|
|
// Undefined behavior invoked - the destination type can't represent
|
|
// the input constant.
|
|
return failure();
|
|
}
|
|
|
|
if (!rhs.isZero()) {
|
|
APFloat apf(floatTy.getFloatSemantics(),
|
|
APInt::getZero(floatTy.getWidth()));
|
|
apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
|
|
|
|
bool equal = apf == rhs;
|
|
if (!equal) {
|
|
// If we had a comparison against a fractional value, we have to adjust
|
|
// the compare predicate and sometimes the value. rhsInt is rounded
|
|
// towards zero at this point.
|
|
switch (pred) {
|
|
case CmpIPredicate::ne: // (float)int != 4.4 --> true
|
|
rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
|
|
/*width=*/1);
|
|
return success();
|
|
case CmpIPredicate::eq: // (float)int == 4.4 --> false
|
|
rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
|
|
/*width=*/1);
|
|
return success();
|
|
case CmpIPredicate::ule:
|
|
// (float)int <= 4.4 --> int <= 4
|
|
// (float)int <= -4.4 --> false
|
|
if (rhs.isNegative()) {
|
|
rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
|
|
/*width=*/1);
|
|
return success();
|
|
}
|
|
break;
|
|
case CmpIPredicate::sle:
|
|
// (float)int <= 4.4 --> int <= 4
|
|
// (float)int <= -4.4 --> int < -4
|
|
if (rhs.isNegative())
|
|
pred = CmpIPredicate::slt;
|
|
break;
|
|
case CmpIPredicate::ult:
|
|
// (float)int < -4.4 --> false
|
|
// (float)int < 4.4 --> int <= 4
|
|
if (rhs.isNegative()) {
|
|
rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
|
|
/*width=*/1);
|
|
return success();
|
|
}
|
|
pred = CmpIPredicate::ule;
|
|
break;
|
|
case CmpIPredicate::slt:
|
|
// (float)int < -4.4 --> int < -4
|
|
// (float)int < 4.4 --> int <= 4
|
|
if (!rhs.isNegative())
|
|
pred = CmpIPredicate::sle;
|
|
break;
|
|
case CmpIPredicate::ugt:
|
|
// (float)int > 4.4 --> int > 4
|
|
// (float)int > -4.4 --> true
|
|
if (rhs.isNegative()) {
|
|
rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
|
|
/*width=*/1);
|
|
return success();
|
|
}
|
|
break;
|
|
case CmpIPredicate::sgt:
|
|
// (float)int > 4.4 --> int > 4
|
|
// (float)int > -4.4 --> int >= -4
|
|
if (rhs.isNegative())
|
|
pred = CmpIPredicate::sge;
|
|
break;
|
|
case CmpIPredicate::uge:
|
|
// (float)int >= -4.4 --> true
|
|
// (float)int >= 4.4 --> int > 4
|
|
if (rhs.isNegative()) {
|
|
rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
|
|
/*width=*/1);
|
|
return success();
|
|
}
|
|
pred = CmpIPredicate::ugt;
|
|
break;
|
|
case CmpIPredicate::sge:
|
|
// (float)int >= -4.4 --> int >= -4
|
|
// (float)int >= 4.4 --> int > 4
|
|
if (!rhs.isNegative())
|
|
pred = CmpIPredicate::sgt;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Lower this FP comparison into an appropriate integer version of the
|
|
// comparison.
|
|
rewriter.replaceOpWithNewOp<CmpIOp>(
|
|
op, pred, intVal,
|
|
rewriter.create<ConstantOp>(
|
|
op.getLoc(), intVal.getType(),
|
|
rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
|
|
return success();
|
|
}
|
|
};
|
|
|
|
void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
MLIRContext *context) {
|
|
patterns.insert<CmpFIntToFPConst>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SelectOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Transforms a select of a boolean to arithmetic operations
|
|
//
|
|
// arith.select %arg, %x, %y : i1
|
|
//
|
|
// becomes
|
|
//
|
|
// and(%arg, %x) or and(!%arg, %y)
|
|
struct SelectI1Simplify : public OpRewritePattern<arith::SelectOp> {
|
|
using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(arith::SelectOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
if (!op.getType().isInteger(1))
|
|
return failure();
|
|
|
|
Value falseConstant =
|
|
rewriter.create<arith::ConstantIntOp>(op.getLoc(), true, 1);
|
|
Value notCondition = rewriter.create<arith::XOrIOp>(
|
|
op.getLoc(), op.getCondition(), falseConstant);
|
|
|
|
Value trueVal = rewriter.create<arith::AndIOp>(
|
|
op.getLoc(), op.getCondition(), op.getTrueValue());
|
|
Value falseVal = rewriter.create<arith::AndIOp>(op.getLoc(), notCondition,
|
|
op.getFalseValue());
|
|
rewriter.replaceOpWithNewOp<arith::OrIOp>(op, trueVal, falseVal);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// select %arg, %c1, %c0 => extui %arg
|
|
struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
|
|
using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(arith::SelectOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
// Cannot extui i1 to i1, or i1 to f32
|
|
if (!op.getType().isa<IntegerType>() || op.getType().isInteger(1))
|
|
return failure();
|
|
|
|
// select %x, c1, %c0 => extui %arg
|
|
if (matchPattern(op.getTrueValue(), m_One()))
|
|
if (matchPattern(op.getFalseValue(), m_Zero())) {
|
|
rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
|
|
op.getCondition());
|
|
return success();
|
|
}
|
|
|
|
// select %x, c0, %c1 => extui (xor %arg, true)
|
|
if (matchPattern(op.getTrueValue(), m_Zero()))
|
|
if (matchPattern(op.getFalseValue(), m_One())) {
|
|
rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
|
|
op, op.getType(),
|
|
rewriter.create<arith::XOrIOp>(
|
|
op.getLoc(), op.getCondition(),
|
|
rewriter.create<arith::ConstantIntOp>(
|
|
op.getLoc(), 1, op.getCondition().getType())));
|
|
return success();
|
|
}
|
|
|
|
return failure();
|
|
}
|
|
};
|
|
|
|
void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.add<SelectI1Simplify, SelectToExtUI>(context);
|
|
}
|
|
|
|
OpFoldResult arith::SelectOp::fold(ArrayRef<Attribute> operands) {
|
|
Value trueVal = getTrueValue();
|
|
Value falseVal = getFalseValue();
|
|
if (trueVal == falseVal)
|
|
return trueVal;
|
|
|
|
Value condition = getCondition();
|
|
|
|
// select true, %0, %1 => %0
|
|
if (matchPattern(condition, m_One()))
|
|
return trueVal;
|
|
|
|
// select false, %0, %1 => %1
|
|
if (matchPattern(condition, m_Zero()))
|
|
return falseVal;
|
|
|
|
// select %x, true, false => %x
|
|
if (getType().isInteger(1))
|
|
if (matchPattern(getTrueValue(), m_One()))
|
|
if (matchPattern(getFalseValue(), m_Zero()))
|
|
return condition;
|
|
|
|
if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
|
|
auto pred = cmp.getPredicate();
|
|
if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
|
|
auto cmpLhs = cmp.getLhs();
|
|
auto cmpRhs = cmp.getRhs();
|
|
|
|
// %0 = arith.cmpi eq, %arg0, %arg1
|
|
// %1 = arith.select %0, %arg0, %arg1 => %arg1
|
|
|
|
// %0 = arith.cmpi ne, %arg0, %arg1
|
|
// %1 = arith.select %0, %arg0, %arg1 => %arg0
|
|
|
|
if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
|
|
(cmpRhs == trueVal && cmpLhs == falseVal))
|
|
return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
|
|
}
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
Type conditionType, resultType;
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
|
|
if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
|
|
parser.parseOptionalAttrDict(result.attributes) ||
|
|
parser.parseColonType(resultType))
|
|
return failure();
|
|
|
|
// Check for the explicit condition type if this is a masked tensor or vector.
|
|
if (succeeded(parser.parseOptionalComma())) {
|
|
conditionType = resultType;
|
|
if (parser.parseType(resultType))
|
|
return failure();
|
|
} else {
|
|
conditionType = parser.getBuilder().getI1Type();
|
|
}
|
|
|
|
result.addTypes(resultType);
|
|
return parser.resolveOperands(operands,
|
|
{conditionType, resultType, resultType},
|
|
parser.getNameLoc(), result.operands);
|
|
}
|
|
|
|
void arith::SelectOp::print(OpAsmPrinter &p) {
|
|
p << " " << getOperands();
|
|
p.printOptionalAttrDict((*this)->getAttrs());
|
|
p << " : ";
|
|
if (ShapedType condType = getCondition().getType().dyn_cast<ShapedType>())
|
|
p << condType << ", ";
|
|
p << getType();
|
|
}
|
|
|
|
LogicalResult arith::SelectOp::verify() {
|
|
Type conditionType = getCondition().getType();
|
|
if (conditionType.isSignlessInteger(1))
|
|
return success();
|
|
|
|
// If the result type is a vector or tensor, the type can be a mask with the
|
|
// same elements.
|
|
Type resultType = getType();
|
|
if (!resultType.isa<TensorType, VectorType>())
|
|
return emitOpError() << "expected condition to be a signless i1, but got "
|
|
<< conditionType;
|
|
Type shapedConditionType = getI1SameShape(resultType);
|
|
if (conditionType != shapedConditionType) {
|
|
return emitOpError() << "expected condition type to have the same shape "
|
|
"as the result type, expected "
|
|
<< shapedConditionType << ", but got "
|
|
<< conditionType;
|
|
}
|
|
return success();
|
|
}
|
|
//===----------------------------------------------------------------------===//
|
|
// ShLIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::ShLIOp::fold(ArrayRef<Attribute> operands) {
|
|
// Don't fold if shifting more than the bit width.
|
|
bool bounded = false;
|
|
auto result = constFoldBinaryOp<IntegerAttr>(
|
|
operands, [&](const APInt &a, const APInt &b) {
|
|
bounded = b.ule(b.getBitWidth());
|
|
return std::move(a).shl(b);
|
|
});
|
|
return bounded ? result : Attribute();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ShRUIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::ShRUIOp::fold(ArrayRef<Attribute> operands) {
|
|
// Don't fold if shifting more than the bit width.
|
|
bool bounded = false;
|
|
auto result = constFoldBinaryOp<IntegerAttr>(
|
|
operands, [&](const APInt &a, const APInt &b) {
|
|
bounded = b.ule(b.getBitWidth());
|
|
return std::move(a).lshr(b);
|
|
});
|
|
return bounded ? result : Attribute();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ShRSIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::ShRSIOp::fold(ArrayRef<Attribute> operands) {
|
|
// Don't fold if shifting more than the bit width.
|
|
bool bounded = false;
|
|
auto result = constFoldBinaryOp<IntegerAttr>(
|
|
operands, [&](const APInt &a, const APInt &b) {
|
|
bounded = b.ule(b.getBitWidth());
|
|
return std::move(a).ashr(b);
|
|
});
|
|
return bounded ? result : Attribute();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Atomic Enum
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Returns the identity value attribute associated with an AtomicRMWKind op.
|
|
Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
|
|
OpBuilder &builder, Location loc) {
|
|
switch (kind) {
|
|
case AtomicRMWKind::maxf:
|
|
return builder.getFloatAttr(
|
|
resultType,
|
|
APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
|
|
/*Negative=*/true));
|
|
case AtomicRMWKind::addf:
|
|
case AtomicRMWKind::addi:
|
|
case AtomicRMWKind::maxu:
|
|
case AtomicRMWKind::ori:
|
|
return builder.getZeroAttr(resultType);
|
|
case AtomicRMWKind::andi:
|
|
return builder.getIntegerAttr(
|
|
resultType,
|
|
APInt::getAllOnes(resultType.cast<IntegerType>().getWidth()));
|
|
case AtomicRMWKind::maxs:
|
|
return builder.getIntegerAttr(
|
|
resultType,
|
|
APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth()));
|
|
case AtomicRMWKind::minf:
|
|
return builder.getFloatAttr(
|
|
resultType,
|
|
APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
|
|
/*Negative=*/false));
|
|
case AtomicRMWKind::mins:
|
|
return builder.getIntegerAttr(
|
|
resultType,
|
|
APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth()));
|
|
case AtomicRMWKind::minu:
|
|
return builder.getIntegerAttr(
|
|
resultType,
|
|
APInt::getMaxValue(resultType.cast<IntegerType>().getWidth()));
|
|
case AtomicRMWKind::muli:
|
|
return builder.getIntegerAttr(resultType, 1);
|
|
case AtomicRMWKind::mulf:
|
|
return builder.getFloatAttr(resultType, 1);
|
|
// TODO: Add remaining reduction operations.
|
|
default:
|
|
(void)emitOptionalError(loc, "Reduction operation type not supported");
|
|
break;
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
/// Returns the identity value associated with an AtomicRMWKind op.
|
|
Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
|
|
OpBuilder &builder, Location loc) {
|
|
Attribute attr = getIdentityValueAttr(op, resultType, builder, loc);
|
|
return builder.create<arith::ConstantOp>(loc, attr);
|
|
}
|
|
|
|
/// Return the value obtained by applying the reduction operation kind
|
|
/// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
|
|
Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
|
|
Location loc, Value lhs, Value rhs) {
|
|
switch (op) {
|
|
case AtomicRMWKind::addf:
|
|
return builder.create<arith::AddFOp>(loc, lhs, rhs);
|
|
case AtomicRMWKind::addi:
|
|
return builder.create<arith::AddIOp>(loc, lhs, rhs);
|
|
case AtomicRMWKind::mulf:
|
|
return builder.create<arith::MulFOp>(loc, lhs, rhs);
|
|
case AtomicRMWKind::muli:
|
|
return builder.create<arith::MulIOp>(loc, lhs, rhs);
|
|
case AtomicRMWKind::maxf:
|
|
return builder.create<arith::MaxFOp>(loc, lhs, rhs);
|
|
case AtomicRMWKind::minf:
|
|
return builder.create<arith::MinFOp>(loc, lhs, rhs);
|
|
case AtomicRMWKind::maxs:
|
|
return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
|
|
case AtomicRMWKind::mins:
|
|
return builder.create<arith::MinSIOp>(loc, lhs, rhs);
|
|
case AtomicRMWKind::maxu:
|
|
return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
|
|
case AtomicRMWKind::minu:
|
|
return builder.create<arith::MinUIOp>(loc, lhs, rhs);
|
|
case AtomicRMWKind::ori:
|
|
return builder.create<arith::OrIOp>(loc, lhs, rhs);
|
|
case AtomicRMWKind::andi:
|
|
return builder.create<arith::AndIOp>(loc, lhs, rhs);
|
|
// TODO: Add remaining reduction operations.
|
|
default:
|
|
(void)emitOptionalError(loc, "Reduction operation type not supported");
|
|
break;
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TableGen'd op method definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc"
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TableGen'd enum attribute definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc"
|