21 Commits

Author SHA1 Message Date
Boian Petkantchin
abfac563f5
[mlir][mesh] Make sharding propagation and spmdization work on FuncOpInterface (#84415)
Make them more general instead of only supporting `func::FuncOp`.
2024-03-08 08:14:36 -08:00
Boian Petkantchin
4f7ab789bf
[mlir][mesh] add support in spmdization for incomplete sharding annotations (#82442)
Don't require that `mesh.shard` operations come in pairs. If there is
only a single `mesh.shard` operation we assume that the producer result
and consumer operand have the same sharding.
2024-02-22 11:06:14 -08:00
Boian Petkantchin
f78027dfec
[mlir][mesh] Better op result names (#82408)
Implement OpAsmOpInterface for most ops to increase IR readability. For
example `mesh.process_linear_index` would produce a value with name
`proc_linear_idx`.
2024-02-20 16:53:26 -08:00
Boian Petkantchin
dc3258c617
[mlir][mesh] Add all-slice operation (#81218)
This op is the inverse of all-gather. It is useful to have an explicit
concise representation instead of having a blob of slicing logic.

Add lowering for the op that slices from the tensor based on the
in-group process index.

Make resharding generate an all-slice instead of inserting the slicing
logic directly.
2024-02-15 13:03:58 -08:00
Boian Petkantchin
adbf21f12b
[mlir][mesh] Add spmdization pass (#80518)
Add a pass that converts a function that has sharding annotations into
SPMD form.
2024-02-06 20:55:14 -08:00
Boian Petkantchin
31fc0a12e1
[mlir][mesh] Refactoring code organization, tests and docs (#79606)
* Split out `MeshDialect.h` form `MeshOps.h` that defines the dialect
class. Reduces include clutter if you care only about the dialect and
not the ops.

* Expose functions `getMesh` and `collectiveProcessGroupSize`. There
functions are useful for outside users of the dialect.

* Remove unused code.

* Remove examples and tests of mesh.shard attribute in tensor encoding.
Per the decision that Spmdization would be performed on sharding
annotations and there will be no tensors with sharding specified in the
type.
For more info see this RFC comment:
https://discourse.llvm.org/t/rfc-sharding-framework-design-for-device-mesh/73533/81
2024-01-31 07:20:14 -08:00
Boian Petkantchin
9a8437f504
[mlir][mesh] Rename cluster to mesh (#79484)
Rename
* Op mesh.cluster -> mesh.mesh
* Op mesh.cluster_shape -> mesh.mesh_shape
* variables and attributes.

The name `mesh` is more specific to what it really represents. It is a
mesh of devices.
The name `cluster` implies a broader posibility of device
configurations. When just the word `mesh` is used the meaning can often
be inferred from the context whether it refers to the mesh dialect or a
device mesh. The full name can be used when needed.
2024-01-26 07:03:29 -08:00
Boian Petkantchin
5df2c00af3
[mlir][mesh] Remove rank attribute and rename dim_sizes to shape in ClusterOp (#77838)
Remove the somewhat redundant rank attribute.
Before this change
```
mesh.cluster @mesh(rank = 3, dim_sizes = 2x3)
```
After
```
mesh.cluster @mesh(shape = 2x3x?)
```

The rank is instead determined by the provided shape. With this change
no longer `getDimSizes()` can be wrongly assumed to have size equal to
the cluster rank.
Now `getShape().size()` will always equal `getRank()`.
2024-01-15 07:39:09 -08:00
Boian Petkantchin
79aa776267
[mlir][mesh] Add lowering of process multi-index op (#77490)
* Rename mesh.process_index -> mesh.process_multi_index.
* Add mesh.process_linear_index op.
* Add lowering of mesh.process_multi_index into an expression using
mesh.process_linear_index, mesh.cluster_shape and
affine.delinearize_index.

This is useful to lower mesh ops and prepare them for further lowering
where the runtime may have only the linear index of a device/process.
For example in MPI we have a rank (linear index) in a communicator.
2024-01-10 07:01:16 -08:00
Boian Petkantchin
ab590377a3
[mlir][mesh] Add folding of ClusterShapeOp (#77033)
If the mesh has static size on some of the requested axes, the result is
substituted with a constant.
2024-01-09 13:42:56 -08:00
Boian Petkantchin
fc18b13492
[mlir][mesh] In sharding attr use FlatSymbolRefAttr instead of SymbolRefAttr (#76886)
Analogous to func.call use FlatSymbolRefAttr to reference the
corresponding mesh.
2024-01-05 07:14:07 -08:00
Matthias Springer
bb6d5c2200
[mlir][Transforms] GreedyPatternRewriteDriver: Do not CSE constants during iterations (#75897)
The `GreedyPatternRewriteDriver` tries to iteratively fold ops and apply
rewrite patterns to ops. It has special handling for constants: they are
CSE'd and sometimes moved to parent regions to allow for additional
CSE'ing. This happens in `OperationFolder`.

To allow for efficient CSE'ing, `OperationFolder` maintains an internal
lookup data structure to find the existing constant ops with the same
value for each `IsolatedFromAbove` region:
```c++
/// A mapping between an insertion region and the constants that have been
/// created within it.
DenseMap<Region *, ConstantMap> foldScopes;
```

Rewrite patterns are allowed to modify operations. In particular, they
may move operations (including constants) from one region to another
one. Such an IR rewrite can make the above lookup data structure
inconsistent.

We encountered such a bug in a downstream project. This bug materialized
in the form of an op that uses the result of a constant op from a
different `IsolatedFromAbove` region (that is not accessible).

This commit changes the behavior of the `GreedyPatternRewriteDriver`
such that `OperationFolder` is used to CSE constants at the beginning of
each iteration (as the worklist is populated), but no longer during an
iteration. `OperationFolder` is no longer used after populating the
worklist, so we do not have to care about inconsistent state in the
`OperationFolder` due to IR rewrites. The `GreedyPatternRewriteDriver`
now performs the op folding by itself instead of calling
`OperationFolder::tryToFold`.

This change changes the order of constant ops in test cases, but not the
region in which they appear. All broken test cases were fixed by turning
`CHECK` into `CHECK-DAG`.

Alternatives considered: The state of `OperationFolder` could be
partially invalidated with every `notifyOperationModified` notification.
That is more fragile than the solution in this commit because incorrect
rewriter API usage can lead to missing notifications and hard-to-debug
`IsolatedFromAbove` violations. (It did not fix the above mention bug in
a downstream project, which could be due to incorrect rewriter API usage
or due to another conceptual problem that I missed.) Moreover, ops are
frequently getting modified during a greedy pattern rewrite, so we would
likely keep invalidating large parts of the state of `OperationFolder`
over and over.

Migration guide: Turn `CHECK` into `CHECK-DAG` in test cases. Constant
ops are no longer folded during a greedy pattern rewrite. If you rely on
folding (and rematerialization) of constant ops during a greedy pattern
rewrite, turn the folder into a pattern.
2024-01-05 09:22:18 +01:00
Boian Petkantchin
1a8fb88719
[mlir][mesh] Add resharding spmdization on a 1D device mesh (#76179)
The current implementation supports only sharding of tensor axes that
have size divisible by the mesh axis size.
2024-01-02 15:50:07 -08:00
Boian Petkantchin
5e29112719
[mlir][mesh] Add verification and canonicalization for some collectives (#74905)
Add verification and canonicalization for
broadcast, gather, recv, reduce, scatter, send and shift.

The canonicalizations only remove trivial collectives with empty
mesh_axes attrubutes.
2023-12-15 06:41:10 -08:00
Boian Petkantchin
4b3446771f
[mlir][mesh] Add endomorphism simplification for all-reduce (#73150)
Does transformations like
all_reduce(x) + all_reduce(y) -> all_reduce(x + y)

max(all_reduce(x), all_reduce(y)) -> all_reduce(max(x, y))
when the all_reduce element-wise op is max.

Added general rewrite pattern HomomorphismSimplification and
EndomorphismSimplification that encapsulate the general algorithm.
Made specialization for all-reduce with respect to
addf, addi, minsi, maxsi, minimumf and maximumf
in the Arithmetic dialect.
2023-12-12 10:21:52 -08:00
Boian Petkantchin
944e031e36
[mlir][mesh] Use tensor shape notation for the shape of a cluster (#73826)
Examle:

substitute
mesh.cluster @mesh0(rank = 2, dim_sizes = [0, 4])

with
mesh.cluster @mesh0(rank = 2, dim_sizes = ?x4)

Same as tensor/memref shapes. The only difference is for 0-rank shapes.
With tensors you would have something like `tensor<f32>`. Here to avoid
matching an empty string a 0-rank shape is denoted by `[]`.
2023-12-08 11:34:44 -08:00
Boian Petkantchin
5f7c8c1068
[mlir][mesh] Add collective communication operations (#71960)
Add all-gather, all-reduce, all-to-all and reduce-scatter. These
operations have device mesh semantics.
2023-11-21 06:50:24 -08:00
Chengji Yao
b0d5b4d252
[MLIR][Mesh] Add sharding propagation pass (#71261)
Add a pass that propagates sharding information throughout the graph.
After this pass, each of the operations' operands and results is
annotated with a mesh.shard operation.

The pass is driven by a newly added ShardingInterface, and an implementation
for element-wise and matmul ops in the TOSA dialect is provided.
2023-11-03 21:07:31 -07:00
Mehdi Amini
466abaf152 Revert "[MLIR][Mesh] Add sharding propagation pass (#69665)"
This reverts commit 9d9400d7de9b928e3018af97e8b381a4a6ba5162.
This reverts commit bda763aea0b854178c01eac9f309042d9aaa823b.

The buildbot is broken and tests are failing.
2023-11-03 17:52:41 -07:00
Chengji Yao
9d9400d7de
[MLIR][Mesh] Add sharding propagation pass (#69665)
Add a pass that propagates sharding information throughout the graph.
After this pass, each of the operations' operands and results is
annotated with a `mesh.shard` operation, and the operations themselves
are added with sharding option attributes.

The pass is driven by  a newly added `ShardingInterface`, and an implementation
for element-wise and matmul ops in the TOSA dialect is provided.
2023-11-03 17:12:42 -07:00
Chengji Yao
08545e8516
[MLIR] Add a new Mesh dialect (#68007)
This is the 1st PR of [Mesh sharding
RFC](https://discourse.llvm.org/t/open-mlir-meeting-9-28-2023-rfc-sharding-framework-design-for-device-mesh/73695),
includes

Includes:

- mesh.cluster op
- mesh.shard op (the mesh.annotate op in the RFC slides, the name is
modified a bit from @stellaraccident 's advice, which I think might be a
bit more concise)
- MeshSharding attribute
2023-10-10 11:35:40 -07:00