[flang] add hlfir.all intrinsic
Adds a new HLFIR operation for the ALL intrinsic according to the design set out in flang/docs/HighLevel.md Differential Revision: https://reviews.llvm.org/D151090
This commit is contained in:
parent
544a240ff7
commit
206b8538a6
@ -317,6 +317,27 @@ def hlfir_ConcatOp : hlfir_Op<"concat", []> {
|
|||||||
let hasVerifier = 1;
|
let hasVerifier = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def hlfir_AllOp : hlfir_Op<"all", []> {
|
||||||
|
let summary = "ALL transformational intrinsic";
|
||||||
|
let description = [{
|
||||||
|
Takes a logical array MASK as argument, optionally along a particular dimension,
|
||||||
|
and returns true if all elements of MASK are true.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
AnyFortranLogicalArrayObject:$mask,
|
||||||
|
Optional<AnyIntegerType>:$dim
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs AnyFortranValue);
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$mask (`dim` $dim^)? attr-dict `:` functional-type(operands, results)
|
||||||
|
}];
|
||||||
|
|
||||||
|
let hasVerifier = 1;
|
||||||
|
}
|
||||||
|
|
||||||
def hlfir_AnyOp : hlfir_Op<"any", []> {
|
def hlfir_AnyOp : hlfir_Op<"any", []> {
|
||||||
let summary = "ANY transformational intrinsic";
|
let summary = "ANY transformational intrinsic";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
@ -442,16 +442,19 @@ mlir::LogicalResult hlfir::ParentComponentOp::verify() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AnyOp
|
// LogicalReductionOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
mlir::LogicalResult hlfir::AnyOp::verify() {
|
template <typename LogicalReductionOp>
|
||||||
mlir::Operation *op = getOperation();
|
static mlir::LogicalResult
|
||||||
|
verifyLogicalReductionOp(LogicalReductionOp reductionOp) {
|
||||||
|
mlir::Operation *op = reductionOp->getOperation();
|
||||||
|
|
||||||
auto results = op->getResultTypes();
|
auto results = op->getResultTypes();
|
||||||
assert(results.size() == 1);
|
assert(results.size() == 1);
|
||||||
|
|
||||||
mlir::Value mask = getMask();
|
mlir::Value mask = reductionOp->getMask();
|
||||||
mlir::Value dim = getDim();
|
mlir::Value dim = reductionOp->getDim();
|
||||||
|
|
||||||
fir::SequenceType maskTy =
|
fir::SequenceType maskTy =
|
||||||
hlfir::getFortranElementOrSequenceType(mask.getType())
|
hlfir::getFortranElementOrSequenceType(mask.getType())
|
||||||
.cast<fir::SequenceType>();
|
.cast<fir::SequenceType>();
|
||||||
@ -462,7 +465,7 @@ mlir::LogicalResult hlfir::AnyOp::verify() {
|
|||||||
if (mlir::isa<fir::LogicalType>(resultType)) {
|
if (mlir::isa<fir::LogicalType>(resultType)) {
|
||||||
// Result is of the same type as MASK
|
// Result is of the same type as MASK
|
||||||
if (resultType != logicalTy)
|
if (resultType != logicalTy)
|
||||||
return emitOpError(
|
return reductionOp->emitOpError(
|
||||||
"result must have the same element type as MASK argument");
|
"result must have the same element type as MASK argument");
|
||||||
|
|
||||||
} else if (auto resultExpr =
|
} else if (auto resultExpr =
|
||||||
@ -470,25 +473,42 @@ mlir::LogicalResult hlfir::AnyOp::verify() {
|
|||||||
// Result should only be in hlfir.expr form if it is an array
|
// Result should only be in hlfir.expr form if it is an array
|
||||||
if (maskShape.size() > 1 && dim != nullptr) {
|
if (maskShape.size() > 1 && dim != nullptr) {
|
||||||
if (!resultExpr.isArray())
|
if (!resultExpr.isArray())
|
||||||
return emitOpError("result must be an array");
|
return reductionOp->emitOpError("result must be an array");
|
||||||
|
|
||||||
if (resultExpr.getEleTy() != logicalTy)
|
if (resultExpr.getEleTy() != logicalTy)
|
||||||
return emitOpError(
|
return reductionOp->emitOpError(
|
||||||
"result must have the same element type as MASK argument");
|
"result must have the same element type as MASK argument");
|
||||||
|
|
||||||
llvm::ArrayRef<int64_t> resultShape = resultExpr.getShape();
|
llvm::ArrayRef<int64_t> resultShape = resultExpr.getShape();
|
||||||
// Result has rank n-1
|
// Result has rank n-1
|
||||||
if (resultShape.size() != (maskShape.size() - 1))
|
if (resultShape.size() != (maskShape.size() - 1))
|
||||||
return emitOpError("result rank must be one less than MASK");
|
return reductionOp->emitOpError(
|
||||||
|
"result rank must be one less than MASK");
|
||||||
} else {
|
} else {
|
||||||
return emitOpError("result must be of logical type");
|
return reductionOp->emitOpError("result must be of logical type");
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return emitOpError("result must be of logical type");
|
return reductionOp->emitOpError("result must be of logical type");
|
||||||
}
|
}
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AllOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
mlir::LogicalResult hlfir::AllOp::verify() {
|
||||||
|
return verifyLogicalReductionOp<hlfir::AllOp *>(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AnyOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
mlir::LogicalResult hlfir::AnyOp::verify() {
|
||||||
|
return verifyLogicalReductionOp<hlfir::AnyOp *>(this);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ConcatOp
|
// ConcatOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -537,11 +557,12 @@ void hlfir::ConcatOp::build(mlir::OpBuilder &builder,
|
|||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ReductionOp
|
// NumericalReductionOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
template <typename ReductionOp>
|
template <typename NumericalReductionOp>
|
||||||
static mlir::LogicalResult verifyReductionOp(ReductionOp reductionOp) {
|
static mlir::LogicalResult
|
||||||
|
verifyNumericalReductionOp(NumericalReductionOp reductionOp) {
|
||||||
mlir::Operation *op = reductionOp->getOperation();
|
mlir::Operation *op = reductionOp->getOperation();
|
||||||
|
|
||||||
auto results = op->getResultTypes();
|
auto results = op->getResultTypes();
|
||||||
@ -619,7 +640,7 @@ static mlir::LogicalResult verifyReductionOp(ReductionOp reductionOp) {
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
mlir::LogicalResult hlfir::ProductOp::verify() {
|
mlir::LogicalResult hlfir::ProductOp::verify() {
|
||||||
return verifyReductionOp<hlfir::ProductOp *>(this);
|
return verifyNumericalReductionOp<hlfir::ProductOp *>(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -645,7 +666,7 @@ void hlfir::SetLengthOp::build(mlir::OpBuilder &builder,
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
mlir::LogicalResult hlfir::SumOp::verify() {
|
mlir::LogicalResult hlfir::SumOp::verify() {
|
||||||
return verifyReductionOp<hlfir::SumOp *>(this);
|
return verifyNumericalReductionOp<hlfir::SumOp *>(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
113
flang/test/HLFIR/all.fir
Normal file
113
flang/test/HLFIR/all.fir
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
// Test hlfir.all operation parse, verify (no errors), and unparse
|
||||||
|
|
||||||
|
// RUN: fir-opt %s | fir-opt | FileCheck %s
|
||||||
|
|
||||||
|
// mask is an expression of known shape
|
||||||
|
func.func @all0(%arg0: !hlfir.expr<2x!fir.logical<4>>) {
|
||||||
|
%all = hlfir.all %arg0 : (!hlfir.expr<2x!fir.logical<4>>) -> !fir.logical<4>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK: func.func @all0(%[[ARRAY:.*]]: !hlfir.expr<2x!fir.logical<4>>) {
|
||||||
|
// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!hlfir.expr<2x!fir.logical<4>>) -> !fir.logical<4>
|
||||||
|
// CHECK-NEXT: return
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
|
||||||
|
// mask is an expression of assumed shape
|
||||||
|
func.func @all1(%arg0: !hlfir.expr<?x!fir.logical<4>>) {
|
||||||
|
%all = hlfir.all %arg0 : (!hlfir.expr<?x!fir.logical<4>>) -> !fir.logical<4>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK: func.func @all1(%[[ARRAY:.*]]: !hlfir.expr<?x!fir.logical<4>>) {
|
||||||
|
// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!hlfir.expr<?x!fir.logical<4>>) -> !fir.logical<4>
|
||||||
|
// CHECK-NEXT: return
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
|
||||||
|
// mask is a boxed array
|
||||||
|
func.func @all2(%arg0: !fir.box<!fir.array<2x!fir.logical<4>>>) {
|
||||||
|
%all = hlfir.all %arg0 : (!fir.box<!fir.array<2x!fir.logical<4>>>) -> !fir.logical<4>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK: func.func @all2(%[[ARRAY:.*]]: !fir.box<!fir.array<2x!fir.logical<4>>>) {
|
||||||
|
// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!fir.box<!fir.array<2x!fir.logical<4>>>) -> !fir.logical<4>
|
||||||
|
// CHECK-NEXT: return
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
|
||||||
|
// mask is an assumed shape boxed array
|
||||||
|
func.func @all3(%arg0: !fir.box<!fir.array<?x!fir.logical<4>>>){
|
||||||
|
%all = hlfir.all %arg0 : (!fir.box<!fir.array<?x!fir.logical<4>>>) -> !fir.logical<4>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK: func.func @all3(%[[ARRAY:.*]]: !fir.box<!fir.array<?x!fir.logical<4>>>) {
|
||||||
|
// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!fir.box<!fir.array<?x!fir.logical<4>>>) -> !fir.logical<4>
|
||||||
|
// CHECK-NEXT: return
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
|
||||||
|
// mask is a 2-dimensional array
|
||||||
|
func.func @all4(%arg0: !fir.box<!fir.array<?x?x!fir.logical<4>>>){
|
||||||
|
%all = hlfir.all %arg0 : (!fir.box<!fir.array<?x?x!fir.logical<4>>>) -> !fir.logical<4>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK: func.func @all4(%[[ARRAY:.*]]: !fir.box<!fir.array<?x?x!fir.logical<4>>>) {
|
||||||
|
// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!fir.box<!fir.array<?x?x!fir.logical<4>>>) -> !fir.logical<4>
|
||||||
|
// CHECK-NEXT: return
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
|
||||||
|
// mask and dim argument
|
||||||
|
func.func @all5(%arg0: !fir.box<!fir.array<?x!fir.logical<4>>>, %arg1: i32) {
|
||||||
|
%all = hlfir.all %arg0 dim %arg1 : (!fir.box<!fir.array<?x!fir.logical<4>>>, i32) -> !fir.logical<4>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK: func.func @all5(%[[ARRAY:.*]]: !fir.box<!fir.array<?x!fir.logical<4>>>, %[[DIM:.*]]: i32) {
|
||||||
|
// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] dim %[[DIM]] : (!fir.box<!fir.array<?x!fir.logical<4>>>, i32) -> !fir.logical<4>
|
||||||
|
// CHECK-NEXT: return
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
|
||||||
|
// hlfir.all with dim argument with an unusual type
|
||||||
|
func.func @all6(%arg0: !fir.box<!fir.array<?x!fir.logical<4>>>, %arg1: index) {
|
||||||
|
%all = hlfir.all %arg0 dim %arg1 : (!fir.box<!fir.array<?x!fir.logical<4>>>, index) ->!fir.logical<4>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK: func.func @all6(%[[ARRAY:.*]]: !fir.box<!fir.array<?x!fir.logical<4>>>, %[[DIM:.*]]: index) {
|
||||||
|
// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] dim %[[DIM]] : (!fir.box<!fir.array<?x!fir.logical<4>>>, index) -> !fir.logical<4>
|
||||||
|
// CHECK-NEXT: return
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
|
||||||
|
// mask is a 2 dimensional array with dim
|
||||||
|
func.func @all7(%arg0: !fir.box<!fir.array<?x?x!fir.logical<4>>>, %arg1: i32) {
|
||||||
|
%all = hlfir.all %arg0 dim %arg1 : (!fir.box<!fir.array<?x?x!fir.logical<4>>>, i32) -> !hlfir.expr<?x!fir.logical<4>>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK: func.func @all7(%[[ARRAY:.*]]: !fir.box<!fir.array<?x?x!fir.logical<4>>>, %[[DIM:.*]]: i32) {
|
||||||
|
// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] dim %[[DIM]] : (!fir.box<!fir.array<?x?x!fir.logical<4>>>, i32) -> !hlfir.expr<?x!fir.logical<4>>
|
||||||
|
// CHECK-NEXT: return
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
|
||||||
|
// known shape expr return
|
||||||
|
func.func @all8(%arg0: !fir.box<!fir.array<2x2x!fir.logical<4>>>, %arg1: i32) {
|
||||||
|
%all = hlfir.all %arg0 dim %arg1 : (!fir.box<!fir.array<2x2x!fir.logical<4>>>, i32) -> !hlfir.expr<2x!fir.logical<4>>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK: func.func @all8(%[[ARRAY:.*]]: !fir.box<!fir.array<2x2x!fir.logical<4>>>, %[[DIM:.*]]: i32) {
|
||||||
|
// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] dim %[[DIM]] : (!fir.box<!fir.array<2x2x!fir.logical<4>>>, i32) -> !hlfir.expr<2x!fir.logical<4>>
|
||||||
|
// CHECK-NEXT: return
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
|
||||||
|
// hlfir.all with mask argument of ref<array<>> type
|
||||||
|
func.func @all9(%arg0: !fir.ref<!fir.array<?x!fir.logical<4>>>) {
|
||||||
|
%all = hlfir.all %arg0 : (!fir.ref<!fir.array<?x!fir.logical<4>>>) -> !fir.logical<4>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK: func.func @all9(%[[ARRAY:.*]]: !fir.ref<!fir.array<?x!fir.logical<4>>>) {
|
||||||
|
// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!fir.ref<!fir.array<?x!fir.logical<4>>>) -> !fir.logical<4>
|
||||||
|
// CHECK-NEXT: return
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
|
||||||
|
// hlfir.all with fir.logical<8> type
|
||||||
|
func.func @all10(%arg0: !fir.box<!fir.array<?x!fir.logical<8>>>) {
|
||||||
|
%all = hlfir.all %arg0 : (!fir.box<!fir.array<?x!fir.logical<8>>>) -> !fir.logical<8>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK: func.func @all10(%[[ARRAY:.*]]: !fir.box<!fir.array<?x!fir.logical<8>>>) {
|
||||||
|
// CHECK-NEXT: %[[ALL:.*]] = hlfir.all %[[ARRAY]] : (!fir.box<!fir.array<?x!fir.logical<8>>>) -> !fir.logical<8>
|
||||||
|
// CHECK-NEXT: return
|
||||||
|
// CHECK-NEXT: }
|
@ -332,6 +332,42 @@ func.func @bad_any6(%arg0: !hlfir.expr<?x!fir.logical<4>>) {
|
|||||||
%0 = hlfir.any %arg0 : (!hlfir.expr<?x!fir.logical<4>>) -> !hlfir.expr<!fir.logical<4>>
|
%0 = hlfir.any %arg0 : (!hlfir.expr<?x!fir.logical<4>>) -> !hlfir.expr<!fir.logical<4>>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
func.func @bad_all1(%arg0: !hlfir.expr<?x!fir.logical<4>>) {
|
||||||
|
// expected-error@+1 {{'hlfir.all' op result must have the same element type as MASK argument}}
|
||||||
|
%0 = hlfir.all %arg0 : (!hlfir.expr<?x!fir.logical<4>>) -> !fir.logical<8>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
func.func @bad_all2(%arg0: !hlfir.expr<?x?x!fir.logical<4>>, %arg1: i32) {
|
||||||
|
// expected-error@+1 {{'hlfir.all' op result must have the same element type as MASK argument}}
|
||||||
|
%0 = hlfir.all %arg0 dim %arg1 : (!hlfir.expr<?x?x!fir.logical<4>>, i32) -> !hlfir.expr<?x!fir.logical<8>>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
func.func @bad_all3(%arg0: !hlfir.expr<?x?x!fir.logical<4>>, %arg1: i32){
|
||||||
|
// expected-error@+1 {{'hlfir.all' op result rank must be one less than MASK}}
|
||||||
|
%0 = hlfir.all %arg0 dim %arg1 : (!hlfir.expr<?x?x!fir.logical<4>>, i32) -> !hlfir.expr<?x?x!fir.logical<4>>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
func.func @bad_all4(%arg0: !hlfir.expr<?x?x!fir.logical<4>>, %arg1: i32) {
|
||||||
|
// expected-error@+1 {{'hlfir.all' op result must be an array}}
|
||||||
|
%0 = hlfir.all %arg0 dim %arg1 : (!hlfir.expr<?x?x!fir.logical<4>>, i32) -> !hlfir.expr<!fir.logical<4>>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
func.func @bad_all5(%arg0: !hlfir.expr<?x!fir.logical<4>>) {
|
||||||
|
// expected-error@+1 {{'hlfir.all' op result must be of logical type}}
|
||||||
|
%0 = hlfir.all %arg0 : (!hlfir.expr<?x!fir.logical<4>>) -> i32
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
func.func @bad_all6(%arg0: !hlfir.expr<?x!fir.logical<4>>) {
|
||||||
|
// expected-error@+1 {{'hlfir.all' op result must be of logical type}}
|
||||||
|
%0 = hlfir.all %arg0 : (!hlfir.expr<?x!fir.logical<4>>) -> !hlfir.expr<!fir.logical<4>>
|
||||||
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
func.func @bad_product1(%arg0: !hlfir.expr<?xi32>, %arg1: i32, %arg2: !fir.box<!fir.logical<4>>) {
|
func.func @bad_product1(%arg0: !hlfir.expr<?xi32>, %arg1: i32, %arg2: !fir.box<!fir.logical<4>>) {
|
||||||
// expected-error@+1 {{'hlfir.product' op result must have the same element type as ARRAY argument}}
|
// expected-error@+1 {{'hlfir.product' op result must have the same element type as ARRAY argument}}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user