8 Commits

Author SHA1 Message Date
Boian Petkantchin
ff2720d190
[mlir][mesh] Dedublicate iterator type and partial type information (#81920)
The two types duplicated mostly the same values.
Here they are decomposed to carry orthogonal and complimentary
information.

Use `utils::IteratorType` instead of `mesh::IteratorType`. It now has
only parallel and reduction values.

Rename `Partial` to `ReductionKind`.

Add `getReductionLoopIteratorKinds` method to `ShardingInterface`.
2024-02-16 07:10:46 -08:00
Boian Petkantchin
dc3258c617
[mlir][mesh] Add all-slice operation (#81218)
This op is the inverse of all-gather. It is useful to have an explicit
concise representation instead of having a blob of slicing logic.

Add lowering for the op that slices from the tensor based on the
in-group process index.

Make resharding generate an all-slice instead of inserting the slicing
logic directly.
2024-02-15 13:03:58 -08:00
Boian Petkantchin
9a8437f504
[mlir][mesh] Rename cluster to mesh (#79484)
Rename
* Op mesh.cluster -> mesh.mesh
* Op mesh.cluster_shape -> mesh.mesh_shape
* variables and attributes.

The name `mesh` is more specific to what it really represents. It is a
mesh of devices.
The name `cluster` implies a broader posibility of device
configurations. When just the word `mesh` is used the meaning can often
be inferred from the context whether it refers to the mesh dialect or a
device mesh. The full name can be used when needed.
2024-01-26 07:03:29 -08:00
Boian Petkantchin
5df2c00af3
[mlir][mesh] Remove rank attribute and rename dim_sizes to shape in ClusterOp (#77838)
Remove the somewhat redundant rank attribute.
Before this change
```
mesh.cluster @mesh(rank = 3, dim_sizes = 2x3)
```
After
```
mesh.cluster @mesh(shape = 2x3x?)
```

The rank is instead determined by the provided shape. With this change
no longer `getDimSizes()` can be wrongly assumed to have size equal to
the cluster rank.
Now `getShape().size()` will always equal `getRank()`.
2024-01-15 07:39:09 -08:00
Matthias Springer
0cb024b357
[mlir][Mesh] Fix invalid IR in rewrite pattern (#78094)
This commit fixes `test/Dialect/Mesh/folding.mlir` when running with
`MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS`.

```
/usr/local/google/home/springerm/mlir_public/llvm-project/mlir/test/Dialect/Mesh/folding.mlir:19:10: error: Unexpected number of results 0. Expected 2.
  %0:2 = mesh.cluster_shape @mesh1 : index, index
         ^
/usr/local/google/home/springerm/mlir_public/llvm-project/mlir/test/Dialect/Mesh/folding.mlir:19:10: note: see current operation: "mesh.cluster_shape"() <{axes = array<i16>, mesh = @mesh1}> : () -> ()
mlir-asm-printer: Verifying operation: builtin.module
Unexpected number of results 0. Expected 2.
mlir-asm-printer: 'builtin.module' failed to verify and will be printed in generic form
"builtin.module"() ({
  "mesh.cluster"() <{dim_sizes = array<i64: 2, 3>, rank = 2 : i64, sym_name = "mesh1"}> : () -> ()
  "func.func"() <{function_type = () -> (index, index), sym_name = "cluster_shape_op_folding_all_axes_static_mesh"}> ({
    %0 = "arith.constant"() <{value = 2 : index}> : () -> index
    %1 = "arith.constant"() <{value = 3 : index}> : () -> index
    "mesh.cluster_shape"() <{axes = array<i16>, mesh = @mesh1}> : () -> ()
    %2:2 = "mesh.cluster_shape"() <{axes = array<i16>, mesh = @mesh1}> : () -> (index, index)
    "func.return"(%0, %1) : (index, index) -> ()
  }) : () -> ()
}) : () -> ()
LLVM ERROR: IR failed to verify after pattern application
```

If `axes` is empty, the op verifier assumes that all dimensions are
queried. (Expected 2 results.)
2024-01-15 09:00:43 +01:00
Boian Petkantchin
79aa776267
[mlir][mesh] Add lowering of process multi-index op (#77490)
* Rename mesh.process_index -> mesh.process_multi_index.
* Add mesh.process_linear_index op.
* Add lowering of mesh.process_multi_index into an expression using
mesh.process_linear_index, mesh.cluster_shape and
affine.delinearize_index.

This is useful to lower mesh ops and prepare them for further lowering
where the runtime may have only the linear index of a device/process.
For example in MPI we have a rank (linear index) in a communicator.
2024-01-10 07:01:16 -08:00
Boian Petkantchin
ab590377a3
[mlir][mesh] Add folding of ClusterShapeOp (#77033)
If the mesh has static size on some of the requested axes, the result is
substituted with a constant.
2024-01-09 13:42:56 -08:00
Boian Petkantchin
4b3446771f
[mlir][mesh] Add endomorphism simplification for all-reduce (#73150)
Does transformations like
all_reduce(x) + all_reduce(y) -> all_reduce(x + y)

max(all_reduce(x), all_reduce(y)) -> all_reduce(max(x, y))
when the all_reduce element-wise op is max.

Added general rewrite pattern HomomorphismSimplification and
EndomorphismSimplification that encapsulate the general algorithm.
Made specialization for all-reduce with respect to
addf, addi, minsi, maxsi, minimumf and maximumf
in the Arithmetic dialect.
2023-12-12 10:21:52 -08:00