[mlir][tosa] Check extension cooperative profiles in target environment (#185476)

This commit moves checks on extension cooperative profiles out of
profile conformance and into target environment verification. This
allows the checks to be enforced when the target is created, not during
profile conformance validation.
This commit is contained in:
Luke Hutton 2026-03-10 16:37:02 +00:00 committed by GitHub
parent ebaf174409
commit e266cb5a9a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 80 additions and 62 deletions

View File

@ -137,32 +137,6 @@ public:
OpComplianceInfo<T>
findMatchedEntry(Operation *op, SmallVector<OpComplianceInfo<T>> compInfo);
SmallVector<Profile> getCooperativeProfiles(Extension ext) {
switch (ext) {
case Extension::int16:
case Extension::int4:
case Extension::doubleround:
case Extension::inexactround:
return {Profile::pro_int};
case Extension::bf16:
case Extension::fp8e4m3:
case Extension::fp8e5m2:
case Extension::fft:
case Extension::mxfp:
case Extension::mxfp_conv:
return {Profile::pro_fp};
case Extension::variable:
case Extension::controlflow:
case Extension::dynamic:
case Extension::int64:
case Extension::shape:
return {Profile::pro_fp, Profile::pro_int};
case Extension::none:
return {};
};
llvm_unreachable("bad Extension type");
}
// Debug utilites.
template <typename T>
SmallVector<StringRef> stringifyProfile(ArrayRef<T> profiles);

View File

@ -52,6 +52,32 @@ TosaSpecificationVersion getMinVersion(const Extension &extension) {
llvm_unreachable("Unknown TOSA extension");
}
SmallVector<Profile, 2> getCooperativeProfiles(Extension ext) {
switch (ext) {
case Extension::int16:
case Extension::int4:
case Extension::doubleround:
case Extension::inexactround:
return {Profile::pro_int};
case Extension::bf16:
case Extension::fp8e4m3:
case Extension::fp8e5m2:
case Extension::fft:
case Extension::mxfp:
case Extension::mxfp_conv:
return {Profile::pro_fp};
case Extension::variable:
case Extension::controlflow:
case Extension::dynamic:
case Extension::int64:
case Extension::shape:
return {Profile::pro_fp, Profile::pro_int};
case Extension::none:
return {};
};
llvm_unreachable("bad Extension type");
}
TosaSpecificationVersion getMinVersion(const Level &level) {
switch (level) {
case Level::eightK:
@ -90,14 +116,34 @@ LogicalResult TargetEnv::verifyTargetInformation(TargetEnvAttr targetAttr,
return success();
};
const auto isExtensionCooperativeWithProfile =
[&](Extension ext) -> LogicalResult {
const auto cooperativeProfiles = getCooperativeProfiles(ext);
const ArrayRef<Profile> targetProfiles = targetAttr.getProfiles();
if (!llvm::any_of(cooperativeProfiles,
[&targetProfiles](const auto &profile) {
return llvm::is_contained(targetProfiles, profile);
}))
return emitError(targetAttrLoc)
<< "use of extension '" << stringifyEnum(ext)
<< "' requires any of profiles: [" << cooperativeProfiles
<< "] to be enabled in the target";
return success();
};
for (const auto &profile : targetAttr.getProfiles())
if (failed(
isCompatibleWithTargetVersion(profile, targetAttrLoc, "profile")))
return failure();
for (const auto &extension : targetAttr.getExtensions())
for (const auto &extension : targetAttr.getExtensions()) {
if (failed(isCompatibleWithTargetVersion(extension, targetAttrLoc,
"extension")))
return failure();
if (failed(isExtensionCooperativeWithProfile(extension)))
return failure();
}
if (failed(isCompatibleWithTargetVersion(targetAttr.getLevel(), targetAttrLoc,
"level")))
return failure();

View File

@ -63,6 +63,10 @@ public:
MLIRContext *ctx = &getContext();
const auto targetEnvAttr = TargetEnvAttr::get(
ctx, specificationVersion, level, selectedProfiles, selectedExtensions);
if (failed(TargetEnv::verifyTargetInformation(targetEnvAttr, mod.getLoc())))
return signalPassFailure();
mod->setAttr(TargetEnvAttr::name, targetEnvAttr);
}

View File

@ -435,21 +435,6 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
return failure();
}
// Each extension can contain a list of profiles that it works with, usually
// have the same data type.
if constexpr (std::is_same_v<T, Extension>) {
for (const auto &mode : opRequiredMode) {
SmallVector<Profile> coProfs = getCooperativeProfiles(mode);
if (!targetEnv.allowsAnyOf(coProfs)) {
op->emitOpError() << "illegal: requires ["
<< llvm::join(stringifyProfile<Profile>(coProfs),
", ")
<< "] to work with but not enabled in target\n";
return failure();
}
}
}
// Ensure the profile inference match the profile knowledge of the
// specification.
for (const auto &cands : specRequiredModeSet) {

View File

@ -186,6 +186,13 @@ func.func @test_concat(%arg0: tensor<13x21x3xbf16>, %arg1: tensor<13x21x3xbf16>)
return %0 : tensor<26x21x3xbf16>
}
// -----
func.func @test_concat(%arg0: tensor<13x21x3xi16>, %arg1: tensor<13x21x3xi16>) -> tensor<26x21x3xi16> {
// expected-error@+1 {{'tosa.concat' op illegal: requires [int16] but not enabled in target}}
%0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<13x21x3xi16>, tensor<13x21x3xi16>) -> tensor<26x21x3xi16>
return %0 : tensor<26x21x3xi16>
}
// -----
func.func @test_pad(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xbf16> {
%padding = tosa.const_shape {values = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6>

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="extensions=dynamic" -tosa-validate
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=dynamic" -tosa-validate
func.func @test_argmax_rank_invalid(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32> {
// expected-error@+1 {{'tosa.argmax' op failed level check: operand rank(shape) <= MAX_RANK}}

View File

@ -1,8 +1,8 @@
//--------------------------------------------------------------------------------------------------
// Enable all supported extensions to focus the verification of expected profile requirement errors.
//--------------------------------------------------------------------------------------------------
//-----------------------------------------------------------------------------
// Check validation of operations when no profiles are specified in the target.
//-----------------------------------------------------------------------------
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="" -tosa-validate="strict-op-spec-alignment"
// -----
func.func @test_add_i32(%arg0: tensor<13x21x1xi32>, %arg1: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {

View File

@ -1,8 +1,8 @@
//--------------------------------------------------------------------------------------------------
// Enable all supported extensions to focus the verification of expected profile requirement errors.
//--------------------------------------------------------------------------------------------------
//-----------------------------------------------------------------------------
// Check operations fail to validate when pro_fp is not provided in the target.
//-----------------------------------------------------------------------------
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround,mxfp,mxfp_conv" -tosa-validate="strict-op-spec-alignment"
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int" -tosa-validate="strict-op-spec-alignment"
// -----
func.func @test_const_f16() -> tensor<3x11x11x3xf16> {

View File

@ -1,8 +1,8 @@
//--------------------------------------------------------------------------------------------------
// Enable all supported extensions to focus the verification of expected profile requirement errors.
//--------------------------------------------------------------------------------------------------
//--------------------------------------------------------------------------------
// Check operations fail to validation when pro_int is not provided in the target.
//--------------------------------------------------------------------------------
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_fp" -tosa-validate="strict-op-spec-alignment"
// -----
func.func @test_const_i1() -> tensor<3x11x11x3xi1> {
@ -179,13 +179,6 @@ func.func @test_reduce_sum(%arg0: tensor<13x21x3xi32>) -> tensor<1x21x3xi32> {
return %0 : tensor<1x21x3xi32>
}
// -----
func.func @test_concat(%arg0: tensor<13x21x3xi16>, %arg1: tensor<13x21x3xi16>) -> tensor<26x21x3xi16> {
// expected-error@+1 {{'tosa.concat' op illegal: requires [pro_int] to work with but not enabled in target}}
%0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<13x21x3xi16>, tensor<13x21x3xi16>) -> tensor<26x21x3xi16>
return %0 : tensor<26x21x3xi16>
}
// -----
func.func @test_pad(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> {
%padding = tosa.const_shape {values = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6>

View File

@ -0,0 +1,9 @@
// RUN: mlir-opt %s -verify-diagnostics -tosa-attach-target="profiles=pro_fp extensions=int16"
// expected-error@below {{use of extension 'int16' requires any of profiles: [pro_int] to be enabled in the target}}
module {
func.func @test_simple(%arg0 : tensor<1x1x1x1xf32>, %arg1 : tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> {
%1 = tosa.add %arg0, %arg1 : (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>
return %1 : tensor<1x1x1x1xf32>
}
}