[MLIR][NVVM] Add new narrow FP convert Ops (#184291)

This change adds the following NVVM Ops for new narrow FP conversions
introduced in PTX 9.1:
- `convert.{f32x2/bf16x2}.to.s2f6x2`
- `convert.s2f6x2.to.bf16x2`
- `convert.bf16x2.to.f8x2` (extended for `f8E4M3FN` and `f8E5M2` types)
- `convert.{f16x2/bf16x2}.to.f6x2`
- `convert.{f16x2/bf16x2}.to.f4x2`

PTX ISA Reference:
https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt
This commit is contained in:
Srinivasa Ravi 2026-04-06 12:06:25 +05:30 committed by GitHub
parent e326ff2a88
commit 63231ebfe7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 807 additions and 57 deletions

View File

@ -1972,6 +1972,44 @@ def NVVM_ConvertF32x2ToF4x2Op : NVVM_Op<"convert.f32x2.to.f4x2"> {
}];
}
class NVVM_ConvertFPx2ToF4x2Op<string srcType>
: NVVM_Op<"convert."#!tolower(srcType)#"x2.to.f4x2"> {
let summary = "Convert an " # !tolower(srcType) # "x2 input to f4x2";
let description = [{
This Op converts each of the given }]#srcType#[{ inputs in an }]#!tolower
(srcType)#[{x2 vector to the specified fp4 type.
The result `dst` is returned as an i8 type where the converted values are
packed such that the value converted from the first element of `a` is
stored in the lower 4 bits of `dst` and the value converted from the second
element of `a` is stored in the upper 4 bits of `dst`.
The `relu` attribute, when set, lowers to the '.relu' variant of
the cvt instruction.
}];
let results = (outs I8:$dst);
let arguments = (ins
VectorOfLengthAndType<[2], [!cast<Type>(srcType)]>:$src,
DefaultValuedAttr<BoolAttr, "false">:$relu,
TypeAttrOf<F4E2M1FN>:$dstTy);
let assemblyFormat =
"$src attr-dict `:` type($src) `->` type($dst) `(` $dstTy `)`";
let extraClassDeclaration = [{
static NVVM::IDArgPair
getIntrinsicIDAndArgs(NVVM::Convert}]#srcType#[{x2ToF4x2Op &op,
LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder);
}];
string llvmBuilder = [{
auto [intId, args] = NVVM::Convert}]#srcType#[{x2ToF4x2Op::getIntrinsicIDAndArgs(op, moduleTranslation, builder);
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, args);
$dst = builder.CreateTruncOrBitCast(packedI16, llvm::Type::getInt8Ty(builder.getContext()));
}];
}
def NVVM_ConvertF16x2ToF4x2Op : NVVM_ConvertFPx2ToF4x2Op<"F16">;
def NVVM_ConvertBF16x2ToF4x2Op : NVVM_ConvertFPx2ToF4x2Op<"BF16">;
def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> {
let summary = "Convert a pair of float inputs to f6x2";
let description = [{
@ -2015,6 +2053,48 @@ def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> {
}];
}
class NVVM_ConvertFPx2ToF6x2Op<string srcType>
: NVVM_Op<"convert."#!tolower(srcType)#"x2.to.f6x2"> {
let summary = "Convert an " # !tolower(srcType) # "x2 input to f6x2";
let description = [{
This Op converts each of the given }]#srcType#[{ inputs in an }]#!tolower
(srcType)#[{x2 vector to the specified fp6 type. The result `dst` is
represented either as an i16 type or as a vector of two i8 types.
If `dst` is returned as an i16 type, the converted values are packed such
that the value converted from the first element of `a` is stored in the
lower 8 bits of `dst` with 2 MSB bits padded with zeros and the value
converted from the second element of `a` is stored in the upper 8 bits of
`dst` with 2 MSB bits padded with zeros.
If `dst` is returned as a vector type, each converted value is stored as an
i8 element in the vector with 2 MSB bits padded with zeros.
The `relu` attribute, when set, lowers to the '.relu' variant of
the cvt instruction.
}];
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
let arguments = (ins
VectorOfLengthAndType<[2], [!cast<Type>(srcType)]>:$src,
DefaultValuedAttr<BoolAttr, "false">:$relu,
TypeAttrOf<AnyTypeOf<[F6E2M3FN, F6E3M2FN]>>:$dstTy);
let assemblyFormat = "$src attr-dict `:` type($src) `->` type($dst) `(` $dstTy `)`";
let extraClassDeclaration = [{
static llvm::Intrinsic::ID getIntrinsicID(mlir::Type dstTy, bool hasRelu);
}];
string llvmBuilder = [{
auto intId = NVVM::Convert}]#srcType#[{x2ToF6x2Op::getIntrinsicID($dstTy, $relu);
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$src});
if(op.getDst().getType().isInteger(16))
$dst = packedI16;
else
$dst = builder.CreateBitCast(packedI16,
llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
}];
}
def NVVM_ConvertF16x2ToF6x2Op : NVVM_ConvertFPx2ToF6x2Op<"F16">;
def NVVM_ConvertBF16x2ToF6x2Op : NVVM_ConvertFPx2ToF6x2Op<"BF16">;
def NVVM_ConvertF32x2ToF8x2Op : NVVM_Op<"convert.f32x2.to.f8x2"> {
let summary = "Convert a pair of float inputs to f8x2";
let description = [{
@ -2109,13 +2189,12 @@ def NVVM_ConvertBF16x2ToF8x2Op : NVVM_Op<"convert.bf16x2.to.f8x2"> {
let summary = "Convert a pair of bf16 inputs to f8x2";
let description = [{
This Op converts the given bf16 inputs in a bf16x2 vector to the specified
f8 type.
The result `dst` is represented as an i16 type or as a vector
of two i8 types.
If `dst` is returned as an i16 type, the converted values from `a`
are packed such that the value converted from the first element of `a`
is stored in the upper 8 bits of `dst` and the value converted from the
second element of `a` is stored in the lower 8 bits of `dst`.
f8 type. The result `dst` is represented either as a packed i16 type or as
a vector of two i8 types.
If `dst` is returned as an i16 type, the converted values are packed such
that the value converted from the first element of `a` is stored in the
lower 8 bits of `dst` and the value converted from the second element of
`a` is stored in the upper 8 bits of `dst`.
If `dst` is returned as a vector type, each converted value is stored as an
i8 element in the vector.
The `rnd` and `sat` attributes specify the rounding and saturation modes
@ -2127,20 +2206,22 @@ def NVVM_ConvertBF16x2ToF8x2Op : NVVM_Op<"convert.bf16x2.to.f8x2"> {
let hasVerifier = 1;
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
let arguments = (ins
VectorOfLengthAndType<[2], [BF16]>:$a,
VectorOfLengthAndType<[2], [BF16]>:$src,
DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
TypeAttr:$dstTy);
let assemblyFormat = "$a attr-dict `:` type($a) `->` type($dst) `(` $dstTy `)`";
DefaultValuedAttr<BoolAttr, "false">:$relu,
TypeAttrOf<AnyTypeOf<[F8E8M0FNU, F8E4M3FN, F8E5M2]>>:$dstTy);
let assemblyFormat = "$src attr-dict `:` type($src) `->` type($dst) `(` $dstTy `)`";
let extraClassDeclaration = [{
static llvm::Intrinsic::ID getIntrinsicID(NVVM::FPRoundingMode rnd,
NVVM::SaturationMode sat);
static llvm::Intrinsic::ID getIntrinsicID(mlir::Type dstTy,
NVVM::FPRoundingMode rnd,
NVVM::SaturationMode sat, bool hasRelu);
}];
string llvmBuilder = [{
auto intId = NVVM::ConvertBF16x2ToF8x2Op::getIntrinsicID($rnd, $sat);
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a});
auto intId = NVVM::ConvertBF16x2ToF8x2Op::getIntrinsicID($dstTy, $rnd, $sat, $relu);
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$src});
if(op.getDst().getType().isInteger(16))
$dst = packedI16;
else
@ -2185,6 +2266,117 @@ def NVVM_ConvertF6x2ToF16x2Op :
def NVVM_ConvertF4x2ToF16x2Op :
NVVM_ConvertToFP16x2Op_Base<"F4", I8, "F16">;
def NVVM_ConvertF32x2ToS2F6x2Op : NVVM_Op<"convert.f32x2.to.s2f6x2"> {
let summary = "Convert a pair of f32 inputs to S2F6x2";
let description = [{
This Op converts each of the given f32 inputs to the
S2F6x2 type. The result `dst` can be either a packed i16 type or a vector
of two i8 types.
If `dst` is returned as an i16 type, the converted values are packed such
that the value converted from `a` is stored in the upper 8 bits of `dst`
and the value converted from `b` is stored in the lower 8 bits of `dst`.
If `dst` is returned as a vector type, each converted value is stored as an
i8 element in the vector.
The `relu` attribute, when set, lowers to the '.relu' variant
of the cvt instruction.
The optional scaling-factors for each of the inputs are provided through
the operand `scaleFactor` as a packed i16 type. Only `ue8m0` is supported
as the type of the scale-factor currently.
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
}];
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
let arguments = (ins F32:$a, F32:$b,
Optional<I16>:$scaleFactor,
DefaultValuedAttr<BoolAttr, "false">:$relu);
let assemblyFormat =
"$a `,` $b (`,` $scaleFactor^)? attr-dict `:` type($dst)";
let extraClassDeclaration = [{
static IDArgPair getIntrinsicIDAndArgs(Operation &op,
LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder);
}];
string llvmBuilder = [{
auto [id, args] = NVVM::ConvertF32x2ToS2F6x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder);
llvm::Value *packedI16 = createIntrinsicCall(builder, id, args);
if(op.getDst().getType().isInteger(16))
$dst = packedI16;
else
$dst = builder.CreateBitCast(packedI16,
llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
}];
}
def NVVM_ConvertBF16x2ToS2F6x2Op : NVVM_Op<"convert.bf16x2.to.s2f6x2"> {
let summary = "Convert a pair of BF16 inputs to S2F6x2";
let description = [{
This Op converts each of the given BF16 inputs in a bf16x2 vector to the
S2F6x2 type. The result `dst` can be either a packed i16 type or a vector
of two i8 types.
If `dst` is returned as an i16 type, the converted values are packed such
that the value converted from the first element of `a` is stored in the
lower 8 bits of `dst` and the value converted from the second element of
`a` is stored in the upper 8 bits of `dst`.
If `dst` is returned as a vector type, each converted value is stored as an
i8 element in the vector.
The `relu` attribute, when set, lowers to the '.relu' variant
of the cvt instruction.
The optional scaling-factors for each of the inputs are provided through
the operand `scaleFactor` as a packed i16 type. Only `ue8m0` is supported
as the type of the scale-factor currently.
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
}];
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
let arguments = (ins
VectorOfLengthAndType<[2], [BF16]>:$src,
Optional<I16>:$scaleFactor,
DefaultValuedAttr<BoolAttr, "false">:$relu);
let assemblyFormat =
"$src (`,` $scaleFactor^)? attr-dict `:` type($src) `->` type($dst)";
let extraClassDeclaration = [{
static IDArgPair getIntrinsicIDAndArgs(Operation &op,
LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder);
}];
string llvmBuilder = [{
auto [id, args] = NVVM::ConvertBF16x2ToS2F6x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder);
llvm::Value *packedI16 = createIntrinsicCall(builder, id, args);
if(op.getDst().getType().isInteger(16))
$dst = packedI16;
else
$dst = builder.CreateBitCast(packedI16,
llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
}];
}
def NVVM_ConvertS2F6x2ToBF16x2Op : NVVM_SingleResultIntrinsicOp<"convert.s2f6x2.to.bf16x2", [], "$dst"> {
let summary = "Convert s2f6x2 to bf16x2";
let description = [{
This Op converts a pair of s2f6x2 inputs to bf16x2 type. The result `dst`
is represented as a vector of two bf16 elements.
The `relu` attribute, when set, lowers to the '.relu' variant
of the cvt instruction.
The optional scaling-factors for each of the inputs are provided through
the operand `scaleFactor` as a packed i16 type. Only `ue8m0` is supported
as the type of the scale-factor currently.
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
}];
let results = (outs VectorOfLengthAndType<[2], [BF16]>:$dst);
let arguments = (ins VectorOfLengthAndType<[2], [I8]>:$src,
Optional<I16>:$scaleFactor,
DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
DefaultValuedAttr<BoolAttr, "false">:$relu);
let assemblyFormat =
"$src (`,` $scaleFactor^)? attr-dict `:` type($src) `->` type($dst)";
}
//===----------------------------------------------------------------------===//
// NVVM Stochastic Rounding Conversion Ops
//===----------------------------------------------------------------------===//

View File

@ -414,18 +414,45 @@ LogicalResult ConvertF16x2ToF8x2Op::verify() {
LogicalResult ConvertBF16x2ToF8x2Op::verify() {
using RndMode = NVVM::FPRoundingMode;
using SatMode = NVVM::SaturationMode;
if (!llvm::isa<mlir::Float8E8M0FNUType>(getDstTy()))
return emitOpError("Only ") << mlir::Float8E8M0FNUType::get(getContext())
<< " type is supported for conversions from "
"bf16x2 to f8x2.";
bool isRoundingModeRN = getRnd() == RndMode::RN;
bool isRoundingModeRZ = getRnd() == RndMode::RZ;
bool isRoundingModeRP = getRnd() == RndMode::RP;
bool isSatFinite = getSat() == SatMode::SATFINITE;
bool hasRelu = getRelu();
auto rnd = getRnd();
if (rnd != RndMode::RZ && rnd != RndMode::RP)
return emitOpError("Only RZ and RP rounding modes are supported for "
"conversions from bf16x2 to f8x2.");
mlir::MLIRContext *ctx = getContext();
return success();
return llvm::TypeSwitch<mlir::Type, LogicalResult>(getDstTy())
.Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
[&](mlir::Type) -> LogicalResult {
if (!isRoundingModeRN)
return emitOpError("Only RN rounding mode is supported for "
"conversions from bf16x2 to ")
<< mlir::Float8E4M3FNType::get(ctx) << " and "
<< mlir::Float8E5M2Type::get(ctx) << " types";
if (!isSatFinite)
return emitOpError("Only SATFINITE saturation mode is supported "
"for conversions from bf16x2 to ")
<< mlir::Float8E4M3FNType::get(ctx) << " and "
<< mlir::Float8E5M2Type::get(ctx) << " types";
return success();
})
.Case<mlir::Float8E8M0FNUType>([&](mlir::Type) -> LogicalResult {
if (!(isRoundingModeRZ || isRoundingModeRP))
return emitOpError("Only RZ and RP rounding modes are supported for "
"conversions from bf16x2 to ")
<< mlir::Float8E8M0FNUType::get(ctx) << " type";
if (hasRelu)
return emitOpError("relu not supported for conversions to ")
<< mlir::Float8E8M0FNUType::get(ctx) << " type";
return success();
})
.Default([&](mlir::Type) -> LogicalResult {
llvm_unreachable("Invalid conversion in ConvertBF16x2ToF8x2Op");
return failure();
});
}
LogicalResult ConvertF32x2ToF4x2Op::verify() {
@ -4232,6 +4259,80 @@ llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
});
}
NVVM::IDArgPair
ConvertF16x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF16x2ToF4x2Op &op,
LLVM::ModuleTranslation &mt,
llvm::IRBuilderBase &builder) {
mlir::Type dstTy = op.getDstTy();
bool hasRelu = op.getRelu();
llvm::Intrinsic::ID intId = llvm::Intrinsic::not_intrinsic;
if (llvm::isa<mlir::Float4E2M1FNType>(dstTy))
intId = hasRelu ? llvm::Intrinsic::nvvm_f16x2_to_e2m1x2_rn_relu_satfinite
: llvm::Intrinsic::nvvm_f16x2_to_e2m1x2_rn_satfinite;
llvm::SmallVector<llvm::Value *> args;
args.push_back(mt.lookupValue(op.getSrc()));
return {intId, std::move(args)};
}
NVVM::IDArgPair
ConvertBF16x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertBF16x2ToF4x2Op &op,
LLVM::ModuleTranslation &mt,
llvm::IRBuilderBase &builder) {
mlir::Type dstTy = op.getDstTy();
bool hasRelu = op.getRelu();
llvm::Intrinsic::ID intId = llvm::Intrinsic::not_intrinsic;
if (llvm::isa<mlir::Float4E2M1FNType>(dstTy))
intId = hasRelu ? llvm::Intrinsic::nvvm_bf16x2_to_e2m1x2_rn_relu_satfinite
: llvm::Intrinsic::nvvm_bf16x2_to_e2m1x2_rn_satfinite;
llvm::SmallVector<llvm::Value *> args;
args.push_back(mt.lookupValue(op.getSrc()));
return {intId, std::move(args)};
}
llvm::Intrinsic::ID ConvertF16x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
bool hasRelu) {
return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
.Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
return hasRelu ? llvm::Intrinsic::nvvm_f16x2_to_e2m3x2_rn_relu_satfinite
: llvm::Intrinsic::nvvm_f16x2_to_e2m3x2_rn_satfinite;
})
.Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
return hasRelu ? llvm::Intrinsic::nvvm_f16x2_to_e3m2x2_rn_relu_satfinite
: llvm::Intrinsic::nvvm_f16x2_to_e3m2x2_rn_satfinite;
})
.Default([](mlir::Type) {
llvm_unreachable("Invalid conversion in ConvertF16x2ToF6x2Op");
return llvm::Intrinsic::not_intrinsic;
});
}
llvm::Intrinsic::ID ConvertBF16x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
bool hasRelu) {
return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
.Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
return hasRelu
? llvm::Intrinsic::nvvm_bf16x2_to_e2m3x2_rn_relu_satfinite
: llvm::Intrinsic::nvvm_bf16x2_to_e2m3x2_rn_satfinite;
})
.Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
return hasRelu
? llvm::Intrinsic::nvvm_bf16x2_to_e3m2x2_rn_relu_satfinite
: llvm::Intrinsic::nvvm_bf16x2_to_e3m2x2_rn_satfinite;
})
.Default([](mlir::Type) {
llvm_unreachable("Invalid conversion in ConvertBF16x2ToF6x2Op");
return llvm::Intrinsic::not_intrinsic;
});
}
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
: llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd
@ -4287,22 +4388,39 @@ llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy,
});
}
#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \
has_satf ? llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd##_satfinite \
: llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd
llvm::Intrinsic::ID
ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
NVVM::SaturationMode sat) {
ConvertBF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy,
NVVM::FPRoundingMode rnd,
NVVM::SaturationMode sat, bool hasRelu) {
bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
switch (rnd) {
case NVVM::FPRoundingMode::RZ:
return GET_BF16X2_TO_F8X2_ID(rz, hasSatFinite);
case NVVM::FPRoundingMode::RP:
return GET_BF16X2_TO_F8X2_ID(rp, hasSatFinite);
default:
llvm_unreachable("Invalid rounding mode for CvtBF16x2ToF8x2Op");
}
static constexpr llvm::Intrinsic::ID ue8m0x2IDs[] = {
llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rz,
llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rp,
llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rz_satfinite,
llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rp_satfinite,
};
return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
.Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
return hasRelu
? llvm::Intrinsic::nvvm_bf16x2_to_e4m3x2_rn_relu_satfinite
: llvm::Intrinsic::nvvm_bf16x2_to_e4m3x2_rn_satfinite;
})
.Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
return hasRelu
? llvm::Intrinsic::nvvm_bf16x2_to_e5m2x2_rn_relu_satfinite
: llvm::Intrinsic::nvvm_bf16x2_to_e5m2x2_rn_satfinite;
})
.Case<mlir::Float8E8M0FNUType>([&](mlir::Float8E8M0FNUType) {
bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
unsigned index = (hasSatFinite << 1) | hasRoundingModeRP;
return ue8m0x2IDs[index];
})
.Default([](mlir::Type) {
llvm_unreachable("Invalid conversion in ConvertBF16x2ToF8x2Op");
return llvm::Intrinsic::not_intrinsic;
});
}
NVVM::IDArgPair ConvertF8x2ToF16x2Op::getIntrinsicIDAndArgs(
@ -4397,6 +4515,74 @@ NVVM::IDArgPair ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs(
return {intId, {extendedI16}};
}
NVVM::IDArgPair ConvertF32x2ToS2F6x2Op::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::ConvertF32x2ToS2F6x2Op>(op);
bool hasRelu = thisOp.getRelu();
bool hasScale = static_cast<bool>(thisOp.getScaleFactor());
llvm::Intrinsic::ID id =
hasRelu
? llvm::Intrinsic::nvvm_ff_to_s2f6x2_rn_relu_satfinite_scale_n2_ue8m0
: llvm::Intrinsic::nvvm_ff_to_s2f6x2_rn_satfinite_scale_n2_ue8m0;
// Fill the Intrinsic Args
llvm::SmallVector<llvm::Value *> args;
args.push_back(mt.lookupValue(thisOp.getA()));
args.push_back(mt.lookupValue(thisOp.getB()));
args.push_back(hasScale ? mt.lookupValue(thisOp.getScaleFactor())
: builder.getInt16(0x7f7f));
return {id, std::move(args)};
}
NVVM::IDArgPair ConvertBF16x2ToS2F6x2Op::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::ConvertBF16x2ToS2F6x2Op>(op);
bool hasRelu = thisOp.getRelu();
bool hasScale = static_cast<bool>(thisOp.getScaleFactor());
llvm::Intrinsic::ID id =
hasRelu
? llvm::Intrinsic::
nvvm_bf16x2_to_s2f6x2_rn_relu_satfinite_scale_n2_ue8m0
: llvm::Intrinsic::nvvm_bf16x2_to_s2f6x2_rn_satfinite_scale_n2_ue8m0;
// Fill the Intrinsic Args
llvm::SmallVector<llvm::Value *> args;
args.push_back(mt.lookupValue(thisOp.getSrc()));
args.push_back(hasScale ? mt.lookupValue(thisOp.getScaleFactor())
: builder.getInt16(0x7f7f));
return {id, std::move(args)};
}
NVVM::IDArgPair ConvertS2F6x2ToBF16x2Op::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::ConvertS2F6x2ToBF16x2Op>(op);
bool hasRelu = thisOp.getRelu();
bool hasScale = static_cast<bool>(thisOp.getScaleFactor());
bool hasSat = thisOp.getSat() == NVVM::SaturationMode::SATFINITE;
static constexpr llvm::Intrinsic::ID ids[] = {
llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_scale_n2_ue8m0,
llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_relu_scale_n2_ue8m0,
llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_satfinite_scale_n2_ue8m0,
llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_relu_satfinite_scale_n2_ue8m0,
};
unsigned idx = (hasSat << 1) | hasRelu;
// Fill the Intrinsic Args
llvm::SmallVector<llvm::Value *> args;
llvm::Value *packedI16 =
builder.CreateBitCast(mt.lookupValue(thisOp.getSrc()),
llvm::Type::getInt16Ty(builder.getContext()));
args.push_back(packedI16);
args.push_back(hasScale ? mt.lookupValue(thisOp.getScaleFactor())
: builder.getInt16(0x7f7f));
return {ids[idx], std::move(args)};
}
llvm::Intrinsic::ID
Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
LLVM::ModuleTranslation &mt,

