[MLIR][NVVM] Add new narrow FP convert Ops (#184291)
This change adds the following NVVM Ops for new narrow FP conversions
introduced in PTX 9.1:
- `convert.{f32x2/bf16x2}.to.s2f6x2`
- `convert.s2f6x2.to.bf16x2`
- `convert.bf16x2.to.f8x2` (extended for `f8E4M3FN` and `f8E5M2` types)
- `convert.{f16x2/bf16x2}.to.f6x2`
- `convert.{f16x2/bf16x2}.to.f4x2`
PTX ISA Reference:
https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt
This commit is contained in:
parent
e326ff2a88
commit
63231ebfe7
@ -1972,6 +1972,44 @@ def NVVM_ConvertF32x2ToF4x2Op : NVVM_Op<"convert.f32x2.to.f4x2"> {
|
||||
}];
|
||||
}
|
||||
|
||||
class NVVM_ConvertFPx2ToF4x2Op<string srcType>
|
||||
: NVVM_Op<"convert."#!tolower(srcType)#"x2.to.f4x2"> {
|
||||
let summary = "Convert an " # !tolower(srcType) # "x2 input to f4x2";
|
||||
let description = [{
|
||||
This Op converts each of the given }]#srcType#[{ inputs in an }]#!tolower
|
||||
(srcType)#[{x2 vector to the specified fp4 type.
|
||||
The result `dst` is returned as an i8 type where the converted values are
|
||||
packed such that the value converted from the first element of `a` is
|
||||
stored in the lower 4 bits of `dst` and the value converted from the second
|
||||
element of `a` is stored in the upper 4 bits of `dst`.
|
||||
The `relu` attribute, when set, lowers to the '.relu' variant of
|
||||
the cvt instruction.
|
||||
}];
|
||||
let results = (outs I8:$dst);
|
||||
let arguments = (ins
|
||||
VectorOfLengthAndType<[2], [!cast<Type>(srcType)]>:$src,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$relu,
|
||||
TypeAttrOf<F4E2M1FN>:$dstTy);
|
||||
|
||||
let assemblyFormat =
|
||||
"$src attr-dict `:` type($src) `->` type($dst) `(` $dstTy `)`";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static NVVM::IDArgPair
|
||||
getIntrinsicIDAndArgs(NVVM::Convert}]#srcType#[{x2ToF4x2Op &op,
|
||||
LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder);
|
||||
}];
|
||||
|
||||
string llvmBuilder = [{
|
||||
auto [intId, args] = NVVM::Convert}]#srcType#[{x2ToF4x2Op::getIntrinsicIDAndArgs(op, moduleTranslation, builder);
|
||||
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, args);
|
||||
$dst = builder.CreateTruncOrBitCast(packedI16, llvm::Type::getInt8Ty(builder.getContext()));
|
||||
}];
|
||||
}
|
||||
|
||||
def NVVM_ConvertF16x2ToF4x2Op : NVVM_ConvertFPx2ToF4x2Op<"F16">;
|
||||
def NVVM_ConvertBF16x2ToF4x2Op : NVVM_ConvertFPx2ToF4x2Op<"BF16">;
|
||||
|
||||
def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> {
|
||||
let summary = "Convert a pair of float inputs to f6x2";
|
||||
let description = [{
|
||||
@ -2015,6 +2053,48 @@ def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> {
|
||||
}];
|
||||
}
|
||||
|
||||
class NVVM_ConvertFPx2ToF6x2Op<string srcType>
|
||||
: NVVM_Op<"convert."#!tolower(srcType)#"x2.to.f6x2"> {
|
||||
let summary = "Convert an " # !tolower(srcType) # "x2 input to f6x2";
|
||||
let description = [{
|
||||
This Op converts each of the given }]#srcType#[{ inputs in an }]#!tolower
|
||||
(srcType)#[{x2 vector to the specified fp6 type. The result `dst` is
|
||||
represented either as an i16 type or as a vector of two i8 types.
|
||||
If `dst` is returned as an i16 type, the converted values are packed such
|
||||
that the value converted from the first element of `a` is stored in the
|
||||
lower 8 bits of `dst` with 2 MSB bits padded with zeros and the value
|
||||
converted from the second element of `a` is stored in the upper 8 bits of
|
||||
`dst` with 2 MSB bits padded with zeros.
|
||||
If `dst` is returned as a vector type, each converted value is stored as an
|
||||
i8 element in the vector with 2 MSB bits padded with zeros.
|
||||
The `relu` attribute, when set, lowers to the '.relu' variant of
|
||||
the cvt instruction.
|
||||
}];
|
||||
|
||||
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
|
||||
let arguments = (ins
|
||||
VectorOfLengthAndType<[2], [!cast<Type>(srcType)]>:$src,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$relu,
|
||||
TypeAttrOf<AnyTypeOf<[F6E2M3FN, F6E3M2FN]>>:$dstTy);
|
||||
let assemblyFormat = "$src attr-dict `:` type($src) `->` type($dst) `(` $dstTy `)`";
|
||||
let extraClassDeclaration = [{
|
||||
static llvm::Intrinsic::ID getIntrinsicID(mlir::Type dstTy, bool hasRelu);
|
||||
}];
|
||||
|
||||
string llvmBuilder = [{
|
||||
auto intId = NVVM::Convert}]#srcType#[{x2ToF6x2Op::getIntrinsicID($dstTy, $relu);
|
||||
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$src});
|
||||
if(op.getDst().getType().isInteger(16))
|
||||
$dst = packedI16;
|
||||
else
|
||||
$dst = builder.CreateBitCast(packedI16,
|
||||
llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
|
||||
}];
|
||||
}
|
||||
|
||||
def NVVM_ConvertF16x2ToF6x2Op : NVVM_ConvertFPx2ToF6x2Op<"F16">;
|
||||
def NVVM_ConvertBF16x2ToF6x2Op : NVVM_ConvertFPx2ToF6x2Op<"BF16">;
|
||||
|
||||
def NVVM_ConvertF32x2ToF8x2Op : NVVM_Op<"convert.f32x2.to.f8x2"> {
|
||||
let summary = "Convert a pair of float inputs to f8x2";
|
||||
let description = [{
|
||||
@ -2109,13 +2189,12 @@ def NVVM_ConvertBF16x2ToF8x2Op : NVVM_Op<"convert.bf16x2.to.f8x2"> {
|
||||
let summary = "Convert a pair of bf16 inputs to f8x2";
|
||||
let description = [{
|
||||
This Op converts the given bf16 inputs in a bf16x2 vector to the specified
|
||||
f8 type.
|
||||
The result `dst` is represented as an i16 type or as a vector
|
||||
of two i8 types.
|
||||
If `dst` is returned as an i16 type, the converted values from `a`
|
||||
are packed such that the value converted from the first element of `a`
|
||||
is stored in the upper 8 bits of `dst` and the value converted from the
|
||||
second element of `a` is stored in the lower 8 bits of `dst`.
|
||||
f8 type. The result `dst` is represented either as a packed i16 type or as
|
||||
a vector of two i8 types.
|
||||
If `dst` is returned as an i16 type, the converted values are packed such
|
||||
that the value converted from the first element of `a` is stored in the
|
||||
lower 8 bits of `dst` and the value converted from the second element of
|
||||
`a` is stored in the upper 8 bits of `dst`.
|
||||
If `dst` is returned as a vector type, each converted value is stored as an
|
||||
i8 element in the vector.
|
||||
The `rnd` and `sat` attributes specify the rounding and saturation modes
|
||||
@ -2127,20 +2206,22 @@ def NVVM_ConvertBF16x2ToF8x2Op : NVVM_Op<"convert.bf16x2.to.f8x2"> {
|
||||
let hasVerifier = 1;
|
||||
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
|
||||
let arguments = (ins
|
||||
VectorOfLengthAndType<[2], [BF16]>:$a,
|
||||
VectorOfLengthAndType<[2], [BF16]>:$src,
|
||||
DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
|
||||
DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
|
||||
TypeAttr:$dstTy);
|
||||
let assemblyFormat = "$a attr-dict `:` type($a) `->` type($dst) `(` $dstTy `)`";
|
||||
DefaultValuedAttr<BoolAttr, "false">:$relu,
|
||||
TypeAttrOf<AnyTypeOf<[F8E8M0FNU, F8E4M3FN, F8E5M2]>>:$dstTy);
|
||||
let assemblyFormat = "$src attr-dict `:` type($src) `->` type($dst) `(` $dstTy `)`";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static llvm::Intrinsic::ID getIntrinsicID(NVVM::FPRoundingMode rnd,
|
||||
NVVM::SaturationMode sat);
|
||||
static llvm::Intrinsic::ID getIntrinsicID(mlir::Type dstTy,
|
||||
NVVM::FPRoundingMode rnd,
|
||||
NVVM::SaturationMode sat, bool hasRelu);
|
||||
}];
|
||||
|
||||
string llvmBuilder = [{
|
||||
auto intId = NVVM::ConvertBF16x2ToF8x2Op::getIntrinsicID($rnd, $sat);
|
||||
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a});
|
||||
auto intId = NVVM::ConvertBF16x2ToF8x2Op::getIntrinsicID($dstTy, $rnd, $sat, $relu);
|
||||
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$src});
|
||||
if(op.getDst().getType().isInteger(16))
|
||||
$dst = packedI16;
|
||||
else
|
||||
@ -2185,6 +2266,117 @@ def NVVM_ConvertF6x2ToF16x2Op :
|
||||
def NVVM_ConvertF4x2ToF16x2Op :
|
||||
NVVM_ConvertToFP16x2Op_Base<"F4", I8, "F16">;
|
||||
|
||||
def NVVM_ConvertF32x2ToS2F6x2Op : NVVM_Op<"convert.f32x2.to.s2f6x2"> {
|
||||
let summary = "Convert a pair of f32 inputs to S2F6x2";
|
||||
let description = [{
|
||||
This Op converts each of the given f32 inputs to the
|
||||
S2F6x2 type. The result `dst` can be either a packed i16 type or a vector
|
||||
of two i8 types.
|
||||
If `dst` is returned as an i16 type, the converted values are packed such
|
||||
that the value converted from `a` is stored in the upper 8 bits of `dst`
|
||||
and the value converted from `b` is stored in the lower 8 bits of `dst`.
|
||||
If `dst` is returned as a vector type, each converted value is stored as an
|
||||
i8 element in the vector.
|
||||
The `relu` attribute, when set, lowers to the '.relu' variant
|
||||
of the cvt instruction.
|
||||
The optional scaling-factors for each of the inputs are provided through
|
||||
the operand `scaleFactor` as a packed i16 type. Only `ue8m0` is supported
|
||||
as the type of the scale-factor currently.
|
||||
|
||||
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
|
||||
}];
|
||||
|
||||
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
|
||||
let arguments = (ins F32:$a, F32:$b,
|
||||
Optional<I16>:$scaleFactor,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$relu);
|
||||
let assemblyFormat =
|
||||
"$a `,` $b (`,` $scaleFactor^)? attr-dict `:` type($dst)";
|
||||
let extraClassDeclaration = [{
|
||||
static IDArgPair getIntrinsicIDAndArgs(Operation &op,
|
||||
LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder);
|
||||
}];
|
||||
|
||||
string llvmBuilder = [{
|
||||
auto [id, args] = NVVM::ConvertF32x2ToS2F6x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder);
|
||||
llvm::Value *packedI16 = createIntrinsicCall(builder, id, args);
|
||||
if(op.getDst().getType().isInteger(16))
|
||||
$dst = packedI16;
|
||||
else
|
||||
$dst = builder.CreateBitCast(packedI16,
|
||||
llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
|
||||
}];
|
||||
}
|
||||
|
||||
def NVVM_ConvertBF16x2ToS2F6x2Op : NVVM_Op<"convert.bf16x2.to.s2f6x2"> {
|
||||
let summary = "Convert a pair of BF16 inputs to S2F6x2";
|
||||
let description = [{
|
||||
This Op converts each of the given BF16 inputs in a bf16x2 vector to the
|
||||
S2F6x2 type. The result `dst` can be either a packed i16 type or a vector
|
||||
of two i8 types.
|
||||
If `dst` is returned as an i16 type, the converted values are packed such
|
||||
that the value converted from the first element of `a` is stored in the
|
||||
lower 8 bits of `dst` and the value converted from the second element of
|
||||
`a` is stored in the upper 8 bits of `dst`.
|
||||
If `dst` is returned as a vector type, each converted value is stored as an
|
||||
i8 element in the vector.
|
||||
The `relu` attribute, when set, lowers to the '.relu' variant
|
||||
of the cvt instruction.
|
||||
The optional scaling-factors for each of the inputs are provided through
|
||||
the operand `scaleFactor` as a packed i16 type. Only `ue8m0` is supported
|
||||
as the type of the scale-factor currently.
|
||||
|
||||
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
|
||||
}];
|
||||
|
||||
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
|
||||
let arguments = (ins
|
||||
VectorOfLengthAndType<[2], [BF16]>:$src,
|
||||
Optional<I16>:$scaleFactor,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$relu);
|
||||
let assemblyFormat =
|
||||
"$src (`,` $scaleFactor^)? attr-dict `:` type($src) `->` type($dst)";
|
||||
let extraClassDeclaration = [{
|
||||
static IDArgPair getIntrinsicIDAndArgs(Operation &op,
|
||||
LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder);
|
||||
}];
|
||||
|
||||
string llvmBuilder = [{
|
||||
auto [id, args] = NVVM::ConvertBF16x2ToS2F6x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder);
|
||||
llvm::Value *packedI16 = createIntrinsicCall(builder, id, args);
|
||||
if(op.getDst().getType().isInteger(16))
|
||||
$dst = packedI16;
|
||||
else
|
||||
$dst = builder.CreateBitCast(packedI16,
|
||||
llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
|
||||
}];
|
||||
}
|
||||
|
||||
def NVVM_ConvertS2F6x2ToBF16x2Op : NVVM_SingleResultIntrinsicOp<"convert.s2f6x2.to.bf16x2", [], "$dst"> {
|
||||
let summary = "Convert s2f6x2 to bf16x2";
|
||||
let description = [{
|
||||
This Op converts a pair of s2f6x2 inputs to bf16x2 type. The result `dst`
|
||||
is represented as a vector of two bf16 elements.
|
||||
|
||||
The `relu` attribute, when set, lowers to the '.relu' variant
|
||||
of the cvt instruction.
|
||||
|
||||
The optional scaling-factors for each of the inputs are provided through
|
||||
the operand `scaleFactor` as a packed i16 type. Only `ue8m0` is supported
|
||||
as the type of the scale-factor currently.
|
||||
|
||||
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
|
||||
}];
|
||||
|
||||
let results = (outs VectorOfLengthAndType<[2], [BF16]>:$dst);
|
||||
let arguments = (ins VectorOfLengthAndType<[2], [I8]>:$src,
|
||||
Optional<I16>:$scaleFactor,
|
||||
DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$relu);
|
||||
let assemblyFormat =
|
||||
"$src (`,` $scaleFactor^)? attr-dict `:` type($src) `->` type($dst)";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NVVM Stochastic Rounding Conversion Ops
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@ -414,18 +414,45 @@ LogicalResult ConvertF16x2ToF8x2Op::verify() {
|
||||
|
||||
LogicalResult ConvertBF16x2ToF8x2Op::verify() {
|
||||
using RndMode = NVVM::FPRoundingMode;
|
||||
using SatMode = NVVM::SaturationMode;
|
||||
|
||||
if (!llvm::isa<mlir::Float8E8M0FNUType>(getDstTy()))
|
||||
return emitOpError("Only ") << mlir::Float8E8M0FNUType::get(getContext())
|
||||
<< " type is supported for conversions from "
|
||||
"bf16x2 to f8x2.";
|
||||
bool isRoundingModeRN = getRnd() == RndMode::RN;
|
||||
bool isRoundingModeRZ = getRnd() == RndMode::RZ;
|
||||
bool isRoundingModeRP = getRnd() == RndMode::RP;
|
||||
bool isSatFinite = getSat() == SatMode::SATFINITE;
|
||||
bool hasRelu = getRelu();
|
||||
|
||||
auto rnd = getRnd();
|
||||
if (rnd != RndMode::RZ && rnd != RndMode::RP)
|
||||
return emitOpError("Only RZ and RP rounding modes are supported for "
|
||||
"conversions from bf16x2 to f8x2.");
|
||||
mlir::MLIRContext *ctx = getContext();
|
||||
|
||||
return success();
|
||||
return llvm::TypeSwitch<mlir::Type, LogicalResult>(getDstTy())
|
||||
.Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
|
||||
[&](mlir::Type) -> LogicalResult {
|
||||
if (!isRoundingModeRN)
|
||||
return emitOpError("Only RN rounding mode is supported for "
|
||||
"conversions from bf16x2 to ")
|
||||
<< mlir::Float8E4M3FNType::get(ctx) << " and "
|
||||
<< mlir::Float8E5M2Type::get(ctx) << " types";
|
||||
if (!isSatFinite)
|
||||
return emitOpError("Only SATFINITE saturation mode is supported "
|
||||
"for conversions from bf16x2 to ")
|
||||
<< mlir::Float8E4M3FNType::get(ctx) << " and "
|
||||
<< mlir::Float8E5M2Type::get(ctx) << " types";
|
||||
return success();
|
||||
})
|
||||
.Case<mlir::Float8E8M0FNUType>([&](mlir::Type) -> LogicalResult {
|
||||
if (!(isRoundingModeRZ || isRoundingModeRP))
|
||||
return emitOpError("Only RZ and RP rounding modes are supported for "
|
||||
"conversions from bf16x2 to ")
|
||||
<< mlir::Float8E8M0FNUType::get(ctx) << " type";
|
||||
if (hasRelu)
|
||||
return emitOpError("relu not supported for conversions to ")
|
||||
<< mlir::Float8E8M0FNUType::get(ctx) << " type";
|
||||
return success();
|
||||
})
|
||||
.Default([&](mlir::Type) -> LogicalResult {
|
||||
llvm_unreachable("Invalid conversion in ConvertBF16x2ToF8x2Op");
|
||||
return failure();
|
||||
});
|
||||
}
|
||||
|
||||
LogicalResult ConvertF32x2ToF4x2Op::verify() {
|
||||
@ -4232,6 +4259,80 @@ llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
|
||||
});
|
||||
}
|
||||
|
||||
NVVM::IDArgPair
|
||||
ConvertF16x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF16x2ToF4x2Op &op,
|
||||
LLVM::ModuleTranslation &mt,
|
||||
llvm::IRBuilderBase &builder) {
|
||||
mlir::Type dstTy = op.getDstTy();
|
||||
bool hasRelu = op.getRelu();
|
||||
|
||||
llvm::Intrinsic::ID intId = llvm::Intrinsic::not_intrinsic;
|
||||
|
||||
if (llvm::isa<mlir::Float4E2M1FNType>(dstTy))
|
||||
intId = hasRelu ? llvm::Intrinsic::nvvm_f16x2_to_e2m1x2_rn_relu_satfinite
|
||||
: llvm::Intrinsic::nvvm_f16x2_to_e2m1x2_rn_satfinite;
|
||||
|
||||
llvm::SmallVector<llvm::Value *> args;
|
||||
args.push_back(mt.lookupValue(op.getSrc()));
|
||||
|
||||
return {intId, std::move(args)};
|
||||
}
|
||||
|
||||
NVVM::IDArgPair
|
||||
ConvertBF16x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertBF16x2ToF4x2Op &op,
|
||||
LLVM::ModuleTranslation &mt,
|
||||
llvm::IRBuilderBase &builder) {
|
||||
mlir::Type dstTy = op.getDstTy();
|
||||
bool hasRelu = op.getRelu();
|
||||
|
||||
llvm::Intrinsic::ID intId = llvm::Intrinsic::not_intrinsic;
|
||||
|
||||
if (llvm::isa<mlir::Float4E2M1FNType>(dstTy))
|
||||
intId = hasRelu ? llvm::Intrinsic::nvvm_bf16x2_to_e2m1x2_rn_relu_satfinite
|
||||
: llvm::Intrinsic::nvvm_bf16x2_to_e2m1x2_rn_satfinite;
|
||||
|
||||
llvm::SmallVector<llvm::Value *> args;
|
||||
args.push_back(mt.lookupValue(op.getSrc()));
|
||||
|
||||
return {intId, std::move(args)};
|
||||
}
|
||||
|
||||
llvm::Intrinsic::ID ConvertF16x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
|
||||
bool hasRelu) {
|
||||
return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
|
||||
.Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
|
||||
return hasRelu ? llvm::Intrinsic::nvvm_f16x2_to_e2m3x2_rn_relu_satfinite
|
||||
: llvm::Intrinsic::nvvm_f16x2_to_e2m3x2_rn_satfinite;
|
||||
})
|
||||
.Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
|
||||
return hasRelu ? llvm::Intrinsic::nvvm_f16x2_to_e3m2x2_rn_relu_satfinite
|
||||
: llvm::Intrinsic::nvvm_f16x2_to_e3m2x2_rn_satfinite;
|
||||
})
|
||||
.Default([](mlir::Type) {
|
||||
llvm_unreachable("Invalid conversion in ConvertF16x2ToF6x2Op");
|
||||
return llvm::Intrinsic::not_intrinsic;
|
||||
});
|
||||
}
|
||||
|
||||
llvm::Intrinsic::ID ConvertBF16x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
|
||||
bool hasRelu) {
|
||||
return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
|
||||
.Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
|
||||
return hasRelu
|
||||
? llvm::Intrinsic::nvvm_bf16x2_to_e2m3x2_rn_relu_satfinite
|
||||
: llvm::Intrinsic::nvvm_bf16x2_to_e2m3x2_rn_satfinite;
|
||||
})
|
||||
.Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
|
||||
return hasRelu
|
||||
? llvm::Intrinsic::nvvm_bf16x2_to_e3m2x2_rn_relu_satfinite
|
||||
: llvm::Intrinsic::nvvm_bf16x2_to_e3m2x2_rn_satfinite;
|
||||
})
|
||||
.Default([](mlir::Type) {
|
||||
llvm_unreachable("Invalid conversion in ConvertBF16x2ToF6x2Op");
|
||||
return llvm::Intrinsic::not_intrinsic;
|
||||
});
|
||||
}
|
||||
|
||||
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
|
||||
has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
|
||||
: llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd
|
||||
@ -4287,22 +4388,39 @@ llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy,
|
||||
});
|
||||
}
|
||||
|
||||
#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \
|
||||
has_satf ? llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd##_satfinite \
|
||||
: llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd
|
||||
|
||||
llvm::Intrinsic::ID
|
||||
ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
|
||||
NVVM::SaturationMode sat) {
|
||||
ConvertBF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy,
|
||||
NVVM::FPRoundingMode rnd,
|
||||
NVVM::SaturationMode sat, bool hasRelu) {
|
||||
bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
|
||||
switch (rnd) {
|
||||
case NVVM::FPRoundingMode::RZ:
|
||||
return GET_BF16X2_TO_F8X2_ID(rz, hasSatFinite);
|
||||
case NVVM::FPRoundingMode::RP:
|
||||
return GET_BF16X2_TO_F8X2_ID(rp, hasSatFinite);
|
||||
default:
|
||||
llvm_unreachable("Invalid rounding mode for CvtBF16x2ToF8x2Op");
|
||||
}
|
||||
|
||||
static constexpr llvm::Intrinsic::ID ue8m0x2IDs[] = {
|
||||
llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rz,
|
||||
llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rp,
|
||||
llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rz_satfinite,
|
||||
llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rp_satfinite,
|
||||
};
|
||||
|
||||
return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
|
||||
.Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
|
||||
return hasRelu
|
||||
? llvm::Intrinsic::nvvm_bf16x2_to_e4m3x2_rn_relu_satfinite
|
||||
: llvm::Intrinsic::nvvm_bf16x2_to_e4m3x2_rn_satfinite;
|
||||
})
|
||||
.Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
|
||||
return hasRelu
|
||||
? llvm::Intrinsic::nvvm_bf16x2_to_e5m2x2_rn_relu_satfinite
|
||||
: llvm::Intrinsic::nvvm_bf16x2_to_e5m2x2_rn_satfinite;
|
||||
})
|
||||
.Case<mlir::Float8E8M0FNUType>([&](mlir::Float8E8M0FNUType) {
|
||||
bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
|
||||
unsigned index = (hasSatFinite << 1) | hasRoundingModeRP;
|
||||
return ue8m0x2IDs[index];
|
||||
})
|
||||
.Default([](mlir::Type) {
|
||||
llvm_unreachable("Invalid conversion in ConvertBF16x2ToF8x2Op");
|
||||
return llvm::Intrinsic::not_intrinsic;
|
||||
});
|
||||
}
|
||||
|
||||
NVVM::IDArgPair ConvertF8x2ToF16x2Op::getIntrinsicIDAndArgs(
|
||||
@ -4397,6 +4515,74 @@ NVVM::IDArgPair ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs(
|
||||
return {intId, {extendedI16}};
|
||||
}
|
||||
|
||||
NVVM::IDArgPair ConvertF32x2ToS2F6x2Op::getIntrinsicIDAndArgs(
|
||||
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
|
||||
auto thisOp = cast<NVVM::ConvertF32x2ToS2F6x2Op>(op);
|
||||
bool hasRelu = thisOp.getRelu();
|
||||
bool hasScale = static_cast<bool>(thisOp.getScaleFactor());
|
||||
|
||||
llvm::Intrinsic::ID id =
|
||||
hasRelu
|
||||
? llvm::Intrinsic::nvvm_ff_to_s2f6x2_rn_relu_satfinite_scale_n2_ue8m0
|
||||
: llvm::Intrinsic::nvvm_ff_to_s2f6x2_rn_satfinite_scale_n2_ue8m0;
|
||||
|
||||
// Fill the Intrinsic Args
|
||||
llvm::SmallVector<llvm::Value *> args;
|
||||
args.push_back(mt.lookupValue(thisOp.getA()));
|
||||
args.push_back(mt.lookupValue(thisOp.getB()));
|
||||
args.push_back(hasScale ? mt.lookupValue(thisOp.getScaleFactor())
|
||||
: builder.getInt16(0x7f7f));
|
||||
return {id, std::move(args)};
|
||||
}
|
||||
|
||||
NVVM::IDArgPair ConvertBF16x2ToS2F6x2Op::getIntrinsicIDAndArgs(
|
||||
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
|
||||
auto thisOp = cast<NVVM::ConvertBF16x2ToS2F6x2Op>(op);
|
||||
bool hasRelu = thisOp.getRelu();
|
||||
bool hasScale = static_cast<bool>(thisOp.getScaleFactor());
|
||||
|
||||
llvm::Intrinsic::ID id =
|
||||
hasRelu
|
||||
? llvm::Intrinsic::
|
||||
nvvm_bf16x2_to_s2f6x2_rn_relu_satfinite_scale_n2_ue8m0
|
||||
: llvm::Intrinsic::nvvm_bf16x2_to_s2f6x2_rn_satfinite_scale_n2_ue8m0;
|
||||
|
||||
// Fill the Intrinsic Args
|
||||
llvm::SmallVector<llvm::Value *> args;
|
||||
args.push_back(mt.lookupValue(thisOp.getSrc()));
|
||||
args.push_back(hasScale ? mt.lookupValue(thisOp.getScaleFactor())
|
||||
: builder.getInt16(0x7f7f));
|
||||
return {id, std::move(args)};
|
||||
}
|
||||
|
||||
NVVM::IDArgPair ConvertS2F6x2ToBF16x2Op::getIntrinsicIDAndArgs(
|
||||
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
|
||||
auto thisOp = cast<NVVM::ConvertS2F6x2ToBF16x2Op>(op);
|
||||
bool hasRelu = thisOp.getRelu();
|
||||
bool hasScale = static_cast<bool>(thisOp.getScaleFactor());
|
||||
bool hasSat = thisOp.getSat() == NVVM::SaturationMode::SATFINITE;
|
||||
|
||||
static constexpr llvm::Intrinsic::ID ids[] = {
|
||||
llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_scale_n2_ue8m0,
|
||||
llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_relu_scale_n2_ue8m0,
|
||||
llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_satfinite_scale_n2_ue8m0,
|
||||
llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_relu_satfinite_scale_n2_ue8m0,
|
||||
};
|
||||
|
||||
unsigned idx = (hasSat << 1) | hasRelu;
|
||||
|
||||
// Fill the Intrinsic Args
|
||||
llvm::SmallVector<llvm::Value *> args;
|
||||
llvm::Value *packedI16 =
|
||||
builder.CreateBitCast(mt.lookupValue(thisOp.getSrc()),
|
||||
llvm::Type::getInt16Ty(builder.getContext()));
|
||||
args.push_back(packedI16);
|
||||
args.push_back(hasScale ? mt.lookupValue(thisOp.getScaleFactor())
|
||||
: builder.getInt16(0x7f7f));
|
||||
|
||||
return {ids[idx], std::move(args)};
|
||||
}
|
||||
|
||||
llvm::Intrinsic::ID
|
||||
Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
|
||||
LLVM::ModuleTranslation &mt,
|
||||
|
||||
@ -64,7 +64,7 @@ llvm.func @convert_f32x2_to_f8x2_rs_not_supported(%a : f32, %b : f32) {
|
||||
// -----
|
||||
|
||||
llvm.func @convert_bf16x2_to_f8x2_rs_not_supported(%src : vector<2xbf16>) {
|
||||
// expected-error @below {{Only RZ and RP rounding modes are supported for conversions from bf16x2 to f8x2.}}
|
||||
// expected-error @below {{Only RZ and RP rounding modes are supported for conversions from bf16x2 to 'f8E8M0FNU' type}}
|
||||
%res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16> -> i16 (f8E8M0FNU)
|
||||
llvm.return
|
||||
}
|
||||
|
||||
@ -11,6 +11,34 @@ llvm.func @convert_f32x2_to_f4x2_e2m1(%srcA : f32, %srcB : f32) {
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @convert_f16x2_to_f4x2
|
||||
llvm.func @convert_f16x2_to_f4x2(%srcA : vector<2xf16>) {
|
||||
// CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.f16x2.to.e2m1x2.rn.satfinite(<2 x half> %{{.*}})
|
||||
// CHECK-NEXT: %{{.*}} = trunc i16 %[[res1]] to i8
|
||||
%res1 = nvvm.convert.f16x2.to.f4x2 %srcA : vector<2xf16> -> i8 (f4E2M1FN)
|
||||
// CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.f16x2.to.e2m1x2.rn.relu.satfinite(<2 x half> %{{.*}})
|
||||
// CHECK-NEXT: %{{.*}} = trunc i16 %[[res2]] to i8
|
||||
%res2 = nvvm.convert.f16x2.to.f4x2 %srcA {relu = true} : vector<2xf16> -> i8 (f4E2M1FN)
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @convert_bf16x2_to_f4x2
|
||||
llvm.func @convert_bf16x2_to_f4x2(%srcA : vector<2xbf16>) {
|
||||
// CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.bf16x2.to.e2m1x2.rn.satfinite(<2 x bfloat> %{{.*}})
|
||||
// CHECK-NEXT: %{{.*}} = trunc i16 %[[res1]] to i8
|
||||
%res1 = nvvm.convert.bf16x2.to.f4x2 %srcA : vector<2xbf16> -> i8 (f4E2M1FN)
|
||||
// CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.bf16x2.to.e2m1x2.rn.relu.satfinite(<2 x bfloat> %{{.*}})
|
||||
// CHECK-NEXT: %{{.*}} = trunc i16 %[[res2]] to i8
|
||||
%res2 = nvvm.convert.bf16x2.to.f4x2 %srcA {relu = true} : vector<2xbf16> -> i8 (f4E2M1FN)
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @convert_f4x2_to_f16x2
|
||||
llvm.func @convert_f4x2_to_f16x2(%src : i8) {
|
||||
// CHECK: %[[res1:.*]] = zext i8 %{{.*}} to i16
|
||||
|
||||
17
mlir/test/Target/LLVMIR/nvvm/convert_fp4x2_invalid.mlir
Normal file
17
mlir/test/Target/LLVMIR/nvvm/convert_fp4x2_invalid.mlir
Normal file
@ -0,0 +1,17 @@
|
||||
// RUN: mlir-translate -mlir-to-llvmir -verify-diagnostics %s
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @convert_f16x2_to_f4x2_invalid_type(%src : vector<2xf16>) {
|
||||
// expected-error @below {{attribute 'dstTy' failed to satisfy constraint: type attribute of f4E2M1FN type}}
|
||||
%res = nvvm.convert.f16x2.to.f4x2 %src : vector<2xf16> -> i8 (f8E4M3FN)
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @convert_bf16x2_to_f4x2_invalid_type(%src : vector<2xbf16>) {
|
||||
// expected-error @below {{attribute 'dstTy' failed to satisfy constraint: type attribute of f4E2M1FN type}}
|
||||
%res = nvvm.convert.bf16x2.to.f4x2 %src : vector<2xbf16> -> i8 (f8E4M3FN)
|
||||
llvm.return
|
||||
}
|
||||
@ -1,11 +1,20 @@
|
||||
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @convert_f32x2_to_fp6x2_packed
|
||||
llvm.func @convert_f32x2_to_fp6x2_packed(%srcA : f32, %srcB : f32) {
|
||||
// CHECK-LABEL: @convert_f32x2_to_fp6x2_e2m3
|
||||
llvm.func @convert_f32x2_to_fp6x2_e2m3(%srcA : f32, %srcB : f32) {
|
||||
//CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}})
|
||||
%res1 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : i16 (f6E2M3FN)
|
||||
//CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}})
|
||||
%res2 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB {relu = true} : i16 (f6E2M3FN)
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @convert_f32x2_to_fp6x2_e3m2
|
||||
llvm.func @convert_f32x2_to_fp6x2_e3m2(%srcA : f32, %srcB : f32) {
|
||||
//CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}})
|
||||
%res2 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : i16 (f6E3M2FN)
|
||||
%res1 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : i16 (f6E3M2FN)
|
||||
//CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}})
|
||||
%res2 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB {relu = true} : i16 (f6E3M2FN)
|
||||
llvm.return
|
||||
}
|
||||
|
||||
@ -22,6 +31,68 @@ llvm.func @convert_f32x2_to_fp6x2_vector(%srcA : f32, %srcB : f32) {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @convert_f16x2_to_fp6x2_e2m3
|
||||
llvm.func @convert_f16x2_to_fp6x2_e2m3(%srcA : vector<2xf16>) {
|
||||
// CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e2m3x2.rn.satfinite(<2 x half> %{{.*}})
|
||||
%res1 = nvvm.convert.f16x2.to.f6x2 %srcA : vector<2xf16> -> i16 (f6E2M3FN)
|
||||
// CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e2m3x2.rn.relu.satfinite(<2 x half> %{{.*}})
|
||||
%res2 = nvvm.convert.f16x2.to.f6x2 %srcA {relu = true} : vector<2xf16> -> i16 (f6E2M3FN)
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @convert_f16x2_to_fp6x2_e3m2
|
||||
llvm.func @convert_f16x2_to_fp6x2_e3m2(%srcA : vector<2xf16>) {
|
||||
// CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e3m2x2.rn.satfinite(<2 x half> %{{.*}})
|
||||
%res1 = nvvm.convert.f16x2.to.f6x2 %srcA : vector<2xf16> -> i16 (f6E3M2FN)
|
||||
// CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e3m2x2.rn.relu.satfinite(<2 x half> %{{.*}})
|
||||
%res2 = nvvm.convert.f16x2.to.f6x2 %srcA {relu = true} : vector<2xf16> -> i16 (f6E3M2FN)
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @convert_f16x2_to_fp6x2_vector
|
||||
llvm.func @convert_f16x2_to_fp6x2_vector(%srcA : vector<2xf16>) {
|
||||
// CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.f16x2.to.e2m3x2.rn.satfinite(<2 x half> %{{.*}})
|
||||
// CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
|
||||
%res1 = nvvm.convert.f16x2.to.f6x2 %srcA : vector<2xf16> -> vector<2xi8> (f6E2M3FN)
|
||||
// CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.f16x2.to.e3m2x2.rn.satfinite(<2 x half> %{{.*}})
|
||||
// CHECK-NEXT: %{{.*}} = bitcast i16 %[[res2]] to <2 x i8>
|
||||
%res2 = nvvm.convert.f16x2.to.f6x2 %srcA : vector<2xf16> -> vector<2xi8> (f6E3M2FN)
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @convert_bf16x2_to_fp6x2_e2m3
|
||||
llvm.func @convert_bf16x2_to_fp6x2_e2m3(%srcA : vector<2xbf16>, %scale_factor : i16) {
|
||||
// CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e2m3x2.rn.satfinite(<2 x bfloat> %{{.*}})
|
||||
%res1 = nvvm.convert.bf16x2.to.f6x2 %srcA : vector<2xbf16> -> i16 (f6E2M3FN)
|
||||
// CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e2m3x2.rn.relu.satfinite(<2 x bfloat> %{{.*}})
|
||||
%res2 = nvvm.convert.bf16x2.to.f6x2 %srcA {relu = true} : vector<2xbf16> -> i16 (f6E2M3FN)
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @convert_bf16x2_to_fp6x2_e3m2
|
||||
llvm.func @convert_bf16x2_to_fp6x2_e3m2(%srcA : vector<2xbf16>, %scale_factor : i16) {
|
||||
// CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e3m2x2.rn.satfinite(<2 x bfloat> %{{.*}})
|
||||
%res1 = nvvm.convert.bf16x2.to.f6x2 %srcA : vector<2xbf16> -> i16 (f6E3M2FN)
|
||||
// CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e3m2x2.rn.relu.satfinite(<2 x bfloat> %{{.*}})
|
||||
%res2 = nvvm.convert.bf16x2.to.f6x2 %srcA {relu = true} : vector<2xbf16> -> i16 (f6E3M2FN)
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @convert_bf16x2_to_fp6x2_vector
|
||||
llvm.func @convert_bf16x2_to_fp6x2_vector(%srcA : vector<2xbf16>, %scale_factor : i16) {
|
||||
// CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e2m3x2.rn.satfinite(<2 x bfloat> %{{.*}})
|
||||
// CHECK-NEXT: %{{.*}} = bitcast i16 %{{.*}} to <2 x i8>
|
||||
%res1 = nvvm.convert.bf16x2.to.f6x2 %srcA : vector<2xbf16> -> vector<2xi8> (f6E2M3FN)
|
||||
// CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e3m2x2.rn.satfinite(<2 x bfloat> %{{.*}})
|
||||
// CHECK-NEXT: %{{.*}} = bitcast i16 %{{.*}} to <2 x i8>
|
||||
%res2 = nvvm.convert.bf16x2.to.f6x2 %srcA : vector<2xbf16> -> vector<2xi8> (f6E3M2FN)
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @convert_f6x2_to_f16x2_e2m3
|
||||
llvm.func @convert_f6x2_to_f16x2_e2m3(%src : vector<2xi8>) {
|
||||
// CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16
|
||||
|
||||
17
mlir/test/Target/LLVMIR/nvvm/convert_fp6x2_invalid.mlir
Normal file
17
mlir/test/Target/LLVMIR/nvvm/convert_fp6x2_invalid.mlir
Normal file
@ -0,0 +1,17 @@
|
||||
// RUN: mlir-translate -mlir-to-llvmir -verify-diagnostics %s
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @convert_f16x2_to_f6x2_invalid_type(%src : vector<2xf16>) {
|
||||
// expected-error @below {{attribute 'dstTy' failed to satisfy constraint: type attribute of f6E2M3FN type or f6E3M2FN type}}
|
||||
%res = nvvm.convert.f16x2.to.f6x2 %src : vector<2xf16> -> vector<2xi8> (f8E4M3FN)
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @convert_bf16x2_to_f6x2_invalid_type(%src : vector<2xbf16>) {
|
||||
// expected-error @below {{attribute 'dstTy' failed to satisfy constraint: type attribute of f6E2M3FN type or f6E3M2FN type}}
|
||||
%res = nvvm.convert.bf16x2.to.f6x2 %src : vector<2xbf16> -> vector<2xi8> (f8E4M3FN)
|
||||
llvm.return
|
||||
}
|
||||
@ -90,6 +90,25 @@ llvm.func @convert_bf16x2_to_f8x2_ue8m0(%src : vector<2xbf16>) {
|
||||
llvm.return
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: @convert_bf16x2_to_f8x2_e4m3
|
||||
llvm.func @convert_bf16x2_to_f8x2_e4m3(%srcA : vector<2xbf16>) {
|
||||
// CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e4m3x2.rn.satfinite(<2 x bfloat> %{{.*}})
|
||||
%res1 = nvvm.convert.bf16x2.to.f8x2 %srcA {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> i16 (f8E4M3FN)
|
||||
// CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e4m3x2.rn.relu.satfinite(<2 x bfloat> %{{.*}})
|
||||
%res2 = nvvm.convert.bf16x2.to.f8x2 %srcA {relu = true, rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> i16 (f8E4M3FN)
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @convert_bf16x2_to_f8x2_e5m2
|
||||
llvm.func @convert_bf16x2_to_f8x2_e5m2(%srcA : vector<2xbf16>) {
|
||||
// CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e5m2x2.rn.satfinite(<2 x bfloat> %{{.*}})
|
||||
%res1 = nvvm.convert.bf16x2.to.f8x2 %srcA {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> vector<2xi8> (f8E5M2)
|
||||
// CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e5m2x2.rn.relu.satfinite(<2 x bfloat> %{{.*}})
|
||||
%res2 = nvvm.convert.bf16x2.to.f8x2 %srcA {relu = true, rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> vector<2xi8> (f8E5M2)
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @convert_bf16x2_to_f8x2_vector_return
|
||||
llvm.func @convert_bf16x2_to_f8x2_vector_return(%src : vector<2xbf16>) {
|
||||
// CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rz(<2 x bfloat> %{{.*}})
|
||||
@ -98,6 +117,12 @@ llvm.func @convert_bf16x2_to_f8x2_vector_return(%src : vector<2xbf16>) {
|
||||
// CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rp.satfinite(<2 x bfloat> %{{.*}})
|
||||
// CHECK-NEXT: %{{.*}} = bitcast i16 %[[res2]] to <2 x i8>
|
||||
%res2 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rp>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> vector<2xi8> (f8E8M0FNU)
|
||||
// CHECK: %[[res3:.*]] = call i16 @llvm.nvvm.bf16x2.to.e4m3x2.rn.satfinite(<2 x bfloat> %{{.*}})
|
||||
// CHECK-NEXT: %{{.*}} = bitcast i16 %[[res3]] to <2 x i8>
|
||||
%res3 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> vector<2xi8> (f8E4M3FN)
|
||||
// CHECK: %[[res4:.*]] = call i16 @llvm.nvvm.bf16x2.to.e5m2x2.rn.relu.satfinite(<2 x bfloat> %{{.*}})
|
||||
// CHECK-NEXT: %{{.*}} = bitcast i16 %[[res4]] to <2 x i8>
|
||||
%res4 = nvvm.convert.bf16x2.to.f8x2 %src {relu = true, rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> vector<2xi8> (f8E5M2)
|
||||
llvm.return
|
||||
}
|
||||
|
||||
|
||||
41
mlir/test/Target/LLVMIR/nvvm/convert_fp8x2_invalid.mlir
Normal file
41
mlir/test/Target/LLVMIR/nvvm/convert_fp8x2_invalid.mlir
Normal file
@ -0,0 +1,41 @@
|
||||
// RUN: mlir-translate -mlir-to-llvmir -verify-diagnostics %s
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @convert_bf16x2_to_f8x2_invalid_type(%src : vector<2xbf16>) {
|
||||
// expected-error @below {{attribute 'dstTy' failed to satisfy constraint: type attribute of f8E8M0FNU type or f8E4M3FN type or f8E5M2 type}}
|
||||
%res = nvvm.convert.bf16x2.to.f8x2 %src : vector<2xbf16> -> vector<2xi8> (f6E2M3FN)
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @convert_bf16x2_to_f8x2_invalid_rounding_1(%src : vector<2xbf16>) {
|
||||
// expected-error @below {{Only RN rounding mode is supported for conversions from bf16x2 to 'f8E4M3FN' and 'f8E5M2' types}}
|
||||
%res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rm>} : vector<2xbf16> -> vector<2xi8> (f8E4M3FN)
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_rounding_2(%src : vector<2xbf16>) {
|
||||
// expected-error @below {{Only RZ and RP rounding modes are supported for conversions from bf16x2 to 'f8E8M0FNU' type}}
|
||||
%res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16> -> vector<2xi8> (f8E8M0FNU)
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @convert_bf16x2_to_f8x2_invalid_sat_mode(%src : vector<2xbf16>) {
|
||||
// expected-error @below {{Only SATFINITE saturation mode is supported for conversions from bf16x2 to 'f8E4M3FN' and 'f8E5M2' types}}
|
||||
%res = nvvm.convert.bf16x2.to.f8x2 %src {sat = #nvvm.sat_mode<none>, rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16> -> vector<2xi8> (f8E4M3FN)
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @convert_bf16x2_to_f8x2_invalid_relu(%src : vector<2xbf16>) {
|
||||
// expected-error @below {{relu not supported for conversions to 'f8E8M0FNU' type}}
|
||||
%res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rp>, relu = true} : vector<2xbf16> -> vector<2xi8> (f8E8M0FNU)
|
||||
llvm.return
|
||||
}
|
||||
189
mlir/test/Target/LLVMIR/nvvm/convert_s2f6x2.mlir
Normal file
189
mlir/test/Target/LLVMIR/nvvm/convert_s2f6x2.mlir
Normal file
@ -0,0 +1,189 @@
|
||||
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
|
||||
|
||||
llvm.func @convert_f32x2_to_s2f6x2(%srcA : f32, %srcB : f32) -> i16 {
|
||||
// CHECK-LABEL: define i16 @convert_f32x2_to_s2f6x2(float %0, float %1) {
|
||||
// CHECK-NEXT: %3 = call i16 @llvm.nvvm.ff.to.s2f6x2.rn.satfinite.scale.n2.ue8m0(float %0, float %1, i16 32639)
|
||||
// CHECK-NEXT: %4 = call i16 @llvm.nvvm.ff.to.s2f6x2.rn.relu.satfinite.scale.n2.ue8m0(float %0, float %1, i16 32639)
|
||||
// CHECK-NEXT: %5 = or i16 %3, %4
|
||||
// CHECK-NEXT: ret i16 %5
|
||||
// CHECK-NEXT: }
|
||||
%res1 = nvvm.convert.f32x2.to.s2f6x2 %srcA, %srcB : i16
|
||||
%res2 = nvvm.convert.f32x2.to.s2f6x2 %srcA, %srcB {relu = true} : i16
|
||||
|
||||
// Combine results to avoid dead code elimination
|
||||
%final_result = llvm.or %res1, %res2 : i16
|
||||
llvm.return %final_result : i16
|
||||
}
|
||||
|
||||
llvm.func @convert_f32x2_to_s2f6x2_scale(%srcA : f32, %srcB : f32, %scale : i16) -> i16 {
|
||||
// CHECK-LABEL: define i16 @convert_f32x2_to_s2f6x2_scale(float %0, float %1, i16 %2) {
|
||||
// CHECK-NEXT: %4 = call i16 @llvm.nvvm.ff.to.s2f6x2.rn.satfinite.scale.n2.ue8m0(float %0, float %1, i16 %2)
|
||||
// CHECK-NEXT: %5 = call i16 @llvm.nvvm.ff.to.s2f6x2.rn.relu.satfinite.scale.n2.ue8m0(float %0, float %1, i16 %2)
|
||||
// CHECK-NEXT: %6 = or i16 %4, %5
|
||||
// CHECK-NEXT: ret i16 %6
|
||||
// CHECK-NEXT: }
|
||||
%res1 = nvvm.convert.f32x2.to.s2f6x2 %srcA, %srcB, %scale : i16
|
||||
%res2 = nvvm.convert.f32x2.to.s2f6x2 %srcA, %srcB, %scale {relu = true} : i16
|
||||
|
||||
// Combine results to avoid dead code elimination
|
||||
%final_result = llvm.or %res1, %res2 : i16
|
||||
llvm.return %final_result : i16
|
||||
}
|
||||
|
||||
llvm.func @convert_f32x2_to_s2f6x2_vector(%srcA : f32, %srcB : f32) -> vector<2xi8> {
|
||||
// CHECK-LABEL: define <2 x i8> @convert_f32x2_to_s2f6x2_vector(float %0, float %1) {
|
||||
// CHECK-NEXT: %3 = call i16 @llvm.nvvm.ff.to.s2f6x2.rn.satfinite.scale.n2.ue8m0(float %0, float %1, i16 32639)
|
||||
// CHECK-NEXT: %4 = bitcast i16 %3 to <2 x i8>
|
||||
// CHECK-NEXT: ret <2 x i8> %4
|
||||
// CHECK-NEXT: }
|
||||
%res1 = nvvm.convert.f32x2.to.s2f6x2 %srcA, %srcB : vector<2xi8>
|
||||
llvm.return %res1 : vector<2xi8>
|
||||
}
|
||||
|
||||
llvm.func @convert_f32x2_to_s2f6x2_vector_scale(%srcA : f32, %srcB : f32, %scale : i16) -> vector<2xi8> {
|
||||
// CHECK-LABEL: define <2 x i8> @convert_f32x2_to_s2f6x2_vector_scale(float %0, float %1, i16 %2) {
|
||||
// CHECK-NEXT: %4 = call i16 @llvm.nvvm.ff.to.s2f6x2.rn.satfinite.scale.n2.ue8m0(float %0, float %1, i16 %2)
|
||||
// CHECK-NEXT: %5 = bitcast i16 %4 to <2 x i8>
|
||||
// CHECK-NEXT: ret <2 x i8> %5
|
||||
// CHECK-NEXT: }
|
||||
%res1 = nvvm.convert.f32x2.to.s2f6x2 %srcA, %srcB, %scale : vector<2xi8>
|
||||
llvm.return %res1 : vector<2xi8>
|
||||
}
|
||||
|
||||
llvm.func @convert_bf16x2_to_s2f6x2(%srcA : vector<2xbf16>) -> i16 {
|
||||
// CHECK-LABEL: define i16 @convert_bf16x2_to_s2f6x2(<2 x bfloat> %0) {
|
||||
// CHECK-NEXT: %2 = call i16 @llvm.nvvm.bf16x2.to.s2f6x2.rn.satfinite.scale.n2.ue8m0(<2 x bfloat> %0, i16 32639)
|
||||
// CHECK-NEXT: %3 = call i16 @llvm.nvvm.bf16x2.to.s2f6x2.rn.relu.satfinite.scale.n2.ue8m0(<2 x bfloat> %0, i16 32639)
|
||||
// CHECK-NEXT: %4 = or i16 %2, %3
|
||||
// CHECK-NEXT: ret i16 %4
|
||||
// CHECK-NEXT: }
|
||||
%res1 = nvvm.convert.bf16x2.to.s2f6x2 %srcA : vector<2xbf16> -> i16
|
||||
%res2 = nvvm.convert.bf16x2.to.s2f6x2 %srcA {relu = true} : vector<2xbf16> -> i16
|
||||
|
||||
// Combine results to avoid dead code elimination
|
||||
%final_result = llvm.or %res1, %res2 : i16
|
||||
llvm.return %final_result : i16
|
||||
}
|
||||
|
||||
llvm.func @convert_bf16x2_to_s2f6x2_scale(%srcA : vector<2xbf16>, %scale : i16) -> i16 {
|
||||
// CHECK-LABEL: define i16 @convert_bf16x2_to_s2f6x2_scale(<2 x bfloat> %0, i16 %1) {
|
||||
// CHECK-NEXT: %3 = call i16 @llvm.nvvm.bf16x2.to.s2f6x2.rn.satfinite.scale.n2.ue8m0(<2 x bfloat> %0, i16 %1)
|
||||
// CHECK-NEXT: %4 = call i16 @llvm.nvvm.bf16x2.to.s2f6x2.rn.relu.satfinite.scale.n2.ue8m0(<2 x bfloat> %0, i16 %1)
|
||||
// CHECK-NEXT: %5 = or i16 %3, %4
|
||||
// CHECK-NEXT: ret i16 %5
|
||||
// CHECK-NEXT: }
|
||||
%res1 = nvvm.convert.bf16x2.to.s2f6x2 %srcA, %scale : vector<2xbf16> -> i16
|
||||
%res2 = nvvm.convert.bf16x2.to.s2f6x2 %srcA, %scale {relu = true} : vector<2xbf16> -> i16
|
||||
|
||||
// Combine results to avoid dead code elimination
|
||||
%final_result = llvm.or %res1, %res2 : i16
|
||||
llvm.return %final_result : i16
|
||||
}
|
||||
|
||||
llvm.func @convert_bf16x2_to_s2f6x2_vector(%srcA : vector<2xbf16>) -> vector<2xi8> {
|
||||
// CHECK-LABEL: define <2 x i8> @convert_bf16x2_to_s2f6x2_vector(<2 x bfloat> %0) {
|
||||
// CHECK-NEXT: %2 = call i16 @llvm.nvvm.bf16x2.to.s2f6x2.rn.satfinite.scale.n2.ue8m0(<2 x bfloat> %0, i16 32639)
|
||||
// CHECK-NEXT: %3 = bitcast i16 %2 to <2 x i8>
|
||||
// CHECK-NEXT: ret <2 x i8> %3
|
||||
// CHECK-NEXT: }
|
||||
%res1 = nvvm.convert.bf16x2.to.s2f6x2 %srcA : vector<2xbf16> -> vector<2xi8>
|
||||
llvm.return %res1 : vector<2xi8>
|
||||
}
|
||||
|
||||
llvm.func @convert_bf16x2_to_s2f6x2_vector_scale(%srcA : vector<2xbf16>, %scale : i16) -> vector<2xi8> {
|
||||
// CHECK-LABEL: define <2 x i8> @convert_bf16x2_to_s2f6x2_vector_scale(<2 x bfloat> %0, i16 %1) {
|
||||
// CHECK-NEXT: %3 = call i16 @llvm.nvvm.bf16x2.to.s2f6x2.rn.satfinite.scale.n2.ue8m0(<2 x bfloat> %0, i16 %1)
|
||||
// CHECK-NEXT: %4 = bitcast i16 %3 to <2 x i8>
|
||||
// CHECK-NEXT: ret <2 x i8> %4
|
||||
// CHECK-NEXT: }
|
||||
%res1 = nvvm.convert.bf16x2.to.s2f6x2 %srcA, %scale : vector<2xbf16> -> vector<2xi8>
|
||||
llvm.return %res1 : vector<2xi8>
|
||||
}
|
||||
|
||||
// 1. no relu, no scale, no satfinite
|
||||
llvm.func @convert_s2f6x2_to_bf16x2(%src : vector<2xi8>) -> vector<2xbf16> {
|
||||
// CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2(<2 x i8> %0) {
|
||||
// CHECK-NEXT: %2 = bitcast <2 x i8> %0 to i16
|
||||
// CHECK-NEXT: %3 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.scale.n2.ue8m0(i16 %2, i16 32639)
|
||||
// CHECK-NEXT: ret <2 x bfloat> %3
|
||||
// CHECK-NEXT: }
|
||||
%res = nvvm.convert.s2f6x2.to.bf16x2 %src : vector<2xi8> -> vector<2xbf16>
|
||||
llvm.return %res : vector<2xbf16>
|
||||
}
|
||||
|
||||
// 2. relu, no scale, no satfinite
|
||||
llvm.func @convert_s2f6x2_to_bf16x2_relu(%src : vector<2xi8>) -> vector<2xbf16> {
|
||||
// CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_relu(<2 x i8> %0) {
|
||||
// CHECK-NEXT: %2 = bitcast <2 x i8> %0 to i16
|
||||
// CHECK-NEXT: %3 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.relu.scale.n2.ue8m0(i16 %2, i16 32639)
|
||||
// CHECK-NEXT: ret <2 x bfloat> %3
|
||||
// CHECK-NEXT: }
|
||||
%res = nvvm.convert.s2f6x2.to.bf16x2 %src {relu = true} : vector<2xi8> -> vector<2xbf16>
|
||||
llvm.return %res : vector<2xbf16>
|
||||
}
|
||||
|
||||
// 3. no relu, with scale, no satfinite
|
||||
llvm.func @convert_s2f6x2_to_bf16x2_scale(%src : vector<2xi8>, %scale : i16) -> vector<2xbf16> {
|
||||
// CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_scale(<2 x i8> %0, i16 %1) {
|
||||
// CHECK-NEXT: %3 = bitcast <2 x i8> %0 to i16
|
||||
// CHECK-NEXT: %4 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.scale.n2.ue8m0(i16 %3, i16 %1)
|
||||
// CHECK-NEXT: ret <2 x bfloat> %4
|
||||
// CHECK-NEXT: }
|
||||
%res = nvvm.convert.s2f6x2.to.bf16x2 %src, %scale : vector<2xi8> -> vector<2xbf16>
|
||||
llvm.return %res : vector<2xbf16>
|
||||
}
|
||||
|
||||
// 4. relu, with scale, no satfinite
|
||||
llvm.func @convert_s2f6x2_to_bf16x2_scale_relu(%src : vector<2xi8>, %scale : i16) -> vector<2xbf16> {
|
||||
// CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_scale_relu(<2 x i8> %0, i16 %1) {
|
||||
// CHECK-NEXT: %3 = bitcast <2 x i8> %0 to i16
|
||||
// CHECK-NEXT: %4 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.relu.scale.n2.ue8m0(i16 %3, i16 %1)
|
||||
// CHECK-NEXT: ret <2 x bfloat> %4
|
||||
// CHECK-NEXT: }
|
||||
%res = nvvm.convert.s2f6x2.to.bf16x2 %src, %scale {relu = true} : vector<2xi8> -> vector<2xbf16>
|
||||
llvm.return %res : vector<2xbf16>
|
||||
}
|
||||
|
||||
// 5. no relu, no scale, satfinite
|
||||
llvm.func @convert_s2f6x2_to_bf16x2_satfinite(%src : vector<2xi8>) -> vector<2xbf16> {
|
||||
// CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_satfinite(<2 x i8> %0) {
|
||||
// CHECK-NEXT: %2 = bitcast <2 x i8> %0 to i16
|
||||
// CHECK-NEXT: %3 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.satfinite.scale.n2.ue8m0(i16 %2, i16 32639)
|
||||
// CHECK-NEXT: ret <2 x bfloat> %3
|
||||
// CHECK-NEXT: }
|
||||
%res = nvvm.convert.s2f6x2.to.bf16x2 %src {sat = #nvvm.sat_mode<satfinite>} : vector<2xi8> -> vector<2xbf16>
|
||||
llvm.return %res : vector<2xbf16>
|
||||
}
|
||||
|
||||
// 6. relu, no scale, satfinite
|
||||
llvm.func @convert_s2f6x2_to_bf16x2_relu_satfinite(%src : vector<2xi8>) -> vector<2xbf16> {
|
||||
// CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_relu_satfinite(<2 x i8> %0) {
|
||||
// CHECK-NEXT: %2 = bitcast <2 x i8> %0 to i16
|
||||
// CHECK-NEXT: %3 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.relu.satfinite.scale.n2.ue8m0(i16 %2, i16 32639)
|
||||
// CHECK-NEXT: ret <2 x bfloat> %3
|
||||
// CHECK-NEXT: }
|
||||
%res = nvvm.convert.s2f6x2.to.bf16x2 %src {relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xi8> -> vector<2xbf16>
|
||||
llvm.return %res : vector<2xbf16>
|
||||
}
|
||||
|
||||
// 7. no relu, with scale, satfinite
|
||||
llvm.func @convert_s2f6x2_to_bf16x2_scale_satfinite(%src : vector<2xi8>, %scale : i16) -> vector<2xbf16> {
|
||||
// CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_scale_satfinite(<2 x i8> %0, i16 %1) {
|
||||
// CHECK-NEXT: %3 = bitcast <2 x i8> %0 to i16
|
||||
// CHECK-NEXT: %4 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.satfinite.scale.n2.ue8m0(i16 %3, i16 %1)
|
||||
// CHECK-NEXT: ret <2 x bfloat> %4
|
||||
// CHECK-NEXT: }
|
||||
%res = nvvm.convert.s2f6x2.to.bf16x2 %src, %scale {sat = #nvvm.sat_mode<satfinite>} : vector<2xi8> -> vector<2xbf16>
|
||||
llvm.return %res : vector<2xbf16>
|
||||
}
|
||||
|
||||
// 8. relu, with scale, satfinite
|
||||
llvm.func @convert_s2f6x2_to_bf16x2_scale_relu_satfinite(%src : vector<2xi8>, %scale : i16) -> vector<2xbf16> {
|
||||
// CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_scale_relu_satfinite(<2 x i8> %0, i16 %1) {
|
||||
// CHECK-NEXT: %3 = bitcast <2 x i8> %0 to i16
|
||||
// CHECK-NEXT: %4 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.relu.satfinite.scale.n2.ue8m0(i16 %3, i16 %1)
|
||||
// CHECK-NEXT: ret <2 x bfloat> %4
|
||||
// CHECK-NEXT: }
|
||||
%res = nvvm.convert.s2f6x2.to.bf16x2 %src, %scale {relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xi8> -> vector<2xbf16>
|
||||
llvm.return %res : vector<2xbf16>
|
||||
}
|
||||
@ -209,22 +209,6 @@ llvm.func @nvvm_cvt_f16x2_to_f8x2_invalid_type(%src : vector<2xf16>) {
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_type(%src : vector<2xbf16>) {
|
||||
// expected-error @below {{Only 'f8E8M0FNU' type is supported for conversions from bf16x2 to f8x2.}}
|
||||
%res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16> -> i16 (f8E4M3FN)
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_rounding(%src : vector<2xbf16>) {
|
||||
// expected-error @below {{Only RZ and RP rounding modes are supported for conversions from bf16x2 to f8x2.}}
|
||||
%res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16> -> i16 (f8E8M0FNU)
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @nvvm_cvt_f32x2_to_f6x2_invalid_type(%a : f32, %b : f32) {
|
||||
// expected-error @below {{Only 'f6E2M3FN' and 'f6E3M2FN' types are supported for conversions from f32x2 to f6x2.}}
|
||||
%res = nvvm.convert.f32x2.to.f6x2 %a, %b : i16 (f8E8M0FNU)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user