- introduces a simplify pass, which finds such patterns and replaces it with the equivalent `reduce-scatter` - promotes the test-pass `test-shard-optimizations` to a proper pass and adds - folding allgather+allslice into reduce_scatter - sanitizes the `shard.reduce_scatter` op - adds a new `mpi.reduce_scatter_block` op - lowers `shard.reduce_scatter` to MPI - lowers `mpi-reduce_scatter_block` to llvm --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
114 lines
4.0 KiB
C++
114 lines
4.0 KiB
C++
//===- MPIOps.cpp - MPI dialect ops implementation ------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/MPI/IR/MPI.h"
|
|
#include "mlir/Dialect/MPI/IR/Utils.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::mpi;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Verifiers
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult mlir::mpi::ReduceScatterBlockOp::verify() {
|
|
if (getSendbuf().getType().getElementType() !=
|
|
getRecvbuf().getType().getElementType())
|
|
return emitOpError("sendbuf and recvbuf must have the same element type");
|
|
return success();
|
|
}
|
|
|
|
namespace {
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Canonicalization patterns
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// If input memref has dynamic shape and is a cast and if the cast's input has
|
|
// static shape, fold the cast's static input into the given operation.
|
|
template <typename OpT>
|
|
struct FoldCast final : public mlir::OpRewritePattern<OpT> {
|
|
using mlir::OpRewritePattern<OpT>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(OpT op,
|
|
mlir::PatternRewriter &b) const override {
|
|
auto mRef = op.getRef();
|
|
if (mRef.getType().hasStaticShape()) {
|
|
return mlir::failure();
|
|
}
|
|
auto defOp = mRef.getDefiningOp();
|
|
if (!defOp || !mlir::isa<mlir::memref::CastOp>(defOp)) {
|
|
return mlir::failure();
|
|
}
|
|
auto src = mlir::cast<mlir::memref::CastOp>(defOp).getSource();
|
|
if (!src.getType().hasStaticShape()) {
|
|
return mlir::failure();
|
|
}
|
|
op.getRefMutable().assign(src);
|
|
return mlir::success();
|
|
}
|
|
};
|
|
|
|
struct FoldRank final : public mlir::OpRewritePattern<mlir::mpi::CommRankOp> {
|
|
using mlir::OpRewritePattern<mlir::mpi::CommRankOp>::OpRewritePattern;
|
|
LogicalResult matchAndRewrite(mlir::mpi::CommRankOp op,
|
|
mlir::PatternRewriter &b) const override {
|
|
return FoldToDLTIConst(op, "MPI:comm_world_rank", b);
|
|
}
|
|
};
|
|
|
|
struct FoldSize final : public mlir::OpRewritePattern<mlir::mpi::CommSizeOp> {
|
|
using mlir::OpRewritePattern<mlir::mpi::CommSizeOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(mlir::mpi::CommSizeOp op,
|
|
mlir::PatternRewriter &b) const override {
|
|
return FoldToDLTIConst(op, "MPI:comm_world_size", b);
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
void mlir::mpi::SendOp::getCanonicalizationPatterns(
|
|
mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
|
|
results.add<FoldCast<mlir::mpi::SendOp>>(context);
|
|
}
|
|
|
|
void mlir::mpi::RecvOp::getCanonicalizationPatterns(
|
|
mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
|
|
results.add<FoldCast<mlir::mpi::RecvOp>>(context);
|
|
}
|
|
|
|
void mlir::mpi::ISendOp::getCanonicalizationPatterns(
|
|
mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
|
|
results.add<FoldCast<mlir::mpi::ISendOp>>(context);
|
|
}
|
|
|
|
void mlir::mpi::IRecvOp::getCanonicalizationPatterns(
|
|
mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
|
|
results.add<FoldCast<mlir::mpi::IRecvOp>>(context);
|
|
}
|
|
|
|
void mlir::mpi::CommRankOp::getCanonicalizationPatterns(
|
|
mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
|
|
results.add<FoldRank>(context);
|
|
}
|
|
|
|
void mlir::mpi::CommSizeOp::getCanonicalizationPatterns(
|
|
mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
|
|
results.add<FoldSize>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TableGen'd op method definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "mlir/Dialect/MPI/IR/MPIOps.cpp.inc"
|