From 3a20c005f6fce92600ca1167255e2a7ef7a0e0d0 Mon Sep 17 00:00:00 2001 From: Davide Grohmann Date: Mon, 4 Aug 2025 14:55:29 +0200 Subject: [PATCH] Reland "[mlir][spirv] Fix UpdateVCEPass to deduce the correct set of capabilities" (#151502) This reland PR #151108 The original PR made sanitizer builds to fail. The issue are now resolved in this new patch. Original commit message: > When deducing capabilities implied capabilities are not considered, which causes generation of incorrect SPIR-V modules. This commit fixes that by pulling in the capability set all the implied ones. --------- Signed-off-by: Davide Grohmann Co-authored-by: Jakub Kuderski --- .../SPIRV/Transforms/UpdateVCEPass.cpp | 9 ++++++ .../SPIRV/Transforms/vce-deduction.mlir | 32 +++++++++---------- 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp index a53d0a7e4c44..6d3bda421f30 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp @@ -95,6 +95,13 @@ static LogicalResult checkAndUpdateCapabilityRequirements( return success(); } +static void addAllImpliedCapabilities(SetVector &caps) { + SetVector tmp; + for (spirv::Capability cap : caps) + tmp.insert_range(getRecursiveImpliedCapabilities(cap)); + caps.insert_range(std::move(tmp)); +} + void UpdateVCEPass::runOnOperation() { spirv::ModuleOp module = getOperation(); @@ -174,6 +181,8 @@ void UpdateVCEPass::runOnOperation() { if (walkResult.wasInterrupted()) return signalPassFailure(); + addAllImpliedCapabilities(deducedCapabilities); + // Update min version requirement for capabilities after deducing them. for (spirv::Capability cap : deducedCapabilities) { if (std::optional minVersion = spirv::getMinVersion(cap)) { diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir index 8d7f3da4007c..4e534a30ad51 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir @@ -7,7 +7,7 @@ // Test deducing minimal version. // spirv.IAdd is available from v1.0. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -21,7 +21,7 @@ spirv.module Logical GLSL450 attributes { // Test deducing minimal version. // spirv.GroupNonUniformBallot is available since v1.3. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -32,7 +32,7 @@ spirv.module Logical GLSL450 attributes { } } -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { @@ -48,7 +48,7 @@ spirv.module Logical GLSL450 attributes { // Test minimal capabilities. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -61,10 +61,10 @@ spirv.module Logical GLSL450 attributes { // Test Physical Storage Buffers are deduced correctly. -// CHECK: spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce +// CHECK: spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce spirv.module PhysicalStorageBuffer64 GLSL450 attributes { spirv.target_env = #spirv.target_env< - #spirv.vce, #spirv.resource_limits<>> + #spirv.vce, #spirv.resource_limits<>> } { spirv.func @physical_ptr(%val : !spirv.ptr { spirv.decoration = #spirv.decoration }) "None" { spirv.Return @@ -74,7 +74,7 @@ spirv.module PhysicalStorageBuffer64 GLSL450 attributes { // Test deducing implied capability. // AtomicStorage implies Shader. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -95,7 +95,7 @@ spirv.module Logical GLSL450 attributes { // * GroupNonUniformArithmetic // * GroupNonUniformBallot -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -106,7 +106,7 @@ spirv.module Logical GLSL450 attributes { } } -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -120,7 +120,7 @@ spirv.module Logical GLSL450 attributes { // Test type required capabilities // Using 8-bit integers in non-interface storage class requires Int8. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -132,7 +132,7 @@ spirv.module Logical GLSL450 attributes { } // Using 16-bit floats in non-interface storage class requires Float16. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -144,7 +144,7 @@ spirv.module Logical GLSL450 attributes { } // Using 16-element vectors requires Vector16. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -162,7 +162,7 @@ spirv.module Logical GLSL450 attributes { // Test deducing minimal extensions. // spirv.KHR.SubgroupBallot requires the SPV_KHR_shader_ballot extension. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, #spirv.resource_limits<>> @@ -208,7 +208,7 @@ spirv.module Logical GLSL450 attributes { // Complicated nested types // * Buffer requires ImageBuffer or SampledBuffer. // * Rg32f requires StorageImageExtendedFormats. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce, @@ -219,7 +219,7 @@ spirv.module Logical GLSL450 attributes { } // Using bfloat16 requires BFloat16TypeKHR capability and SPV_KHR_bfloat16 extension. -// CHECK: requires #spirv.vce +// CHECK: requires #spirv.vce spirv.module Logical GLSL450 attributes { spirv.target_env = #spirv.target_env< #spirv.vce,