24 Commits

Author SHA1 Message Date
Boian Petkantchin
fb582b6ace
[mlir] Implement Mesh's ShardingInterface for Linalg ops (#82284)
Allows linalg structured operations to be handled during spmdization and
sharding propagation.

There is only support for projected permutation indexing maps.
2024-03-07 17:05:44 -08:00
Boian Petkantchin
f78027dfec
[mlir][mesh] Better op result names (#82408)
Implement OpAsmOpInterface for most ops to increase IR readability. For
example `mesh.process_linear_index` would produce a value with name
`proc_linear_idx`.
2024-02-20 16:53:26 -08:00
Boian Petkantchin
ff2720d190
[mlir][mesh] Dedublicate iterator type and partial type information (#81920)
The two types duplicated mostly the same values.
Here they are decomposed to carry orthogonal and complimentary
information.

Use `utils::IteratorType` instead of `mesh::IteratorType`. It now has
only parallel and reduction values.

Rename `Partial` to `ReductionKind`.

Add `getReductionLoopIteratorKinds` method to `ShardingInterface`.
2024-02-16 07:10:46 -08:00
Boian Petkantchin
dc3258c617
[mlir][mesh] Add all-slice operation (#81218)
This op is the inverse of all-gather. It is useful to have an explicit
concise representation instead of having a blob of slicing logic.

Add lowering for the op that slices from the tensor based on the
in-group process index.

Make resharding generate an all-slice instead of inserting the slicing
logic directly.
2024-02-15 13:03:58 -08:00
Boian Petkantchin
adbf21f12b
[mlir][mesh] Add spmdization pass (#80518)
Add a pass that converts a function that has sharding annotations into
SPMD form.
2024-02-06 20:55:14 -08:00
Boian Petkantchin
31fc0a12e1
[mlir][mesh] Refactoring code organization, tests and docs (#79606)
* Split out `MeshDialect.h` form `MeshOps.h` that defines the dialect
class. Reduces include clutter if you care only about the dialect and
not the ops.

* Expose functions `getMesh` and `collectiveProcessGroupSize`. There
functions are useful for outside users of the dialect.

* Remove unused code.

* Remove examples and tests of mesh.shard attribute in tensor encoding.
Per the decision that Spmdization would be performed on sharding
annotations and there will be no tensors with sharding specified in the
type.
For more info see this RFC comment:
https://discourse.llvm.org/t/rfc-sharding-framework-design-for-device-mesh/73533/81
2024-01-31 07:20:14 -08:00
Boian Petkantchin
9a8437f504
[mlir][mesh] Rename cluster to mesh (#79484)
Rename
* Op mesh.cluster -> mesh.mesh
* Op mesh.cluster_shape -> mesh.mesh_shape
* variables and attributes.

The name `mesh` is more specific to what it really represents. It is a
mesh of devices.
The name `cluster` implies a broader posibility of device
configurations. When just the word `mesh` is used the meaning can often
be inferred from the context whether it refers to the mesh dialect or a
device mesh. The full name can be used when needed.
2024-01-26 07:03:29 -08:00
Jie Fu
071207ea41 [mlir] Fix -Wsign-compare in MeshOps.cpp (NFC)
llvm-project/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp:204:25:
 error: comparison of integers of different signs: 'size_t' (aka 'unsigned long') and 'int64_t' (aka 'long') [-Werror,-Wsign-compare]
  if (getShape().size() > rank)
      ~~~~~~~~~~~~~~~~~ ^ ~~~~
1 error generated.
2024-01-16 07:08:24 +08:00
Boian Petkantchin
5df2c00af3
[mlir][mesh] Remove rank attribute and rename dim_sizes to shape in ClusterOp (#77838)
Remove the somewhat redundant rank attribute.
Before this change
```
mesh.cluster @mesh(rank = 3, dim_sizes = 2x3)
```
After
```
mesh.cluster @mesh(shape = 2x3x?)
```

The rank is instead determined by the provided shape. With this change
no longer `getDimSizes()` can be wrongly assumed to have size equal to
the cluster rank.
Now `getShape().size()` will always equal `getRank()`.
2024-01-15 07:39:09 -08:00
Boian Petkantchin
31fd6d116d
[mlir][mesh] fix ProcessMultiIndexOp building (#77676)
Insert default empty mesh axes array instead of null attribute without MLIR context, since the attribute is default-valued not just optional.
2024-01-10 17:28:17 -08:00
Boian Petkantchin
79aa776267
[mlir][mesh] Add lowering of process multi-index op (#77490)
* Rename mesh.process_index -> mesh.process_multi_index.
* Add mesh.process_linear_index op.
* Add lowering of mesh.process_multi_index into an expression using
mesh.process_linear_index, mesh.cluster_shape and
affine.delinearize_index.

This is useful to lower mesh ops and prepare them for further lowering
where the runtime may have only the linear index of a device/process.
For example in MPI we have a rank (linear index) in a communicator.
2024-01-10 07:01:16 -08:00
Boian Petkantchin
fc18b13492
[mlir][mesh] In sharding attr use FlatSymbolRefAttr instead of SymbolRefAttr (#76886)
Analogous to func.call use FlatSymbolRefAttr to reference the
corresponding mesh.
2024-01-05 07:14:07 -08:00
Boian Petkantchin
7a4c49756d
[mlir][mesh] Use one type for mesh axis (#76830)
Make all ops and attributes use the types MeshAxis and MeshAxesAttr
instead of int16_t, int32_t, DenseI16ArrayAttr and DenseI32ArrayAttr.
2024-01-03 15:47:11 -08:00
Boian Petkantchin
1a8fb88719
[mlir][mesh] Add resharding spmdization on a 1D device mesh (#76179)
The current implementation supports only sharding of tensor axes that
have size divisible by the mesh axis size.
2024-01-02 15:50:07 -08:00
Boian Petkantchin
5e29112719
[mlir][mesh] Add verification and canonicalization for some collectives (#74905)
Add verification and canonicalization for
broadcast, gather, recv, reduce, scatter, send and shift.

The canonicalizations only remove trivial collectives with empty
mesh_axes attrubutes.
2023-12-15 06:41:10 -08:00
Boian Petkantchin
944e031e36
[mlir][mesh] Use tensor shape notation for the shape of a cluster (#73826)
Examle:

substitute
mesh.cluster @mesh0(rank = 2, dim_sizes = [0, 4])

with
mesh.cluster @mesh0(rank = 2, dim_sizes = ?x4)

Same as tensor/memref shapes. The only difference is for 0-rank shapes.
With tensors you would have something like `tensor<f32>`. Here to avoid
matching an empty string a 0-rank shape is denoted by `[]`.
2023-12-08 11:34:44 -08:00
Boian Petkantchin
dff2f59be3
[mlir][mesh] Add TableGen deffinitions of more collective ops (#73842)
Add definitions for
broadcast, gather, receive, reduce, scatter, send and shift.
2023-12-04 09:11:47 -08:00
Boian Petkantchin
5f7c8c1068
[mlir][mesh] Add collective communication operations (#71960)
Add all-gather, all-reduce, all-to-all and reduce-scatter. These
operations have device mesh semantics.
2023-11-21 06:50:24 -08:00
Chengji Yao
b0d5b4d252
[MLIR][Mesh] Add sharding propagation pass (#71261)
Add a pass that propagates sharding information throughout the graph.
After this pass, each of the operations' operands and results is
annotated with a mesh.shard operation.

The pass is driven by a newly added ShardingInterface, and an implementation
for element-wise and matmul ops in the TOSA dialect is provided.
2023-11-03 21:07:31 -07:00
Mehdi Amini
466abaf152 Revert "[MLIR][Mesh] Add sharding propagation pass (#69665)"
This reverts commit 9d9400d7de9b928e3018af97e8b381a4a6ba5162.
This reverts commit bda763aea0b854178c01eac9f309042d9aaa823b.

The buildbot is broken and tests are failing.
2023-11-03 17:52:41 -07:00
Jie Fu
bda763aea0 [mlir] Fix -Wreturn-type in MeshOps.cpp (NFC)
/llvm-project/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp:73:1: error: non-void function does not return a value in all control paths [-Werror,-Wreturn-type]
}
^
1 error generated.
2023-11-04 08:47:30 +08:00
Chengji Yao
9d9400d7de
[MLIR][Mesh] Add sharding propagation pass (#69665)
Add a pass that propagates sharding information throughout the graph.
After this pass, each of the operations' operands and results is
annotated with a `mesh.shard` operation, and the operations themselves
are added with sharding option attributes.

The pass is driven by  a newly added `ShardingInterface`, and an implementation
for element-wise and matmul ops in the TOSA dialect is provided.
2023-11-03 17:12:42 -07:00
Chengji Yao
1a21196b9a
[MLIR] reverse int8 type's printing logic (#69361)
Specializing for 8-bit integers to ensure values are printed as integers

Fixes #69310
2023-10-18 10:30:13 -07:00
Chengji Yao
08545e8516
[MLIR] Add a new Mesh dialect (#68007)
This is the 1st PR of [Mesh sharding
RFC](https://discourse.llvm.org/t/open-mlir-meeting-9-28-2023-rfc-sharding-framework-design-for-device-mesh/73695),
includes

Includes:

- mesh.cluster op
- mesh.shard op (the mesh.annotate op in the RFC slides, the name is
modified a bit from @stellaraccident 's advice, which I think might be a
bit more concise)
- MeshSharding attribute
2023-10-10 11:35:40 -07:00