
This is a follow-up to 5df62bdc9be9c258c5ac45c8093b71e23777fa0e. That commit should not have needed to make the vector.insert and vector.extract conversions to SPIR-V directly handle the static poison index case, as there is a fold from those to ub.poison, and a conversion pattern from ub.poison to spirv.Undef, however: - The ub.poison fold result could not be materialized by the vector dialect (fixed as of d13940ee263ff50b7a71e21424913cc0266bf9d4). - The conversion pattern wasn't being populated in VectorToSPIRVPass, which is used by the tests. This commit changes this. - The ub.poison to spirv.Undef pattern rejected non-scalar types, which prevented its use for vector results. It is unclear why this restriction existed; a remark in D156163 said this was to avoid converting "user types", but it is not obvious why these shouldn't be permitted (the SPIR-V specification allows OpUndef for all types except OpTypeVoid). This commit removes this restriction. With these fixed, this commit removes the redundant static poison index handling, and updates the tests.
80 lines
2.6 KiB
C++
80 lines
2.6 KiB
C++
//===- UBToSPIRV.cpp - UB to SPIRV-V dialect conversion -------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
|
|
|
|
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
|
|
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
|
|
#include "mlir/Dialect/UB/IR/UBOps.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
|
|
namespace mlir {
|
|
#define GEN_PASS_DEF_UBTOSPIRVCONVERSIONPASS
|
|
#include "mlir/Conversion/Passes.h.inc"
|
|
} // namespace mlir
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
|
|
struct PoisonOpLowering final : OpConversionPattern<ub::PoisonOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(ub::PoisonOp op, OpAdaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Type origType = op.getType();
|
|
Type resType = getTypeConverter()->convertType(origType);
|
|
if (!resType)
|
|
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
|
|
diag << "failed to convert result type " << origType;
|
|
});
|
|
|
|
rewriter.replaceOpWithNewOp<spirv::UndefOp>(op, resType);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Pass Definition
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
struct UBToSPIRVConversionPass final
|
|
: impl::UBToSPIRVConversionPassBase<UBToSPIRVConversionPass> {
|
|
using Base::Base;
|
|
|
|
void runOnOperation() override {
|
|
Operation *op = getOperation();
|
|
spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
|
|
std::unique_ptr<SPIRVConversionTarget> target =
|
|
SPIRVConversionTarget::get(targetAttr);
|
|
|
|
SPIRVConversionOptions options;
|
|
SPIRVTypeConverter typeConverter(targetAttr, options);
|
|
|
|
RewritePatternSet patterns(&getContext());
|
|
ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
|
|
|
|
if (failed(applyPartialConversion(op, *target, std::move(patterns))))
|
|
signalPassFailure();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Pattern Population
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void mlir::ub::populateUBToSPIRVConversionPatterns(
|
|
const SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
|
|
patterns.add<PoisonOpLowering>(converter, patterns.getContext());
|
|
}
|