diff --git a/offload/plugins-nextgen/common/include/PluginInterface.h b/offload/plugins-nextgen/common/include/PluginInterface.h index c3d553b1e3b7..83d20c620b96 100644 --- a/offload/plugins-nextgen/common/include/PluginInterface.h +++ b/offload/plugins-nextgen/common/include/PluginInterface.h @@ -1367,6 +1367,24 @@ struct GenericPluginTy { virtual Expected isELFCompatible(uint32_t DeviceID, StringRef Image) const = 0; + /// Indicate if an image is compatible with the plugin. This is called if + /// the image is not recognized as compatible by the common layer. This gives + /// the plugin a chance to inspect the image and decide if it is compatible. + virtual Expected isImageCompatible(StringRef Image) const { + return false; + } + + /// Indicate if an image is compatible with the plugin devices. This is + /// called if the image is not recognized as compatible by the common layer. + /// This gives the plugin a chance to inspect the image and decide if it is + /// compatible. Notice that this function may be called before actually + /// initializing the devices. So we could not move this function into + /// GenericDeviceTy. + virtual Expected isImageCompatible(uint32_t DeviceID, + StringRef Image) const { + return isImageCompatible(Image); + } + virtual Error flushQueueImpl(omp_interop_val_t *Interop) { return Plugin::success(); } diff --git a/offload/plugins-nextgen/common/src/PluginInterface.cpp b/offload/plugins-nextgen/common/src/PluginInterface.cpp index f681213b3879..21ba9db292c4 100644 --- a/offload/plugins-nextgen/common/src/PluginInterface.cpp +++ b/offload/plugins-nextgen/common/src/PluginInterface.cpp @@ -1652,7 +1652,10 @@ int32_t GenericPluginTy::isPluginCompatible(StringRef Image) { return *MatchOrErr; } default: - return false; + auto MatchOrErr = isImageCompatible(Image); + if (Error Err = MatchOrErr.takeError()) + return HandleError(std::move(Err)); + return *MatchOrErr; } } @@ -1689,7 +1692,10 @@ int32_t GenericPluginTy::isDeviceCompatible(int32_t DeviceId, StringRef Image) { return *MatchOrErr; } default: - return false; + auto MatchOrErr = isImageCompatible(DeviceId, Image); + if (Error Err = MatchOrErr.takeError()) + return HandleError(std::move(Err)); + return *MatchOrErr; } } diff --git a/offload/plugins-nextgen/level_zero/include/L0Plugin.h b/offload/plugins-nextgen/level_zero/include/L0Plugin.h index cd964a0d4689..7ab3696a49af 100644 --- a/offload/plugins-nextgen/level_zero/include/L0Plugin.h +++ b/offload/plugins-nextgen/level_zero/include/L0Plugin.h @@ -107,6 +107,8 @@ public: Error flushQueueImpl(omp_interop_val_t *Interop) override; Error syncBarrierImpl(omp_interop_val_t *Interop) override; Error asyncBarrierImpl(omp_interop_val_t *Interop) override; + + Expected isImageCompatible(StringRef Image) const override; }; } // namespace llvm::omp::target::plugin diff --git a/offload/plugins-nextgen/level_zero/src/L0Plugin.cpp b/offload/plugins-nextgen/level_zero/src/L0Plugin.cpp index 285fe797b5d7..2b3e61256639 100644 --- a/offload/plugins-nextgen/level_zero/src/L0Plugin.cpp +++ b/offload/plugins-nextgen/level_zero/src/L0Plugin.cpp @@ -233,6 +233,11 @@ Error LevelZeroPluginTy::asyncBarrierImpl(omp_interop_val_t *Interop) { return Plugin::success(); } +// We only need to check for formats other than ELF here +Expected LevelZeroPluginTy::isImageCompatible(StringRef Image) const { + return identify_magic(Image) == file_magic::spirv_object; +} + } // namespace llvm::omp::target::plugin extern "C" {