[mlir][mesh] Add collective communication operations (#71960)

Add all-gather, all-reduce, all-to-all and reduce-scatter. These
operations have device mesh semantics.
This commit is contained in:
Boian Petkantchin 2023-11-21 06:50:24 -08:00 committed by GitHub
parent ac75171d41
commit 5f7c8c1068
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 1179 additions and 3 deletions

View File

@ -0,0 +1,43 @@
# 'mesh' Dialect
The `mesh` dialect contains a set of attributes, operations and interfaces that
are useful for representing sharding and communication on a device mesh
cluster.
[TOC]
## Collective Communication Operations
There are a number of operations in the Mesh dialect to facilitate
communication between devices in a mesh.
It is assumed that the user is familiar with collective operations.
[Wikipedia](https://en.wikipedia.org/wiki/Collective_operation) has a good
explanation.
The main addition is that the collectives in this dialect have mesh
semantics.
The operation attributes `mesh` and `mesh_axes` specifies a list of device mesh
axes that partition the devices into disjoint groups.
The collective operation is performed between devices in the same group.
Devices that have the same coordinates outside of axes `mesh_axes` are in the
same group.
For example if we have a device mesh of size `2x3x4x5` and the partition mesh
axes list is `[0, 1]` then devices are partitioned into the groups
`{ { (i, j, k, m) | 0<=i<2, 0<=j<3 } | 0<=k<4, 0<=m<5 }`.
Devices (1, 0, 2, 3) and (1, 1, 2, 3) will be in the same group.
Device (1, 0, 2, 4) will be in another group.
Some collective operations like all-to-all and all-gather care about the
order of devices.
The order of device in a device group is induced by the order of axes in
`mesh_axes`.
The axes are ordered from outer to inner.
If we have an axis list `[3, 1]` then device `(i, 1, k, 0)` will precede
both devices `(i, 0, k, 1)` and `(i, 2, k, 0)`.
## Operations
[include "Dialects/MeshOps.md"]
## Attributes
[include "Dialects/MeshAttributes.md"]

View File

@ -23,9 +23,7 @@ def Mesh_Dialect : Dialect {
let cppNamespace = "::mlir::mesh";
let description = [{
The `mesh` dialect contains a set of attributes, operations, interfaces that
are useful for representing sharding and communication on device mesh
cluster.
See [Mesh dialect documentation](mlir/docs/Dialects/Mesh.md).
}];
let dependentDialects = [
@ -49,6 +47,10 @@ def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor"
let cppNamespace = "::mlir::mesh";
}
def Mesh_PartialAttr : EnumAttr<Mesh_Dialect, Mesh_Partial, "partial"> {
let assemblyFormat = "`<` $value `>`";
}
// Mesh_IteratorType and Mesh_Partial are used to annotate different aspects of
// distributed tensors. Mesh_IteratorType annotates loops in an operation, while
// Mesh_Partial indicates whether a tensor is sharded on a specific dimension or

View File

@ -10,9 +10,12 @@
#define MLIR_DIALECT_MESH_IR_MESHOPS_H
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include <algorithm>
#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.h.inc"

View File

@ -13,6 +13,8 @@ include "mlir/Dialect/Mesh/IR/MeshBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/BuiltinTypes.td"
include "mlir/IR/CommonAttrConstraints.td"
include "mlir/IR/CommonTypeConstraints.td"
include "mlir/IR/SymbolInterfaces.td"
//===----------------------------------------------------------------------===//
@ -77,6 +79,18 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
$sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` $dim_sizes^)? `)`
attr-dict
}];
let extraClassDeclaration = [{
// The `dim_sizes` attribute may have size less than the rank of the mesh.
// Returns the shape of the mesh with missing trailing dimensions
// explicitly set as dynamic.
::mlir::SmallVector<int64_t> canonicalDimSizes();
template <typename OutIt>
void canonicalDimSizes(OutIt outIt) {
std::copy(getDimSizes().begin(), getDimSizes().end(), outIt);
std::fill_n(outIt, getRank() - getDimSizes().size(), 0);
}
}];
let hasVerifier = 1;
}
@ -171,4 +185,219 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
}];
}
//===----------------------------------------------------------------------===//
// collective communication ops
//===----------------------------------------------------------------------===//
class Mesh_CollectiveCommunicationOpBase<
string mnemonic, list<Trait> traits = []> :
Mesh_Op<mnemonic,
!listconcat(traits,
[DeclareOpInterfaceMethods<SymbolUserOpInterface>])> {
dag commonArgs = (ins
FlatSymbolRefAttr:$mesh,
DefaultValuedAttr<DenseI16ArrayAttr, "{}">:$mesh_axes
);
}
def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
SameOperandsAndResultElementType,
SameOperandsAndResultRank
]> {
let summary = "All-gather over a device mesh.";
let description = [{
Gathers along the `gather_axis` tensor axis.
Example:
```mlir
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
...
%1 = mesh.all_gather %0 on @mesh0 mesh_axes = [1] gather_axis = 1
: tensor<2x2xi8> -> tensor<2x4xi8>
```
Input:
```
+-------+-------+
device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1)
| 3 4 | 7 8 |
+-------+-------+
device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1)
| 11 12 | 15 16 |
+-------+-------+
```
Result:
```
gather tensor
axis 1
------------>
+-------------+
| 1 2 5 6 | <- devices (0, 0) and (0, 1)
| 3 4 7 8 |
+-------------+
| 9 10 13 14 | <- devices (1, 0) and (1, 1)
| 11 12 15 16 |
+-------------+
```
}];
let arguments = !con(commonArgs, (ins
AnyNon0RankedTensor:$input,
IndexAttr:$gather_axis
));
let results = (outs
AnyNon0RankedTensor:$result
);
let assemblyFormat = [{
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `gather_axis` `=` $gather_axis
attr-dict `:` type($input) `->` type($result)
}];
let hasCanonicalizer = 1;
}
def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
SameOperandsAndResultShape]> {
let summary = "All-reduce over a device mesh.";
let description = [{
The accumulation element type is specified by the result type and
it does not need to match the input element type.
The input element is converted to the result element type before
performing the reduction.
Attributes:
`reduction`: Indicates the reduction method.
Example:
```
%1 = mesh.all_reduce %0 on @mesh0 mesh_axes = [1, 0] reduction = <max>
: tensor<3x4xf32> -> tensor<3x4xf64>
```
}];
let arguments = !con(commonArgs, (ins
AnyRankedTensor:$input,
DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction
));
let results = (outs
AnyRankedTensor:$result
);
let assemblyFormat = [{
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? (`reduction` `=` $reduction^)?
attr-dict `:` type($input) `->` type($result)
}];
let hasCanonicalizer = 1;
}
def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
let summary = "All-to-all over a device mesh.";
let description = [{
Performs an all-to-all on tensor pieces split along `split_axis`.
The resulting pieces are concatenated along `concat_axis` on ech device.
Example:
```
mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
...
%1 = mesh.all_to_all %0 on @mesh0 mesh_axes = [0]
split_axis = 0 concat_axis = 0
: tensor<3x2xi8> -> tensor<3x2xi8>
```
Input:
```
device device device
(0) (1) (2)
+-------+-------+-------+ | split and concat along
| 11 12 | 21 22 | 31 32 | | tensor axis 0
| 13 14 | 23 24 | 33 34 |
| 15 16 | 25 26 | 35 36 |
+-------+-------+-------+
```
Result:
```
device device device
(0) (1) (2)
+-------+-------+-------+
| 11 12 | 13 14 | 15 16 |
| 21 22 | 23 24 | 25 26 |
| 31 32 | 33 34 | 35 36 |
+-------+-------+-------+
```
}];
let arguments = !con(commonArgs, (ins
AnyNon0RankedTensor:$input,
IndexAttr:$split_axis,
IndexAttr:$concat_axis
));
let results = (outs
AnyNon0RankedTensor:$result
);
let assemblyFormat = [{
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
`split_axis` `=` $split_axis
`concat_axis` `=` $concat_axis
attr-dict `:` type($input) `->` type($result)
}];
let hasCanonicalizer = 1;
}
def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [
SameOperandsAndResultRank]> {
let summary = "Reduce-scatter over a device mesh.";
let description = [{
After the reduction, the result is scattered within each device group.
The tensor is split along `scatter_axis` and the pieces distributed
across the device group.
Example:
```
mesh.cluster @mesh0(rank = 1, dim_sizes = [2, 2])
...
%1 = mesh.reduce_scatter %0 on @mesh0 mesh_axes = [1]
reduction = <max> scatter_axis = 0
: tensor<3x4xf32> -> tensor<1x4xf64>
```
Input:
```
device
(0, 1)
+-------+-------+ | scatter tensor
device (0, 0) -> | 1 2 | 5 6 | | axis 0
| 3 4 | 7 8 |
+-------+-------+
device (1, 0) -> | 9 10 | 13 14 |
| 11 12 | 15 16 |
+-------+-------+
device
(1, 1)
```
Result:
```
+-------+
| 6 8 | <- devices (0, 0)
+-------+
| 10 12 | <- devices (0, 1)
+-------+
| 22 24 | <- devices (1, 0)
+-------+
| 26 28 | <- devices (1, 1)
+-------+
```
}];
let arguments = !con(commonArgs, (ins
AnyNon0RankedTensor:$input,
DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction,
IndexAttr:$scatter_axis
));
let results = (outs
AnyRankedTensor:$result
);
let assemblyFormat = [{
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
(`reduction` `=` $reduction^)?
`scatter_axis` `=` $scatter_axis
attr-dict `:` type($input) `->` type($result)
}];
let hasCanonicalizer = 1;
}
#endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD

View File

@ -8,10 +8,27 @@
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include <algorithm>
#include <functional>
#include <iterator>
#include <numeric>
#include <optional>
#include <string>
#include <utility>
#define DEBUG_TYPE "mesh-ops"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
@ -21,6 +38,60 @@ using namespace mlir::mesh;
#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.cpp.inc"
template <typename It>
static It canonicalizeSetAsArray(It begin, It end) {
llvm::sort(begin, end);
return std::unique(begin, end);
}
template <typename R>
static auto canonicalizeSetAsArray(R &&range) {
return canonicalizeSetAsArray(adl_begin(range), adl_end(range));
}
template <typename T>
static SmallVector<T> &canonicalizeSetAsVector(SmallVector<T> &vec) {
auto newEnd = canonicalizeSetAsArray(vec);
vec.resize(newEnd - vec.begin());
return vec;
}
template <typename DimSize>
static bool isMeshDimensionDynamic(DimSize size) {
return size <= DimSize(0);
}
using MeshAxis = int16_t;
namespace {
struct DimensionSize {
static DimensionSize dynamic() { return DimensionSize(ShapedType::kDynamic); }
DimensionSize(int64_t val) : val(val) {}
int64_t value() const { return val; }
operator int64_t() const { return val; }
bool isDynamic() const { return ShapedType::isDynamic(val); }
private:
int64_t val;
};
} // namespace
static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs) {
if (lhs.isDynamic() || rhs.isDynamic()) {
return DimensionSize::dynamic();
}
return lhs.value() / rhs.value();
}
static DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) {
if (lhs.isDynamic() || rhs.isDynamic()) {
return DimensionSize::dynamic();
}
return lhs.value() * rhs.value();
}
//===----------------------------------------------------------------------===//
// Mesh dialect
//===----------------------------------------------------------------------===//
@ -96,6 +167,13 @@ LogicalResult ClusterOp::verify() {
return success();
}
SmallVector<int64_t> ClusterOp::canonicalDimSizes() {
SmallVector<int64_t> result;
canonicalDimSizes(std::back_inserter(result));
result.reserve(getRank());
return result;
}
//===----------------------------------------------------------------------===//
// mesh.shard op
//===----------------------------------------------------------------------===//
@ -129,6 +207,327 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
//===----------------------------------------------------------------------===//
// collective communication ops
//===----------------------------------------------------------------------===//
namespace {
template <typename Op>
struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern<Op> {
using OpRewritePattern<Op>::OpRewritePattern;
LogicalResult matchAndRewrite(Op op,
PatternRewriter &rewriter) const override {
auto meshAxes = op.getMeshAxes();
if (!meshAxes.empty()) {
return failure();
}
if (op.getInput().getType() != op.getResult().getType()) {
return failure();
}
rewriter.replaceAllUsesWith(op.getResult(), op.getInput());
rewriter.eraseOp(op.getOperation());
return success();
}
};
} // namespace
static FailureOr<ClusterOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
SymbolTableCollection &symbolTable) {
mesh::ClusterOp mesh =
symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(op, meshSymbol);
if (!mesh) {
return op->emitError() << "Undefined required mesh symbol \""
<< meshSymbol.getValue() << "\".";
}
return mesh;
}
template <typename It>
bool isUnique(It begin, It end) {
if (begin == end) {
return true;
}
It next = std::next(begin);
if (next == end) {
return true;
}
for (; next != end; ++next, ++begin) {
if (*begin == *next) {
return false;
}
}
return true;
}
static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
ClusterOp mesh) {
SmallVector<MeshAxis> sorted = llvm::to_vector(axes);
llvm::sort(sorted);
if (!isUnique(sorted.begin(), sorted.end())) {
return emitError(loc) << "Mesh axes contains duplicate elements.";
}
MeshAxis rank = mesh.getRank();
for (auto axis : axes) {
if (axis >= rank || axis < 0) {
return emitError(loc)
<< "0-based mesh axis index " << axis
<< " is out of bounds. The referenced mesh \"" << mesh.getSymName()
<< "\" is of rank " << rank << ".";
}
}
return success();
}
template <typename Op>
static FailureOr<ClusterOp>
getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) {
auto mesh = ::getMesh(op.getOperation(), op.getMeshAttr(), symbolTable);
if (failed(mesh)) {
return failure();
}
if (failed(verifyMeshAxes(op.getLoc(), op.getMeshAxes(), mesh.value()))) {
return failure();
}
return mesh;
}
template <typename It>
static auto product(It begin, It end) {
using ElementType = std::decay_t<decltype(*begin)>;
return std::accumulate(begin, end, static_cast<ElementType>(1),
std::multiplies<ElementType>());
}
template <typename R>
static auto product(R &&range) {
return product(adl_begin(range), adl_end(range));
}
static int64_t collectiveDeviceGroupSize(ArrayRef<MeshAxis> meshAxes,
ArrayRef<int64_t> meshShape) {
int64_t res = 1;
for (MeshAxis axis : meshAxes) {
if (isMeshDimensionDynamic(meshShape[axis])) {
return ShapedType::kDynamic;
}
assert(size_t(axis) < meshShape.size());
res *= meshShape[axis];
}
return res;
}
static LogicalResult verifyDimensionCompatibility(Location loc,
int64_t expectedDimSize,
int64_t resultDimSize,
int64_t resultAxis) {
if (!ShapedType::isDynamic(resultDimSize) &&
expectedDimSize != resultDimSize) {
return emitError(loc) << "Dimension size mismatch for result axis "
<< resultAxis << ". Expected "
<< (ShapedType::isDynamic(expectedDimSize)
? Twine("dynamic")
: Twine(expectedDimSize))
<< ", but got " << resultDimSize << ".";
}
return success();
}
static LogicalResult verifyAllGatherOperandAndResultShape(
Value operand, Value result, int64_t gatherAxis,
ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
auto resultRank = result.getType().template cast<ShapedType>().getRank();
if (gatherAxis < 0 || gatherAxis >= resultRank) {
return emitError(result.getLoc())
<< "Gather axis " << gatherAxis << " is out of bounds [0, "
<< resultRank << ").";
}
ShapedType operandType = operand.getType().cast<ShapedType>();
ShapedType resultType = result.getType().cast<ShapedType>();
auto deviceGroupSize =
DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
auto expectedResultDimSize =
axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
if (failed(verifyDimensionCompatibility(
result.getLoc(), expectedResultDimSize, resultDimSize, axis))) {
return failure();
}
}
return success();
}
static LogicalResult verifyAllToAllOperandAndResultShape(
Value operand, Value result, int64_t splitAxis, int64_t concatAxis,
ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
ShapedType operandType = operand.getType().cast<ShapedType>();
ShapedType resultType = result.getType().cast<ShapedType>();
for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
if (failed(verifyDimensionCompatibility(
result.getLoc(), operandType.getDimSize(axis),
resultType.getDimSize(axis), axis))) {
return failure();
}
}
}
if (splitAxis == concatAxis) {
return success();
}
auto deviceGroupSize =
DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
DimensionSize expectedResultConcatDimSize =
operandConcatDimSize * deviceGroupSize;
DimensionSize expectedResultSplitDimSize =
operandSplitDimSize / deviceGroupSize;
if (!expectedResultSplitDimSize.isDynamic() &&
int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) {
expectedResultSplitDimSize = DimensionSize::dynamic();
}
if (failed(verifyDimensionCompatibility(
result.getLoc(), expectedResultConcatDimSize.value(),
resultType.getDimSize(concatAxis), concatAxis))) {
return failure();
}
if (failed(verifyDimensionCompatibility(
result.getLoc(), expectedResultSplitDimSize.value(),
resultType.getDimSize(splitAxis), splitAxis))) {
return failure();
}
return success();
}
static LogicalResult verifyReduceScatterOperandAndResultShape(
Value operand, Value result, int64_t scatterAxis,
ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
ShapedType operandType = operand.getType().cast<ShapedType>();
ShapedType resultType = result.getType().cast<ShapedType>();
for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
if (axis != scatterAxis) {
if (failed(verifyDimensionCompatibility(
result.getLoc(), operandType.getDimSize(axis),
resultType.getDimSize(axis), axis))) {
return failure();
}
}
}
auto deviceGroupSize =
DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
auto operandScatterDimSize =
DimensionSize(operandType.getDimSize(scatterAxis));
if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
int64_t(operandScatterDimSize) % int64_t(deviceGroupSize) != 0) {
return emitError(result.getLoc())
<< "Operand dimension size " << int64_t(operandScatterDimSize)
<< " is not divisible by collective device group size "
<< int64_t(deviceGroupSize) << " for scatter axis " << scatterAxis
<< ".";
}
DimensionSize expectedResultScatterDimSize =
operandScatterDimSize / deviceGroupSize;
if (failed(verifyDimensionCompatibility(
result.getLoc(), expectedResultScatterDimSize.value(),
resultType.getDimSize(scatterAxis), scatterAxis))) {
return failure();
}
return success();
}
//===----------------------------------------------------------------------===//
// mesh.all_gather op
//===----------------------------------------------------------------------===//
LogicalResult
AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
if (failed(mesh)) {
return failure();
}
auto gatherAxis = getGatherAxis().getSExtValue();
return verifyAllGatherOperandAndResultShape(getOperand(), getResult(),
gatherAxis, getMeshAxes(),
mesh.value().canonicalDimSizes());
}
void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context);
}
//===----------------------------------------------------------------------===//
// mesh.all_reduce op
//===----------------------------------------------------------------------===//
LogicalResult
AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return getMeshAndVerifyAxes(*this, symbolTable);
}
void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
}
//===----------------------------------------------------------------------===//
// mesh.all_to_all op
//===----------------------------------------------------------------------===//
LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
if (failed(mesh)) {
return failure();
}
return verifyAllToAllOperandAndResultShape(
getOperand(), getResult(), getSplitAxis().getSExtValue(),
getConcatAxis().getSExtValue(), getMeshAxes(),
mesh.value().canonicalDimSizes());
}
void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context);
}
//===----------------------------------------------------------------------===//
// mesh.reduce_scatter op
//===----------------------------------------------------------------------===//
LogicalResult
ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
if (failed(mesh)) {
return failure();
}
return verifyReduceScatterOperandAndResultShape(
getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
mesh.value().canonicalDimSizes());
}
void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context);
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//

View File

@ -0,0 +1,101 @@
// RUN: mlir-opt --canonicalize %s | FileCheck %s
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
// CHECK-LABEL: func @all_reduce_empty_mesh_axes
func.func @all_reduce_empty_mesh_axes(
// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
%arg0 : tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NOT: mesh.all_reduce
%0 = mesh.all_reduce %arg0 on @mesh0
mesh_axes = []
: tensor<4xf32> -> tensor<4xf32>
// CHECK: return %[[ARG]]
return %0 : tensor<4xf32>
}
// CHECK-LABEL: func @all_reduce_empty_mesh_axes_different_return_type
func.func @all_reduce_empty_mesh_axes_different_return_type(
%arg0 : tensor<4xf32>) -> tensor<4xf64> {
// CHECK: mesh.all_reduce
%0 = mesh.all_reduce %arg0 on @mesh0
// CHECK-NOT: mesh_axes
mesh_axes = []
: tensor<4xf32> -> tensor<4xf64>
return %0 : tensor<4xf64>
}
// CHECK-LABEL: func @all_reduce_default_reduction
func.func @all_reduce_default_reduction(
%arg0 : tensor<4xf32>) -> tensor<4xf64> {
%0 = mesh.all_reduce %arg0 on @mesh0
mesh_axes = [0]
// CHECK-NOT: reduction
reduction = <sum>
: tensor<4xf32> -> tensor<4xf64>
return %0 : tensor<4xf64>
}
// CHECK-LABEL: func @all_to_all_empty_mesh_axes
func.func @all_to_all_empty_mesh_axes(
// CHECK-SAME: %[[ARG:.*]]: tensor<8xf32>
%arg0 : tensor<8xf32>) -> tensor<8xf32> {
// CHECK-NOT: mesh.all_to_all
%0 = mesh.all_to_all %arg0 on @mesh0
mesh_axes = []
split_axis = 0
concat_axis = 0
: tensor<8xf32> -> tensor<8xf32>
// CHECK: return %[[ARG]]
return %0 : tensor<8xf32>
}
// CHECK-LABEL: func @all_gather_empty_mesh_axes
func.func @all_gather_empty_mesh_axes(
// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
%arg0 : tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NOT: mesh.all_gather
%0 = mesh.all_gather %arg0 on @mesh0
mesh_axes = []
gather_axis = 0
: tensor<4xf32> -> tensor<4xf32>
// CHECK: return %[[ARG]]
return %0 : tensor<4xf32>
}
// CHECK-LABEL: func @reduce_scatter_empty_mesh_axes
func.func @reduce_scatter_empty_mesh_axes(
// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
%arg0 : tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NOT: mesh.reduce_scatter
%0 = mesh.reduce_scatter %arg0 on @mesh0
mesh_axes = []
scatter_axis = 0
: tensor<4xf32> -> tensor<4xf32>
// CHECK: return %[[ARG]]
return %0 : tensor<4xf32>
}
// CHECK-LABEL: func @reduce_scatter_empty_mesh_axes_different_return_type
func.func @reduce_scatter_empty_mesh_axes_different_return_type(
%arg0 : tensor<4xf32>) -> tensor<4xf64> {
// CHECK: mesh.reduce_scatter
%0 = mesh.reduce_scatter %arg0 on @mesh0
// CHECK-NOT: mesh_axes
mesh_axes = []
scatter_axis = 0
: tensor<4xf32> -> tensor<4xf64>
return %0 : tensor<4xf64>
}
// CHECK-LABEL: func @reduce_scatter_default_reduction
func.func @reduce_scatter_default_reduction(
%arg0 : tensor<4xf32>) -> tensor<2xf64> {
%0 = mesh.reduce_scatter %arg0 on @mesh0
mesh_axes = [0]
// CHECK-NOT: reduction
reduction = <sum>
scatter_axis = 0
: tensor<4xf32> -> tensor<2xf64>
return %0 : tensor<2xf64>
}

View File

@ -67,3 +67,279 @@ func.func @mesh_axis_negtive_in_partial(
tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[-1]>> {
return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[-1]>>
}
// -----
func.func @all_reduce_invalid_mesh_symbol(
%arg0 : tensor<4xf32>) -> tensor<4xf64> {
// expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
%0 = mesh.all_reduce %arg0 on @this_mesh_symbol_does_not_exist reduction = <sum>
: tensor<4xf32> -> tensor<4xf64>
return %0 : tensor<4xf64>
}
// -----
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
func.func @all_reduce_invalid_mesh_axis(
%arg0 : tensor<4xf32>) -> tensor<4xf64> {
// expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
%0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [2] reduction = <sum>
: tensor<4xf32> -> tensor<4xf64>
return %0 : tensor<4xf64>
}
// -----
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
func.func @all_reduce_duplicate_mesh_axis(
%arg0 : tensor<4xf32>) -> tensor<4xf64> {
// expected-error@+1 {{Mesh axes contains duplicate elements.}}
%0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0, 1, 0] reduction = <sum>
: tensor<4xf32> -> tensor<4xf64>
return %0 : tensor<4xf64>
}
// -----
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
func.func @all_reduce_invalid_tensor_dimension_size(
%arg0 : tensor<4xf32>) -> tensor<5xf64> {
// expected-error@+1 {{'mesh.all_reduce' op requires the same shape for all operands and results}}
%0 = mesh.all_reduce %arg0 on @mesh0 : tensor<4xf32> -> tensor<5xf64>
return %0 : tensor<5xf64>
}
// -----
func.func @all_gather_invalid_mesh_symbol(
%arg0 : tensor<4xf32>) -> tensor<4xf32> {
// expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
%0 = mesh.all_gather %arg0 on @this_mesh_symbol_does_not_exist gather_axis = 0
: tensor<4xf32> -> tensor<4xf32>
return %0 : tensor<4xf32>
}
// -----
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
func.func @all_gather_invalid_mesh_axis(
%arg0 : tensor<4xf32>) -> tensor<4xf32> {
// expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
%0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 0
: tensor<4xf32> -> tensor<4xf32>
return %0 : tensor<4xf32>
}
// -----
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
func.func @all_reduce_duplicate_mesh_axis(
%arg0 : tensor<4xf32>) -> tensor<4xf32> {
// expected-error@+1 {{Mesh axes contains duplicate elements.}}
%0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2, 2] gather_axis = 0
: tensor<4xf32> -> tensor<4xf32>
return %0 : tensor<4xf32>
}
// -----
mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
func.func @all_gather_invalid_non_gather_axis_dimension_size(
%arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
// expected-error@+1 {{Dimension size mismatch for result axis 1. Expected 4, but got 5.}}
%0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0
: tensor<3x4xf32> -> tensor<3x5xf32>
return %0 : tensor<3x5xf32>
}
// -----
mesh.cluster @mesh0(rank = 2, dim_sizes = [1, 2])
func.func @all_gather_invalid_gather_axis_dimension_size(
%arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
// expected-error@+1 {{Dimension size mismatch for result axis 1. Expected 8, but got 5.}}
%0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [1] gather_axis = 1
: tensor<3x4xf32> -> tensor<3x5xf32>
return %0 : tensor<3x5xf32>
}
// -----
mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
func.func @all_gather_invalid_gather_axis_dynamic_dimension(
%arg0 : tensor<?xf32>) -> tensor<3xf32> {
// expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}}
%0 = mesh.all_gather %arg0 on @mesh0 gather_axis = 0
: tensor<?xf32> -> tensor<3xf32>
return %0 : tensor<3xf32>
}
// -----
mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
func.func @all_gather_invalid_gather_axis(
%arg0 : tensor<3xf32>) -> tensor<3xf32> {
// expected-error@+1 {{Gather axis 1 is out of bounds [0, 1).}}
%0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 1
: tensor<3xf32> -> tensor<3xf32>
return %0 : tensor<3xf32>
}
// -----
mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
func.func @all_gather_invalid_negative_gather_axis(
%arg0 : tensor<3xf32>) -> tensor<3xf32> {
// expected-error@+1 {{Gather axis -1 is out of bounds [0, 1).}}
%0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = -1
: tensor<3xf32> -> tensor<3xf32>
return %0 : tensor<3xf32>
}
// -----
func.func @all_to_all_invalid_mesh_symbol(
%arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
// expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
%0 = mesh.all_to_all %arg0 on @this_mesh_symbol_does_not_exist
split_axis = 1 concat_axis = 0
: tensor<3x6xi8> -> tensor<3x6xi8>
return %0 : tensor<3x6xi8>
}
// -----
mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
func.func @all_to_all_duplicate_mesh_axis(
%arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
// expected-error@+1 {{Mesh axes contains duplicate elements.}}
%0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0, 0]
split_axis = 0 concat_axis = 0
: tensor<3x6xi8> -> tensor<3x6xi8>
return %0 : tensor<3x6xi8>
}
// -----
mesh.cluster @mesh0(rank = 2, dim_sizes = [0, 1])
func.func @all_to_all_invalid_non_dynamic_result_dimension_induced_by_dynamic_device_group(
%arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
// expected-error@+1 {{Dimension size mismatch for result axis 1. Expected dynamic, but got 6.}}
%0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0]
split_axis = 0 concat_axis = 1
: tensor<3x6xi8> -> tensor<3x6xi8>
return %0 : tensor<3x6xi8>
}
// -----
mesh.cluster @mesh0(rank = 2, dim_sizes = [1, 1])
func.func @all_to_all_invalid_non_dynamic_result_split_dimension_induced_by_dynamic_operand_dimension(
%arg0 : tensor<?x6xi8>) -> tensor<3x?xi8> {
// expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}}
%0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [1]
split_axis = 0 concat_axis = 1
: tensor<?x6xi8> -> tensor<3x?xi8>
return %0 : tensor<3x?xi8>
}
// -----
mesh.cluster @mesh0(rank = 2, dim_sizes = [1, 1])
func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_induced_by_dynamic_operand_dimension(
%arg0 : tensor<3x?xi8>) -> tensor<?x3xi8> {
// expected-error@+1 {{Dimension size mismatch for result axis 1. Expected dynamic, but got 3.}}
%0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [1]
split_axis = 0 concat_axis = 1
: tensor<3x?xi8> -> tensor<?x3xi8>
return %0 : tensor<?x3xi8>
}
// -----
mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_size(
%arg0 : tensor<3x2xi8>) -> tensor<1x7xi8> {
// expected-error@+1 {{Dimension size mismatch for result axis 1. Expected 6, but got 7.}}
%0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0]
split_axis = 0 concat_axis = 1
: tensor<3x2xi8> -> tensor<1x7xi8>
return %0 : tensor<1x7xi8>
}
// -----
mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size(
%arg0 : tensor<3x2xi8>) -> tensor<2x6xi8> {
// expected-error@+1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
%0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0]
split_axis = 0 concat_axis = 1
: tensor<3x2xi8> -> tensor<2x6xi8>
return %0 : tensor<2x6xi8>
}
// -----
mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
func.func @reduce_scatter_duplicate_mesh_axis(
%arg0 : tensor<?xf32>) -> tensor<?xf64> {
// expected-error@+1 {{Mesh axes contains duplicate elements.}}
%0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [0, 0] scatter_axis = 0
: tensor<?xf32> -> tensor<?xf64>
return %0 : tensor<?xf64>
}
// -----
mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
func.func @reduce_scatter_invalid_dynamic_dimension(
%arg0 : tensor<?xf32>) -> tensor<2xf64> {
// expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}}
%0 = mesh.reduce_scatter %arg0 on @mesh0 scatter_axis = 0
: tensor<?xf32> -> tensor<2xf64>
return %0 : tensor<2xf64>
}
// -----
mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
func.func @reduce_scatter_invalid_static_dimension_size(
%arg0 : tensor<3xf32>) -> tensor<2xf64> {
// expected-error@+1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
%0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [0] scatter_axis = 0
: tensor<3xf32> -> tensor<2xf64>
return %0 : tensor<2xf64>
}
// -----
mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
func.func @reduce_scatter_invalid_operand_static_dimension_size(
%arg0 : tensor<4xf32>) -> tensor<?xf64> {
// expected-error@+1 {{Operand dimension size 4 is not divisible by collective device group size 3 for scatter axis 0.}}
%0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [0] scatter_axis = 0
: tensor<4xf32> -> tensor<?xf64>
return %0 : tensor<?xf64>
}

View File

@ -12,6 +12,8 @@ mesh.cluster @mesh2(rank = 2, dim_sizes = [0, 4])
// CHECK: mesh.cluster @mesh3
mesh.cluster @mesh3(rank = 2)
mesh.cluster @mesh4(rank = 1, dim_sizes = [3])
// CHECK-LABEL: func @mesh_shard_encoding_fully_replicated
func.func @mesh_shard_encoding_fully_replicated(
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32, #mesh.shard<@mesh0, {{\[\[}}]]>>
@ -126,3 +128,124 @@ func.func @mesh_shard_op_two_users(%arg0 : tensor<4x8xf32>) ->
%2 = mesh.shard %0 to <@mesh0, [[2]]> annotate_for_users : tensor<4x8xf32>
return %1, %2 : tensor<4x8xf32>, tensor<4x8xf32>
}
// CHECK-LABEL: func @all_reduce
func.func @all_reduce(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
%arg0 : tensor<3x4xf32>) -> tensor<3x4xf64> {
// CHECK-NEXT: mesh.all_reduce %[[ARG]] on @mesh0 mesh_axes = [1, 0] reduction = <max>
// CHECK-SAME: : tensor<3x4xf32> -> tensor<3x4xf64>
%0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [1, 0] reduction = <max>
: tensor<3x4xf32> -> tensor<3x4xf64>
return %0 : tensor<3x4xf64>
}
// CHECK-LABEL: func @all_gather
func.func @all_gather(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
%arg0 : tensor<3x4xf32>) -> tensor<3x16xf32> {
// CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh0 mesh_axes = [2] gather_axis = 1
// CHECK-SAME: : tensor<3x4xf32> -> tensor<3x16xf32>
%0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 1
: tensor<3x4xf32> -> tensor<3x16xf32>
return %0 : tensor<3x16xf32>
}
// CHECK-LABEL: func @all_gather_dynamic_dims_in_tensor
func.func @all_gather_dynamic_dims_in_tensor(
// CHECK-SAME: %[[ARG:.*]]: tensor<?x?xf32>
%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh0 mesh_axes = [2] gather_axis = 1
// CHECK-SAME: : tensor<?x?xf32> -> tensor<?x?xf32>
%0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 1
: tensor<?x?xf32> -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// CHECK-LABEL: func @all_gather_dynamic_dims_in_mesh
func.func @all_gather_dynamic_dims_in_mesh(
// CHECK-SAME: %[[ARG:.*]]: tensor<5x6xf32>
%arg0 : tensor<5x6xf32>) -> tensor<5x?xf32> {
// CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh3 mesh_axes = [1] gather_axis = 1
// CHECK-SAME: : tensor<5x6xf32> -> tensor<5x?xf32>
%0 = mesh.all_gather %arg0 on @mesh3 mesh_axes = [1] gather_axis = 1
: tensor<5x6xf32> -> tensor<5x?xf32>
return %0 : tensor<5x?xf32>
}
// CHECK-LABEL: func @all_to_all
func.func @all_to_all(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
%arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
// CHECK-NEXT: mesh.all_to_all %[[ARG]]
// CHECK-SAME: on @mesh4 split_axis = 1 concat_axis = 0
// CHECK-SAME: : tensor<3x6xi8> -> tensor<3x6xi8>
%0 = mesh.all_to_all %arg0 on @mesh4
split_axis = 1 concat_axis = 0
: tensor<3x6xi8> -> tensor<3x6xi8>
return %0 : tensor<3x6xi8>
}
// CHECK-LABEL: func @all_to_all_dynamic_dims_in_result
func.func @all_to_all_dynamic_dims_in_result(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
%arg0 : tensor<3x6xi8>) -> tensor<3x?xi8> {
// CHECK-NEXT: mesh.all_to_all %[[ARG]]
// CHECK-SAME: on @mesh4 split_axis = 1 concat_axis = 0
// CHECK-SAME: : tensor<3x6xi8> -> tensor<3x?xi8>
%0 = mesh.all_to_all %arg0 on @mesh4
split_axis = 1 concat_axis = 0
: tensor<3x6xi8> -> tensor<3x?xi8>
return %0 : tensor<3x?xi8>
}
// CHECK-LABEL: func @all_to_all_same_split_concat_dim_with_dynamic_device_group_size
func.func @all_to_all_same_split_concat_dim_with_dynamic_device_group_size(
// CHECK-SAME: %[[ARG:.*]]: tensor<3xi8>
%arg0 : tensor<3xi8>) -> tensor<3xi8> {
// CHECK-NEXT: mesh.all_to_all %[[ARG]]
// CHECK-SAME: @mesh4 split_axis = 0 concat_axis = 0
// CHECK-SAME: : tensor<3xi8> -> tensor<3xi8>
%0 = mesh.all_to_all %arg0 on @mesh4
split_axis = 0 concat_axis = 0
: tensor<3xi8> -> tensor<3xi8>
return %0 : tensor<3xi8>
}
// CHECK-LABEL: func @all_to_all_non_divisible_split_axis_size
func.func @all_to_all_non_divisible_split_axis_size(
// CHECK-SAME: %[[ARG:.*]]: tensor<2x3xi8>
%arg0 : tensor<2x3xi8>) -> tensor<?x12xi8> {
// CHECK-NEXT: mesh.all_to_all %[[ARG]]
// CHECK-SAME: @mesh0 mesh_axes = [0, 1] split_axis = 0 concat_axis = 1
// CHECK-SAME: : tensor<2x3xi8> -> tensor<?x12xi8>
%0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0, 1]
split_axis = 0 concat_axis = 1
: tensor<2x3xi8> -> tensor<?x12xi8>
return %0 : tensor<?x12xi8>
}
// CHECK-LABEL: func @reduce_scatter_static_dimensions
func.func @reduce_scatter_static_dimensions(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
%arg0 : tensor<3x4xf32>) -> tensor<3x1xf64> {
// CHECK-NEXT: mesh.reduce_scatter %[[ARG]]
// CHECK-SAME: on @mesh0 mesh_axes = [2] reduction = <max> scatter_axis = 1
// CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf64>
%0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [2]
reduction = <max> scatter_axis = 1
: tensor<3x4xf32> -> tensor<3x1xf64>
return %0 : tensor<3x1xf64>
}
// CHECK-LABEL: func @reduce_scatter_dynamic_dimensions
func.func @reduce_scatter_dynamic_dimensions(
// CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
%arg0 : tensor<?xf32>) -> tensor<?xf64> {
// CHECK-NEXT: mesh.reduce_scatter %[[ARG]]
// CHECK-SAME: on @mesh3 mesh_axes = [0, 1] scatter_axis = 0
// CHECK-SAME: : tensor<?xf32> -> tensor<?xf64>
%0 = mesh.reduce_scatter %arg0 on @mesh3 mesh_axes = [0, 1] scatter_axis = 0
: tensor<?xf32> -> tensor<?xf64>
return %0 : tensor<?xf64>
}