[mlir][tosa] Check extension cooperative profiles in target environment (#185476)
This commit moves checks on extension cooperative profiles out of profile conformance and into target environment verification. This allows the checks to be enforced when the target is created, not during profile conformance validation.
This commit is contained in:
parent
ebaf174409
commit
e266cb5a9a
@ -137,32 +137,6 @@ public:
|
||||
OpComplianceInfo<T>
|
||||
findMatchedEntry(Operation *op, SmallVector<OpComplianceInfo<T>> compInfo);
|
||||
|
||||
SmallVector<Profile> getCooperativeProfiles(Extension ext) {
|
||||
switch (ext) {
|
||||
case Extension::int16:
|
||||
case Extension::int4:
|
||||
case Extension::doubleround:
|
||||
case Extension::inexactround:
|
||||
return {Profile::pro_int};
|
||||
case Extension::bf16:
|
||||
case Extension::fp8e4m3:
|
||||
case Extension::fp8e5m2:
|
||||
case Extension::fft:
|
||||
case Extension::mxfp:
|
||||
case Extension::mxfp_conv:
|
||||
return {Profile::pro_fp};
|
||||
case Extension::variable:
|
||||
case Extension::controlflow:
|
||||
case Extension::dynamic:
|
||||
case Extension::int64:
|
||||
case Extension::shape:
|
||||
return {Profile::pro_fp, Profile::pro_int};
|
||||
case Extension::none:
|
||||
return {};
|
||||
};
|
||||
llvm_unreachable("bad Extension type");
|
||||
}
|
||||
|
||||
// Debug utilites.
|
||||
template <typename T>
|
||||
SmallVector<StringRef> stringifyProfile(ArrayRef<T> profiles);
|
||||
|
||||
@ -52,6 +52,32 @@ TosaSpecificationVersion getMinVersion(const Extension &extension) {
|
||||
llvm_unreachable("Unknown TOSA extension");
|
||||
}
|
||||
|
||||
SmallVector<Profile, 2> getCooperativeProfiles(Extension ext) {
|
||||
switch (ext) {
|
||||
case Extension::int16:
|
||||
case Extension::int4:
|
||||
case Extension::doubleround:
|
||||
case Extension::inexactround:
|
||||
return {Profile::pro_int};
|
||||
case Extension::bf16:
|
||||
case Extension::fp8e4m3:
|
||||
case Extension::fp8e5m2:
|
||||
case Extension::fft:
|
||||
case Extension::mxfp:
|
||||
case Extension::mxfp_conv:
|
||||
return {Profile::pro_fp};
|
||||
case Extension::variable:
|
||||
case Extension::controlflow:
|
||||
case Extension::dynamic:
|
||||
case Extension::int64:
|
||||
case Extension::shape:
|
||||
return {Profile::pro_fp, Profile::pro_int};
|
||||
case Extension::none:
|
||||
return {};
|
||||
};
|
||||
llvm_unreachable("bad Extension type");
|
||||
}
|
||||
|
||||
TosaSpecificationVersion getMinVersion(const Level &level) {
|
||||
switch (level) {
|
||||
case Level::eightK:
|
||||
@ -90,14 +116,34 @@ LogicalResult TargetEnv::verifyTargetInformation(TargetEnvAttr targetAttr,
|
||||
return success();
|
||||
};
|
||||
|
||||
const auto isExtensionCooperativeWithProfile =
|
||||
[&](Extension ext) -> LogicalResult {
|
||||
const auto cooperativeProfiles = getCooperativeProfiles(ext);
|
||||
|
||||
const ArrayRef<Profile> targetProfiles = targetAttr.getProfiles();
|
||||
if (!llvm::any_of(cooperativeProfiles,
|
||||
[&targetProfiles](const auto &profile) {
|
||||
return llvm::is_contained(targetProfiles, profile);
|
||||
}))
|
||||
return emitError(targetAttrLoc)
|
||||
<< "use of extension '" << stringifyEnum(ext)
|
||||
<< "' requires any of profiles: [" << cooperativeProfiles
|
||||
<< "] to be enabled in the target";
|
||||
|
||||
return success();
|
||||
};
|
||||
|
||||
for (const auto &profile : targetAttr.getProfiles())
|
||||
if (failed(
|
||||
isCompatibleWithTargetVersion(profile, targetAttrLoc, "profile")))
|
||||
return failure();
|
||||
for (const auto &extension : targetAttr.getExtensions())
|
||||
for (const auto &extension : targetAttr.getExtensions()) {
|
||||
if (failed(isCompatibleWithTargetVersion(extension, targetAttrLoc,
|
||||
"extension")))
|
||||
return failure();
|
||||
if (failed(isExtensionCooperativeWithProfile(extension)))
|
||||
return failure();
|
||||
}
|
||||
if (failed(isCompatibleWithTargetVersion(targetAttr.getLevel(), targetAttrLoc,
|
||||
"level")))
|
||||
return failure();
|
||||
|
||||
@ -63,6 +63,10 @@ public:
|
||||
MLIRContext *ctx = &getContext();
|
||||
const auto targetEnvAttr = TargetEnvAttr::get(
|
||||
ctx, specificationVersion, level, selectedProfiles, selectedExtensions);
|
||||
|
||||
if (failed(TargetEnv::verifyTargetInformation(targetEnvAttr, mod.getLoc())))
|
||||
return signalPassFailure();
|
||||
|
||||
mod->setAttr(TargetEnvAttr::name, targetEnvAttr);
|
||||
}
|
||||
|
||||
|
||||
@ -435,21 +435,6 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Each extension can contain a list of profiles that it works with, usually
|
||||
// have the same data type.
|
||||
if constexpr (std::is_same_v<T, Extension>) {
|
||||
for (const auto &mode : opRequiredMode) {
|
||||
SmallVector<Profile> coProfs = getCooperativeProfiles(mode);
|
||||
if (!targetEnv.allowsAnyOf(coProfs)) {
|
||||
op->emitOpError() << "illegal: requires ["
|
||||
<< llvm::join(stringifyProfile<Profile>(coProfs),
|
||||
", ")
|
||||
<< "] to work with but not enabled in target\n";
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure the profile inference match the profile knowledge of the
|
||||
// specification.
|
||||
for (const auto &cands : specRequiredModeSet) {
|
||||
|
||||
@ -186,6 +186,13 @@ func.func @test_concat(%arg0: tensor<13x21x3xbf16>, %arg1: tensor<13x21x3xbf16>)
|
||||
return %0 : tensor<26x21x3xbf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
func.func @test_concat(%arg0: tensor<13x21x3xi16>, %arg1: tensor<13x21x3xi16>) -> tensor<26x21x3xi16> {
|
||||
// expected-error@+1 {{'tosa.concat' op illegal: requires [int16] but not enabled in target}}
|
||||
%0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<13x21x3xi16>, tensor<13x21x3xi16>) -> tensor<26x21x3xi16>
|
||||
return %0 : tensor<26x21x3xi16>
|
||||
}
|
||||
|
||||
// -----
|
||||
func.func @test_pad(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xbf16> {
|
||||
%padding = tosa.const_shape {values = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6>
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="extensions=dynamic" -tosa-validate
|
||||
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=dynamic" -tosa-validate
|
||||
|
||||
func.func @test_argmax_rank_invalid(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32> {
|
||||
// expected-error@+1 {{'tosa.argmax' op failed level check: operand rank(shape) <= MAX_RANK}}
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
//--------------------------------------------------------------------------------------------------
|
||||
// Enable all supported extensions to focus the verification of expected profile requirement errors.
|
||||
//--------------------------------------------------------------------------------------------------
|
||||
//-----------------------------------------------------------------------------
|
||||
// Check validation of operations when no profiles are specified in the target.
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
|
||||
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="" -tosa-validate="strict-op-spec-alignment"
|
||||
|
||||
// -----
|
||||
func.func @test_add_i32(%arg0: tensor<13x21x1xi32>, %arg1: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
//--------------------------------------------------------------------------------------------------
|
||||
// Enable all supported extensions to focus the verification of expected profile requirement errors.
|
||||
//--------------------------------------------------------------------------------------------------
|
||||
//-----------------------------------------------------------------------------
|
||||
// Check operations fail to validate when pro_fp is not provided in the target.
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround,mxfp,mxfp_conv" -tosa-validate="strict-op-spec-alignment"
|
||||
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int" -tosa-validate="strict-op-spec-alignment"
|
||||
|
||||
// -----
|
||||
func.func @test_const_f16() -> tensor<3x11x11x3xf16> {
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
//--------------------------------------------------------------------------------------------------
|
||||
// Enable all supported extensions to focus the verification of expected profile requirement errors.
|
||||
//--------------------------------------------------------------------------------------------------
|
||||
//--------------------------------------------------------------------------------
|
||||
// Check operations fail to validation when pro_int is not provided in the target.
|
||||
//--------------------------------------------------------------------------------
|
||||
|
||||
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
|
||||
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_fp" -tosa-validate="strict-op-spec-alignment"
|
||||
|
||||
// -----
|
||||
func.func @test_const_i1() -> tensor<3x11x11x3xi1> {
|
||||
@ -179,13 +179,6 @@ func.func @test_reduce_sum(%arg0: tensor<13x21x3xi32>) -> tensor<1x21x3xi32> {
|
||||
return %0 : tensor<1x21x3xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
func.func @test_concat(%arg0: tensor<13x21x3xi16>, %arg1: tensor<13x21x3xi16>) -> tensor<26x21x3xi16> {
|
||||
// expected-error@+1 {{'tosa.concat' op illegal: requires [pro_int] to work with but not enabled in target}}
|
||||
%0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<13x21x3xi16>, tensor<13x21x3xi16>) -> tensor<26x21x3xi16>
|
||||
return %0 : tensor<26x21x3xi16>
|
||||
}
|
||||
|
||||
// -----
|
||||
func.func @test_pad(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> {
|
||||
%padding = tosa.const_shape {values = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6>
|
||||
|
||||
@ -0,0 +1,9 @@
|
||||
// RUN: mlir-opt %s -verify-diagnostics -tosa-attach-target="profiles=pro_fp extensions=int16"
|
||||
|
||||
// expected-error@below {{use of extension 'int16' requires any of profiles: [pro_int] to be enabled in the target}}
|
||||
module {
|
||||
func.func @test_simple(%arg0 : tensor<1x1x1x1xf32>, %arg1 : tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> {
|
||||
%1 = tosa.add %arg0, %arg1 : (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>
|
||||
return %1 : tensor<1x1x1x1xf32>
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user