Reland "[flang] Inline hlfir.dot_product. (#123143)" (#123385)

This reverts commit afc43a7b626ae07f56e6534320e0b46d26070750.
+Fixed declaration of hlfir::genExtentsVector().

Some good results for induct2, where dot_product is applied
to a vector of unknow size and a known 3-element vector:
the inlining ends up generating a 3-iteration loop, which
is then fully unrolled. With late FIR simplification
it is not happening even when the simplified intrinsics
implementation is inlined by LLVM (because the loop bounds
are not known).

This change just follows the current approach to expose
the loops for later worksharing application.
This commit is contained in:
Slava Zakharin 2025-01-17 12:09:44 -08:00 committed by GitHub
parent 580ba2eed2
commit 71ff486bee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 326 additions and 115 deletions

View File

@ -513,6 +513,12 @@ genTypeAndKindConvert(mlir::Location loc, fir::FirOpBuilder &builder,
Entity loadElementAt(mlir::Location loc, fir::FirOpBuilder &builder,
Entity entity, mlir::ValueRange oneBasedIndices);
/// Return a vector of extents for the given entity.
/// The function creates new operations, but tries to clean-up
/// after itself.
llvm::SmallVector<mlir::Value, Fortran::common::maxRank>
genExtentsVector(mlir::Location loc, fir::FirOpBuilder &builder, Entity entity);
} // namespace hlfir
#endif // FORTRAN_OPTIMIZER_BUILDER_HLFIRTOOLS_H

View File

@ -1421,3 +1421,15 @@ hlfir::Entity hlfir::loadElementAt(mlir::Location loc,
return loadTrivialScalar(loc, builder,
getElementAt(loc, builder, entity, oneBasedIndices));
}
llvm::SmallVector<mlir::Value, Fortran::common::maxRank>
hlfir::genExtentsVector(mlir::Location loc, fir::FirOpBuilder &builder,
hlfir::Entity entity) {
entity = hlfir::derefPointersAndAllocatables(loc, builder, entity);
mlir::Value shape = hlfir::genShape(loc, builder, entity);
llvm::SmallVector<mlir::Value, Fortran::common::maxRank> extents =
hlfir::getExplicitExtentsFromShape(shape, builder);
if (shape.getUses().empty())
shape.getDefiningOp()->erase();
return extents;
}

View File

@ -37,6 +37,79 @@ static llvm::cl::opt<bool> forceMatmulAsElemental(
namespace {
// Helper class to generate operations related to computing
// product of values.
class ProductFactory {
public:
ProductFactory(mlir::Location loc, fir::FirOpBuilder &builder)
: loc(loc), builder(builder) {}
// Generate an update of the inner product value:
// acc += v1 * v2, OR
// acc += CONJ(v1) * v2, OR
// acc ||= v1 && v2
//
// CONJ parameter specifies whether the first complex product argument
// needs to be conjugated.
template <bool CONJ = false>
mlir::Value genAccumulateProduct(mlir::Value acc, mlir::Value v1,
mlir::Value v2) {
mlir::Type resultType = acc.getType();
acc = castToProductType(acc, resultType);
v1 = castToProductType(v1, resultType);
v2 = castToProductType(v2, resultType);
mlir::Value result;
if (mlir::isa<mlir::FloatType>(resultType)) {
result = builder.create<mlir::arith::AddFOp>(
loc, acc, builder.create<mlir::arith::MulFOp>(loc, v1, v2));
} else if (mlir::isa<mlir::ComplexType>(resultType)) {
if constexpr (CONJ)
result = fir::IntrinsicLibrary{builder, loc}.genConjg(resultType, v1);
else
result = v1;
result = builder.create<fir::AddcOp>(
loc, acc, builder.create<fir::MulcOp>(loc, result, v2));
} else if (mlir::isa<mlir::IntegerType>(resultType)) {
result = builder.create<mlir::arith::AddIOp>(
loc, acc, builder.create<mlir::arith::MulIOp>(loc, v1, v2));
} else if (mlir::isa<fir::LogicalType>(resultType)) {
result = builder.create<mlir::arith::OrIOp>(
loc, acc, builder.create<mlir::arith::AndIOp>(loc, v1, v2));
} else {
llvm_unreachable("unsupported type");
}
return builder.createConvert(loc, resultType, result);
}
private:
mlir::Location loc;
fir::FirOpBuilder &builder;
mlir::Value castToProductType(mlir::Value value, mlir::Type type) {
if (mlir::isa<fir::LogicalType>(type))
return builder.createConvert(loc, builder.getIntegerType(1), value);
// TODO: the multiplications/additions by/of zero resulting from
// complex * real are optimized by LLVM under -fno-signed-zeros
// -fno-honor-nans.
// We can make them disappear by default if we:
// * either expand the complex multiplication into real
// operations, OR
// * set nnan nsz fast-math flags to the complex operations.
if (fir::isa_complex(type) && !fir::isa_complex(value.getType())) {
mlir::Value zeroCmplx = fir::factory::createZeroValue(builder, loc, type);
fir::factory::Complex helper(builder, loc);
mlir::Type partType = helper.getComplexPartType(type);
return helper.insertComplexPart(zeroCmplx,
castToProductType(value, partType),
/*isImagPart=*/false);
}
return builder.createConvert(loc, type, value);
}
};
class TransposeAsElementalConversion
: public mlir::OpRewritePattern<hlfir::TransposeOp> {
public:
@ -90,11 +163,8 @@ private:
static mlir::Value genResultShape(mlir::Location loc,
fir::FirOpBuilder &builder,
hlfir::Entity array) {
mlir::Value inShape = hlfir::genShape(loc, builder, array);
llvm::SmallVector<mlir::Value> inExtents =
hlfir::getExplicitExtentsFromShape(inShape, builder);
if (inShape.getUses().empty())
inShape.getDefiningOp()->erase();
llvm::SmallVector<mlir::Value, 2> inExtents =
hlfir::genExtentsVector(loc, builder, array);
// transpose indices
assert(inExtents.size() == 2 && "checked in TransposeOp::validate");
@ -137,7 +207,7 @@ public:
mlir::Value resultShape, dimExtent;
llvm::SmallVector<mlir::Value> arrayExtents;
if (isTotalReduction)
arrayExtents = genArrayExtents(loc, builder, array);
arrayExtents = hlfir::genExtentsVector(loc, builder, array);
else
std::tie(resultShape, dimExtent) =
genResultShapeForPartialReduction(loc, builder, array, dimVal);
@ -163,7 +233,8 @@ public:
// If DIM is not present, do total reduction.
// Initial value for the reduction.
mlir::Value reductionInitValue = genInitValue(loc, builder, elementType);
mlir::Value reductionInitValue =
fir::factory::createZeroValue(builder, loc, elementType);
// The reduction loop may be unordered if FastMathFlags::reassoc
// transformations are allowed. The integer reduction is always
@ -264,17 +335,6 @@ public:
}
private:
static llvm::SmallVector<mlir::Value>
genArrayExtents(mlir::Location loc, fir::FirOpBuilder &builder,
hlfir::Entity array) {
mlir::Value inShape = hlfir::genShape(loc, builder, array);
llvm::SmallVector<mlir::Value> inExtents =
hlfir::getExplicitExtentsFromShape(inShape, builder);
if (inShape.getUses().empty())
inShape.getDefiningOp()->erase();
return inExtents;
}
// Return fir.shape specifying the shape of the result
// of a SUM reduction with DIM=dimVal. The second return value
// is the extent of the DIM dimension.
@ -283,7 +343,7 @@ private:
fir::FirOpBuilder &builder,
hlfir::Entity array, int64_t dimVal) {
llvm::SmallVector<mlir::Value> inExtents =
genArrayExtents(loc, builder, array);
hlfir::genExtentsVector(loc, builder, array);
assert(dimVal > 0 && dimVal <= static_cast<int64_t>(inExtents.size()) &&
"DIM must be present and a positive constant not exceeding "
"the array's rank");
@ -293,26 +353,6 @@ private:
return {builder.create<fir::ShapeOp>(loc, inExtents), dimExtent};
}
// Generate the initial value for a SUM reduction with the given
// data type.
static mlir::Value genInitValue(mlir::Location loc,
fir::FirOpBuilder &builder,
mlir::Type elementType) {
if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) {
const llvm::fltSemantics &sem = ty.getFloatSemantics();
return builder.createRealConstant(loc, elementType,
llvm::APFloat::getZero(sem));
} else if (auto ty = mlir::dyn_cast<mlir::ComplexType>(elementType)) {
mlir::Value initValue = genInitValue(loc, builder, ty.getElementType());
return fir::factory::Complex{builder, loc}.createComplex(ty, initValue,
initValue);
} else if (mlir::isa<mlir::IntegerType>(elementType)) {
return builder.createIntegerConstant(loc, elementType, 0);
}
llvm_unreachable("unsupported SUM reduction type");
}
// Generate scalar addition of the two values (of the same data type).
static mlir::Value genScalarAdd(mlir::Location loc,
fir::FirOpBuilder &builder,
@ -570,16 +610,10 @@ private:
static std::tuple<mlir::Value, mlir::Value>
genResultShape(mlir::Location loc, fir::FirOpBuilder &builder,
hlfir::Entity input1, hlfir::Entity input2) {
mlir::Value input1Shape = hlfir::genShape(loc, builder, input1);
llvm::SmallVector<mlir::Value> input1Extents =
hlfir::getExplicitExtentsFromShape(input1Shape, builder);
if (input1Shape.getUses().empty())
input1Shape.getDefiningOp()->erase();
mlir::Value input2Shape = hlfir::genShape(loc, builder, input2);
llvm::SmallVector<mlir::Value> input2Extents =
hlfir::getExplicitExtentsFromShape(input2Shape, builder);
if (input2Shape.getUses().empty())
input2Shape.getDefiningOp()->erase();
llvm::SmallVector<mlir::Value, 2> input1Extents =
hlfir::genExtentsVector(loc, builder, input1);
llvm::SmallVector<mlir::Value, 2> input2Extents =
hlfir::genExtentsVector(loc, builder, input2);
llvm::SmallVector<mlir::Value, 2> newExtents;
mlir::Value innerProduct1Extent, innerProduct2Extent;
@ -627,60 +661,6 @@ private:
innerProductExtent[0]};
}
static mlir::Value castToProductType(mlir::Location loc,
fir::FirOpBuilder &builder,
mlir::Value value, mlir::Type type) {
if (mlir::isa<fir::LogicalType>(type))
return builder.createConvert(loc, builder.getIntegerType(1), value);
// TODO: the multiplications/additions by/of zero resulting from
// complex * real are optimized by LLVM under -fno-signed-zeros
// -fno-honor-nans.
// We can make them disappear by default if we:
// * either expand the complex multiplication into real
// operations, OR
// * set nnan nsz fast-math flags to the complex operations.
if (fir::isa_complex(type) && !fir::isa_complex(value.getType())) {
mlir::Value zeroCmplx = fir::factory::createZeroValue(builder, loc, type);
fir::factory::Complex helper(builder, loc);
mlir::Type partType = helper.getComplexPartType(type);
return helper.insertComplexPart(
zeroCmplx, castToProductType(loc, builder, value, partType),
/*isImagPart=*/false);
}
return builder.createConvert(loc, type, value);
}
// Generate an update of the inner product value:
// acc += v1 * v2, OR
// acc ||= v1 && v2
static mlir::Value genAccumulateProduct(mlir::Location loc,
fir::FirOpBuilder &builder,
mlir::Type resultType,
mlir::Value acc, mlir::Value v1,
mlir::Value v2) {
acc = castToProductType(loc, builder, acc, resultType);
v1 = castToProductType(loc, builder, v1, resultType);
v2 = castToProductType(loc, builder, v2, resultType);
mlir::Value result;
if (mlir::isa<mlir::FloatType>(resultType))
result = builder.create<mlir::arith::AddFOp>(
loc, acc, builder.create<mlir::arith::MulFOp>(loc, v1, v2));
else if (mlir::isa<mlir::ComplexType>(resultType))
result = builder.create<fir::AddcOp>(
loc, acc, builder.create<fir::MulcOp>(loc, v1, v2));
else if (mlir::isa<mlir::IntegerType>(resultType))
result = builder.create<mlir::arith::AddIOp>(
loc, acc, builder.create<mlir::arith::MulIOp>(loc, v1, v2));
else if (mlir::isa<fir::LogicalType>(resultType))
result = builder.create<mlir::arith::OrIOp>(
loc, acc, builder.create<mlir::arith::AndIOp>(loc, v1, v2));
else
llvm_unreachable("unsupported type");
return builder.createConvert(loc, resultType, result);
}
static mlir::LogicalResult
genContiguousMatmul(mlir::Location loc, fir::FirOpBuilder &builder,
hlfir::Entity result, mlir::Value resultShape,
@ -748,9 +728,9 @@ private:
hlfir::loadElementAt(loc, builder, lhs, {I, K});
hlfir::Entity rhsElementValue =
hlfir::loadElementAt(loc, builder, rhs, {K, J});
mlir::Value productValue = genAccumulateProduct(
loc, builder, resultElementType, resultElementValue,
lhsElementValue, rhsElementValue);
mlir::Value productValue =
ProductFactory{loc, builder}.genAccumulateProduct(
resultElementValue, lhsElementValue, rhsElementValue);
builder.create<hlfir::AssignOp>(loc, productValue, resultElement);
return {};
};
@ -785,9 +765,9 @@ private:
hlfir::loadElementAt(loc, builder, lhs, {J, K});
hlfir::Entity rhsElementValue =
hlfir::loadElementAt(loc, builder, rhs, {K});
mlir::Value productValue = genAccumulateProduct(
loc, builder, resultElementType, resultElementValue,
lhsElementValue, rhsElementValue);
mlir::Value productValue =
ProductFactory{loc, builder}.genAccumulateProduct(
resultElementValue, lhsElementValue, rhsElementValue);
builder.create<hlfir::AssignOp>(loc, productValue, resultElement);
return {};
};
@ -817,9 +797,9 @@ private:
hlfir::loadElementAt(loc, builder, lhs, {K});
hlfir::Entity rhsElementValue =
hlfir::loadElementAt(loc, builder, rhs, {K, J});
mlir::Value productValue = genAccumulateProduct(
loc, builder, resultElementType, resultElementValue,
lhsElementValue, rhsElementValue);
mlir::Value productValue =
ProductFactory{loc, builder}.genAccumulateProduct(
resultElementValue, lhsElementValue, rhsElementValue);
builder.create<hlfir::AssignOp>(loc, productValue, resultElement);
return {};
};
@ -885,9 +865,9 @@ private:
hlfir::loadElementAt(loc, builder, lhs, lhsIndices);
hlfir::Entity rhsElementValue =
hlfir::loadElementAt(loc, builder, rhs, rhsIndices);
mlir::Value productValue = genAccumulateProduct(
loc, builder, resultElementType, reductionArgs[0], lhsElementValue,
rhsElementValue);
mlir::Value productValue =
ProductFactory{loc, builder}.genAccumulateProduct(
reductionArgs[0], lhsElementValue, rhsElementValue);
return {productValue};
};
llvm::SmallVector<mlir::Value, 1> innerProductValue =
@ -904,6 +884,73 @@ private:
}
};
class DotProductConversion
: public mlir::OpRewritePattern<hlfir::DotProductOp> {
public:
using mlir::OpRewritePattern<hlfir::DotProductOp>::OpRewritePattern;
llvm::LogicalResult
matchAndRewrite(hlfir::DotProductOp product,
mlir::PatternRewriter &rewriter) const override {
hlfir::Entity op = hlfir::Entity{product};
if (!op.isScalar())
return rewriter.notifyMatchFailure(product, "produces non-scalar result");
mlir::Location loc = product.getLoc();
fir::FirOpBuilder builder{rewriter, product.getOperation()};
hlfir::Entity lhs = hlfir::Entity{product.getLhs()};
hlfir::Entity rhs = hlfir::Entity{product.getRhs()};
mlir::Type resultElementType = product.getType();
bool isUnordered = mlir::isa<mlir::IntegerType>(resultElementType) ||
mlir::isa<fir::LogicalType>(resultElementType) ||
static_cast<bool>(builder.getFastMathFlags() &
mlir::arith::FastMathFlags::reassoc);
mlir::Value extent = genProductExtent(loc, builder, lhs, rhs);
auto genBody = [&](mlir::Location loc, fir::FirOpBuilder &builder,
mlir::ValueRange oneBasedIndices,
mlir::ValueRange reductionArgs)
-> llvm::SmallVector<mlir::Value, 1> {
hlfir::Entity lhsElementValue =
hlfir::loadElementAt(loc, builder, lhs, oneBasedIndices);
hlfir::Entity rhsElementValue =
hlfir::loadElementAt(loc, builder, rhs, oneBasedIndices);
mlir::Value productValue =
ProductFactory{loc, builder}.genAccumulateProduct</*CONJ=*/true>(
reductionArgs[0], lhsElementValue, rhsElementValue);
return {productValue};
};
mlir::Value initValue =
fir::factory::createZeroValue(builder, loc, resultElementType);
llvm::SmallVector<mlir::Value, 1> result = hlfir::genLoopNestWithReductions(
loc, builder, {extent},
/*reductionInits=*/{initValue}, genBody, isUnordered);
rewriter.replaceOp(product, result[0]);
return mlir::success();
}
private:
static mlir::Value genProductExtent(mlir::Location loc,
fir::FirOpBuilder &builder,
hlfir::Entity input1,
hlfir::Entity input2) {
llvm::SmallVector<mlir::Value, 1> input1Extents =
hlfir::genExtentsVector(loc, builder, input1);
llvm::SmallVector<mlir::Value, 1> input2Extents =
hlfir::genExtentsVector(loc, builder, input2);
assert(input1Extents.size() == 1 && input2Extents.size() == 1 &&
"hlfir.dot_product arguments must be vectors");
llvm::SmallVector<mlir::Value, 1> extent =
fir::factory::deduceOptimalExtents(input1Extents, input2Extents);
return extent[0];
}
};
class SimplifyHLFIRIntrinsics
: public hlfir::impl::SimplifyHLFIRIntrinsicsBase<SimplifyHLFIRIntrinsics> {
public:
@ -939,6 +986,8 @@ public:
if (forceMatmulAsElemental || this->allowNewSideEffects)
patterns.insert<MatmulConversion<hlfir::MatmulOp>>(context);
patterns.insert<DotProductConversion>(context);
if (mlir::failed(mlir::applyPatternsGreedily(
getOperation(), std::move(patterns), config))) {
mlir::emitError(getOperation()->getLoc(),

View File

@ -0,0 +1,144 @@
// Test hlfir.dot_product simplification to a reduction loop:
// RUN: fir-opt --simplify-hlfir-intrinsics %s | FileCheck %s
func.func @dot_product_integer(%arg0: !hlfir.expr<?xi16>, %arg1: !hlfir.expr<?xi32>) -> i32 {
%res = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr<?xi16>, !hlfir.expr<?xi32>) -> i32
return %res : i32
}
// CHECK-LABEL: func.func @dot_product_integer(
// CHECK-SAME: %[[VAL_0:.*]]: !hlfir.expr<?xi16>,
// CHECK-SAME: %[[VAL_1:.*]]: !hlfir.expr<?xi32>) -> i32 {
// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_3:.*]] = arith.constant 0 : i32
// CHECK: %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?xi16>) -> !fir.shape<1>
// CHECK: %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<1>) -> index
// CHECK: %[[VAL_6:.*]] = fir.do_loop %[[VAL_7:.*]] = %[[VAL_2]] to %[[VAL_5]] step %[[VAL_2]] unordered iter_args(%[[VAL_8:.*]] = %[[VAL_3]]) -> (i32) {
// CHECK: %[[VAL_9:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_7]] : (!hlfir.expr<?xi16>, index) -> i16
// CHECK: %[[VAL_10:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_7]] : (!hlfir.expr<?xi32>, index) -> i32
// CHECK: %[[VAL_11:.*]] = fir.convert %[[VAL_9]] : (i16) -> i32
// CHECK: %[[VAL_12:.*]] = arith.muli %[[VAL_11]], %[[VAL_10]] : i32
// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_8]], %[[VAL_12]] : i32
// CHECK: fir.result %[[VAL_13]] : i32
// CHECK: }
// CHECK: return %[[VAL_6]] : i32
// CHECK: }
func.func @dot_product_real(%arg0: !hlfir.expr<?xf32>, %arg1: !hlfir.expr<?xf16>) -> f32 {
%res = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr<?xf32>, !hlfir.expr<?xf16>) -> f32
return %res : f32
}
// CHECK-LABEL: func.func @dot_product_real(
// CHECK-SAME: %[[VAL_0:.*]]: !hlfir.expr<?xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: !hlfir.expr<?xf16>) -> f32 {
// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?xf32>) -> !fir.shape<1>
// CHECK: %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<1>) -> index
// CHECK: %[[VAL_6:.*]] = fir.do_loop %[[VAL_7:.*]] = %[[VAL_2]] to %[[VAL_5]] step %[[VAL_2]] iter_args(%[[VAL_8:.*]] = %[[VAL_3]]) -> (f32) {
// CHECK: %[[VAL_9:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_7]] : (!hlfir.expr<?xf32>, index) -> f32
// CHECK: %[[VAL_10:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_7]] : (!hlfir.expr<?xf16>, index) -> f16
// CHECK: %[[VAL_11:.*]] = fir.convert %[[VAL_10]] : (f16) -> f32
// CHECK: %[[VAL_12:.*]] = arith.mulf %[[VAL_9]], %[[VAL_11]] : f32
// CHECK: %[[VAL_13:.*]] = arith.addf %[[VAL_8]], %[[VAL_12]] : f32
// CHECK: fir.result %[[VAL_13]] : f32
// CHECK: }
// CHECK: return %[[VAL_6]] : f32
// CHECK: }
func.func @dot_product_complex(%arg0: !hlfir.expr<?xcomplex<f32>>, %arg1: !hlfir.expr<?xcomplex<f16>>) -> complex<f32> {
%res = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr<?xcomplex<f32>>, !hlfir.expr<?xcomplex<f16>>) -> complex<f32>
return %res : complex<f32>
}
// CHECK-LABEL: func.func @dot_product_complex(
// CHECK-SAME: %[[VAL_0:.*]]: !hlfir.expr<?xcomplex<f32>>,
// CHECK-SAME: %[[VAL_1:.*]]: !hlfir.expr<?xcomplex<f16>>) -> complex<f32> {
// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?xcomplex<f32>>) -> !fir.shape<1>
// CHECK: %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<1>) -> index
// CHECK: %[[VAL_6:.*]] = fir.undefined complex<f32>
// CHECK: %[[VAL_7:.*]] = fir.insert_value %[[VAL_6]], %[[VAL_3]], [0 : index] : (complex<f32>, f32) -> complex<f32>
// CHECK: %[[VAL_8:.*]] = fir.insert_value %[[VAL_7]], %[[VAL_3]], [1 : index] : (complex<f32>, f32) -> complex<f32>
// CHECK: %[[VAL_9:.*]] = fir.do_loop %[[VAL_10:.*]] = %[[VAL_2]] to %[[VAL_5]] step %[[VAL_2]] iter_args(%[[VAL_11:.*]] = %[[VAL_8]]) -> (complex<f32>) {
// CHECK: %[[VAL_12:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_10]] : (!hlfir.expr<?xcomplex<f32>>, index) -> complex<f32>
// CHECK: %[[VAL_13:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_10]] : (!hlfir.expr<?xcomplex<f16>>, index) -> complex<f16>
// CHECK: %[[VAL_14:.*]] = fir.convert %[[VAL_13]] : (complex<f16>) -> complex<f32>
// CHECK: %[[VAL_15:.*]] = fir.extract_value %[[VAL_12]], [1 : index] : (complex<f32>) -> f32
// CHECK: %[[VAL_16:.*]] = arith.negf %[[VAL_15]] : f32
// CHECK: %[[VAL_17:.*]] = fir.insert_value %[[VAL_12]], %[[VAL_16]], [1 : index] : (complex<f32>, f32) -> complex<f32>
// CHECK: %[[VAL_18:.*]] = fir.mulc %[[VAL_17]], %[[VAL_14]] : complex<f32>
// CHECK: %[[VAL_19:.*]] = fir.addc %[[VAL_11]], %[[VAL_18]] : complex<f32>
// CHECK: fir.result %[[VAL_19]] : complex<f32>
// CHECK: }
// CHECK: return %[[VAL_9]] : complex<f32>
// CHECK: }
func.func @dot_product_real_complex(%arg0: !hlfir.expr<?xf32>, %arg1: !hlfir.expr<?xcomplex<f16>>) -> complex<f32> {
%res = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr<?xf32>, !hlfir.expr<?xcomplex<f16>>) -> complex<f32>
return %res : complex<f32>
}
// CHECK-LABEL: func.func @dot_product_real_complex(
// CHECK-SAME: %[[VAL_0:.*]]: !hlfir.expr<?xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: !hlfir.expr<?xcomplex<f16>>) -> complex<f32> {
// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?xf32>) -> !fir.shape<1>
// CHECK: %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<1>) -> index
// CHECK: %[[VAL_6:.*]] = fir.undefined complex<f32>
// CHECK: %[[VAL_7:.*]] = fir.insert_value %[[VAL_6]], %[[VAL_3]], [0 : index] : (complex<f32>, f32) -> complex<f32>
// CHECK: %[[VAL_8:.*]] = fir.insert_value %[[VAL_7]], %[[VAL_3]], [1 : index] : (complex<f32>, f32) -> complex<f32>
// CHECK: %[[VAL_9:.*]] = fir.do_loop %[[VAL_10:.*]] = %[[VAL_2]] to %[[VAL_5]] step %[[VAL_2]] iter_args(%[[VAL_11:.*]] = %[[VAL_8]]) -> (complex<f32>) {
// CHECK: %[[VAL_12:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_10]] : (!hlfir.expr<?xf32>, index) -> f32
// CHECK: %[[VAL_13:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_10]] : (!hlfir.expr<?xcomplex<f16>>, index) -> complex<f16>
// CHECK: %[[VAL_14:.*]] = fir.undefined complex<f32>
// CHECK: %[[VAL_15:.*]] = fir.insert_value %[[VAL_14]], %[[VAL_3]], [0 : index] : (complex<f32>, f32) -> complex<f32>
// CHECK: %[[VAL_16:.*]] = fir.insert_value %[[VAL_15]], %[[VAL_3]], [1 : index] : (complex<f32>, f32) -> complex<f32>
// CHECK: %[[VAL_17:.*]] = fir.insert_value %[[VAL_16]], %[[VAL_12]], [0 : index] : (complex<f32>, f32) -> complex<f32>
// CHECK: %[[VAL_18:.*]] = fir.convert %[[VAL_13]] : (complex<f16>) -> complex<f32>
// CHECK: %[[VAL_19:.*]] = fir.extract_value %[[VAL_17]], [1 : index] : (complex<f32>) -> f32
// CHECK: %[[VAL_20:.*]] = arith.negf %[[VAL_19]] : f32
// CHECK: %[[VAL_21:.*]] = fir.insert_value %[[VAL_17]], %[[VAL_20]], [1 : index] : (complex<f32>, f32) -> complex<f32>
// CHECK: %[[VAL_22:.*]] = fir.mulc %[[VAL_21]], %[[VAL_18]] : complex<f32>
// CHECK: %[[VAL_23:.*]] = fir.addc %[[VAL_11]], %[[VAL_22]] : complex<f32>
// CHECK: fir.result %[[VAL_23]] : complex<f32>
// CHECK: }
// CHECK: return %[[VAL_9]] : complex<f32>
// CHECK: }
func.func @dot_product_logical(%arg0: !hlfir.expr<?x!fir.logical<1>>, %arg1: !hlfir.expr<?x!fir.logical<4>>) -> !fir.logical<4> {
%res = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr<?x!fir.logical<1>>, !hlfir.expr<?x!fir.logical<4>>) -> !fir.logical<4>
return %res : !fir.logical<4>
}
// CHECK-LABEL: func.func @dot_product_logical(
// CHECK-SAME: %[[VAL_0:.*]]: !hlfir.expr<?x!fir.logical<1>>,
// CHECK-SAME: %[[VAL_1:.*]]: !hlfir.expr<?x!fir.logical<4>>) -> !fir.logical<4> {
// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_3:.*]] = arith.constant false
// CHECK: %[[VAL_4:.*]] = hlfir.shape_of %[[VAL_0]] : (!hlfir.expr<?x!fir.logical<1>>) -> !fir.shape<1>
// CHECK: %[[VAL_5:.*]] = hlfir.get_extent %[[VAL_4]] {dim = 0 : index} : (!fir.shape<1>) -> index
// CHECK: %[[VAL_6:.*]] = fir.convert %[[VAL_3]] : (i1) -> !fir.logical<4>
// CHECK: %[[VAL_7:.*]] = fir.do_loop %[[VAL_8:.*]] = %[[VAL_2]] to %[[VAL_5]] step %[[VAL_2]] unordered iter_args(%[[VAL_9:.*]] = %[[VAL_6]]) -> (!fir.logical<4>) {
// CHECK: %[[VAL_10:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_8]] : (!hlfir.expr<?x!fir.logical<1>>, index) -> !fir.logical<1>
// CHECK: %[[VAL_11:.*]] = hlfir.apply %[[VAL_1]], %[[VAL_8]] : (!hlfir.expr<?x!fir.logical<4>>, index) -> !fir.logical<4>
// CHECK: %[[VAL_12:.*]] = fir.convert %[[VAL_9]] : (!fir.logical<4>) -> i1
// CHECK: %[[VAL_13:.*]] = fir.convert %[[VAL_10]] : (!fir.logical<1>) -> i1
// CHECK: %[[VAL_14:.*]] = fir.convert %[[VAL_11]] : (!fir.logical<4>) -> i1
// CHECK: %[[VAL_15:.*]] = arith.andi %[[VAL_13]], %[[VAL_14]] : i1
// CHECK: %[[VAL_16:.*]] = arith.ori %[[VAL_12]], %[[VAL_15]] : i1
// CHECK: %[[VAL_17:.*]] = fir.convert %[[VAL_16]] : (i1) -> !fir.logical<4>
// CHECK: fir.result %[[VAL_17]] : !fir.logical<4>
// CHECK: }
// CHECK: return %[[VAL_7]] : !fir.logical<4>
// CHECK: }
func.func @dot_product_known_dim(%arg0: !hlfir.expr<10xf32>, %arg1: !hlfir.expr<?xi16>) -> f32 {
%res1 = hlfir.dot_product %arg0 %arg1 : (!hlfir.expr<10xf32>, !hlfir.expr<?xi16>) -> f32
%res2 = hlfir.dot_product %arg1 %arg0 : (!hlfir.expr<?xi16>, !hlfir.expr<10xf32>) -> f32
%res = arith.addf %res1, %res2 : f32
return %res : f32
}
// CHECK-LABEL: func.func @dot_product_known_dim(
// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_4:.*]] = arith.constant 10 : index
// CHECK: fir.do_loop %{{.*}} = %[[VAL_2]] to %[[VAL_4]] step %[[VAL_2]]
// CHECK: fir.do_loop %{{.*}} = %[[VAL_2]] to %[[VAL_4]] step %[[VAL_2]]