[mlir][mesh] Add collective communication operations (#71960)
Add all-gather, all-reduce, all-to-all and reduce-scatter. These operations have device mesh semantics.
This commit is contained in:
parent
ac75171d41
commit
5f7c8c1068
43
mlir/docs/Dialects/Mesh.md
Normal file
43
mlir/docs/Dialects/Mesh.md
Normal file
@ -0,0 +1,43 @@
|
||||
# 'mesh' Dialect
|
||||
|
||||
The `mesh` dialect contains a set of attributes, operations and interfaces that
|
||||
are useful for representing sharding and communication on a device mesh
|
||||
cluster.
|
||||
|
||||
[TOC]
|
||||
|
||||
## Collective Communication Operations
|
||||
There are a number of operations in the Mesh dialect to facilitate
|
||||
communication between devices in a mesh.
|
||||
It is assumed that the user is familiar with collective operations.
|
||||
[Wikipedia](https://en.wikipedia.org/wiki/Collective_operation) has a good
|
||||
explanation.
|
||||
The main addition is that the collectives in this dialect have mesh
|
||||
semantics.
|
||||
|
||||
The operation attributes `mesh` and `mesh_axes` specifies a list of device mesh
|
||||
axes that partition the devices into disjoint groups.
|
||||
The collective operation is performed between devices in the same group.
|
||||
Devices that have the same coordinates outside of axes `mesh_axes` are in the
|
||||
same group.
|
||||
For example if we have a device mesh of size `2x3x4x5` and the partition mesh
|
||||
axes list is `[0, 1]` then devices are partitioned into the groups
|
||||
`{ { (i, j, k, m) | 0<=i<2, 0<=j<3 } | 0<=k<4, 0<=m<5 }`.
|
||||
Devices (1, 0, 2, 3) and (1, 1, 2, 3) will be in the same group.
|
||||
Device (1, 0, 2, 4) will be in another group.
|
||||
Some collective operations like all-to-all and all-gather care about the
|
||||
order of devices.
|
||||
The order of device in a device group is induced by the order of axes in
|
||||
`mesh_axes`.
|
||||
The axes are ordered from outer to inner.
|
||||
If we have an axis list `[3, 1]` then device `(i, 1, k, 0)` will precede
|
||||
both devices `(i, 0, k, 1)` and `(i, 2, k, 0)`.
|
||||
|
||||
|
||||
## Operations
|
||||
|
||||
[include "Dialects/MeshOps.md"]
|
||||
|
||||
## Attributes
|
||||
|
||||
[include "Dialects/MeshAttributes.md"]
|
@ -23,9 +23,7 @@ def Mesh_Dialect : Dialect {
|
||||
let cppNamespace = "::mlir::mesh";
|
||||
|
||||
let description = [{
|
||||
The `mesh` dialect contains a set of attributes, operations, interfaces that
|
||||
are useful for representing sharding and communication on device mesh
|
||||
cluster.
|
||||
See [Mesh dialect documentation](mlir/docs/Dialects/Mesh.md).
|
||||
}];
|
||||
|
||||
let dependentDialects = [
|
||||
@ -49,6 +47,10 @@ def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor"
|
||||
let cppNamespace = "::mlir::mesh";
|
||||
}
|
||||
|
||||
def Mesh_PartialAttr : EnumAttr<Mesh_Dialect, Mesh_Partial, "partial"> {
|
||||
let assemblyFormat = "`<` $value `>`";
|
||||
}
|
||||
|
||||
// Mesh_IteratorType and Mesh_Partial are used to annotate different aspects of
|
||||
// distributed tensors. Mesh_IteratorType annotates loops in an operation, while
|
||||
// Mesh_Partial indicates whether a tensor is sharded on a specific dimension or
|
||||
|
@ -10,9 +10,12 @@
|
||||
#define MLIR_DIALECT_MESH_IR_MESHOPS_H
|
||||
|
||||
#include "mlir/Bytecode/BytecodeOpInterface.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include <algorithm>
|
||||
|
||||
#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.h.inc"
|
||||
|
||||
|
@ -13,6 +13,8 @@ include "mlir/Dialect/Mesh/IR/MeshBase.td"
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/IR/BuiltinTypes.td"
|
||||
include "mlir/IR/CommonAttrConstraints.td"
|
||||
include "mlir/IR/CommonTypeConstraints.td"
|
||||
include "mlir/IR/SymbolInterfaces.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -77,6 +79,18 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
|
||||
$sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` $dim_sizes^)? `)`
|
||||
attr-dict
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
// The `dim_sizes` attribute may have size less than the rank of the mesh.
|
||||
// Returns the shape of the mesh with missing trailing dimensions
|
||||
// explicitly set as dynamic.
|
||||
::mlir::SmallVector<int64_t> canonicalDimSizes();
|
||||
|
||||
template <typename OutIt>
|
||||
void canonicalDimSizes(OutIt outIt) {
|
||||
std::copy(getDimSizes().begin(), getDimSizes().end(), outIt);
|
||||
std::fill_n(outIt, getRank() - getDimSizes().size(), 0);
|
||||
}
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
@ -171,4 +185,219 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// collective communication ops
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class Mesh_CollectiveCommunicationOpBase<
|
||||
string mnemonic, list<Trait> traits = []> :
|
||||
Mesh_Op<mnemonic,
|
||||
!listconcat(traits,
|
||||
[DeclareOpInterfaceMethods<SymbolUserOpInterface>])> {
|
||||
dag commonArgs = (ins
|
||||
FlatSymbolRefAttr:$mesh,
|
||||
DefaultValuedAttr<DenseI16ArrayAttr, "{}">:$mesh_axes
|
||||
);
|
||||
}
|
||||
|
||||
def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
|
||||
SameOperandsAndResultElementType,
|
||||
SameOperandsAndResultRank
|
||||
]> {
|
||||
let summary = "All-gather over a device mesh.";
|
||||
let description = [{
|
||||
Gathers along the `gather_axis` tensor axis.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
|
||||
...
|
||||
%1 = mesh.all_gather %0 on @mesh0 mesh_axes = [1] gather_axis = 1
|
||||
: tensor<2x2xi8> -> tensor<2x4xi8>
|
||||
```
|
||||
Input:
|
||||
```
|
||||
+-------+-------+
|
||||
device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1)
|
||||
| 3 4 | 7 8 |
|
||||
+-------+-------+
|
||||
device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1)
|
||||
| 11 12 | 15 16 |
|
||||
+-------+-------+
|
||||
```
|
||||
Result:
|
||||
```
|
||||
gather tensor
|
||||
axis 1
|
||||
------------>
|
||||
+-------------+
|
||||
| 1 2 5 6 | <- devices (0, 0) and (0, 1)
|
||||
| 3 4 7 8 |
|
||||
+-------------+
|
||||
| 9 10 13 14 | <- devices (1, 0) and (1, 1)
|
||||
| 11 12 15 16 |
|
||||
+-------------+
|
||||
```
|
||||
}];
|
||||
let arguments = !con(commonArgs, (ins
|
||||
AnyNon0RankedTensor:$input,
|
||||
IndexAttr:$gather_axis
|
||||
));
|
||||
let results = (outs
|
||||
AnyNon0RankedTensor:$result
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `gather_axis` `=` $gather_axis
|
||||
attr-dict `:` type($input) `->` type($result)
|
||||
}];
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
|
||||
SameOperandsAndResultShape]> {
|
||||
let summary = "All-reduce over a device mesh.";
|
||||
let description = [{
|
||||
The accumulation element type is specified by the result type and
|
||||
it does not need to match the input element type.
|
||||
The input element is converted to the result element type before
|
||||
performing the reduction.
|
||||
|
||||
Attributes:
|
||||
`reduction`: Indicates the reduction method.
|
||||
|
||||
Example:
|
||||
```
|
||||
%1 = mesh.all_reduce %0 on @mesh0 mesh_axes = [1, 0] reduction = <max>
|
||||
: tensor<3x4xf32> -> tensor<3x4xf64>
|
||||
```
|
||||
}];
|
||||
let arguments = !con(commonArgs, (ins
|
||||
AnyRankedTensor:$input,
|
||||
DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction
|
||||
));
|
||||
let results = (outs
|
||||
AnyRankedTensor:$result
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? (`reduction` `=` $reduction^)?
|
||||
attr-dict `:` type($input) `->` type($result)
|
||||
}];
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
|
||||
SameOperandsAndResultElementType,
|
||||
SameOperandsAndResultRank]> {
|
||||
let summary = "All-to-all over a device mesh.";
|
||||
let description = [{
|
||||
Performs an all-to-all on tensor pieces split along `split_axis`.
|
||||
The resulting pieces are concatenated along `concat_axis` on ech device.
|
||||
|
||||
Example:
|
||||
```
|
||||
mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
|
||||
...
|
||||
%1 = mesh.all_to_all %0 on @mesh0 mesh_axes = [0]
|
||||
split_axis = 0 concat_axis = 0
|
||||
: tensor<3x2xi8> -> tensor<3x2xi8>
|
||||
```
|
||||
Input:
|
||||
```
|
||||
device device device
|
||||
(0) (1) (2)
|
||||
+-------+-------+-------+ | split and concat along
|
||||
| 11 12 | 21 22 | 31 32 | | tensor axis 0
|
||||
| 13 14 | 23 24 | 33 34 | ↓
|
||||
| 15 16 | 25 26 | 35 36 |
|
||||
+-------+-------+-------+
|
||||
```
|
||||
Result:
|
||||
```
|
||||
device device device
|
||||
(0) (1) (2)
|
||||
+-------+-------+-------+
|
||||
| 11 12 | 13 14 | 15 16 |
|
||||
| 21 22 | 23 24 | 25 26 |
|
||||
| 31 32 | 33 34 | 35 36 |
|
||||
+-------+-------+-------+
|
||||
```
|
||||
}];
|
||||
let arguments = !con(commonArgs, (ins
|
||||
AnyNon0RankedTensor:$input,
|
||||
IndexAttr:$split_axis,
|
||||
IndexAttr:$concat_axis
|
||||
));
|
||||
let results = (outs
|
||||
AnyNon0RankedTensor:$result
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
|
||||
`split_axis` `=` $split_axis
|
||||
`concat_axis` `=` $concat_axis
|
||||
attr-dict `:` type($input) `->` type($result)
|
||||
}];
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [
|
||||
SameOperandsAndResultRank]> {
|
||||
let summary = "Reduce-scatter over a device mesh.";
|
||||
let description = [{
|
||||
After the reduction, the result is scattered within each device group.
|
||||
The tensor is split along `scatter_axis` and the pieces distributed
|
||||
across the device group.
|
||||
Example:
|
||||
```
|
||||
mesh.cluster @mesh0(rank = 1, dim_sizes = [2, 2])
|
||||
...
|
||||
%1 = mesh.reduce_scatter %0 on @mesh0 mesh_axes = [1]
|
||||
reduction = <max> scatter_axis = 0
|
||||
: tensor<3x4xf32> -> tensor<1x4xf64>
|
||||
```
|
||||
Input:
|
||||
```
|
||||
device
|
||||
(0, 1)
|
||||
↓
|
||||
+-------+-------+ | scatter tensor
|
||||
device (0, 0) -> | 1 2 | 5 6 | | axis 0
|
||||
| 3 4 | 7 8 | ↓
|
||||
+-------+-------+
|
||||
device (1, 0) -> | 9 10 | 13 14 |
|
||||
| 11 12 | 15 16 |
|
||||
+-------+-------+
|
||||
↑
|
||||
device
|
||||
(1, 1)
|
||||
```
|
||||
Result:
|
||||
```
|
||||
+-------+
|
||||
| 6 8 | <- devices (0, 0)
|
||||
+-------+
|
||||
| 10 12 | <- devices (0, 1)
|
||||
+-------+
|
||||
| 22 24 | <- devices (1, 0)
|
||||
+-------+
|
||||
| 26 28 | <- devices (1, 1)
|
||||
+-------+
|
||||
```
|
||||
}];
|
||||
let arguments = !con(commonArgs, (ins
|
||||
AnyNon0RankedTensor:$input,
|
||||
DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction,
|
||||
IndexAttr:$scatter_axis
|
||||
));
|
||||
let results = (outs
|
||||
AnyRankedTensor:$result
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
|
||||
(`reduction` `=` $reduction^)?
|
||||
`scatter_axis` `=` $scatter_axis
|
||||
attr-dict `:` type($input) `->` type($result)
|
||||
}];
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
#endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
|
||||
|
@ -8,10 +8,27 @@
|
||||
|
||||
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#define DEBUG_TYPE "mesh-ops"
|
||||
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
|
||||
@ -21,6 +38,60 @@ using namespace mlir::mesh;
|
||||
|
||||
#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.cpp.inc"
|
||||
|
||||
template <typename It>
|
||||
static It canonicalizeSetAsArray(It begin, It end) {
|
||||
llvm::sort(begin, end);
|
||||
return std::unique(begin, end);
|
||||
}
|
||||
|
||||
template <typename R>
|
||||
static auto canonicalizeSetAsArray(R &&range) {
|
||||
return canonicalizeSetAsArray(adl_begin(range), adl_end(range));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static SmallVector<T> &canonicalizeSetAsVector(SmallVector<T> &vec) {
|
||||
auto newEnd = canonicalizeSetAsArray(vec);
|
||||
vec.resize(newEnd - vec.begin());
|
||||
return vec;
|
||||
}
|
||||
|
||||
template <typename DimSize>
|
||||
static bool isMeshDimensionDynamic(DimSize size) {
|
||||
return size <= DimSize(0);
|
||||
}
|
||||
|
||||
using MeshAxis = int16_t;
|
||||
|
||||
namespace {
|
||||
|
||||
struct DimensionSize {
|
||||
static DimensionSize dynamic() { return DimensionSize(ShapedType::kDynamic); }
|
||||
DimensionSize(int64_t val) : val(val) {}
|
||||
int64_t value() const { return val; }
|
||||
operator int64_t() const { return val; }
|
||||
bool isDynamic() const { return ShapedType::isDynamic(val); }
|
||||
|
||||
private:
|
||||
int64_t val;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
static DimensionSize operator/(DimensionSize lhs, DimensionSize rhs) {
|
||||
if (lhs.isDynamic() || rhs.isDynamic()) {
|
||||
return DimensionSize::dynamic();
|
||||
}
|
||||
return lhs.value() / rhs.value();
|
||||
}
|
||||
|
||||
static DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) {
|
||||
if (lhs.isDynamic() || rhs.isDynamic()) {
|
||||
return DimensionSize::dynamic();
|
||||
}
|
||||
return lhs.value() * rhs.value();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Mesh dialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -96,6 +167,13 @@ LogicalResult ClusterOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<int64_t> ClusterOp::canonicalDimSizes() {
|
||||
SmallVector<int64_t> result;
|
||||
canonicalDimSizes(std::back_inserter(result));
|
||||
result.reserve(getRank());
|
||||
return result;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// mesh.shard op
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -129,6 +207,327 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// collective communication ops
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename Op>
|
||||
struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern<Op> {
|
||||
using OpRewritePattern<Op>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(Op op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto meshAxes = op.getMeshAxes();
|
||||
if (!meshAxes.empty()) {
|
||||
return failure();
|
||||
}
|
||||
if (op.getInput().getType() != op.getResult().getType()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
rewriter.replaceAllUsesWith(op.getResult(), op.getInput());
|
||||
rewriter.eraseOp(op.getOperation());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
static FailureOr<ClusterOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
|
||||
SymbolTableCollection &symbolTable) {
|
||||
mesh::ClusterOp mesh =
|
||||
symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(op, meshSymbol);
|
||||
if (!mesh) {
|
||||
return op->emitError() << "Undefined required mesh symbol \""
|
||||
<< meshSymbol.getValue() << "\".";
|
||||
}
|
||||
|
||||
return mesh;
|
||||
}
|
||||
|
||||
template <typename It>
|
||||
bool isUnique(It begin, It end) {
|
||||
if (begin == end) {
|
||||
return true;
|
||||
}
|
||||
It next = std::next(begin);
|
||||
if (next == end) {
|
||||
return true;
|
||||
}
|
||||
for (; next != end; ++next, ++begin) {
|
||||
if (*begin == *next) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
|
||||
ClusterOp mesh) {
|
||||
SmallVector<MeshAxis> sorted = llvm::to_vector(axes);
|
||||
llvm::sort(sorted);
|
||||
if (!isUnique(sorted.begin(), sorted.end())) {
|
||||
return emitError(loc) << "Mesh axes contains duplicate elements.";
|
||||
}
|
||||
|
||||
MeshAxis rank = mesh.getRank();
|
||||
for (auto axis : axes) {
|
||||
if (axis >= rank || axis < 0) {
|
||||
return emitError(loc)
|
||||
<< "0-based mesh axis index " << axis
|
||||
<< " is out of bounds. The referenced mesh \"" << mesh.getSymName()
|
||||
<< "\" is of rank " << rank << ".";
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
static FailureOr<ClusterOp>
|
||||
getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) {
|
||||
auto mesh = ::getMesh(op.getOperation(), op.getMeshAttr(), symbolTable);
|
||||
if (failed(mesh)) {
|
||||
return failure();
|
||||
}
|
||||
if (failed(verifyMeshAxes(op.getLoc(), op.getMeshAxes(), mesh.value()))) {
|
||||
return failure();
|
||||
}
|
||||
return mesh;
|
||||
}
|
||||
|
||||
template <typename It>
|
||||
static auto product(It begin, It end) {
|
||||
using ElementType = std::decay_t<decltype(*begin)>;
|
||||
return std::accumulate(begin, end, static_cast<ElementType>(1),
|
||||
std::multiplies<ElementType>());
|
||||
}
|
||||
|
||||
template <typename R>
|
||||
static auto product(R &&range) {
|
||||
return product(adl_begin(range), adl_end(range));
|
||||
}
|
||||
|
||||
static int64_t collectiveDeviceGroupSize(ArrayRef<MeshAxis> meshAxes,
|
||||
ArrayRef<int64_t> meshShape) {
|
||||
int64_t res = 1;
|
||||
|
||||
for (MeshAxis axis : meshAxes) {
|
||||
if (isMeshDimensionDynamic(meshShape[axis])) {
|
||||
return ShapedType::kDynamic;
|
||||
}
|
||||
assert(size_t(axis) < meshShape.size());
|
||||
res *= meshShape[axis];
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
static LogicalResult verifyDimensionCompatibility(Location loc,
|
||||
int64_t expectedDimSize,
|
||||
int64_t resultDimSize,
|
||||
int64_t resultAxis) {
|
||||
if (!ShapedType::isDynamic(resultDimSize) &&
|
||||
expectedDimSize != resultDimSize) {
|
||||
return emitError(loc) << "Dimension size mismatch for result axis "
|
||||
<< resultAxis << ". Expected "
|
||||
<< (ShapedType::isDynamic(expectedDimSize)
|
||||
? Twine("dynamic")
|
||||
: Twine(expectedDimSize))
|
||||
<< ", but got " << resultDimSize << ".";
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyAllGatherOperandAndResultShape(
|
||||
Value operand, Value result, int64_t gatherAxis,
|
||||
ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
|
||||
auto resultRank = result.getType().template cast<ShapedType>().getRank();
|
||||
if (gatherAxis < 0 || gatherAxis >= resultRank) {
|
||||
return emitError(result.getLoc())
|
||||
<< "Gather axis " << gatherAxis << " is out of bounds [0, "
|
||||
<< resultRank << ").";
|
||||
}
|
||||
|
||||
ShapedType operandType = operand.getType().cast<ShapedType>();
|
||||
ShapedType resultType = result.getType().cast<ShapedType>();
|
||||
auto deviceGroupSize =
|
||||
DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
|
||||
for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
|
||||
auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
|
||||
auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
|
||||
auto expectedResultDimSize =
|
||||
axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
|
||||
if (failed(verifyDimensionCompatibility(
|
||||
result.getLoc(), expectedResultDimSize, resultDimSize, axis))) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyAllToAllOperandAndResultShape(
|
||||
Value operand, Value result, int64_t splitAxis, int64_t concatAxis,
|
||||
ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
|
||||
ShapedType operandType = operand.getType().cast<ShapedType>();
|
||||
ShapedType resultType = result.getType().cast<ShapedType>();
|
||||
for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
|
||||
if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
|
||||
if (failed(verifyDimensionCompatibility(
|
||||
result.getLoc(), operandType.getDimSize(axis),
|
||||
resultType.getDimSize(axis), axis))) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (splitAxis == concatAxis) {
|
||||
return success();
|
||||
}
|
||||
|
||||
auto deviceGroupSize =
|
||||
DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
|
||||
auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
|
||||
auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
|
||||
DimensionSize expectedResultConcatDimSize =
|
||||
operandConcatDimSize * deviceGroupSize;
|
||||
DimensionSize expectedResultSplitDimSize =
|
||||
operandSplitDimSize / deviceGroupSize;
|
||||
if (!expectedResultSplitDimSize.isDynamic() &&
|
||||
int64_t(operandSplitDimSize) % int64_t(deviceGroupSize) != 0) {
|
||||
expectedResultSplitDimSize = DimensionSize::dynamic();
|
||||
}
|
||||
if (failed(verifyDimensionCompatibility(
|
||||
result.getLoc(), expectedResultConcatDimSize.value(),
|
||||
resultType.getDimSize(concatAxis), concatAxis))) {
|
||||
return failure();
|
||||
}
|
||||
if (failed(verifyDimensionCompatibility(
|
||||
result.getLoc(), expectedResultSplitDimSize.value(),
|
||||
resultType.getDimSize(splitAxis), splitAxis))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyReduceScatterOperandAndResultShape(
|
||||
Value operand, Value result, int64_t scatterAxis,
|
||||
ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
|
||||
ShapedType operandType = operand.getType().cast<ShapedType>();
|
||||
ShapedType resultType = result.getType().cast<ShapedType>();
|
||||
for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
|
||||
if (axis != scatterAxis) {
|
||||
if (failed(verifyDimensionCompatibility(
|
||||
result.getLoc(), operandType.getDimSize(axis),
|
||||
resultType.getDimSize(axis), axis))) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto deviceGroupSize =
|
||||
DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
|
||||
auto operandScatterDimSize =
|
||||
DimensionSize(operandType.getDimSize(scatterAxis));
|
||||
if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
|
||||
int64_t(operandScatterDimSize) % int64_t(deviceGroupSize) != 0) {
|
||||
return emitError(result.getLoc())
|
||||
<< "Operand dimension size " << int64_t(operandScatterDimSize)
|
||||
<< " is not divisible by collective device group size "
|
||||
<< int64_t(deviceGroupSize) << " for scatter axis " << scatterAxis
|
||||
<< ".";
|
||||
}
|
||||
DimensionSize expectedResultScatterDimSize =
|
||||
operandScatterDimSize / deviceGroupSize;
|
||||
if (failed(verifyDimensionCompatibility(
|
||||
result.getLoc(), expectedResultScatterDimSize.value(),
|
||||
resultType.getDimSize(scatterAxis), scatterAxis))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// mesh.all_gather op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult
|
||||
AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||||
auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
|
||||
if (failed(mesh)) {
|
||||
return failure();
|
||||
}
|
||||
auto gatherAxis = getGatherAxis().getSExtValue();
|
||||
return verifyAllGatherOperandAndResultShape(getOperand(), getResult(),
|
||||
gatherAxis, getMeshAxes(),
|
||||
mesh.value().canonicalDimSizes());
|
||||
}
|
||||
|
||||
void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
patterns.add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// mesh.all_reduce op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult
|
||||
AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||||
return getMeshAndVerifyAxes(*this, symbolTable);
|
||||
}
|
||||
|
||||
void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// mesh.all_to_all op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||||
auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
|
||||
if (failed(mesh)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
return verifyAllToAllOperandAndResultShape(
|
||||
getOperand(), getResult(), getSplitAxis().getSExtValue(),
|
||||
getConcatAxis().getSExtValue(), getMeshAxes(),
|
||||
mesh.value().canonicalDimSizes());
|
||||
}
|
||||
|
||||
void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
patterns.add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// mesh.reduce_scatter op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult
|
||||
ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||||
auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
|
||||
if (failed(mesh)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
return verifyReduceScatterOperandAndResultShape(
|
||||
getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
|
||||
mesh.value().canonicalDimSizes());
|
||||
}
|
||||
|
||||
void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
101
mlir/test/Dialect/Mesh/canonicalization.mlir
Normal file
101
mlir/test/Dialect/Mesh/canonicalization.mlir
Normal file
@ -0,0 +1,101 @@
|
||||
// RUN: mlir-opt --canonicalize %s | FileCheck %s
|
||||
|
||||
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
|
||||
|
||||
// CHECK-LABEL: func @all_reduce_empty_mesh_axes
|
||||
func.func @all_reduce_empty_mesh_axes(
|
||||
// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
|
||||
%arg0 : tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK-NOT: mesh.all_reduce
|
||||
%0 = mesh.all_reduce %arg0 on @mesh0
|
||||
mesh_axes = []
|
||||
: tensor<4xf32> -> tensor<4xf32>
|
||||
// CHECK: return %[[ARG]]
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @all_reduce_empty_mesh_axes_different_return_type
|
||||
func.func @all_reduce_empty_mesh_axes_different_return_type(
|
||||
%arg0 : tensor<4xf32>) -> tensor<4xf64> {
|
||||
// CHECK: mesh.all_reduce
|
||||
%0 = mesh.all_reduce %arg0 on @mesh0
|
||||
// CHECK-NOT: mesh_axes
|
||||
mesh_axes = []
|
||||
: tensor<4xf32> -> tensor<4xf64>
|
||||
return %0 : tensor<4xf64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @all_reduce_default_reduction
|
||||
func.func @all_reduce_default_reduction(
|
||||
%arg0 : tensor<4xf32>) -> tensor<4xf64> {
|
||||
%0 = mesh.all_reduce %arg0 on @mesh0
|
||||
mesh_axes = [0]
|
||||
// CHECK-NOT: reduction
|
||||
reduction = <sum>
|
||||
: tensor<4xf32> -> tensor<4xf64>
|
||||
return %0 : tensor<4xf64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @all_to_all_empty_mesh_axes
|
||||
func.func @all_to_all_empty_mesh_axes(
|
||||
// CHECK-SAME: %[[ARG:.*]]: tensor<8xf32>
|
||||
%arg0 : tensor<8xf32>) -> tensor<8xf32> {
|
||||
// CHECK-NOT: mesh.all_to_all
|
||||
%0 = mesh.all_to_all %arg0 on @mesh0
|
||||
mesh_axes = []
|
||||
split_axis = 0
|
||||
concat_axis = 0
|
||||
: tensor<8xf32> -> tensor<8xf32>
|
||||
// CHECK: return %[[ARG]]
|
||||
return %0 : tensor<8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @all_gather_empty_mesh_axes
|
||||
func.func @all_gather_empty_mesh_axes(
|
||||
// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
|
||||
%arg0 : tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK-NOT: mesh.all_gather
|
||||
%0 = mesh.all_gather %arg0 on @mesh0
|
||||
mesh_axes = []
|
||||
gather_axis = 0
|
||||
: tensor<4xf32> -> tensor<4xf32>
|
||||
// CHECK: return %[[ARG]]
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @reduce_scatter_empty_mesh_axes
|
||||
func.func @reduce_scatter_empty_mesh_axes(
|
||||
// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
|
||||
%arg0 : tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK-NOT: mesh.reduce_scatter
|
||||
%0 = mesh.reduce_scatter %arg0 on @mesh0
|
||||
mesh_axes = []
|
||||
scatter_axis = 0
|
||||
: tensor<4xf32> -> tensor<4xf32>
|
||||
// CHECK: return %[[ARG]]
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @reduce_scatter_empty_mesh_axes_different_return_type
|
||||
func.func @reduce_scatter_empty_mesh_axes_different_return_type(
|
||||
%arg0 : tensor<4xf32>) -> tensor<4xf64> {
|
||||
// CHECK: mesh.reduce_scatter
|
||||
%0 = mesh.reduce_scatter %arg0 on @mesh0
|
||||
// CHECK-NOT: mesh_axes
|
||||
mesh_axes = []
|
||||
scatter_axis = 0
|
||||
: tensor<4xf32> -> tensor<4xf64>
|
||||
return %0 : tensor<4xf64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @reduce_scatter_default_reduction
|
||||
func.func @reduce_scatter_default_reduction(
|
||||
%arg0 : tensor<4xf32>) -> tensor<2xf64> {
|
||||
%0 = mesh.reduce_scatter %arg0 on @mesh0
|
||||
mesh_axes = [0]
|
||||
// CHECK-NOT: reduction
|
||||
reduction = <sum>
|
||||
scatter_axis = 0
|
||||
: tensor<4xf32> -> tensor<2xf64>
|
||||
return %0 : tensor<2xf64>
|
||||
}
|
@ -67,3 +67,279 @@ func.func @mesh_axis_negtive_in_partial(
|
||||
tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[-1]>> {
|
||||
return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[-1]>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @all_reduce_invalid_mesh_symbol(
|
||||
%arg0 : tensor<4xf32>) -> tensor<4xf64> {
|
||||
// expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
|
||||
%0 = mesh.all_reduce %arg0 on @this_mesh_symbol_does_not_exist reduction = <sum>
|
||||
: tensor<4xf32> -> tensor<4xf64>
|
||||
return %0 : tensor<4xf64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
|
||||
|
||||
func.func @all_reduce_invalid_mesh_axis(
|
||||
%arg0 : tensor<4xf32>) -> tensor<4xf64> {
|
||||
// expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
|
||||
%0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [2] reduction = <sum>
|
||||
: tensor<4xf32> -> tensor<4xf64>
|
||||
return %0 : tensor<4xf64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
|
||||
|
||||
func.func @all_reduce_duplicate_mesh_axis(
|
||||
%arg0 : tensor<4xf32>) -> tensor<4xf64> {
|
||||
// expected-error@+1 {{Mesh axes contains duplicate elements.}}
|
||||
%0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0, 1, 0] reduction = <sum>
|
||||
: tensor<4xf32> -> tensor<4xf64>
|
||||
return %0 : tensor<4xf64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
|
||||
|
||||
func.func @all_reduce_invalid_tensor_dimension_size(
|
||||
%arg0 : tensor<4xf32>) -> tensor<5xf64> {
|
||||
// expected-error@+1 {{'mesh.all_reduce' op requires the same shape for all operands and results}}
|
||||
%0 = mesh.all_reduce %arg0 on @mesh0 : tensor<4xf32> -> tensor<5xf64>
|
||||
return %0 : tensor<5xf64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @all_gather_invalid_mesh_symbol(
|
||||
%arg0 : tensor<4xf32>) -> tensor<4xf32> {
|
||||
// expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
|
||||
%0 = mesh.all_gather %arg0 on @this_mesh_symbol_does_not_exist gather_axis = 0
|
||||
: tensor<4xf32> -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
|
||||
|
||||
func.func @all_gather_invalid_mesh_axis(
|
||||
%arg0 : tensor<4xf32>) -> tensor<4xf32> {
|
||||
// expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
|
||||
%0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 0
|
||||
: tensor<4xf32> -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
|
||||
|
||||
func.func @all_reduce_duplicate_mesh_axis(
|
||||
%arg0 : tensor<4xf32>) -> tensor<4xf32> {
|
||||
// expected-error@+1 {{Mesh axes contains duplicate elements.}}
|
||||
%0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2, 2] gather_axis = 0
|
||||
: tensor<4xf32> -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
|
||||
|
||||
func.func @all_gather_invalid_non_gather_axis_dimension_size(
|
||||
%arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
|
||||
// expected-error@+1 {{Dimension size mismatch for result axis 1. Expected 4, but got 5.}}
|
||||
%0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0
|
||||
: tensor<3x4xf32> -> tensor<3x5xf32>
|
||||
return %0 : tensor<3x5xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
mesh.cluster @mesh0(rank = 2, dim_sizes = [1, 2])
|
||||
|
||||
func.func @all_gather_invalid_gather_axis_dimension_size(
|
||||
%arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
|
||||
// expected-error@+1 {{Dimension size mismatch for result axis 1. Expected 8, but got 5.}}
|
||||
%0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [1] gather_axis = 1
|
||||
: tensor<3x4xf32> -> tensor<3x5xf32>
|
||||
return %0 : tensor<3x5xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
|
||||
|
||||
func.func @all_gather_invalid_gather_axis_dynamic_dimension(
|
||||
%arg0 : tensor<?xf32>) -> tensor<3xf32> {
|
||||
// expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}}
|
||||
%0 = mesh.all_gather %arg0 on @mesh0 gather_axis = 0
|
||||
: tensor<?xf32> -> tensor<3xf32>
|
||||
return %0 : tensor<3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
|
||||
|
||||
func.func @all_gather_invalid_gather_axis(
|
||||
%arg0 : tensor<3xf32>) -> tensor<3xf32> {
|
||||
// expected-error@+1 {{Gather axis 1 is out of bounds [0, 1).}}
|
||||
%0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 1
|
||||
: tensor<3xf32> -> tensor<3xf32>
|
||||
return %0 : tensor<3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
|
||||
|
||||
func.func @all_gather_invalid_negative_gather_axis(
|
||||
%arg0 : tensor<3xf32>) -> tensor<3xf32> {
|
||||
// expected-error@+1 {{Gather axis -1 is out of bounds [0, 1).}}
|
||||
%0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = -1
|
||||
: tensor<3xf32> -> tensor<3xf32>
|
||||
return %0 : tensor<3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @all_to_all_invalid_mesh_symbol(
|
||||
%arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
|
||||
// expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
|
||||
%0 = mesh.all_to_all %arg0 on @this_mesh_symbol_does_not_exist
|
||||
split_axis = 1 concat_axis = 0
|
||||
: tensor<3x6xi8> -> tensor<3x6xi8>
|
||||
return %0 : tensor<3x6xi8>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
mesh.cluster @mesh0(rank = 1, dim_sizes = [1])
|
||||
|
||||
func.func @all_to_all_duplicate_mesh_axis(
|
||||
%arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
|
||||
// expected-error@+1 {{Mesh axes contains duplicate elements.}}
|
||||
%0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0, 0]
|
||||
split_axis = 0 concat_axis = 0
|
||||
: tensor<3x6xi8> -> tensor<3x6xi8>
|
||||
return %0 : tensor<3x6xi8>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
mesh.cluster @mesh0(rank = 2, dim_sizes = [0, 1])
|
||||
|
||||
func.func @all_to_all_invalid_non_dynamic_result_dimension_induced_by_dynamic_device_group(
|
||||
%arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
|
||||
// expected-error@+1 {{Dimension size mismatch for result axis 1. Expected dynamic, but got 6.}}
|
||||
%0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0]
|
||||
split_axis = 0 concat_axis = 1
|
||||
: tensor<3x6xi8> -> tensor<3x6xi8>
|
||||
return %0 : tensor<3x6xi8>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
mesh.cluster @mesh0(rank = 2, dim_sizes = [1, 1])
|
||||
|
||||
func.func @all_to_all_invalid_non_dynamic_result_split_dimension_induced_by_dynamic_operand_dimension(
|
||||
%arg0 : tensor<?x6xi8>) -> tensor<3x?xi8> {
|
||||
// expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}}
|
||||
%0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [1]
|
||||
split_axis = 0 concat_axis = 1
|
||||
: tensor<?x6xi8> -> tensor<3x?xi8>
|
||||
return %0 : tensor<3x?xi8>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
mesh.cluster @mesh0(rank = 2, dim_sizes = [1, 1])
|
||||
|
||||
func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_induced_by_dynamic_operand_dimension(
|
||||
%arg0 : tensor<3x?xi8>) -> tensor<?x3xi8> {
|
||||
// expected-error@+1 {{Dimension size mismatch for result axis 1. Expected dynamic, but got 3.}}
|
||||
%0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [1]
|
||||
split_axis = 0 concat_axis = 1
|
||||
: tensor<3x?xi8> -> tensor<?x3xi8>
|
||||
return %0 : tensor<?x3xi8>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
|
||||
|
||||
func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_size(
|
||||
%arg0 : tensor<3x2xi8>) -> tensor<1x7xi8> {
|
||||
// expected-error@+1 {{Dimension size mismatch for result axis 1. Expected 6, but got 7.}}
|
||||
%0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0]
|
||||
split_axis = 0 concat_axis = 1
|
||||
: tensor<3x2xi8> -> tensor<1x7xi8>
|
||||
return %0 : tensor<1x7xi8>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
|
||||
|
||||
func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size(
|
||||
%arg0 : tensor<3x2xi8>) -> tensor<2x6xi8> {
|
||||
// expected-error@+1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
|
||||
%0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0]
|
||||
split_axis = 0 concat_axis = 1
|
||||
: tensor<3x2xi8> -> tensor<2x6xi8>
|
||||
return %0 : tensor<2x6xi8>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
|
||||
|
||||
func.func @reduce_scatter_duplicate_mesh_axis(
|
||||
%arg0 : tensor<?xf32>) -> tensor<?xf64> {
|
||||
// expected-error@+1 {{Mesh axes contains duplicate elements.}}
|
||||
%0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [0, 0] scatter_axis = 0
|
||||
: tensor<?xf32> -> tensor<?xf64>
|
||||
return %0 : tensor<?xf64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
|
||||
|
||||
func.func @reduce_scatter_invalid_dynamic_dimension(
|
||||
%arg0 : tensor<?xf32>) -> tensor<2xf64> {
|
||||
// expected-error@+1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}}
|
||||
%0 = mesh.reduce_scatter %arg0 on @mesh0 scatter_axis = 0
|
||||
: tensor<?xf32> -> tensor<2xf64>
|
||||
return %0 : tensor<2xf64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
|
||||
|
||||
func.func @reduce_scatter_invalid_static_dimension_size(
|
||||
%arg0 : tensor<3xf32>) -> tensor<2xf64> {
|
||||
// expected-error@+1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
|
||||
%0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [0] scatter_axis = 0
|
||||
: tensor<3xf32> -> tensor<2xf64>
|
||||
return %0 : tensor<2xf64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
|
||||
|
||||
func.func @reduce_scatter_invalid_operand_static_dimension_size(
|
||||
%arg0 : tensor<4xf32>) -> tensor<?xf64> {
|
||||
// expected-error@+1 {{Operand dimension size 4 is not divisible by collective device group size 3 for scatter axis 0.}}
|
||||
%0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [0] scatter_axis = 0
|
||||
: tensor<4xf32> -> tensor<?xf64>
|
||||
return %0 : tensor<?xf64>
|
||||
}
|
||||
|
@ -12,6 +12,8 @@ mesh.cluster @mesh2(rank = 2, dim_sizes = [0, 4])
|
||||
// CHECK: mesh.cluster @mesh3
|
||||
mesh.cluster @mesh3(rank = 2)
|
||||
|
||||
mesh.cluster @mesh4(rank = 1, dim_sizes = [3])
|
||||
|
||||
// CHECK-LABEL: func @mesh_shard_encoding_fully_replicated
|
||||
func.func @mesh_shard_encoding_fully_replicated(
|
||||
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32, #mesh.shard<@mesh0, {{\[\[}}]]>>
|
||||
@ -126,3 +128,124 @@ func.func @mesh_shard_op_two_users(%arg0 : tensor<4x8xf32>) ->
|
||||
%2 = mesh.shard %0 to <@mesh0, [[2]]> annotate_for_users : tensor<4x8xf32>
|
||||
return %1, %2 : tensor<4x8xf32>, tensor<4x8xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @all_reduce
|
||||
func.func @all_reduce(
|
||||
// CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
|
||||
%arg0 : tensor<3x4xf32>) -> tensor<3x4xf64> {
|
||||
// CHECK-NEXT: mesh.all_reduce %[[ARG]] on @mesh0 mesh_axes = [1, 0] reduction = <max>
|
||||
// CHECK-SAME: : tensor<3x4xf32> -> tensor<3x4xf64>
|
||||
%0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [1, 0] reduction = <max>
|
||||
: tensor<3x4xf32> -> tensor<3x4xf64>
|
||||
return %0 : tensor<3x4xf64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @all_gather
|
||||
func.func @all_gather(
|
||||
// CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
|
||||
%arg0 : tensor<3x4xf32>) -> tensor<3x16xf32> {
|
||||
// CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh0 mesh_axes = [2] gather_axis = 1
|
||||
// CHECK-SAME: : tensor<3x4xf32> -> tensor<3x16xf32>
|
||||
%0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 1
|
||||
: tensor<3x4xf32> -> tensor<3x16xf32>
|
||||
return %0 : tensor<3x16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @all_gather_dynamic_dims_in_tensor
|
||||
func.func @all_gather_dynamic_dims_in_tensor(
|
||||
// CHECK-SAME: %[[ARG:.*]]: tensor<?x?xf32>
|
||||
%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
// CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh0 mesh_axes = [2] gather_axis = 1
|
||||
// CHECK-SAME: : tensor<?x?xf32> -> tensor<?x?xf32>
|
||||
%0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 1
|
||||
: tensor<?x?xf32> -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @all_gather_dynamic_dims_in_mesh
|
||||
func.func @all_gather_dynamic_dims_in_mesh(
|
||||
// CHECK-SAME: %[[ARG:.*]]: tensor<5x6xf32>
|
||||
%arg0 : tensor<5x6xf32>) -> tensor<5x?xf32> {
|
||||
// CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh3 mesh_axes = [1] gather_axis = 1
|
||||
// CHECK-SAME: : tensor<5x6xf32> -> tensor<5x?xf32>
|
||||
%0 = mesh.all_gather %arg0 on @mesh3 mesh_axes = [1] gather_axis = 1
|
||||
: tensor<5x6xf32> -> tensor<5x?xf32>
|
||||
return %0 : tensor<5x?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @all_to_all
|
||||
func.func @all_to_all(
|
||||
// CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
|
||||
%arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
|
||||
// CHECK-NEXT: mesh.all_to_all %[[ARG]]
|
||||
// CHECK-SAME: on @mesh4 split_axis = 1 concat_axis = 0
|
||||
// CHECK-SAME: : tensor<3x6xi8> -> tensor<3x6xi8>
|
||||
%0 = mesh.all_to_all %arg0 on @mesh4
|
||||
split_axis = 1 concat_axis = 0
|
||||
: tensor<3x6xi8> -> tensor<3x6xi8>
|
||||
return %0 : tensor<3x6xi8>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @all_to_all_dynamic_dims_in_result
|
||||
func.func @all_to_all_dynamic_dims_in_result(
|
||||
// CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
|
||||
%arg0 : tensor<3x6xi8>) -> tensor<3x?xi8> {
|
||||
// CHECK-NEXT: mesh.all_to_all %[[ARG]]
|
||||
// CHECK-SAME: on @mesh4 split_axis = 1 concat_axis = 0
|
||||
// CHECK-SAME: : tensor<3x6xi8> -> tensor<3x?xi8>
|
||||
%0 = mesh.all_to_all %arg0 on @mesh4
|
||||
split_axis = 1 concat_axis = 0
|
||||
: tensor<3x6xi8> -> tensor<3x?xi8>
|
||||
return %0 : tensor<3x?xi8>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @all_to_all_same_split_concat_dim_with_dynamic_device_group_size
|
||||
func.func @all_to_all_same_split_concat_dim_with_dynamic_device_group_size(
|
||||
// CHECK-SAME: %[[ARG:.*]]: tensor<3xi8>
|
||||
%arg0 : tensor<3xi8>) -> tensor<3xi8> {
|
||||
// CHECK-NEXT: mesh.all_to_all %[[ARG]]
|
||||
// CHECK-SAME: @mesh4 split_axis = 0 concat_axis = 0
|
||||
// CHECK-SAME: : tensor<3xi8> -> tensor<3xi8>
|
||||
%0 = mesh.all_to_all %arg0 on @mesh4
|
||||
split_axis = 0 concat_axis = 0
|
||||
: tensor<3xi8> -> tensor<3xi8>
|
||||
return %0 : tensor<3xi8>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @all_to_all_non_divisible_split_axis_size
|
||||
func.func @all_to_all_non_divisible_split_axis_size(
|
||||
// CHECK-SAME: %[[ARG:.*]]: tensor<2x3xi8>
|
||||
%arg0 : tensor<2x3xi8>) -> tensor<?x12xi8> {
|
||||
// CHECK-NEXT: mesh.all_to_all %[[ARG]]
|
||||
// CHECK-SAME: @mesh0 mesh_axes = [0, 1] split_axis = 0 concat_axis = 1
|
||||
// CHECK-SAME: : tensor<2x3xi8> -> tensor<?x12xi8>
|
||||
%0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0, 1]
|
||||
split_axis = 0 concat_axis = 1
|
||||
: tensor<2x3xi8> -> tensor<?x12xi8>
|
||||
return %0 : tensor<?x12xi8>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @reduce_scatter_static_dimensions
|
||||
func.func @reduce_scatter_static_dimensions(
|
||||
// CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
|
||||
%arg0 : tensor<3x4xf32>) -> tensor<3x1xf64> {
|
||||
// CHECK-NEXT: mesh.reduce_scatter %[[ARG]]
|
||||
// CHECK-SAME: on @mesh0 mesh_axes = [2] reduction = <max> scatter_axis = 1
|
||||
// CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf64>
|
||||
%0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [2]
|
||||
reduction = <max> scatter_axis = 1
|
||||
: tensor<3x4xf32> -> tensor<3x1xf64>
|
||||
return %0 : tensor<3x1xf64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @reduce_scatter_dynamic_dimensions
|
||||
func.func @reduce_scatter_dynamic_dimensions(
|
||||
// CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
|
||||
%arg0 : tensor<?xf32>) -> tensor<?xf64> {
|
||||
// CHECK-NEXT: mesh.reduce_scatter %[[ARG]]
|
||||
// CHECK-SAME: on @mesh3 mesh_axes = [0, 1] scatter_axis = 0
|
||||
// CHECK-SAME: : tensor<?xf32> -> tensor<?xf64>
|
||||
%0 = mesh.reduce_scatter %arg0 on @mesh3 mesh_axes = [0, 1] scatter_axis = 0
|
||||
: tensor<?xf32> -> tensor<?xf64>
|
||||
return %0 : tensor<?xf64>
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user