View File

@ -64,7 +64,7 @@ llvm.func @convert_f32x2_to_f8x2_rs_not_supported(%a : f32, %b : f32) {
// -----
llvm.func @convert_bf16x2_to_f8x2_rs_not_supported(%src : vector<2xbf16>) {
// expected-error @below {{Only RZ and RP rounding modes are supported for conversions from bf16x2 to f8x2.}}
// expected-error @below {{Only RZ and RP rounding modes are supported for conversions from bf16x2 to 'f8E8M0FNU' type}}
%res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16> -> i16 (f8E8M0FNU)
llvm.return
}

View File

@ -11,6 +11,34 @@ llvm.func @convert_f32x2_to_f4x2_e2m1(%srcA : f32, %srcB : f32) {
llvm.return
}
// -----
// CHECK-LABEL: @convert_f16x2_to_f4x2
llvm.func @convert_f16x2_to_f4x2(%srcA : vector<2xf16>) {
// CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.f16x2.to.e2m1x2.rn.satfinite(<2 x half> %{{.*}})
// CHECK-NEXT: %{{.*}} = trunc i16 %[[res1]] to i8
%res1 = nvvm.convert.f16x2.to.f4x2 %srcA : vector<2xf16> -> i8 (f4E2M1FN)
// CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.f16x2.to.e2m1x2.rn.relu.satfinite(<2 x half> %{{.*}})
// CHECK-NEXT: %{{.*}} = trunc i16 %[[res2]] to i8
%res2 = nvvm.convert.f16x2.to.f4x2 %srcA {relu = true} : vector<2xf16> -> i8 (f4E2M1FN)
llvm.return
}
// -----
// CHECK-LABEL: @convert_bf16x2_to_f4x2
llvm.func @convert_bf16x2_to_f4x2(%srcA : vector<2xbf16>) {
// CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.bf16x2.to.e2m1x2.rn.satfinite(<2 x bfloat> %{{.*}})
// CHECK-NEXT: %{{.*}} = trunc i16 %[[res1]] to i8
%res1 = nvvm.convert.bf16x2.to.f4x2 %srcA : vector<2xbf16> -> i8 (f4E2M1FN)
// CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.bf16x2.to.e2m1x2.rn.relu.satfinite(<2 x bfloat> %{{.*}})
// CHECK-NEXT: %{{.*}} = trunc i16 %[[res2]] to i8
%res2 = nvvm.convert.bf16x2.to.f4x2 %srcA {relu = true} : vector<2xbf16> -> i8 (f4E2M1FN)
llvm.return
}
// -----
// CHECK-LABEL: @convert_f4x2_to_f16x2
llvm.func @convert_f4x2_to_f16x2(%src : i8) {
// CHECK: %[[res1:.*]] = zext i8 %{{.*}} to i16

View File

@ -0,0 +1,17 @@
// RUN: mlir-translate -mlir-to-llvmir -verify-diagnostics %s
// -----
llvm.func @convert_f16x2_to_f4x2_invalid_type(%src : vector<2xf16>) {
// expected-error @below {{attribute 'dstTy' failed to satisfy constraint: type attribute of f4E2M1FN type}}
%res = nvvm.convert.f16x2.to.f4x2 %src : vector<2xf16> -> i8 (f8E4M3FN)
llvm.return
}
// -----
llvm.func @convert_bf16x2_to_f4x2_invalid_type(%src : vector<2xbf16>) {
// expected-error @below {{attribute 'dstTy' failed to satisfy constraint: type attribute of f4E2M1FN type}}
%res = nvvm.convert.bf16x2.to.f4x2 %src : vector<2xbf16> -> i8 (f8E4M3FN)
llvm.return
}

View File

@ -1,11 +1,20 @@
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
// CHECK-LABEL: @convert_f32x2_to_fp6x2_packed
llvm.func @convert_f32x2_to_fp6x2_packed(%srcA : f32, %srcB : f32) {
// CHECK-LABEL: @convert_f32x2_to_fp6x2_e2m3
llvm.func @convert_f32x2_to_fp6x2_e2m3(%srcA : f32, %srcB : f32) {
//CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}})
%res1 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : i16 (f6E2M3FN)
//CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}})
%res2 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB {relu = true} : i16 (f6E2M3FN)
llvm.return
}
// CHECK-LABEL: @convert_f32x2_to_fp6x2_e3m2
llvm.func @convert_f32x2_to_fp6x2_e3m2(%srcA : f32, %srcB : f32) {
//CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}})
%res2 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : i16 (f6E3M2FN)
%res1 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : i16 (f6E3M2FN)
//CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}})
%res2 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB {relu = true} : i16 (f6E3M2FN)
llvm.return
}
@ -22,6 +31,68 @@ llvm.func @convert_f32x2_to_fp6x2_vector(%srcA : f32, %srcB : f32) {
// -----
// CHECK-LABEL: @convert_f16x2_to_fp6x2_e2m3
llvm.func @convert_f16x2_to_fp6x2_e2m3(%srcA : vector<2xf16>) {
// CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e2m3x2.rn.satfinite(<2 x half> %{{.*}})
%res1 = nvvm.convert.f16x2.to.f6x2 %srcA : vector<2xf16> -> i16 (f6E2M3FN)
// CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e2m3x2.rn.relu.satfinite(<2 x half> %{{.*}})
%res2 = nvvm.convert.f16x2.to.f6x2 %srcA {relu = true} : vector<2xf16> -> i16 (f6E2M3FN)
llvm.return
}
// CHECK-LABEL: @convert_f16x2_to_fp6x2_e3m2
llvm.func @convert_f16x2_to_fp6x2_e3m2(%srcA : vector<2xf16>) {
// CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e3m2x2.rn.satfinite(<2 x half> %{{.*}})
%res1 = nvvm.convert.f16x2.to.f6x2 %srcA : vector<2xf16> -> i16 (f6E3M2FN)
// CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e3m2x2.rn.relu.satfinite(<2 x half> %{{.*}})
%res2 = nvvm.convert.f16x2.to.f6x2 %srcA {relu = true} : vector<2xf16> -> i16 (f6E3M2FN)
llvm.return
}
// CHECK-LABEL: @convert_f16x2_to_fp6x2_vector
llvm.func @convert_f16x2_to_fp6x2_vector(%srcA : vector<2xf16>) {
// CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.f16x2.to.e2m3x2.rn.satfinite(<2 x half> %{{.*}})
// CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
%res1 = nvvm.convert.f16x2.to.f6x2 %srcA : vector<2xf16> -> vector<2xi8> (f6E2M3FN)
// CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.f16x2.to.e3m2x2.rn.satfinite(<2 x half> %{{.*}})
// CHECK-NEXT: %{{.*}} = bitcast i16 %[[res2]] to <2 x i8>
%res2 = nvvm.convert.f16x2.to.f6x2 %srcA : vector<2xf16> -> vector<2xi8> (f6E3M2FN)
llvm.return
}
// -----
// CHECK-LABEL: @convert_bf16x2_to_fp6x2_e2m3
llvm.func @convert_bf16x2_to_fp6x2_e2m3(%srcA : vector<2xbf16>, %scale_factor : i16) {
// CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e2m3x2.rn.satfinite(<2 x bfloat> %{{.*}})
%res1 = nvvm.convert.bf16x2.to.f6x2 %srcA : vector<2xbf16> -> i16 (f6E2M3FN)
// CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e2m3x2.rn.relu.satfinite(<2 x bfloat> %{{.*}})
%res2 = nvvm.convert.bf16x2.to.f6x2 %srcA {relu = true} : vector<2xbf16> -> i16 (f6E2M3FN)
llvm.return
}
// CHECK-LABEL: @convert_bf16x2_to_fp6x2_e3m2
llvm.func @convert_bf16x2_to_fp6x2_e3m2(%srcA : vector<2xbf16>, %scale_factor : i16) {
// CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e3m2x2.rn.satfinite(<2 x bfloat> %{{.*}})
%res1 = nvvm.convert.bf16x2.to.f6x2 %srcA : vector<2xbf16> -> i16 (f6E3M2FN)
// CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e3m2x2.rn.relu.satfinite(<2 x bfloat> %{{.*}})
%res2 = nvvm.convert.bf16x2.to.f6x2 %srcA {relu = true} : vector<2xbf16> -> i16 (f6E3M2FN)
llvm.return
}
// CHECK-LABEL: @convert_bf16x2_to_fp6x2_vector
llvm.func @convert_bf16x2_to_fp6x2_vector(%srcA : vector<2xbf16>, %scale_factor : i16) {
// CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e2m3x2.rn.satfinite(<2 x bfloat> %{{.*}})
// CHECK-NEXT: %{{.*}} = bitcast i16 %{{.*}} to <2 x i8>
%res1 = nvvm.convert.bf16x2.to.f6x2 %srcA : vector<2xbf16> -> vector<2xi8> (f6E2M3FN)
// CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e3m2x2.rn.satfinite(<2 x bfloat> %{{.*}})
// CHECK-NEXT: %{{.*}} = bitcast i16 %{{.*}} to <2 x i8>
%res2 = nvvm.convert.bf16x2.to.f6x2 %srcA : vector<2xbf16> -> vector<2xi8> (f6E3M2FN)
llvm.return
}
// -----
// CHECK-LABEL: @convert_f6x2_to_f16x2_e2m3
llvm.func @convert_f6x2_to_f16x2_e2m3(%src : vector<2xi8>) {
// CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16

View File

@ -0,0 +1,17 @@
// RUN: mlir-translate -mlir-to-llvmir -verify-diagnostics %s
// -----
llvm.func @convert_f16x2_to_f6x2_invalid_type(%src : vector<2xf16>) {
// expected-error @below {{attribute 'dstTy' failed to satisfy constraint: type attribute of f6E2M3FN type or f6E3M2FN type}}
%res = nvvm.convert.f16x2.to.f6x2 %src : vector<2xf16> -> vector<2xi8> (f8E4M3FN)
llvm.return
}
// -----
llvm.func @convert_bf16x2_to_f6x2_invalid_type(%src : vector<2xbf16>) {
// expected-error @below {{attribute 'dstTy' failed to satisfy constraint: type attribute of f6E2M3FN type or f6E3M2FN type}}
%res = nvvm.convert.bf16x2.to.f6x2 %src : vector<2xbf16> -> vector<2xi8> (f8E4M3FN)
llvm.return
}

View File

@ -90,6 +90,25 @@ llvm.func @convert_bf16x2_to_f8x2_ue8m0(%src : vector<2xbf16>) {
llvm.return
}
// CHECK-LABEL: @convert_bf16x2_to_f8x2_e4m3
llvm.func @convert_bf16x2_to_f8x2_e4m3(%srcA : vector<2xbf16>) {
// CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e4m3x2.rn.satfinite(<2 x bfloat> %{{.*}})
%res1 = nvvm.convert.bf16x2.to.f8x2 %srcA {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> i16 (f8E4M3FN)
// CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e4m3x2.rn.relu.satfinite(<2 x bfloat> %{{.*}})
%res2 = nvvm.convert.bf16x2.to.f8x2 %srcA {relu = true, rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> i16 (f8E4M3FN)
llvm.return
}
// CHECK-LABEL: @convert_bf16x2_to_f8x2_e5m2
llvm.func @convert_bf16x2_to_f8x2_e5m2(%srcA : vector<2xbf16>) {
// CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e5m2x2.rn.satfinite(<2 x bfloat> %{{.*}})
%res1 = nvvm.convert.bf16x2.to.f8x2 %srcA {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> vector<2xi8> (f8E5M2)
// CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e5m2x2.rn.relu.satfinite(<2 x bfloat> %{{.*}})
%res2 = nvvm.convert.bf16x2.to.f8x2 %srcA {relu = true, rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> vector<2xi8> (f8E5M2)
llvm.return
}
// CHECK-LABEL: @convert_bf16x2_to_f8x2_vector_return
llvm.func @convert_bf16x2_to_f8x2_vector_return(%src : vector<2xbf16>) {
// CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rz(<2 x bfloat> %{{.*}})
@ -98,6 +117,12 @@ llvm.func @convert_bf16x2_to_f8x2_vector_return(%src : vector<2xbf16>) {
// CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rp.satfinite(<2 x bfloat> %{{.*}})
// CHECK-NEXT: %{{.*}} = bitcast i16 %[[res2]] to <2 x i8>
%res2 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rp>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> vector<2xi8> (f8E8M0FNU)
// CHECK: %[[res3:.*]] = call i16 @llvm.nvvm.bf16x2.to.e4m3x2.rn.satfinite(<2 x bfloat> %{{.*}})
// CHECK-NEXT: %{{.*}} = bitcast i16 %[[res3]] to <2 x i8>
%res3 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> vector<2xi8> (f8E4M3FN)
// CHECK: %[[res4:.*]] = call i16 @llvm.nvvm.bf16x2.to.e5m2x2.rn.relu.satfinite(<2 x bfloat> %{{.*}})
// CHECK-NEXT: %{{.*}} = bitcast i16 %[[res4]] to <2 x i8>
%res4 = nvvm.convert.bf16x2.to.f8x2 %src {relu = true, rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> vector<2xi8> (f8E5M2)
llvm.return
}

View File

@ -0,0 +1,41 @@
// RUN: mlir-translate -mlir-to-llvmir -verify-diagnostics %s
// -----
llvm.func @convert_bf16x2_to_f8x2_invalid_type(%src : vector<2xbf16>) {
// expected-error @below {{attribute 'dstTy' failed to satisfy constraint: type attribute of f8E8M0FNU type or f8E4M3FN type or f8E5M2 type}}
%res = nvvm.convert.bf16x2.to.f8x2 %src : vector<2xbf16> -> vector<2xi8> (f6E2M3FN)
llvm.return
}
// -----
llvm.func @convert_bf16x2_to_f8x2_invalid_rounding_1(%src : vector<2xbf16>) {
// expected-error @below {{Only RN rounding mode is supported for conversions from bf16x2 to 'f8E4M3FN' and 'f8E5M2' types}}
%res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rm>} : vector<2xbf16> -> vector<2xi8> (f8E4M3FN)
llvm.return
}
// -----
llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_rounding_2(%src : vector<2xbf16>) {
// expected-error @below {{Only RZ and RP rounding modes are supported for conversions from bf16x2 to 'f8E8M0FNU' type}}
%res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16> -> vector<2xi8> (f8E8M0FNU)
llvm.return
}
// -----
llvm.func @convert_bf16x2_to_f8x2_invalid_sat_mode(%src : vector<2xbf16>) {
// expected-error @below {{Only SATFINITE saturation mode is supported for conversions from bf16x2 to 'f8E4M3FN' and 'f8E5M2' types}}
%res = nvvm.convert.bf16x2.to.f8x2 %src {sat = #nvvm.sat_mode<none>, rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16> -> vector<2xi8> (f8E4M3FN)
llvm.return
}
// -----
llvm.func @convert_bf16x2_to_f8x2_invalid_relu(%src : vector<2xbf16>) {
// expected-error @below {{relu not supported for conversions to 'f8E8M0FNU' type}}
%res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rp>, relu = true} : vector<2xbf16> -> vector<2xi8> (f8E8M0FNU)
llvm.return
}

View File

@ -0,0 +1,189 @@
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
llvm.func @convert_f32x2_to_s2f6x2(%srcA : f32, %srcB : f32) -> i16 {
// CHECK-LABEL: define i16 @convert_f32x2_to_s2f6x2(float %0, float %1) {
// CHECK-NEXT: %3 = call i16 @llvm.nvvm.ff.to.s2f6x2.rn.satfinite.scale.n2.ue8m0(float %0, float %1, i16 32639)
// CHECK-NEXT: %4 = call i16 @llvm.nvvm.ff.to.s2f6x2.rn.relu.satfinite.scale.n2.ue8m0(float %0, float %1, i16 32639)
// CHECK-NEXT: %5 = or i16 %3, %4
// CHECK-NEXT: ret i16 %5
// CHECK-NEXT: }
%res1 = nvvm.convert.f32x2.to.s2f6x2 %srcA, %srcB : i16
%res2 = nvvm.convert.f32x2.to.s2f6x2 %srcA, %srcB {relu = true} : i16
// Combine results to avoid dead code elimination
%final_result = llvm.or %res1, %res2 : i16
llvm.return %final_result : i16
}
llvm.func @convert_f32x2_to_s2f6x2_scale(%srcA : f32, %srcB : f32, %scale : i16) -> i16 {
// CHECK-LABEL: define i16 @convert_f32x2_to_s2f6x2_scale(float %0, float %1, i16 %2) {
// CHECK-NEXT: %4 = call i16 @llvm.nvvm.ff.to.s2f6x2.rn.satfinite.scale.n2.ue8m0(float %0, float %1, i16 %2)
// CHECK-NEXT: %5 = call i16 @llvm.nvvm.ff.to.s2f6x2.rn.relu.satfinite.scale.n2.ue8m0(float %0, float %1, i16 %2)
// CHECK-NEXT: %6 = or i16 %4, %5
// CHECK-NEXT: ret i16 %6
// CHECK-NEXT: }
%res1 = nvvm.convert.f32x2.to.s2f6x2 %srcA, %srcB, %scale : i16
%res2 = nvvm.convert.f32x2.to.s2f6x2 %srcA, %srcB, %scale {relu = true} : i16
// Combine results to avoid dead code elimination
%final_result = llvm.or %res1, %res2 : i16
llvm.return %final_result : i16
}
llvm.func @convert_f32x2_to_s2f6x2_vector(%srcA : f32, %srcB : f32) -> vector<2xi8> {
// CHECK-LABEL: define <2 x i8> @convert_f32x2_to_s2f6x2_vector(float %0, float %1) {
// CHECK-NEXT: %3 = call i16 @llvm.nvvm.ff.to.s2f6x2.rn.satfinite.scale.n2.ue8m0(float %0, float %1, i16 32639)
// CHECK-NEXT: %4 = bitcast i16 %3 to <2 x i8>
// CHECK-NEXT: ret <2 x i8> %4
// CHECK-NEXT: }
%res1 = nvvm.convert.f32x2.to.s2f6x2 %srcA, %srcB : vector<2xi8>
llvm.return %res1 : vector<2xi8>
}
llvm.func @convert_f32x2_to_s2f6x2_vector_scale(%srcA : f32, %srcB : f32, %scale : i16) -> vector<2xi8> {
// CHECK-LABEL: define <2 x i8> @convert_f32x2_to_s2f6x2_vector_scale(float %0, float %1, i16 %2) {
// CHECK-NEXT: %4 = call i16 @llvm.nvvm.ff.to.s2f6x2.rn.satfinite.scale.n2.ue8m0(float %0, float %1, i16 %2)
// CHECK-NEXT: %5 = bitcast i16 %4 to <2 x i8>
// CHECK-NEXT: ret <2 x i8> %5
// CHECK-NEXT: }
%res1 = nvvm.convert.f32x2.to.s2f6x2 %srcA, %srcB, %scale : vector<2xi8>
llvm.return %res1 : vector<2xi8>
}
llvm.func @convert_bf16x2_to_s2f6x2(%srcA : vector<2xbf16>) -> i16 {
// CHECK-LABEL: define i16 @convert_bf16x2_to_s2f6x2(<2 x bfloat> %0) {
// CHECK-NEXT: %2 = call i16 @llvm.nvvm.bf16x2.to.s2f6x2.rn.satfinite.scale.n2.ue8m0(<2 x bfloat> %0, i16 32639)
// CHECK-NEXT: %3 = call i16 @llvm.nvvm.bf16x2.to.s2f6x2.rn.relu.satfinite.scale.n2.ue8m0(<2 x bfloat> %0, i16 32639)
// CHECK-NEXT: %4 = or i16 %2, %3
// CHECK-NEXT: ret i16 %4
// CHECK-NEXT: }
%res1 = nvvm.convert.bf16x2.to.s2f6x2 %srcA : vector<2xbf16> -> i16
%res2 = nvvm.convert.bf16x2.to.s2f6x2 %srcA {relu = true} : vector<2xbf16> -> i16
// Combine results to avoid dead code elimination
%final_result = llvm.or %res1, %res2 : i16
llvm.return %final_result : i16
}
llvm.func @convert_bf16x2_to_s2f6x2_scale(%srcA : vector<2xbf16>, %scale : i16) -> i16 {
// CHECK-LABEL: define i16 @convert_bf16x2_to_s2f6x2_scale(<2 x bfloat> %0, i16 %1) {
// CHECK-NEXT: %3 = call i16 @llvm.nvvm.bf16x2.to.s2f6x2.rn.satfinite.scale.n2.ue8m0(<2 x bfloat> %0, i16 %1)
// CHECK-NEXT: %4 = call i16 @llvm.nvvm.bf16x2.to.s2f6x2.rn.relu.satfinite.scale.n2.ue8m0(<2 x bfloat> %0, i16 %1)
// CHECK-NEXT: %5 = or i16 %3, %4
// CHECK-NEXT: ret i16 %5
// CHECK-NEXT: }
%res1 = nvvm.convert.bf16x2.to.s2f6x2 %srcA, %scale : vector<2xbf16> -> i16
%res2 = nvvm.convert.bf16x2.to.s2f6x2 %srcA, %scale {relu = true} : vector<2xbf16> -> i16
// Combine results to avoid dead code elimination
%final_result = llvm.or %res1, %res2 : i16
llvm.return %final_result : i16
}
llvm.func @convert_bf16x2_to_s2f6x2_vector(%srcA : vector<2xbf16>) -> vector<2xi8> {
// CHECK-LABEL: define <2 x i8> @convert_bf16x2_to_s2f6x2_vector(<2 x bfloat> %0) {
// CHECK-NEXT: %2 = call i16 @llvm.nvvm.bf16x2.to.s2f6x2.rn.satfinite.scale.n2.ue8m0(<2 x bfloat> %0, i16 32639)
// CHECK-NEXT: %3 = bitcast i16 %2 to <2 x i8>
// CHECK-NEXT: ret <2 x i8> %3
// CHECK-NEXT: }
%res1 = nvvm.convert.bf16x2.to.s2f6x2 %srcA : vector<2xbf16> -> vector<2xi8>
llvm.return %res1 : vector<2xi8>
}
llvm.func @convert_bf16x2_to_s2f6x2_vector_scale(%srcA : vector<2xbf16>, %scale : i16) -> vector<2xi8> {
// CHECK-LABEL: define <2 x i8> @convert_bf16x2_to_s2f6x2_vector_scale(<2 x bfloat> %0, i16 %1) {
// CHECK-NEXT: %3 = call i16 @llvm.nvvm.bf16x2.to.s2f6x2.rn.satfinite.scale.n2.ue8m0(<2 x bfloat> %0, i16 %1)
// CHECK-NEXT: %4 = bitcast i16 %3 to <2 x i8>
// CHECK-NEXT: ret <2 x i8> %4
// CHECK-NEXT: }
%res1 = nvvm.convert.bf16x2.to.s2f6x2 %srcA, %scale : vector<2xbf16> -> vector<2xi8>
llvm.return %res1 : vector<2xi8>
}
// 1. no relu, no scale, no satfinite
llvm.func @convert_s2f6x2_to_bf16x2(%src : vector<2xi8>) -> vector<2xbf16> {
// CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2(<2 x i8> %0) {
// CHECK-NEXT: %2 = bitcast <2 x i8> %0 to i16
// CHECK-NEXT: %3 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.scale.n2.ue8m0(i16 %2, i16 32639)
// CHECK-NEXT: ret <2 x bfloat> %3
// CHECK-NEXT: }
%res = nvvm.convert.s2f6x2.to.bf16x2 %src : vector<2xi8> -> vector<2xbf16>
llvm.return %res : vector<2xbf16>
}
// 2. relu, no scale, no satfinite
llvm.func @convert_s2f6x2_to_bf16x2_relu(%src : vector<2xi8>) -> vector<2xbf16> {
// CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_relu(<2 x i8> %0) {
// CHECK-NEXT: %2 = bitcast <2 x i8> %0 to i16
// CHECK-NEXT: %3 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.relu.scale.n2.ue8m0(i16 %2, i16 32639)
// CHECK-NEXT: ret <2 x bfloat> %3
// CHECK-NEXT: }
%res = nvvm.convert.s2f6x2.to.bf16x2 %src {relu = true} : vector<2xi8> -> vector<2xbf16>
llvm.return %res : vector<2xbf16>
}
// 3. no relu, with scale, no satfinite
llvm.func @convert_s2f6x2_to_bf16x2_scale(%src : vector<2xi8>, %scale : i16) -> vector<2xbf16> {
// CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_scale(<2 x i8> %0, i16 %1) {
// CHECK-NEXT: %3 = bitcast <2 x i8> %0 to i16
// CHECK-NEXT: %4 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.scale.n2.ue8m0(i16 %3, i16 %1)
// CHECK-NEXT: ret <2 x bfloat> %4
// CHECK-NEXT: }
%res = nvvm.convert.s2f6x2.to.bf16x2 %src, %scale : vector<2xi8> -> vector<2xbf16>
llvm.return %res : vector<2xbf16>
}
// 4. relu, with scale, no satfinite
llvm.func @convert_s2f6x2_to_bf16x2_scale_relu(%src : vector<2xi8>, %scale : i16) -> vector<2xbf16> {
// CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_scale_relu(<2 x i8> %0, i16 %1) {
// CHECK-NEXT: %3 = bitcast <2 x i8> %0 to i16
// CHECK-NEXT: %4 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.relu.scale.n2.ue8m0(i16 %3, i16 %1)
// CHECK-NEXT: ret <2 x bfloat> %4
// CHECK-NEXT: }
%res = nvvm.convert.s2f6x2.to.bf16x2 %src, %scale {relu = true} : vector<2xi8> -> vector<2xbf16>
llvm.return %res : vector<2xbf16>
}
// 5. no relu, no scale, satfinite
llvm.func @convert_s2f6x2_to_bf16x2_satfinite(%src : vector<2xi8>) -> vector<2xbf16> {
// CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_satfinite(<2 x i8> %0) {
// CHECK-NEXT: %2 = bitcast <2 x i8> %0 to i16
// CHECK-NEXT: %3 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.satfinite.scale.n2.ue8m0(i16 %2, i16 32639)
// CHECK-NEXT: ret <2 x bfloat> %3
// CHECK-NEXT: }
%res = nvvm.convert.s2f6x2.to.bf16x2 %src {sat = #nvvm.sat_mode<satfinite>} : vector<2xi8> -> vector<2xbf16>
llvm.return %res : vector<2xbf16>
}
// 6. relu, no scale, satfinite
llvm.func @convert_s2f6x2_to_bf16x2_relu_satfinite(%src : vector<2xi8>) -> vector<2xbf16> {
// CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_relu_satfinite(<2 x i8> %0) {
// CHECK-NEXT: %2 = bitcast <2 x i8> %0 to i16
// CHECK-NEXT: %3 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.relu.satfinite.scale.n2.ue8m0(i16 %2, i16 32639)
// CHECK-NEXT: ret <2 x bfloat> %3
// CHECK-NEXT: }
%res = nvvm.convert.s2f6x2.to.bf16x2 %src {relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xi8> -> vector<2xbf16>
llvm.return %res : vector<2xbf16>
}
// 7. no relu, with scale, satfinite
llvm.func @convert_s2f6x2_to_bf16x2_scale_satfinite(%src : vector<2xi8>, %scale : i16) -> vector<2xbf16> {
// CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_scale_satfinite(<2 x i8> %0, i16 %1) {
// CHECK-NEXT: %3 = bitcast <2 x i8> %0 to i16
// CHECK-NEXT: %4 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.satfinite.scale.n2.ue8m0(i16 %3, i16 %1)
// CHECK-NEXT: ret <2 x bfloat> %4
// CHECK-NEXT: }
%res = nvvm.convert.s2f6x2.to.bf16x2 %src, %scale {sat = #nvvm.sat_mode<satfinite>} : vector<2xi8> -> vector<2xbf16>
llvm.return %res : vector<2xbf16>
}
// 8. relu, with scale, satfinite
llvm.func @convert_s2f6x2_to_bf16x2_scale_relu_satfinite(%src : vector<2xi8>, %scale : i16) -> vector<2xbf16> {
// CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_scale_relu_satfinite(<2 x i8> %0, i16 %1) {
// CHECK-NEXT: %3 = bitcast <2 x i8> %0 to i16
// CHECK-NEXT: %4 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.relu.satfinite.scale.n2.ue8m0(i16 %3, i16 %1)
// CHECK-NEXT: ret <2 x bfloat> %4
// CHECK-NEXT: }
%res = nvvm.convert.s2f6x2.to.bf16x2 %src, %scale {relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xi8> -> vector<2xbf16>
llvm.return %res : vector<2xbf16>
}

View File

@ -209,22 +209,6 @@ llvm.func @nvvm_cvt_f16x2_to_f8x2_invalid_type(%src : vector<2xf16>) {
// -----
llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_type(%src : vector<2xbf16>) {
// expected-error @below {{Only 'f8E8M0FNU' type is supported for conversions from bf16x2 to f8x2.}}
%res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16> -> i16 (f8E4M3FN)
llvm.return
}
// -----
llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_rounding(%src : vector<2xbf16>) {
// expected-error @below {{Only RZ and RP rounding modes are supported for conversions from bf16x2 to f8x2.}}
%res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16> -> i16 (f8E8M0FNU)
llvm.return
}
// -----
llvm.func @nvvm_cvt_f32x2_to_f6x2_invalid_type(%a : f32, %b : f32) {
// expected-error @below {{Only 'f6E2M3FN' and 'f6E3M2FN' types are supported for conversions from f32x2 to f6x2.}}
%res = nvvm.convert.f32x2.to.f6x2 %a, %b : i16 (f8E8M0FNU)