Pre-commiting this before landing the new check in https://github.com/llvm/llvm-project/pull/177892
85 lines
2.9 KiB
C++
85 lines
2.9 KiB
C++
//===- LowerNontemporal.cpp -------------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// Add nontemporal attributes to load and stores of variables marked as
|
|
// nontemporal.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "flang/Optimizer/Dialect/FIRCG/CGOps.h"
|
|
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
|
|
#include "flang/Optimizer/OpenMP/Passes.h"
|
|
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace flangomp {
|
|
#define GEN_PASS_DEF_LOWERNONTEMPORALPASS
|
|
#include "flang/Optimizer/OpenMP/Passes.h.inc"
|
|
} // namespace flangomp
|
|
|
|
namespace {
|
|
class LowerNontemporalPass
|
|
: public flangomp::impl::LowerNontemporalPassBase<LowerNontemporalPass> {
|
|
void addNonTemporalAttr(omp::SimdOp simdOp) {
|
|
if (simdOp.getNontemporalVars().empty())
|
|
return;
|
|
|
|
std::function<mlir::Value(mlir::Value)> getBaseOperand =
|
|
[&](mlir::Value operand) -> mlir::Value {
|
|
auto *defOp = operand.getDefiningOp();
|
|
while (defOp) {
|
|
llvm::TypeSwitch<Operation *>(defOp)
|
|
.Case<fir::ArrayCoorOp, fir::cg::XArrayCoorOp, fir::LoadOp>(
|
|
[&](auto op) {
|
|
operand = op.getMemref();
|
|
defOp = operand.getDefiningOp();
|
|
})
|
|
.Case([&](fir::BoxAddrOp op) {
|
|
operand = op.getVal();
|
|
defOp = operand.getDefiningOp();
|
|
})
|
|
.Default([&](auto op) { defOp = nullptr; });
|
|
}
|
|
return operand;
|
|
};
|
|
|
|
// walk through the operations and mark the load and store as nontemporal
|
|
simdOp->walk([&](Operation *op) {
|
|
mlir::Value operand = nullptr;
|
|
|
|
if (auto loadOp = llvm::dyn_cast<fir::LoadOp>(op))
|
|
operand = loadOp.getMemref();
|
|
else if (auto storeOp = llvm::dyn_cast<fir::StoreOp>(op))
|
|
operand = storeOp.getMemref();
|
|
|
|
// Skip load and store operations involving boxes (allocatable or pointer
|
|
// types).
|
|
if (operand && !(fir::isAllocatableType(operand.getType()) ||
|
|
fir::isPointerType((operand.getType())))) {
|
|
operand = getBaseOperand(operand);
|
|
|
|
// TODO : Handling of nontemporal clause inside atomic construct
|
|
if (llvm::is_contained(simdOp.getNontemporalVars(), operand)) {
|
|
if (auto loadOp = llvm::dyn_cast<fir::LoadOp>(op))
|
|
loadOp.setNontemporal(true);
|
|
else if (auto storeOp = llvm::dyn_cast<fir::StoreOp>(op))
|
|
storeOp.setNontemporal(true);
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
void runOnOperation() override {
|
|
Operation *op = getOperation();
|
|
op->walk([&](omp::SimdOp simdOp) { addNonTemporalAttr(simdOp); });
|
|
}
|
|
};
|
|
} // namespace
|