diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h index d9de79e41529..53372f56f2bf 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h @@ -137,32 +137,6 @@ public: OpComplianceInfo findMatchedEntry(Operation *op, SmallVector> compInfo); - SmallVector getCooperativeProfiles(Extension ext) { - switch (ext) { - case Extension::int16: - case Extension::int4: - case Extension::doubleround: - case Extension::inexactround: - return {Profile::pro_int}; - case Extension::bf16: - case Extension::fp8e4m3: - case Extension::fp8e5m2: - case Extension::fft: - case Extension::mxfp: - case Extension::mxfp_conv: - return {Profile::pro_fp}; - case Extension::variable: - case Extension::controlflow: - case Extension::dynamic: - case Extension::int64: - case Extension::shape: - return {Profile::pro_fp, Profile::pro_int}; - case Extension::none: - return {}; - }; - llvm_unreachable("bad Extension type"); - } - // Debug utilites. template SmallVector stringifyProfile(ArrayRef profiles); diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp index 9f616b223bf9..118d8c6443be 100644 --- a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp @@ -52,6 +52,32 @@ TosaSpecificationVersion getMinVersion(const Extension &extension) { llvm_unreachable("Unknown TOSA extension"); } +SmallVector getCooperativeProfiles(Extension ext) { + switch (ext) { + case Extension::int16: + case Extension::int4: + case Extension::doubleround: + case Extension::inexactround: + return {Profile::pro_int}; + case Extension::bf16: + case Extension::fp8e4m3: + case Extension::fp8e5m2: + case Extension::fft: + case Extension::mxfp: + case Extension::mxfp_conv: + return {Profile::pro_fp}; + case Extension::variable: + case Extension::controlflow: + case Extension::dynamic: + case Extension::int64: + case Extension::shape: + return {Profile::pro_fp, Profile::pro_int}; + case Extension::none: + return {}; + }; + llvm_unreachable("bad Extension type"); +} + TosaSpecificationVersion getMinVersion(const Level &level) { switch (level) { case Level::eightK: @@ -90,14 +116,34 @@ LogicalResult TargetEnv::verifyTargetInformation(TargetEnvAttr targetAttr, return success(); }; + const auto isExtensionCooperativeWithProfile = + [&](Extension ext) -> LogicalResult { + const auto cooperativeProfiles = getCooperativeProfiles(ext); + + const ArrayRef targetProfiles = targetAttr.getProfiles(); + if (!llvm::any_of(cooperativeProfiles, + [&targetProfiles](const auto &profile) { + return llvm::is_contained(targetProfiles, profile); + })) + return emitError(targetAttrLoc) + << "use of extension '" << stringifyEnum(ext) + << "' requires any of profiles: [" << cooperativeProfiles + << "] to be enabled in the target"; + + return success(); + }; + for (const auto &profile : targetAttr.getProfiles()) if (failed( isCompatibleWithTargetVersion(profile, targetAttrLoc, "profile"))) return failure(); - for (const auto &extension : targetAttr.getExtensions()) + for (const auto &extension : targetAttr.getExtensions()) { if (failed(isCompatibleWithTargetVersion(extension, targetAttrLoc, "extension"))) return failure(); + if (failed(isExtensionCooperativeWithProfile(extension))) + return failure(); + } if (failed(isCompatibleWithTargetVersion(targetAttr.getLevel(), targetAttrLoc, "level"))) return failure(); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp index a0661e4ee0bd..410d55d63e5f 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaAttachTarget.cpp @@ -63,6 +63,10 @@ public: MLIRContext *ctx = &getContext(); const auto targetEnvAttr = TargetEnvAttr::get( ctx, specificationVersion, level, selectedProfiles, selectedExtensions); + + if (failed(TargetEnv::verifyTargetInformation(targetEnvAttr, mod.getLoc()))) + return signalPassFailure(); + mod->setAttr(TargetEnvAttr::name, targetEnvAttr); } diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp index 5037d0176d68..1b824a4d3258 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp @@ -435,21 +435,6 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension( return failure(); } - // Each extension can contain a list of profiles that it works with, usually - // have the same data type. - if constexpr (std::is_same_v) { - for (const auto &mode : opRequiredMode) { - SmallVector coProfs = getCooperativeProfiles(mode); - if (!targetEnv.allowsAnyOf(coProfs)) { - op->emitOpError() << "illegal: requires [" - << llvm::join(stringifyProfile(coProfs), - ", ") - << "] to work with but not enabled in target\n"; - return failure(); - } - } - } - // Ensure the profile inference match the profile knowledge of the // specification. for (const auto &cands : specRequiredModeSet) { diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir index 16190bb69411..f0f001ec8511 100644 --- a/mlir/test/Dialect/Tosa/invalid_extension.mlir +++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir @@ -186,6 +186,13 @@ func.func @test_concat(%arg0: tensor<13x21x3xbf16>, %arg1: tensor<13x21x3xbf16>) return %0 : tensor<26x21x3xbf16> } +// ----- +func.func @test_concat(%arg0: tensor<13x21x3xi16>, %arg1: tensor<13x21x3xi16>) -> tensor<26x21x3xi16> { + // expected-error@+1 {{'tosa.concat' op illegal: requires [int16] but not enabled in target}} + %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<13x21x3xi16>, tensor<13x21x3xi16>) -> tensor<26x21x3xi16> + return %0 : tensor<26x21x3xi16> +} + // ----- func.func @test_pad(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xbf16> { %padding = tosa.const_shape {values = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6> diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index d9de2e0d37c2..d061da14bb10 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="extensions=dynamic" -tosa-validate +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=dynamic" -tosa-validate func.func @test_argmax_rank_invalid(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32> { // expected-error@+1 {{'tosa.argmax' op failed level check: operand rank(shape) <= MAX_RANK}} diff --git a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir index d1e61345ff31..8e56c9c54446 100644 --- a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir +++ b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir @@ -1,8 +1,8 @@ -//-------------------------------------------------------------------------------------------------- -// Enable all supported extensions to focus the verification of expected profile requirement errors. -//-------------------------------------------------------------------------------------------------- +//----------------------------------------------------------------------------- +// Check validation of operations when no profiles are specified in the target. +//----------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="" -tosa-validate="strict-op-spec-alignment" // ----- func.func @test_add_i32(%arg0: tensor<13x21x1xi32>, %arg1: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> { diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir index d7330cb763fc..fb0ce19dfc5b 100644 --- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir +++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir @@ -1,8 +1,8 @@ -//-------------------------------------------------------------------------------------------------- -// Enable all supported extensions to focus the verification of expected profile requirement errors. -//-------------------------------------------------------------------------------------------------- +//----------------------------------------------------------------------------- +// Check operations fail to validate when pro_fp is not provided in the target. +//----------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround,mxfp,mxfp_conv" -tosa-validate="strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int" -tosa-validate="strict-op-spec-alignment" // ----- func.func @test_const_f16() -> tensor<3x11x11x3xf16> { diff --git a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir index cb760956fbd6..99b602e48feb 100644 --- a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir +++ b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir @@ -1,8 +1,8 @@ -//-------------------------------------------------------------------------------------------------- -// Enable all supported extensions to focus the verification of expected profile requirement errors. -//-------------------------------------------------------------------------------------------------- +//-------------------------------------------------------------------------------- +// Check operations fail to validation when pro_int is not provided in the target. +//-------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_fp" -tosa-validate="strict-op-spec-alignment" // ----- func.func @test_const_i1() -> tensor<3x11x11x3xi1> { @@ -179,13 +179,6 @@ func.func @test_reduce_sum(%arg0: tensor<13x21x3xi32>) -> tensor<1x21x3xi32> { return %0 : tensor<1x21x3xi32> } -// ----- -func.func @test_concat(%arg0: tensor<13x21x3xi16>, %arg1: tensor<13x21x3xi16>) -> tensor<26x21x3xi16> { - // expected-error@+1 {{'tosa.concat' op illegal: requires [pro_int] to work with but not enabled in target}} - %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<13x21x3xi16>, tensor<13x21x3xi16>) -> tensor<26x21x3xi16> - return %0 : tensor<26x21x3xi16> -} - // ----- func.func @test_pad(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> { %padding = tosa.const_shape {values = dense<0> : tensor<6xindex>} : () -> !tosa.shape<6> diff --git a/mlir/test/Dialect/Tosa/tosa-attach-target-non-cooperative-profile.mlir b/mlir/test/Dialect/Tosa/tosa-attach-target-non-cooperative-profile.mlir new file mode 100644 index 000000000000..a72228896cd0 --- /dev/null +++ b/mlir/test/Dialect/Tosa/tosa-attach-target-non-cooperative-profile.mlir @@ -0,0 +1,9 @@ +// RUN: mlir-opt %s -verify-diagnostics -tosa-attach-target="profiles=pro_fp extensions=int16" + +// expected-error@below {{use of extension 'int16' requires any of profiles: [pro_int] to be enabled in the target}} +module { + func.func @test_simple(%arg0 : tensor<1x1x1x1xf32>, %arg1 : tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> { + %1 = tosa.add %arg0, %arg1 : (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> + return %1 : tensor<1x1x1x1xf32> + } +}