llvm-project/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
2026-01-29 10:47:26 -08:00

559 lines
23 KiB
C++

//===- XeGPUSgToWiDistributeExperimental.cpp - XeGPU SG to WI Pass --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>
namespace mlir {
namespace xegpu {
#define GEN_PASS_DEF_XEGPUSGTOWIDISTRIBUTEEXPERIMENTAL
#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
} // namespace xegpu
} // namespace mlir
using namespace mlir;
#define DEBUG_TYPE "xegpu-sg-to-wi-distribute-experimental"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
namespace {
/// Casts the given vector value `v` to the expected vector type `expectedTy`.
static Value castValueTo(ConversionPatternRewriter &rewriter,
TypedValue<VectorType> v, VectorType expectedTy) {
// If the type matches, simply return the value itself.
if (v.getType() == expectedTy)
return v;
// If only shape differs, use shape cast.
if (isa<VectorType>(v.getType()) &&
v.getType().getNumElements() == expectedTy.getNumElements())
return vector::ShapeCastOp::create(rewriter, v.getLoc(), expectedTy, v);
// Else create an unrealized cast.
auto newOp = UnrealizedConversionCastOp::create(rewriter, v.getLoc(),
expectedTy, ValueRange{v});
return newOp.getResult(0);
}
/// Checks if all XeGPU anchor ops and vector results have valid layouts.
static LogicalResult verifyLayouts(Operation *root) {
auto walkResult = root->walk([&](Operation *nestedOp) -> WalkResult {
if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(nestedOp)) {
auto layout = anchorOp.getAnchorLayout();
if (!layout) {
nestedOp->emitError("expected anchor layout attribute on operation");
return WalkResult::interrupt();
}
return WalkResult::advance();
}
// For each vector result, check if the op contains a result layout
// attribute.
for (OpResult result : nestedOp->getResults()) {
if (isa<VectorType>(result.getType())) {
auto layout = xegpu::getDistributeLayoutAttr(result);
if (!layout) {
nestedOp->emitError(
"expected result layout attribute on vector result");
return WalkResult::interrupt();
}
}
}
return WalkResult::advance();
});
return walkResult.wasInterrupted() ? failure() : success();
}
/// Distributes a subgroup-level CreateNdDesc op to workitem-level CreateNdDesc
/// op. This simply drops the layout attribute from the tensor descriptor type.
struct SgToWiCreateNdDesc : public OpConversionPattern<xegpu::CreateNdDescOp> {
using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(xegpu::CreateNdDescOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
xegpu::TensorDescType resultType = op.getType();
// If no layout, nothing to do.
if (!resultType.getLayout())
return failure();
auto newOp = xegpu::CreateNdDescOp::create(
rewriter, op.getLoc(), resultType.dropLayouts(), op.getOperands(),
op->getAttrs());
rewriter.replaceOp(op, newOp.getResult());
return success();
}
};
/// Distributes a subgroup-level LoadNd op to workitem-level LoadNd op. Output
/// of workitem-level LoadNd op is 1D. ShapeCast is added to restore the
/// original rank.
struct SgToWiLoadNd : public OpConversionPattern<xegpu::LoadNdOp> {
using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(xegpu::LoadNdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
// If no layout, nothing to do.
if (!layout)
return failure();
// Check if the layout attached to the tensor descriptor is same as the
// anchor layout. Otherwise, this is a conflict.
if (op.getTensorDescType().getLayout() != layout)
return rewriter.notifyMatchFailure(
op, "conflicting layout attributes on tensor descriptor and anchor");
auto uArch = getUArch(xegpu::getChipStr(op).value_or(""));
if (!uArch)
return rewriter.notifyMatchFailure(
op, "xegpu::LoadNdOp require target attribute attached to "
"determine transpose "
"requirement");
auto supportedWiResultTyOrFailure =
xegpu::getDistributedVectorType(op.getTensorDescType());
auto expectedWiResultTyOrFailure =
xegpu::getDistVecTypeBasedOnLaneLayout(layout, op.getType());
if (failed(supportedWiResultTyOrFailure))
return rewriter.notifyMatchFailure(
op, "unable to compute the workitem vector type for LoadNdOp");
if (failed(expectedWiResultTyOrFailure))
return rewriter.notifyMatchFailure(
op,
"unable to compute expected workitem vector type from lane layout");
auto newOp = xegpu::LoadNdOp::create(
rewriter, op.getLoc(), supportedWiResultTyOrFailure.value(),
adaptor.getTensorDesc(), op.getMixedOffsets(), op.getPackedAttr(),
op.getTransposeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
op.getL3HintAttr(), /**layout**/ nullptr);
// Set the packed attribute if the layout requires it.
newOp.setPacked(xegpu::requirePacked(cast<xegpu::LayoutAttr>(layout)));
// Set the transpose attribute if the layout requires it.
if (xegpu::requireTranspose(cast<xegpu::LayoutAttr>(layout), uArch))
newOp.setTranspose(DenseI64ArrayAttr::get(rewriter.getContext(), {1, 0}));
rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
expectedWiResultTyOrFailure.value()));
return success();
}
};
/// Distributes a subgroup-level StoreNd op to workitem-level StoreNd op. Stored
/// value in workitem-level StoreNd op is 1D. ShapeCast is added to cast the
/// incoming value to 1D.
struct SgToWiStoreNd : public OpConversionPattern<xegpu::StoreNdOp> {
using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(xegpu::StoreNdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
// If no layout, nothing to do.
if (!layout)
return failure();
// Check if the layout attached to the tensor descriptor and value layout is
// same as the anchor layout. Otherwise, this is a conflict.
if (op.getTensorDescType().getLayout() != layout)
return rewriter.notifyMatchFailure(
op, "conflicting layout attributes on tensor descriptor and anchor");
auto valueLayout = xegpu::getDistributeLayoutAttr(op->getOpOperand(0));
if (valueLayout != layout)
return rewriter.notifyMatchFailure(
op, "conflicting layout attributes on value and anchor");
auto supportedWiValueTyOrFailure =
xegpu::getDistributedVectorType(op.getTensorDescType());
if (failed(supportedWiValueTyOrFailure))
return rewriter.notifyMatchFailure(
op,
"unable to compute wi vector type for StoreNdOp value from tensor "
"descriptor");
xegpu::StoreNdOp::create(
rewriter, op.getLoc(),
castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getValue()),
supportedWiValueTyOrFailure.value()),
adaptor.getTensorDesc(), op.getMixedOffsets(), op.getL1HintAttr(),
op.getL2HintAttr(), op.getL3HintAttr(), /**layout**/ nullptr);
rewriter.eraseOp(op);
return success();
}
};
/// Distributes a subgroup-level Dpas op to workitem-level Dpas op. All inpputs
/// and output of workitem-level Dpas op are 1D. Necessary casts are added to
/// convert the inputs and output to/from 1D.
struct SgToWiDpas : public OpConversionPattern<xegpu::DpasOp> {
using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(xegpu::DpasOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// llvm::errs() << "DpasOpPattern matchAndRewrite called\n";
// Check if the op has A, B and CD layouts attached.
auto layoutA = cast<xegpu::LayoutAttr>(op.getLayoutAAttr());
auto layoutB = cast<xegpu::LayoutAttr>(op.getLayoutBAttr());
auto layoutCd = cast<xegpu::LayoutAttr>(op.getLayoutCdAttr());
if (!layoutA || !layoutB || !layoutCd)
return failure();
// llvm::errs() << "tryning to calculate wi types for dpas op\n";
auto wiResultTyOrFailure =
xegpu::getDistributedVectorType(op.getType(), layoutCd);
auto wiATypeOrFailure =
xegpu::getDistributedVectorType(op.getLhs().getType(), layoutA);
auto wiBTypeOrFailure =
xegpu::getDistributedVectorType(op.getRhs().getType(), layoutB);
auto expectedWiResultTyOrFailure =
xegpu::getDistVecTypeBasedOnLaneLayout(layoutCd, op.getType());
if (failed(wiResultTyOrFailure) || failed(wiATypeOrFailure) ||
failed(wiBTypeOrFailure))
return rewriter.notifyMatchFailure(
op, "failed to calculate supported workitem vector types for DpasOp "
"from layouts");
if (failed(expectedWiResultTyOrFailure))
return rewriter.notifyMatchFailure(
op, "unable to compute expected workitem vector type for DpasOp from "
"lane layout");
auto newOp = xegpu::DpasOp::create(
rewriter, op->getLoc(), wiResultTyOrFailure.value(),
castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getLhs()),
wiATypeOrFailure.value()),
castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getRhs()),
wiBTypeOrFailure.value()),
castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getAcc()),
wiResultTyOrFailure.value()),
/** layoutA**/ nullptr,
/** layoutB**/ nullptr, /** layoutCd**/ nullptr);
// Explicitly set the new types to enable correct type materializations.
rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
expectedWiResultTyOrFailure.value()));
return success();
}
};
/// Distributes elementwise ops to workitem-level elementwise ops. This
/// currently handles elementwise ops with single result only.
struct SgToWiElementWise : public ConversionPattern {
SgToWiElementWise(TypeConverter &typeConverter, MLIRContext *ctx)
: ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
// Only match ops with elementwise trait and single result.
if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
return failure();
auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
if (!resultType)
return rewriter.notifyMatchFailure(
op, "operation result is not a vector type");
xegpu::DistributeLayoutAttr layout =
xegpu::getTemporaryLayout(llvm::cast<OpResult>(op->getResult(0)));
if (!layout || !layout.isForSubgroup())
return rewriter.notifyMatchFailure(
op, "operation result does not have subgroup distribute layout");
auto wiShapeOrFailure =
xegpu::getDistVecTypeBasedOnLaneLayout(layout, resultType);
if (failed(wiShapeOrFailure))
return rewriter.notifyMatchFailure(
op, "unable to compute workitem vector type from the layout");
VectorType newResultType = wiShapeOrFailure.value();
OperationState state(op->getLoc(), op->getName());
state.addOperands(operands);
state.addTypes(newResultType);
// Copy all attributes except for DistributeLayoutAttr.
for (auto attr : op->getAttrs()) {
if (!isa<xegpu::DistributeLayoutAttr>(attr.getValue()))
state.addAttribute(attr.getName(), attr.getValue());
}
Operation *newOp = rewriter.create(state);
rewriter.replaceOp(op, newOp->getResult(0));
return success();
}
};
/// Distributes a subgroup-level arith ConstantOp to workitem-level arith
/// ConstantOp.
struct SgToWiArithConstant : public OpConversionPattern<arith::ConstantOp> {
using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto resultType = dyn_cast<VectorType>(op.getType());
if (!resultType)
return failure();
// Only handle dense vector constants
auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
if (!dense)
return rewriter.notifyMatchFailure(
op, "only dense splat vector constants are supported");
xegpu::DistributeLayoutAttr layout =
xegpu::getTemporaryLayout(llvm::cast<OpResult>(op.getResult()));
if (!layout || !layout.isForSubgroup())
return rewriter.notifyMatchFailure(
op, "operation result does not have subgroup distribute layout");
auto wiShapeOrFailure =
xegpu::getDistVecTypeBasedOnLaneLayout(layout, resultType);
if (failed(wiShapeOrFailure))
return rewriter.notifyMatchFailure(
op, "unable to compute workitem vector type from the layout");
VectorType newResultType = wiShapeOrFailure.value();
auto sclarValue = dense.getSplatValue<Attribute>();
auto newDenseAttr = DenseElementsAttr::get(newResultType, sclarValue);
auto newOp = arith::ConstantOp::create(rewriter, op.getLoc(), newResultType,
newDenseAttr);
rewriter.replaceOp(op, newOp.getResult());
return success();
}
};
/// Distributes a subgroup-level PrefetchNd op to workitem-level PrefetchNd op.
struct SgToWiPrefetchNd : public OpConversionPattern<xegpu::PrefetchNdOp> {
using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(xegpu::PrefetchNdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
// If no layout, nothing to do.
if (!layout)
return failure();
xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), adaptor.getTensorDesc(),
op.getMixedOffsets(), op.getL1HintAttr(),
op.getL2HintAttr(), op.getL3HintAttr(),
/**layout**/ nullptr);
rewriter.eraseOp(op);
return success();
}
};
struct XeGPUSgToWiDistributeExperimentalPass
: public xegpu::impl::XeGPUSgToWiDistributeExperimentalBase<
XeGPUSgToWiDistributeExperimentalPass> {
void runOnOperation() override;
};
} // namespace
void XeGPUSgToWiDistributeExperimentalPass::runOnOperation() {
// Verify if all XeGPU anchor ops and vector ops have result layouts.
// TODO: This can be removed once the full layout refactoring is done.
Operation *root = getOperation();
if (failed(verifyLayouts(root))) {
LLVM_DEBUG(DBGS() << "XeGPUSgToWiDistributeExperimentalPass: layout "
"verification failed\n");
signalPassFailure();
return;
}
// Collect existing UnrealizedConversionCastOps. These must be preserved.
llvm::SmallSetVector<UnrealizedConversionCastOp, 8> existingCasts;
root->walk(
[&](UnrealizedConversionCastOp castOp) { existingCasts.insert(castOp); });
// Perform a structural type conversion to convert structural ops to have WI
// types. This will insert UnrealizedConversionCastOps to make the IR
// valid.
auto materializeCast = [&](mlir::OpBuilder &builder, mlir::Type type,
mlir::ValueRange inputs,
mlir::Location loc) -> mlir::Value {
UnrealizedConversionCastOp castOp =
UnrealizedConversionCastOp::create(builder, loc, type, inputs);
return castOp.getResult(0);
};
{
ConversionTarget target(getContext());
TypeConverter typeConverter;
RewritePatternSet patterns(&getContext());
typeConverter.addSourceMaterialization(materializeCast);
typeConverter.addTargetMaterialization(materializeCast);
xegpu::populateXeGPUSgToWiDistributeTypeConversions(typeConverter);
scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter,
patterns, target);
xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality(
typeConverter, patterns, target);
target.addLegalOp<UnrealizedConversionCastOp>();
(void)applyPartialConversion(root, target, std::move(patterns));
}
// Structural type conversion can generate some redundant
// UnrealizedConversionCastOps to materialize the SG type from type converted
// WI type. These are redundant at this point and can be eliminated by
// inserting shape casts instead.
// Example:
// %1 = UnrealizedConversionCastOp %0 : vector<16x1xf32> to vector<16x16xf32>
// %2 = UnrealizedConversionCastOp %1 : vector<16x16xf32> to vector<16xf32>
// This can be replaced with:
// %2 = vector.shape_cast %0 : vector<16x1xf32> to vector<16xf32>
OpBuilder builder(root);
root->walk([&](UnrealizedConversionCastOp op) {
// If this op existed before, nothing to do.
if (existingCasts.contains(op))
return;
// number of inputs and outputs must be 1.
if (op.getNumOperands() != 1 || op.getNumResults() != 1)
return;
// Both input and output types must be vector types.
auto singleInput = op.getInputs()[0];
auto inputTy = dyn_cast<VectorType>(singleInput.getType());
auto outputTy = dyn_cast<VectorType>(op.getResult(0).getType());
if (!inputTy || !outputTy)
return;
// Check if the defining op of the input is also an
// UnrealizedConversionCastOp and it has a single user (which is this
// op).
auto definingOp = singleInput.getDefiningOp<UnrealizedConversionCastOp>();
if (!definingOp || !definingOp->hasOneUse())
return;
auto inputOfDefiningOp = definingOp.getInputs()[0];
// If the input of the defining op and output type are both vector types
// have same number of elements, insert a shape cast.
auto inputOfDefiningOpTy =
dyn_cast<VectorType>(inputOfDefiningOp.getType());
if (inputOfDefiningOpTy &&
inputOfDefiningOpTy.getNumElements() == outputTy.getNumElements()) {
builder.setInsertionPoint(op);
auto shapeCast = vector::ShapeCastOp::create(builder, op.getLoc(),
outputTy, inputOfDefiningOp);
op.replaceAllUsesWith(ValueRange{shapeCast.getResult()});
return;
}
});
// At this point, we will have some dead UnrealizedConversionCastOps. Just
// erase them.
bool changed = true;
while (changed) {
changed = false;
root->walk([&](UnrealizedConversionCastOp op) {
// Skip existing casts.
if (existingCasts.contains(op))
return;
if (op.use_empty()) {
op.erase();
changed = true;
}
});
}
}
void xegpu::populateXeGPUSgToWiDistributeTypeConversions(
TypeConverter &typeConverter) {
// Any type other than TensorDescType and VectorType are legal as is.
typeConverter.addConversion([](Type type) -> std::optional<Type> {
if (!isa<TensorDescType, VectorType>(type))
return type;
return std::nullopt;
});
// For TensorDescType, drop the layout attribute if any.
typeConverter.addConversion([](TensorDescType type) -> Type {
if (type.getLayoutAttr()) {
return type.dropLayouts();
}
return type;
});
// For VectorType, check if there is a distribute layout attribute on the
// value. If so, convert to the distributed vector type based on the layout.
typeConverter.addConversion([](Value v) -> std::optional<Type> {
auto type = v.getType();
// If value is not vector type, nothing to do.
if (!isa<VectorType>(type))
return std::nullopt;
auto layout = xegpu::getDistributeLayoutAttr(v);
if (!layout || !layout.isForSubgroup())
return type;
// Vector type is distributed based on lane layout.
auto newTyOrFailure =
getDistVecTypeBasedOnLaneLayout(layout, cast<VectorType>(type));
if (failed(newTyOrFailure))
return type;
return *newTyOrFailure;
});
}
void xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
populateXeGPUSgToWiDistributeTypeConversions(typeConverter);
// CreateNdDescOp is legal only if its result type has no layout attribute.
target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
[&](xegpu::CreateNdDescOp op) { return !op.getType().getLayoutAttr(); });
// Any anchor XeGPU op is legal only if it has no anchor layout.
target.addDynamicallyLegalDialect<xegpu::XeGPUDialect>([](Operation *op) {
auto anchorOp = dyn_cast<AnchorLayoutInterface>(op);
if (!anchorOp)
return true;
return !anchorOp.getAnchorLayout();
});
// Arith constants are legal only if they have no temporary layout attribute.
target.addDynamicallyLegalOp<arith::ConstantOp>(
[=](arith::ConstantOp op) -> bool {
// If the result type is not a vector, it's legal.
if (!isa<VectorType>(op.getResult().getType()))
return true;
return !xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
});
// In math and arith dialects, only handle elementwise ops with a single
// result and with a result layout attribute.
target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
[=](Operation *op) -> std::optional<bool> {
// Only handle elementwise mappable ops
if (!OpTrait::hasElementwiseMappableTraits(op))
return true;
// Only handle ops with single vector result
if (op->getNumResults() != 1)
return true;
VectorType resultType =
dyn_cast<VectorType>(op->getResult(0).getType());
if (!resultType)
return true;
// Check if all operands are vectors of the same shape
for (Value operand : op->getOperands()) {
VectorType operandType = dyn_cast<VectorType>(operand.getType());
if (!operandType || operandType.getShape() != resultType.getShape()) {
return true;
}
}
return !xegpu::getTemporaryLayout(dyn_cast<OpResult>(op->getResult(0)));
});
target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
patterns.add<SgToWiCreateNdDesc, SgToWiLoadNd, SgToWiStoreNd, SgToWiDpas,
SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd>(
typeConverter, patterns.getContext());
}