[flang] Reset all extents to zero for empty hlfir.elemental loops. (#124867)

An hlfir.elemental with a shape `(0, HUGE)` still runs `HUGE`
number of iterations when expanded into a loop nest.
HLFIR transformational operations inlined as hlfir.elemental
may execute slower comparing to Fortran runtime implementation.
This patch adds an option for BufferizeHLFIR pass to reset all
upper bounds in the elemental loop nests to zero, if the result
is an empty array.

A separate patch will enable this option in the driver after I do
more performance testing. The option is off by default now.
This commit is contained in:
Slava Zakharin 2025-01-29 12:03:05 -08:00 committed by GitHub
parent b8708753c8
commit bac9575274
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 119 additions and 9 deletions

View File

@ -813,6 +813,18 @@ uint64_t getAllocaAddressSpace(mlir::DataLayout *dataLayout);
llvm::SmallVector<mlir::Value> deduceOptimalExtents(mlir::ValueRange extents1,
mlir::ValueRange extents2);
/// Given array extents generate code that sets them all to zeroes,
/// if the array is empty, e.g.:
/// %false = arith.constant false
/// %c0 = arith.constant 0 : index
/// %p1 = arith.cmpi eq, %e0, %c0 : index
/// %p2 = arith.ori %false, %p1 : i1
/// %p3 = arith.cmpi eq, %e1, %c0 : index
/// %p4 = arith.ori %p1, %p2 : i1
/// %result0 = arith.select %p4, %c0, %e0 : index
/// %result1 = arith.select %p4, %c0, %e1 : index
llvm::SmallVector<mlir::Value> updateRuntimeExtentsForEmptyArrays(
fir::FirOpBuilder &builder, mlir::Location loc, mlir::ValueRange extents);
} // namespace fir::factory
#endif // FORTRAN_OPTIMIZER_BUILDER_FIRBUILDER_H

View File

@ -19,6 +19,11 @@ def ConvertHLFIRtoFIR : Pass<"convert-hlfir-to-fir", "::mlir::ModuleOp"> {
def BufferizeHLFIR : Pass<"bufferize-hlfir", "::mlir::ModuleOp"> {
let summary = "Convert HLFIR operations operating on hlfir.expr into operations on memory";
let options = [Option<"optimizeEmptyElementals", "opt-empty-elementals",
"bool", /*default=*/"false",
"When converting hlfir.elemental into a loop nest, "
"check if the resulting expression is an empty array, "
"and make sure none of the loops is executed.">];
}
def OptimizedBufferization : Pass<"opt-bufferization"> {

View File

@ -1759,3 +1759,29 @@ fir::factory::deduceOptimalExtents(mlir::ValueRange extents1,
}
return extents;
}
llvm::SmallVector<mlir::Value> fir::factory::updateRuntimeExtentsForEmptyArrays(
fir::FirOpBuilder &builder, mlir::Location loc, mlir::ValueRange extents) {
if (extents.size() <= 1)
return extents;
mlir::Type i1Type = builder.getI1Type();
mlir::Value isEmpty = createZeroValue(builder, loc, i1Type);
llvm::SmallVector<mlir::Value, Fortran::common::maxRank> zeroes;
for (mlir::Value extent : extents) {
mlir::Type type = extent.getType();
mlir::Value zero = createZeroValue(builder, loc, type);
zeroes.push_back(zero);
mlir::Value isZero = builder.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::eq, extent, zero);
isEmpty = builder.create<mlir::arith::OrIOp>(loc, isEmpty, isZero);
}
llvm::SmallVector<mlir::Value, Fortran::common::maxRank> newExtents;
for (auto [zero, extent] : llvm::zip_equal(zeroes, extents)) {
newExtents.push_back(
builder.create<mlir::arith::SelectOp>(loc, isEmpty, zero, extent));
}
return newExtents;
}

View File

