The fillOp's value needs to casted
During elementwise fusion the fillOp's value was directly referred without casting which can create mismatching dtypes. Reviewed By: mravishankar, ThomasRaoux Differential Revision: https://reviews.llvm.org/D137447
This commit is contained in:
parent
c9eeaedccd
commit
04b449e147
@ -92,6 +92,12 @@ SmallVector<Value>
|
||||
getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
|
||||
ArrayRef<OpFoldResult> valueOrAttrVec);
|
||||
|
||||
/// Converts a scalar value `operand` to type `toType`. If the value doesn't
|
||||
/// convert, a warning will be issued and the operand is returned as is (which
|
||||
/// will presumably yield a verification issue downstream).
|
||||
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand,
|
||||
Type toType, bool isUnsignedCast);
|
||||
|
||||
/// Helper struct to build simple arithmetic quantities with minimal type
|
||||
/// inference support.
|
||||
struct ArithBuilder {
|
||||
|
@ -80,6 +80,50 @@ Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc,
|
||||
return b.create<arith::TruncIOp>(loc, targetIntegerType, value);
|
||||
}
|
||||
|
||||
Value mlir::convertScalarToDtype(OpBuilder &b, Location loc, Value operand,
|
||||
Type toType, bool isUnsignedCast) {
|
||||
if (operand.getType() == toType)
|
||||
return operand;
|
||||
if (auto toIntType = toType.dyn_cast<IntegerType>()) {
|
||||
// If operand is floating point, cast directly to the int type.
|
||||
if (operand.getType().isa<FloatType>()) {
|
||||
if (isUnsignedCast)
|
||||
return b.create<arith::FPToUIOp>(loc, toType, operand);
|
||||
return b.create<arith::FPToSIOp>(loc, toType, operand);
|
||||
}
|
||||
// Cast index operands directly to the int type.
|
||||
if (operand.getType().isIndex())
|
||||
return b.create<arith::IndexCastOp>(loc, toType, operand);
|
||||
if (auto fromIntType = operand.getType().dyn_cast<IntegerType>()) {
|
||||
// Either extend or truncate.
|
||||
if (toIntType.getWidth() > fromIntType.getWidth()) {
|
||||
if (isUnsignedCast)
|
||||
return b.create<arith::ExtUIOp>(loc, toType, operand);
|
||||
return b.create<arith::ExtSIOp>(loc, toType, operand);
|
||||
}
|
||||
if (toIntType.getWidth() < fromIntType.getWidth())
|
||||
return b.create<arith::TruncIOp>(loc, toType, operand);
|
||||
}
|
||||
} else if (auto toFloatType = toType.dyn_cast<FloatType>()) {
|
||||
// If operand is integer, cast directly to the float type.
|
||||
// Note that it is unclear how to cast from BF16<->FP16.
|
||||
if (operand.getType().isa<IntegerType>()) {
|
||||
if (isUnsignedCast)
|
||||
return b.create<arith::UIToFPOp>(loc, toFloatType, operand);
|
||||
return b.create<arith::SIToFPOp>(loc, toFloatType, operand);
|
||||
}
|
||||
if (auto fromFloatType = operand.getType().dyn_cast<FloatType>()) {
|
||||
if (toFloatType.getWidth() > fromFloatType.getWidth())
|
||||
return b.create<arith::ExtFOp>(loc, toFloatType, operand);
|
||||
if (toFloatType.getWidth() < fromFloatType.getWidth())
|
||||
return b.create<arith::TruncFOp>(loc, toFloatType, operand);
|
||||
}
|
||||
}
|
||||
emitWarning(loc) << "could not cast operand of type " << operand.getType()
|
||||
<< " to " << toType;
|
||||
return operand;
|
||||
}
|
||||
|
||||
SmallVector<Value>
|
||||
mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
|
||||
ArrayRef<OpFoldResult> valueOrAttrVec) {
|
||||
|
@ -423,48 +423,7 @@ private:
|
||||
Value cast(Type toType, Value operand, bool isUnsignedCast) {
|
||||
OpBuilder builder = getBuilder();
|
||||
auto loc = operand.getLoc();
|
||||
|
||||
if (operand.getType() == toType)
|
||||
return operand;
|
||||
if (auto toIntType = toType.dyn_cast<IntegerType>()) {
|
||||
// If operand is floating point, cast directly to the int type.
|
||||
if (operand.getType().isa<FloatType>()) {
|
||||
if (isUnsignedCast)
|
||||
return builder.create<arith::FPToUIOp>(loc, toType, operand);
|
||||
return builder.create<arith::FPToSIOp>(loc, toType, operand);
|
||||
}
|
||||
// Cast index operands directly to the int type.
|
||||
if (operand.getType().isIndex())
|
||||
return builder.create<arith::IndexCastOp>(loc, toType, operand);
|
||||
if (auto fromIntType = operand.getType().dyn_cast<IntegerType>()) {
|
||||
// Either extend or truncate.
|
||||
if (toIntType.getWidth() > fromIntType.getWidth()) {
|
||||
if (isUnsignedCast)
|
||||
return builder.create<arith::ExtUIOp>(loc, toType, operand);
|
||||
return builder.create<arith::ExtSIOp>(loc, toType, operand);
|
||||
}
|
||||
if (toIntType.getWidth() < fromIntType.getWidth())
|
||||
return builder.create<arith::TruncIOp>(loc, toType, operand);
|
||||
}
|
||||
} else if (auto toFloatType = toType.dyn_cast<FloatType>()) {
|
||||
// If operand is integer, cast directly to the float type.
|
||||
// Note that it is unclear how to cast from BF16<->FP16.
|
||||
if (operand.getType().isa<IntegerType>()) {
|
||||
if (isUnsignedCast)
|
||||
return builder.create<arith::UIToFPOp>(loc, toFloatType, operand);
|
||||
return builder.create<arith::SIToFPOp>(loc, toFloatType, operand);
|
||||
}
|
||||
if (auto fromFloatType = operand.getType().dyn_cast<FloatType>()) {
|
||||
if (toFloatType.getWidth() > fromFloatType.getWidth())
|
||||
return builder.create<arith::ExtFOp>(loc, toFloatType, operand);
|
||||
if (toFloatType.getWidth() < fromFloatType.getWidth())
|
||||
return builder.create<arith::TruncFOp>(loc, toFloatType, operand);
|
||||
}
|
||||
}
|
||||
|
||||
emitWarning(operand.getLoc()) << "could not cast operand of type "
|
||||
<< operand.getType() << " to " << toType;
|
||||
return operand;
|
||||
return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
|
||||
}
|
||||
|
||||
bool isComplex(Value value) { return value.getType().isa<ComplexType>(); }
|
||||
|
@ -1744,8 +1744,14 @@ struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
|
||||
if (!fillOp)
|
||||
continue;
|
||||
fillFound = true;
|
||||
Value fillVal = fillOp.value();
|
||||
auto resultType =
|
||||
fillOp.result().getType().cast<RankedTensorType>().getElementType();
|
||||
Value convertedVal =
|
||||
convertScalarToDtype(rewriter, fillOp.getLoc(), fillVal, resultType,
|
||||
/*isUnsignedCast =*/false);
|
||||
payload.getArgument(opOperand->getOperandNumber())
|
||||
.replaceAllUsesWith(fillOp.value());
|
||||
.replaceAllUsesWith(convertedVal);
|
||||
}
|
||||
return success(fillFound);
|
||||
}
|
||||
|
@ -1017,6 +1017,30 @@ func.func @fold_fill_generic_basic(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @fold_fill_generic_different_dtype
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf16>) -> tensor<?xf16> {
|
||||
// CHECK-NOT: linalg.fill
|
||||
// CHECK: %[[GENERIC_OP:.*]] = linalg.generic
|
||||
// CHECK-SAME: ins(%[[ARG0]] : tensor<?xf16>)
|
||||
// CHECK-SAME: outs({{.*}} : tensor<?xf16>) {
|
||||
#map0 = affine_map<(d0) -> (d0)>
|
||||
func.func @fold_fill_generic_different_dtype(%arg0: tensor<?xf16>) -> (tensor<?xf16>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
%cst = arith.constant 7.0 : f32
|
||||
%0 = tensor.dim %arg0, %c0 : tensor<?xf16>
|
||||
%1 = tensor.empty(%0) : tensor<?xf16>
|
||||
%2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<?xf16>) -> tensor<?xf16>
|
||||
%3 = tensor.empty(%0) : tensor<?xf16>
|
||||
%4 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %2 : tensor<?xf16>, tensor<?xf16>) outs (%3:tensor<?xf16>) {
|
||||
^bb0(%arg1: f16, %arg2: f16, %arg3: f16):
|
||||
%5 = arith.addf %arg1, %arg2 : f16
|
||||
linalg.yield %5 : f16
|
||||
} -> tensor<?xf16>
|
||||
return %4 : tensor<?xf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @fold_fill_generic_mixedaccess
|
||||
// CHECK-NOT: linalg.fill
|
||||
// CHECK: %[[GENERIC_OP:.*]] = linalg.generic
|
||||
|
Loading…
x
Reference in New Issue
Block a user