The fillOp's value needs to casted

During elementwise fusion the fillOp's value was directly
referred without casting which can create mismatching dtypes.

Reviewed By: mravishankar, ThomasRaoux

Differential Revision: https://reviews.llvm.org/D137447
This commit is contained in:
Prashant Kumar 2022-11-04 17:09:22 +00:00
parent c9eeaedccd
commit 04b449e147
5 changed files with 82 additions and 43 deletions

View File

@ -92,6 +92,12 @@ SmallVector<Value>
getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
ArrayRef<OpFoldResult> valueOrAttrVec);
/// Converts a scalar value `operand` to type `toType`. If the value doesn't
/// convert, a warning will be issued and the operand is returned as is (which
/// will presumably yield a verification issue downstream).
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand,
Type toType, bool isUnsignedCast);
/// Helper struct to build simple arithmetic quantities with minimal type
/// inference support.
struct ArithBuilder {

View File

@ -80,6 +80,50 @@ Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc,
return b.create<arith::TruncIOp>(loc, targetIntegerType, value);
}
Value mlir::convertScalarToDtype(OpBuilder &b, Location loc, Value operand,
Type toType, bool isUnsignedCast) {
if (operand.getType() == toType)
return operand;
if (auto toIntType = toType.dyn_cast<IntegerType>()) {
// If operand is floating point, cast directly to the int type.
if (operand.getType().isa<FloatType>()) {
if (isUnsignedCast)
return b.create<arith::FPToUIOp>(loc, toType, operand);
return b.create<arith::FPToSIOp>(loc, toType, operand);
}
// Cast index operands directly to the int type.
if (operand.getType().isIndex())
return b.create<arith::IndexCastOp>(loc, toType, operand);
if (auto fromIntType = operand.getType().dyn_cast<IntegerType>()) {
// Either extend or truncate.
if (toIntType.getWidth() > fromIntType.getWidth()) {
if (isUnsignedCast)
return b.create<arith::ExtUIOp>(loc, toType, operand);
return b.create<arith::ExtSIOp>(loc, toType, operand);
}
if (toIntType.getWidth() < fromIntType.getWidth())
return b.create<arith::TruncIOp>(loc, toType, operand);
}
} else if (auto toFloatType = toType.dyn_cast<FloatType>()) {
// If operand is integer, cast directly to the float type.
// Note that it is unclear how to cast from BF16<->FP16.
if (operand.getType().isa<IntegerType>()) {
if (isUnsignedCast)
return b.create<arith::UIToFPOp>(loc, toFloatType, operand);
return b.create<arith::SIToFPOp>(loc, toFloatType, operand);
}
if (auto fromFloatType = operand.getType().dyn_cast<FloatType>()) {
if (toFloatType.getWidth() > fromFloatType.getWidth())
return b.create<arith::ExtFOp>(loc, toFloatType, operand);
if (toFloatType.getWidth() < fromFloatType.getWidth())
return b.create<arith::TruncFOp>(loc, toFloatType, operand);
}
}
emitWarning(loc) << "could not cast operand of type " << operand.getType()
<< " to " << toType;
return operand;
}
SmallVector<Value>
mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
ArrayRef<OpFoldResult> valueOrAttrVec) {

View File

@ -423,48 +423,7 @@ private:
Value cast(Type toType, Value operand, bool isUnsignedCast) {
OpBuilder builder = getBuilder();
auto loc = operand.getLoc();
if (operand.getType() == toType)
return operand;
if (auto toIntType = toType.dyn_cast<IntegerType>()) {
// If operand is floating point, cast directly to the int type.
if (operand.getType().isa<FloatType>()) {
if (isUnsignedCast)
return builder.create<arith::FPToUIOp>(loc, toType, operand);
return builder.create<arith::FPToSIOp>(loc, toType, operand);
}
// Cast index operands directly to the int type.
if (operand.getType().isIndex())
return builder.create<arith::IndexCastOp>(loc, toType, operand);
if (auto fromIntType = operand.getType().dyn_cast<IntegerType>()) {
// Either extend or truncate.
if (toIntType.getWidth() > fromIntType.getWidth()) {
if (isUnsignedCast)
return builder.create<arith::ExtUIOp>(loc, toType, operand);
return builder.create<arith::ExtSIOp>(loc, toType, operand);
}
if (toIntType.getWidth() < fromIntType.getWidth())
return builder.create<arith::TruncIOp>(loc, toType, operand);
}
} else if (auto toFloatType = toType.dyn_cast<FloatType>()) {
// If operand is integer, cast directly to the float type.
// Note that it is unclear how to cast from BF16<->FP16.
if (operand.getType().isa<IntegerType>()) {
if (isUnsignedCast)
return builder.create<arith::UIToFPOp>(loc, toFloatType, operand);
return builder.create<arith::SIToFPOp>(loc, toFloatType, operand);
}
if (auto fromFloatType = operand.getType().dyn_cast<FloatType>()) {
if (toFloatType.getWidth() > fromFloatType.getWidth())
return builder.create<arith::ExtFOp>(loc, toFloatType, operand);
if (toFloatType.getWidth() < fromFloatType.getWidth())
return builder.create<arith::TruncFOp>(loc, toFloatType, operand);
}
}
emitWarning(operand.getLoc()) << "could not cast operand of type "
<< operand.getType() << " to " << toType;
return operand;
return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
}
bool isComplex(Value value) { return value.getType().isa<ComplexType>(); }

View File

@ -1744,8 +1744,14 @@ struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {
if (!fillOp)
continue;
fillFound = true;
Value fillVal = fillOp.value();
auto resultType =
fillOp.result().getType().cast<RankedTensorType>().getElementType();
Value convertedVal =
convertScalarToDtype(rewriter, fillOp.getLoc(), fillVal, resultType,
/*isUnsignedCast =*/false);
payload.getArgument(opOperand->getOperandNumber())
.replaceAllUsesWith(fillOp.value());
.replaceAllUsesWith(convertedVal);
}
return success(fillFound);
}

View File

@ -1017,6 +1017,30 @@ func.func @fold_fill_generic_basic(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
// -----
// CHECK-LABEL: func @fold_fill_generic_different_dtype
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf16>) -> tensor<?xf16> {
// CHECK-NOT: linalg.fill
// CHECK: %[[GENERIC_OP:.*]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0]] : tensor<?xf16>)
// CHECK-SAME: outs({{.*}} : tensor<?xf16>) {
#map0 = affine_map<(d0) -> (d0)>
func.func @fold_fill_generic_different_dtype(%arg0: tensor<?xf16>) -> (tensor<?xf16>) {
%c0 = arith.constant 0 : index
%cst = arith.constant 7.0 : f32
%0 = tensor.dim %arg0, %c0 : tensor<?xf16>
%1 = tensor.empty(%0) : tensor<?xf16>
%2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<?xf16>) -> tensor<?xf16>
%3 = tensor.empty(%0) : tensor<?xf16>
%4 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %2 : tensor<?xf16>, tensor<?xf16>) outs (%3:tensor<?xf16>) {
^bb0(%arg1: f16, %arg2: f16, %arg3: f16):
%5 = arith.addf %arg1, %arg2 : f16
linalg.yield %5 : f16
} -> tensor<?xf16>
return %4 : tensor<?xf16>
}
// -----
// CHECK-LABEL: func @fold_fill_generic_mixedaccess
// CHECK-NOT: linalg.fill
// CHECK: %[[GENERIC_OP:.*]] = linalg.generic