@ -761,8 +761,10 @@ struct HLFIRListener : public mlir::OpBuilder::Listener {
struct ElementalOpConversion
: public mlir::OpConversionPattern<hlfir::ElementalOp> {
using mlir::OpConversionPattern<hlfir::ElementalOp>::OpConversionPattern;
explicit ElementalOpConversion(mlir::MLIRContext *ctx)
: mlir::OpConversionPattern<hlfir::ElementalOp>{ctx} {
explicit ElementalOpConversion(mlir::MLIRContext *ctx,
bool optimizeEmptyElementals = false)
: mlir::OpConversionPattern<hlfir::ElementalOp>{ctx},
optimizeEmptyElementals(optimizeEmptyElementals) {
// This pattern recursively converts nested ElementalOp's
// by cloning and then converting them, so we have to allow
// for recursive pattern application. The recursion is bounded
@ -791,6 +793,10 @@ struct ElementalOpConversion
// of the loop nest.
temp = derefPointersAndAllocatables(loc, builder, temp);
if (optimizeEmptyElementals)
extents = fir::factory::updateRuntimeExtentsForEmptyArrays(builder, loc,
extents);
// Generate a loop nest looping around the fir.elemental shape and clone
// fir.elemental region inside the inner loop.
hlfir::LoopNest loopNest =
@ -861,6 +867,9 @@ struct ElementalOpConversion
rewriter.replaceOp(elemental, bufferizedExpr);
return mlir::success();
}
private:
bool optimizeEmptyElementals = false;
};
struct CharExtremumOpConversion
: public mlir::OpConversionPattern<hlfir::CharExtremumOp> {
@ -932,6 +941,8 @@ struct EvaluateInMemoryOpConversion
class BufferizeHLFIR : public hlfir::impl::BufferizeHLFIRBase<BufferizeHLFIR> {
public:
using BufferizeHLFIRBase<BufferizeHLFIR>::BufferizeHLFIRBase;
void runOnOperation() override {
// TODO: make this a pass operating on FuncOp. The issue is that
// FirOpBuilder helpers may generate new FuncOp because of runtime/llvm
@ -943,13 +954,13 @@ public:
auto module = this->getOperation();
auto *context = &getContext();
mlir::RewritePatternSet patterns(context);
patterns
.insert<ApplyOpConversion, AsExprOpConversion, AssignOpConversion,
AssociateOpConversion, CharExtremumOpConversion,
ConcatOpConversion, DestroyOpConversion, ElementalOpConversion,
EndAssociateOpConversion, EvaluateInMemoryOpConversion,
NoReassocOpConversion, SetLengthOpConversion,
ShapeOfOpConversion, GetLengthOpConversion>(context);
patterns.insert<ApplyOpConversion, AsExprOpConversion, AssignOpConversion,
AssociateOpConversion, CharExtremumOpConversion,
ConcatOpConversion, DestroyOpConversion,
EndAssociateOpConversion, EvaluateInMemoryOpConversion,
NoReassocOpConversion, SetLengthOpConversion,
ShapeOfOpConversion, GetLengthOpConversion>(context);
patterns.insert<ElementalOpConversion>(context, optimizeEmptyElementals);
mlir::ConversionTarget target(*context);
// Note that YieldElementOp is not marked as an illegal operation.
// It must be erased by its parent converter and there is no explicit

View File

@ -0,0 +1,56 @@
// Test hlfir.elemental code generation with a dynamic check
// for empty result array
// RUN: fir-opt %s --bufferize-hlfir=opt-empty-elementals=true | FileCheck %s
func.func @test(%v: i32, %e0: i32, %e1: i32, %e2: i64, %e3: i64) {
%shape = fir.shape %e0, %e1, %e2, %e3 : (i32, i32, i64, i64) -> !fir.shape<4>
%result = hlfir.elemental %shape : (!fir.shape<4>) -> !hlfir.expr<?x?x?x?xi32> {
^bb0(%arg0: index, %arg1: index, %arg2: index, %arg3: index):
hlfir.yield_element %v : i32
}
return
}
// CHECK-LABEL: func.func @test(
// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32,
// CHECK-SAME: %[[VAL_3:.*]]: i64, %[[VAL_4:.*]]: i64) {
// CHECK: %[[VAL_5:.*]] = fir.shape %[[VAL_1]], %[[VAL_2]], %[[VAL_3]], %[[VAL_4]] : (i32, i32, i64, i64) -> !fir.shape<4>
// CHECK: %[[VAL_6:.*]] = fir.convert %[[VAL_1]] : (i32) -> index
// CHECK: %[[VAL_7:.*]] = fir.convert %[[VAL_2]] : (i32) -> index
// CHECK: %[[VAL_8:.*]] = fir.convert %[[VAL_3]] : (i64) -> index
// CHECK: %[[VAL_9:.*]] = fir.convert %[[VAL_4]] : (i64) -> index
// CHECK: %[[VAL_10:.*]] = fir.allocmem !fir.array<?x?x?x?xi32>, %[[VAL_6]], %[[VAL_7]], %[[VAL_8]], %[[VAL_9]] {bindc_name = ".tmp.array", uniq_name = ""}
// CHECK: %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_10]](%[[VAL_5]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<?x?x?x?xi32>>, !fir.shape<4>) -> (!fir.box<!fir.array<?x?x?x?xi32>>, !fir.heap<!fir.array<?x?x?x?xi32>>)
// CHECK: %[[VAL_12:.*]] = arith.constant true
// CHECK: %[[VAL_13:.*]] = arith.constant false
// CHECK: %[[C0_1:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_15:.*]] = arith.cmpi eq, %[[VAL_6]], %[[C0_1]] : index
// CHECK: %[[VAL_16:.*]] = arith.ori %[[VAL_13]], %[[VAL_15]] : i1
// CHECK: %[[C0_2:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_17:.*]] = arith.cmpi eq, %[[VAL_7]], %[[C0_2]] : index
// CHECK: %[[VAL_18:.*]] = arith.ori %[[VAL_16]], %[[VAL_17]] : i1
// CHECK: %[[C0_3:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_19:.*]] = arith.cmpi eq, %[[VAL_8]], %[[C0_3]] : index
// CHECK: %[[VAL_20:.*]] = arith.ori %[[VAL_18]], %[[VAL_19]] : i1
// CHECK: %[[C0_4:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_21:.*]] = arith.cmpi eq, %[[VAL_9]], %[[C0_4]] : index
// CHECK: %[[VAL_22:.*]] = arith.ori %[[VAL_20]], %[[VAL_21]] : i1
// CHECK: %[[VAL_23:.*]] = arith.select %[[VAL_22]], %[[C0_1]], %[[VAL_6]] : index
// CHECK: %[[VAL_24:.*]] = arith.select %[[VAL_22]], %[[C0_2]], %[[VAL_7]] : index
// CHECK: %[[VAL_25:.*]] = arith.select %[[VAL_22]], %[[C0_3]], %[[VAL_8]] : index
// CHECK: %[[VAL_26:.*]] = arith.select %[[VAL_22]], %[[C0_4]], %[[VAL_9]] : index
// CHECK: %[[VAL_27:.*]] = arith.constant 1 : index
// CHECK: fir.do_loop %[[VAL_28:.*]] = %[[VAL_27]] to %[[VAL_26]] step %[[VAL_27]] {
// CHECK: fir.do_loop %[[VAL_29:.*]] = %[[VAL_27]] to %[[VAL_25]] step %[[VAL_27]] {
// CHECK: fir.do_loop %[[VAL_30:.*]] = %[[VAL_27]] to %[[VAL_24]] step %[[VAL_27]] {
// CHECK: fir.do_loop %[[VAL_31:.*]] = %[[VAL_27]] to %[[VAL_23]] step %[[VAL_27]] {
// CHECK: %[[VAL_32:.*]] = hlfir.designate %[[VAL_11]]#0 (%[[VAL_31]], %[[VAL_30]], %[[VAL_29]], %[[VAL_28]]) : (!fir.box<!fir.array<?x?x?x?xi32>>, index, index, index, index) -> !fir.ref<i32>
// CHECK: hlfir.assign %[[VAL_0]] to %[[VAL_32]] temporary_lhs : i32, !fir.ref<i32>
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: %[[VAL_33:.*]] = fir.undefined tuple<!fir.box<!fir.array<?x?x?x?xi32>>, i1>
// CHECK: %[[VAL_34:.*]] = fir.insert_value %[[VAL_33]], %[[VAL_12]], [1 : index] : (tuple<!fir.box<!fir.array<?x?x?x?xi32>>, i1>, i1) -> tuple<!fir.box<!fir.array<?x?x?x?xi32>>, i1>
// CHECK: %[[VAL_35:.*]] = fir.insert_value %[[VAL_34]], %[[VAL_11]]#0, [0 : index] : (tuple<!fir.box<!fir.array<?x?x?x?xi32>>, i1>, !fir.box<!fir.array<?x?x?x?xi32>>) -> tuple<!fir.box<!fir.array<?x?x?x?xi32>>, i1>
// CHECK: return
// CHECK: }