diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 82fbdf8e2996..8153111f3cf5 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -1972,6 +1972,44 @@ def NVVM_ConvertF32x2ToF4x2Op : NVVM_Op<"convert.f32x2.to.f4x2"> { }]; } +class NVVM_ConvertFPx2ToF4x2Op + : NVVM_Op<"convert."#!tolower(srcType)#"x2.to.f4x2"> { + let summary = "Convert an " # !tolower(srcType) # "x2 input to f4x2"; + let description = [{ + This Op converts each of the given }]#srcType#[{ inputs in an }]#!tolower + (srcType)#[{x2 vector to the specified fp4 type. + The result `dst` is returned as an i8 type where the converted values are + packed such that the value converted from the first element of `a` is + stored in the lower 4 bits of `dst` and the value converted from the second + element of `a` is stored in the upper 4 bits of `dst`. + The `relu` attribute, when set, lowers to the '.relu' variant of + the cvt instruction. + }]; + let results = (outs I8:$dst); + let arguments = (ins + VectorOfLengthAndType<[2], [!cast(srcType)]>:$src, + DefaultValuedAttr:$relu, + TypeAttrOf:$dstTy); + + let assemblyFormat = + "$src attr-dict `:` type($src) `->` type($dst) `(` $dstTy `)`"; + + let extraClassDeclaration = [{ + static NVVM::IDArgPair + getIntrinsicIDAndArgs(NVVM::Convert}]#srcType#[{x2ToF4x2Op &op, + LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder); + }]; + + string llvmBuilder = [{ + auto [intId, args] = NVVM::Convert}]#srcType#[{x2ToF4x2Op::getIntrinsicIDAndArgs(op, moduleTranslation, builder); + llvm::Value *packedI16 = createIntrinsicCall(builder, intId, args); + $dst = builder.CreateTruncOrBitCast(packedI16, llvm::Type::getInt8Ty(builder.getContext())); + }]; +} + +def NVVM_ConvertF16x2ToF4x2Op : NVVM_ConvertFPx2ToF4x2Op<"F16">; +def NVVM_ConvertBF16x2ToF4x2Op : NVVM_ConvertFPx2ToF4x2Op<"BF16">; + def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> { let summary = "Convert a pair of float inputs to f6x2"; let description = [{ @@ -2015,6 +2053,48 @@ def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> { }]; } +class NVVM_ConvertFPx2ToF6x2Op + : NVVM_Op<"convert."#!tolower(srcType)#"x2.to.f6x2"> { + let summary = "Convert an " # !tolower(srcType) # "x2 input to f6x2"; + let description = [{ + This Op converts each of the given }]#srcType#[{ inputs in an }]#!tolower + (srcType)#[{x2 vector to the specified fp6 type. The result `dst` is + represented either as an i16 type or as a vector of two i8 types. + If `dst` is returned as an i16 type, the converted values are packed such + that the value converted from the first element of `a` is stored in the + lower 8 bits of `dst` with 2 MSB bits padded with zeros and the value + converted from the second element of `a` is stored in the upper 8 bits of + `dst` with 2 MSB bits padded with zeros. + If `dst` is returned as a vector type, each converted value is stored as an + i8 element in the vector with 2 MSB bits padded with zeros. + The `relu` attribute, when set, lowers to the '.relu' variant of + the cvt instruction. + }]; + + let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst); + let arguments = (ins + VectorOfLengthAndType<[2], [!cast(srcType)]>:$src, + DefaultValuedAttr:$relu, + TypeAttrOf>:$dstTy); + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($dst) `(` $dstTy `)`"; + let extraClassDeclaration = [{ + static llvm::Intrinsic::ID getIntrinsicID(mlir::Type dstTy, bool hasRelu); + }]; + + string llvmBuilder = [{ + auto intId = NVVM::Convert}]#srcType#[{x2ToF6x2Op::getIntrinsicID($dstTy, $relu); + llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$src}); + if(op.getDst().getType().isInteger(16)) + $dst = packedI16; + else + $dst = builder.CreateBitCast(packedI16, + llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2)); + }]; +} + +def NVVM_ConvertF16x2ToF6x2Op : NVVM_ConvertFPx2ToF6x2Op<"F16">; +def NVVM_ConvertBF16x2ToF6x2Op : NVVM_ConvertFPx2ToF6x2Op<"BF16">; + def NVVM_ConvertF32x2ToF8x2Op : NVVM_Op<"convert.f32x2.to.f8x2"> { let summary = "Convert a pair of float inputs to f8x2"; let description = [{ @@ -2109,13 +2189,12 @@ def NVVM_ConvertBF16x2ToF8x2Op : NVVM_Op<"convert.bf16x2.to.f8x2"> { let summary = "Convert a pair of bf16 inputs to f8x2"; let description = [{ This Op converts the given bf16 inputs in a bf16x2 vector to the specified - f8 type. - The result `dst` is represented as an i16 type or as a vector - of two i8 types. - If `dst` is returned as an i16 type, the converted values from `a` - are packed such that the value converted from the first element of `a` - is stored in the upper 8 bits of `dst` and the value converted from the - second element of `a` is stored in the lower 8 bits of `dst`. + f8 type. The result `dst` is represented either as a packed i16 type or as + a vector of two i8 types. + If `dst` is returned as an i16 type, the converted values are packed such + that the value converted from the first element of `a` is stored in the + lower 8 bits of `dst` and the value converted from the second element of + `a` is stored in the upper 8 bits of `dst`. If `dst` is returned as a vector type, each converted value is stored as an i8 element in the vector. The `rnd` and `sat` attributes specify the rounding and saturation modes @@ -2127,20 +2206,22 @@ def NVVM_ConvertBF16x2ToF8x2Op : NVVM_Op<"convert.bf16x2.to.f8x2"> { let hasVerifier = 1; let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst); let arguments = (ins - VectorOfLengthAndType<[2], [BF16]>:$a, + VectorOfLengthAndType<[2], [BF16]>:$src, DefaultValuedAttr:$rnd, DefaultValuedAttr:$sat, - TypeAttr:$dstTy); - let assemblyFormat = "$a attr-dict `:` type($a) `->` type($dst) `(` $dstTy `)`"; + DefaultValuedAttr:$relu, + TypeAttrOf>:$dstTy); + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($dst) `(` $dstTy `)`"; let extraClassDeclaration = [{ - static llvm::Intrinsic::ID getIntrinsicID(NVVM::FPRoundingMode rnd, - NVVM::SaturationMode sat); + static llvm::Intrinsic::ID getIntrinsicID(mlir::Type dstTy, + NVVM::FPRoundingMode rnd, + NVVM::SaturationMode sat, bool hasRelu); }]; string llvmBuilder = [{ - auto intId = NVVM::ConvertBF16x2ToF8x2Op::getIntrinsicID($rnd, $sat); - llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a}); + auto intId = NVVM::ConvertBF16x2ToF8x2Op::getIntrinsicID($dstTy, $rnd, $sat, $relu); + llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$src}); if(op.getDst().getType().isInteger(16)) $dst = packedI16; else @@ -2185,6 +2266,117 @@ def NVVM_ConvertF6x2ToF16x2Op : def NVVM_ConvertF4x2ToF16x2Op : NVVM_ConvertToFP16x2Op_Base<"F4", I8, "F16">; +def NVVM_ConvertF32x2ToS2F6x2Op : NVVM_Op<"convert.f32x2.to.s2f6x2"> { + let summary = "Convert a pair of f32 inputs to S2F6x2"; + let description = [{ + This Op converts each of the given f32 inputs to the + S2F6x2 type. The result `dst` can be either a packed i16 type or a vector + of two i8 types. + If `dst` is returned as an i16 type, the converted values are packed such + that the value converted from `a` is stored in the upper 8 bits of `dst` + and the value converted from `b` is stored in the lower 8 bits of `dst`. + If `dst` is returned as a vector type, each converted value is stored as an + i8 element in the vector. + The `relu` attribute, when set, lowers to the '.relu' variant + of the cvt instruction. + The optional scaling-factors for each of the inputs are provided through + the operand `scaleFactor` as a packed i16 type. Only `ue8m0` is supported + as the type of the scale-factor currently. + + [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) + }]; + + let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst); + let arguments = (ins F32:$a, F32:$b, + Optional:$scaleFactor, + DefaultValuedAttr:$relu); + let assemblyFormat = + "$a `,` $b (`,` $scaleFactor^)? attr-dict `:` type($dst)"; + let extraClassDeclaration = [{ + static IDArgPair getIntrinsicIDAndArgs(Operation &op, + LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder); + }]; + + string llvmBuilder = [{ + auto [id, args] = NVVM::ConvertF32x2ToS2F6x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder); + llvm::Value *packedI16 = createIntrinsicCall(builder, id, args); + if(op.getDst().getType().isInteger(16)) + $dst = packedI16; + else + $dst = builder.CreateBitCast(packedI16, + llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2)); + }]; +} + +def NVVM_ConvertBF16x2ToS2F6x2Op : NVVM_Op<"convert.bf16x2.to.s2f6x2"> { + let summary = "Convert a pair of BF16 inputs to S2F6x2"; + let description = [{ + This Op converts each of the given BF16 inputs in a bf16x2 vector to the + S2F6x2 type. The result `dst` can be either a packed i16 type or a vector + of two i8 types. + If `dst` is returned as an i16 type, the converted values are packed such + that the value converted from the first element of `a` is stored in the + lower 8 bits of `dst` and the value converted from the second element of + `a` is stored in the upper 8 bits of `dst`. + If `dst` is returned as a vector type, each converted value is stored as an + i8 element in the vector. + The `relu` attribute, when set, lowers to the '.relu' variant + of the cvt instruction. + The optional scaling-factors for each of the inputs are provided through + the operand `scaleFactor` as a packed i16 type. Only `ue8m0` is supported + as the type of the scale-factor currently. + + [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) + }]; + + let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst); + let arguments = (ins + VectorOfLengthAndType<[2], [BF16]>:$src, + Optional:$scaleFactor, + DefaultValuedAttr:$relu); + let assemblyFormat = + "$src (`,` $scaleFactor^)? attr-dict `:` type($src) `->` type($dst)"; + let extraClassDeclaration = [{ + static IDArgPair getIntrinsicIDAndArgs(Operation &op, + LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder); + }]; + + string llvmBuilder = [{ + auto [id, args] = NVVM::ConvertBF16x2ToS2F6x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder); + llvm::Value *packedI16 = createIntrinsicCall(builder, id, args); + if(op.getDst().getType().isInteger(16)) + $dst = packedI16; + else + $dst = builder.CreateBitCast(packedI16, + llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2)); + }]; +} + +def NVVM_ConvertS2F6x2ToBF16x2Op : NVVM_SingleResultIntrinsicOp<"convert.s2f6x2.to.bf16x2", [], "$dst"> { + let summary = "Convert s2f6x2 to bf16x2"; + let description = [{ + This Op converts a pair of s2f6x2 inputs to bf16x2 type. The result `dst` + is represented as a vector of two bf16 elements. + + The `relu` attribute, when set, lowers to the '.relu' variant + of the cvt instruction. + + The optional scaling-factors for each of the inputs are provided through + the operand `scaleFactor` as a packed i16 type. Only `ue8m0` is supported + as the type of the scale-factor currently. + + [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) + }]; + + let results = (outs VectorOfLengthAndType<[2], [BF16]>:$dst); + let arguments = (ins VectorOfLengthAndType<[2], [I8]>:$src, + Optional:$scaleFactor, + DefaultValuedAttr:$sat, + DefaultValuedAttr:$relu); + let assemblyFormat = + "$src (`,` $scaleFactor^)? attr-dict `:` type($src) `->` type($dst)"; +} + //===----------------------------------------------------------------------===// // NVVM Stochastic Rounding Conversion Ops //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 528e709629eb..24cf83ab3cb3 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -414,18 +414,45 @@ LogicalResult ConvertF16x2ToF8x2Op::verify() { LogicalResult ConvertBF16x2ToF8x2Op::verify() { using RndMode = NVVM::FPRoundingMode; + using SatMode = NVVM::SaturationMode; - if (!llvm::isa(getDstTy())) - return emitOpError("Only ") << mlir::Float8E8M0FNUType::get(getContext()) - << " type is supported for conversions from " - "bf16x2 to f8x2."; + bool isRoundingModeRN = getRnd() == RndMode::RN; + bool isRoundingModeRZ = getRnd() == RndMode::RZ; + bool isRoundingModeRP = getRnd() == RndMode::RP; + bool isSatFinite = getSat() == SatMode::SATFINITE; + bool hasRelu = getRelu(); - auto rnd = getRnd(); - if (rnd != RndMode::RZ && rnd != RndMode::RP) - return emitOpError("Only RZ and RP rounding modes are supported for " - "conversions from bf16x2 to f8x2."); + mlir::MLIRContext *ctx = getContext(); - return success(); + return llvm::TypeSwitch(getDstTy()) + .Case( + [&](mlir::Type) -> LogicalResult { + if (!isRoundingModeRN) + return emitOpError("Only RN rounding mode is supported for " + "conversions from bf16x2 to ") + << mlir::Float8E4M3FNType::get(ctx) << " and " + << mlir::Float8E5M2Type::get(ctx) << " types"; + if (!isSatFinite) + return emitOpError("Only SATFINITE saturation mode is supported " + "for conversions from bf16x2 to ") + << mlir::Float8E4M3FNType::get(ctx) << " and " + << mlir::Float8E5M2Type::get(ctx) << " types"; + return success(); + }) + .Case([&](mlir::Type) -> LogicalResult { + if (!(isRoundingModeRZ || isRoundingModeRP)) + return emitOpError("Only RZ and RP rounding modes are supported for " + "conversions from bf16x2 to ") + << mlir::Float8E8M0FNUType::get(ctx) << " type"; + if (hasRelu) + return emitOpError("relu not supported for conversions to ") + << mlir::Float8E8M0FNUType::get(ctx) << " type"; + return success(); + }) + .Default([&](mlir::Type) -> LogicalResult { + llvm_unreachable("Invalid conversion in ConvertBF16x2ToF8x2Op"); + return failure(); + }); } LogicalResult ConvertF32x2ToF4x2Op::verify() { @@ -4232,6 +4259,80 @@ llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy, }); } +NVVM::IDArgPair +ConvertF16x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF16x2ToF4x2Op &op, + LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + mlir::Type dstTy = op.getDstTy(); + bool hasRelu = op.getRelu(); + + llvm::Intrinsic::ID intId = llvm::Intrinsic::not_intrinsic; + + if (llvm::isa(dstTy)) + intId = hasRelu ? llvm::Intrinsic::nvvm_f16x2_to_e2m1x2_rn_relu_satfinite + : llvm::Intrinsic::nvvm_f16x2_to_e2m1x2_rn_satfinite; + + llvm::SmallVector args; + args.push_back(mt.lookupValue(op.getSrc())); + + return {intId, std::move(args)}; +} + +NVVM::IDArgPair +ConvertBF16x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertBF16x2ToF4x2Op &op, + LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + mlir::Type dstTy = op.getDstTy(); + bool hasRelu = op.getRelu(); + + llvm::Intrinsic::ID intId = llvm::Intrinsic::not_intrinsic; + + if (llvm::isa(dstTy)) + intId = hasRelu ? llvm::Intrinsic::nvvm_bf16x2_to_e2m1x2_rn_relu_satfinite + : llvm::Intrinsic::nvvm_bf16x2_to_e2m1x2_rn_satfinite; + + llvm::SmallVector args; + args.push_back(mt.lookupValue(op.getSrc())); + + return {intId, std::move(args)}; +} + +llvm::Intrinsic::ID ConvertF16x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy, + bool hasRelu) { + return llvm::TypeSwitch(dstTy) + .Case([&](mlir::Float6E2M3FNType) { + return hasRelu ? llvm::Intrinsic::nvvm_f16x2_to_e2m3x2_rn_relu_satfinite + : llvm::Intrinsic::nvvm_f16x2_to_e2m3x2_rn_satfinite; + }) + .Case([&](mlir::Float6E3M2FNType) { + return hasRelu ? llvm::Intrinsic::nvvm_f16x2_to_e3m2x2_rn_relu_satfinite + : llvm::Intrinsic::nvvm_f16x2_to_e3m2x2_rn_satfinite; + }) + .Default([](mlir::Type) { + llvm_unreachable("Invalid conversion in ConvertF16x2ToF6x2Op"); + return llvm::Intrinsic::not_intrinsic; + }); +} + +llvm::Intrinsic::ID ConvertBF16x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy, + bool hasRelu) { + return llvm::TypeSwitch(dstTy) + .Case([&](mlir::Float6E2M3FNType) { + return hasRelu + ? llvm::Intrinsic::nvvm_bf16x2_to_e2m3x2_rn_relu_satfinite + : llvm::Intrinsic::nvvm_bf16x2_to_e2m3x2_rn_satfinite; + }) + .Case([&](mlir::Float6E3M2FNType) { + return hasRelu + ? llvm::Intrinsic::nvvm_bf16x2_to_e3m2x2_rn_relu_satfinite + : llvm::Intrinsic::nvvm_bf16x2_to_e3m2x2_rn_satfinite; + }) + .Default([](mlir::Type) { + llvm_unreachable("Invalid conversion in ConvertBF16x2ToF6x2Op"); + return llvm::Intrinsic::not_intrinsic; + }); +} + #define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \ has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \ : llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd @@ -4287,22 +4388,39 @@ llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy, }); } -#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \ - has_satf ? llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd##_satfinite \ - : llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd - llvm::Intrinsic::ID -ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd, - NVVM::SaturationMode sat) { +ConvertBF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy, + NVVM::FPRoundingMode rnd, + NVVM::SaturationMode sat, bool hasRelu) { bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE); - switch (rnd) { - case NVVM::FPRoundingMode::RZ: - return GET_BF16X2_TO_F8X2_ID(rz, hasSatFinite); - case NVVM::FPRoundingMode::RP: - return GET_BF16X2_TO_F8X2_ID(rp, hasSatFinite); - default: - llvm_unreachable("Invalid rounding mode for CvtBF16x2ToF8x2Op"); - } + + static constexpr llvm::Intrinsic::ID ue8m0x2IDs[] = { + llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rz, + llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rp, + llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rz_satfinite, + llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rp_satfinite, + }; + + return llvm::TypeSwitch(dstTy) + .Case([&](mlir::Float8E4M3FNType) { + return hasRelu + ? llvm::Intrinsic::nvvm_bf16x2_to_e4m3x2_rn_relu_satfinite + : llvm::Intrinsic::nvvm_bf16x2_to_e4m3x2_rn_satfinite; + }) + .Case([&](mlir::Float8E5M2Type) { + return hasRelu + ? llvm::Intrinsic::nvvm_bf16x2_to_e5m2x2_rn_relu_satfinite + : llvm::Intrinsic::nvvm_bf16x2_to_e5m2x2_rn_satfinite; + }) + .Case([&](mlir::Float8E8M0FNUType) { + bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP); + unsigned index = (hasSatFinite << 1) | hasRoundingModeRP; + return ue8m0x2IDs[index]; + }) + .Default([](mlir::Type) { + llvm_unreachable("Invalid conversion in ConvertBF16x2ToF8x2Op"); + return llvm::Intrinsic::not_intrinsic; + }); } NVVM::IDArgPair ConvertF8x2ToF16x2Op::getIntrinsicIDAndArgs( @@ -4397,6 +4515,74 @@ NVVM::IDArgPair ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs( return {intId, {extendedI16}}; } +NVVM::IDArgPair ConvertF32x2ToS2F6x2Op::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast(op); + bool hasRelu = thisOp.getRelu(); + bool hasScale = static_cast(thisOp.getScaleFactor()); + + llvm::Intrinsic::ID id = + hasRelu + ? llvm::Intrinsic::nvvm_ff_to_s2f6x2_rn_relu_satfinite_scale_n2_ue8m0 + : llvm::Intrinsic::nvvm_ff_to_s2f6x2_rn_satfinite_scale_n2_ue8m0; + + // Fill the Intrinsic Args + llvm::SmallVector args; + args.push_back(mt.lookupValue(thisOp.getA())); + args.push_back(mt.lookupValue(thisOp.getB())); + args.push_back(hasScale ? mt.lookupValue(thisOp.getScaleFactor()) + : builder.getInt16(0x7f7f)); + return {id, std::move(args)}; +} + +NVVM::IDArgPair ConvertBF16x2ToS2F6x2Op::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast(op); + bool hasRelu = thisOp.getRelu(); + bool hasScale = static_cast(thisOp.getScaleFactor()); + + llvm::Intrinsic::ID id = + hasRelu + ? llvm::Intrinsic:: + nvvm_bf16x2_to_s2f6x2_rn_relu_satfinite_scale_n2_ue8m0 + : llvm::Intrinsic::nvvm_bf16x2_to_s2f6x2_rn_satfinite_scale_n2_ue8m0; + + // Fill the Intrinsic Args + llvm::SmallVector args; + args.push_back(mt.lookupValue(thisOp.getSrc())); + args.push_back(hasScale ? mt.lookupValue(thisOp.getScaleFactor()) + : builder.getInt16(0x7f7f)); + return {id, std::move(args)}; +} + +NVVM::IDArgPair ConvertS2F6x2ToBF16x2Op::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast(op); + bool hasRelu = thisOp.getRelu(); + bool hasScale = static_cast(thisOp.getScaleFactor()); + bool hasSat = thisOp.getSat() == NVVM::SaturationMode::SATFINITE; + + static constexpr llvm::Intrinsic::ID ids[] = { + llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_scale_n2_ue8m0, + llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_relu_scale_n2_ue8m0, + llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_satfinite_scale_n2_ue8m0, + llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_relu_satfinite_scale_n2_ue8m0, + }; + + unsigned idx = (hasSat << 1) | hasRelu; + + // Fill the Intrinsic Args + llvm::SmallVector args; + llvm::Value *packedI16 = + builder.CreateBitCast(mt.lookupValue(thisOp.getSrc()), + llvm::Type::getInt16Ty(builder.getContext())); + args.push_back(packedI16); + args.push_back(hasScale ? mt.lookupValue(thisOp.getScaleFactor()) + : builder.getInt16(0x7f7f)); + + return {ids[idx], std::move(args)}; +} + llvm::Intrinsic::ID Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, diff --git a/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir b/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir index 506b81e1e704..3e763a036f2d 100644 --- a/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir +++ b/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir @@ -64,7 +64,7 @@ llvm.func @convert_f32x2_to_f8x2_rs_not_supported(%a : f32, %b : f32) { // ----- llvm.func @convert_bf16x2_to_f8x2_rs_not_supported(%src : vector<2xbf16>) { - // expected-error @below {{Only RZ and RP rounding modes are supported for conversions from bf16x2 to f8x2.}} + // expected-error @below {{Only RZ and RP rounding modes are supported for conversions from bf16x2 to 'f8E8M0FNU' type}} %res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> -> i16 (f8E8M0FNU) llvm.return } diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir index 451475ca7602..3d3bd714fa8f 100644 --- a/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir +++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir @@ -11,6 +11,34 @@ llvm.func @convert_f32x2_to_f4x2_e2m1(%srcA : f32, %srcB : f32) { llvm.return } +// ----- + +// CHECK-LABEL: @convert_f16x2_to_f4x2 +llvm.func @convert_f16x2_to_f4x2(%srcA : vector<2xf16>) { + // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.f16x2.to.e2m1x2.rn.satfinite(<2 x half> %{{.*}}) + // CHECK-NEXT: %{{.*}} = trunc i16 %[[res1]] to i8 + %res1 = nvvm.convert.f16x2.to.f4x2 %srcA : vector<2xf16> -> i8 (f4E2M1FN) + // CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.f16x2.to.e2m1x2.rn.relu.satfinite(<2 x half> %{{.*}}) + // CHECK-NEXT: %{{.*}} = trunc i16 %[[res2]] to i8 + %res2 = nvvm.convert.f16x2.to.f4x2 %srcA {relu = true} : vector<2xf16> -> i8 (f4E2M1FN) + llvm.return +} + +// ----- + +// CHECK-LABEL: @convert_bf16x2_to_f4x2 +llvm.func @convert_bf16x2_to_f4x2(%srcA : vector<2xbf16>) { + // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.bf16x2.to.e2m1x2.rn.satfinite(<2 x bfloat> %{{.*}}) + // CHECK-NEXT: %{{.*}} = trunc i16 %[[res1]] to i8 + %res1 = nvvm.convert.bf16x2.to.f4x2 %srcA : vector<2xbf16> -> i8 (f4E2M1FN) + // CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.bf16x2.to.e2m1x2.rn.relu.satfinite(<2 x bfloat> %{{.*}}) + // CHECK-NEXT: %{{.*}} = trunc i16 %[[res2]] to i8 + %res2 = nvvm.convert.bf16x2.to.f4x2 %srcA {relu = true} : vector<2xbf16> -> i8 (f4E2M1FN) + llvm.return +} + +// ----- + // CHECK-LABEL: @convert_f4x2_to_f16x2 llvm.func @convert_f4x2_to_f16x2(%src : i8) { // CHECK: %[[res1:.*]] = zext i8 %{{.*}} to i16 diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2_invalid.mlir new file mode 100644 index 000000000000..d179431fe936 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2_invalid.mlir @@ -0,0 +1,17 @@ +// RUN: mlir-translate -mlir-to-llvmir -verify-diagnostics %s + +// ----- + +llvm.func @convert_f16x2_to_f4x2_invalid_type(%src : vector<2xf16>) { + // expected-error @below {{attribute 'dstTy' failed to satisfy constraint: type attribute of f4E2M1FN type}} + %res = nvvm.convert.f16x2.to.f4x2 %src : vector<2xf16> -> i8 (f8E4M3FN) + llvm.return +} + +// ----- + +llvm.func @convert_bf16x2_to_f4x2_invalid_type(%src : vector<2xbf16>) { + // expected-error @below {{attribute 'dstTy' failed to satisfy constraint: type attribute of f4E2M1FN type}} + %res = nvvm.convert.bf16x2.to.f4x2 %src : vector<2xbf16> -> i8 (f8E4M3FN) + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir index 61a7a48f40d5..8d9e5ff2a6a8 100644 --- a/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir +++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir @@ -1,11 +1,20 @@ // RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s -// CHECK-LABEL: @convert_f32x2_to_fp6x2_packed -llvm.func @convert_f32x2_to_fp6x2_packed(%srcA : f32, %srcB : f32) { +// CHECK-LABEL: @convert_f32x2_to_fp6x2_e2m3 +llvm.func @convert_f32x2_to_fp6x2_e2m3(%srcA : f32, %srcB : f32) { //CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}}) %res1 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : i16 (f6E2M3FN) + //CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}}) + %res2 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB {relu = true} : i16 (f6E2M3FN) + llvm.return +} + +// CHECK-LABEL: @convert_f32x2_to_fp6x2_e3m2 +llvm.func @convert_f32x2_to_fp6x2_e3m2(%srcA : f32, %srcB : f32) { //CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}}) - %res2 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : i16 (f6E3M2FN) + %res1 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : i16 (f6E3M2FN) + //CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}}) + %res2 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB {relu = true} : i16 (f6E3M2FN) llvm.return } @@ -22,6 +31,68 @@ llvm.func @convert_f32x2_to_fp6x2_vector(%srcA : f32, %srcB : f32) { // ----- +// CHECK-LABEL: @convert_f16x2_to_fp6x2_e2m3 +llvm.func @convert_f16x2_to_fp6x2_e2m3(%srcA : vector<2xf16>) { + // CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e2m3x2.rn.satfinite(<2 x half> %{{.*}}) + %res1 = nvvm.convert.f16x2.to.f6x2 %srcA : vector<2xf16> -> i16 (f6E2M3FN) + // CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e2m3x2.rn.relu.satfinite(<2 x half> %{{.*}}) + %res2 = nvvm.convert.f16x2.to.f6x2 %srcA {relu = true} : vector<2xf16> -> i16 (f6E2M3FN) + llvm.return +} + +// CHECK-LABEL: @convert_f16x2_to_fp6x2_e3m2 +llvm.func @convert_f16x2_to_fp6x2_e3m2(%srcA : vector<2xf16>) { + // CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e3m2x2.rn.satfinite(<2 x half> %{{.*}}) + %res1 = nvvm.convert.f16x2.to.f6x2 %srcA : vector<2xf16> -> i16 (f6E3M2FN) + // CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e3m2x2.rn.relu.satfinite(<2 x half> %{{.*}}) + %res2 = nvvm.convert.f16x2.to.f6x2 %srcA {relu = true} : vector<2xf16> -> i16 (f6E3M2FN) + llvm.return +} + +// CHECK-LABEL: @convert_f16x2_to_fp6x2_vector +llvm.func @convert_f16x2_to_fp6x2_vector(%srcA : vector<2xf16>) { + // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.f16x2.to.e2m3x2.rn.satfinite(<2 x half> %{{.*}}) + // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8> + %res1 = nvvm.convert.f16x2.to.f6x2 %srcA : vector<2xf16> -> vector<2xi8> (f6E2M3FN) + // CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.f16x2.to.e3m2x2.rn.satfinite(<2 x half> %{{.*}}) + // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res2]] to <2 x i8> + %res2 = nvvm.convert.f16x2.to.f6x2 %srcA : vector<2xf16> -> vector<2xi8> (f6E3M2FN) + llvm.return +} + +// ----- + +// CHECK-LABEL: @convert_bf16x2_to_fp6x2_e2m3 +llvm.func @convert_bf16x2_to_fp6x2_e2m3(%srcA : vector<2xbf16>, %scale_factor : i16) { + // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e2m3x2.rn.satfinite(<2 x bfloat> %{{.*}}) + %res1 = nvvm.convert.bf16x2.to.f6x2 %srcA : vector<2xbf16> -> i16 (f6E2M3FN) + // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e2m3x2.rn.relu.satfinite(<2 x bfloat> %{{.*}}) + %res2 = nvvm.convert.bf16x2.to.f6x2 %srcA {relu = true} : vector<2xbf16> -> i16 (f6E2M3FN) + llvm.return +} + +// CHECK-LABEL: @convert_bf16x2_to_fp6x2_e3m2 +llvm.func @convert_bf16x2_to_fp6x2_e3m2(%srcA : vector<2xbf16>, %scale_factor : i16) { + // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e3m2x2.rn.satfinite(<2 x bfloat> %{{.*}}) + %res1 = nvvm.convert.bf16x2.to.f6x2 %srcA : vector<2xbf16> -> i16 (f6E3M2FN) + // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e3m2x2.rn.relu.satfinite(<2 x bfloat> %{{.*}}) + %res2 = nvvm.convert.bf16x2.to.f6x2 %srcA {relu = true} : vector<2xbf16> -> i16 (f6E3M2FN) + llvm.return +} + +// CHECK-LABEL: @convert_bf16x2_to_fp6x2_vector +llvm.func @convert_bf16x2_to_fp6x2_vector(%srcA : vector<2xbf16>, %scale_factor : i16) { + // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e2m3x2.rn.satfinite(<2 x bfloat> %{{.*}}) + // CHECK-NEXT: %{{.*}} = bitcast i16 %{{.*}} to <2 x i8> + %res1 = nvvm.convert.bf16x2.to.f6x2 %srcA : vector<2xbf16> -> vector<2xi8> (f6E2M3FN) + // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e3m2x2.rn.satfinite(<2 x bfloat> %{{.*}}) + // CHECK-NEXT: %{{.*}} = bitcast i16 %{{.*}} to <2 x i8> + %res2 = nvvm.convert.bf16x2.to.f6x2 %srcA : vector<2xbf16> -> vector<2xi8> (f6E3M2FN) + llvm.return +} + +// ----- + // CHECK-LABEL: @convert_f6x2_to_f16x2_e2m3 llvm.func @convert_f6x2_to_f16x2_e2m3(%src : vector<2xi8>) { // CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16 diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2_invalid.mlir new file mode 100644 index 000000000000..e993868cf1c9 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2_invalid.mlir @@ -0,0 +1,17 @@ +// RUN: mlir-translate -mlir-to-llvmir -verify-diagnostics %s + +// ----- + +llvm.func @convert_f16x2_to_f6x2_invalid_type(%src : vector<2xf16>) { + // expected-error @below {{attribute 'dstTy' failed to satisfy constraint: type attribute of f6E2M3FN type or f6E3M2FN type}} + %res = nvvm.convert.f16x2.to.f6x2 %src : vector<2xf16> -> vector<2xi8> (f8E4M3FN) + llvm.return +} + +// ----- + +llvm.func @convert_bf16x2_to_f6x2_invalid_type(%src : vector<2xbf16>) { + // expected-error @below {{attribute 'dstTy' failed to satisfy constraint: type attribute of f6E2M3FN type or f6E3M2FN type}} + %res = nvvm.convert.bf16x2.to.f6x2 %src : vector<2xbf16> -> vector<2xi8> (f8E4M3FN) + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir index 4afe901bc08e..d8002d790b6a 100644 --- a/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir +++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir @@ -90,6 +90,25 @@ llvm.func @convert_bf16x2_to_f8x2_ue8m0(%src : vector<2xbf16>) { llvm.return } + +// CHECK-LABEL: @convert_bf16x2_to_f8x2_e4m3 +llvm.func @convert_bf16x2_to_f8x2_e4m3(%srcA : vector<2xbf16>) { + // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e4m3x2.rn.satfinite(<2 x bfloat> %{{.*}}) + %res1 = nvvm.convert.bf16x2.to.f8x2 %srcA {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xbf16> -> i16 (f8E4M3FN) + // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e4m3x2.rn.relu.satfinite(<2 x bfloat> %{{.*}}) + %res2 = nvvm.convert.bf16x2.to.f8x2 %srcA {relu = true, rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xbf16> -> i16 (f8E4M3FN) + llvm.return +} + +// CHECK-LABEL: @convert_bf16x2_to_f8x2_e5m2 +llvm.func @convert_bf16x2_to_f8x2_e5m2(%srcA : vector<2xbf16>) { + // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e5m2x2.rn.satfinite(<2 x bfloat> %{{.*}}) + %res1 = nvvm.convert.bf16x2.to.f8x2 %srcA {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xbf16> -> vector<2xi8> (f8E5M2) + // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e5m2x2.rn.relu.satfinite(<2 x bfloat> %{{.*}}) + %res2 = nvvm.convert.bf16x2.to.f8x2 %srcA {relu = true, rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xbf16> -> vector<2xi8> (f8E5M2) + llvm.return +} + // CHECK-LABEL: @convert_bf16x2_to_f8x2_vector_return llvm.func @convert_bf16x2_to_f8x2_vector_return(%src : vector<2xbf16>) { // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rz(<2 x bfloat> %{{.*}}) @@ -98,6 +117,12 @@ llvm.func @convert_bf16x2_to_f8x2_vector_return(%src : vector<2xbf16>) { // CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rp.satfinite(<2 x bfloat> %{{.*}}) // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res2]] to <2 x i8> %res2 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xbf16> -> vector<2xi8> (f8E8M0FNU) + // CHECK: %[[res3:.*]] = call i16 @llvm.nvvm.bf16x2.to.e4m3x2.rn.satfinite(<2 x bfloat> %{{.*}}) + // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res3]] to <2 x i8> + %res3 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xbf16> -> vector<2xi8> (f8E4M3FN) + // CHECK: %[[res4:.*]] = call i16 @llvm.nvvm.bf16x2.to.e5m2x2.rn.relu.satfinite(<2 x bfloat> %{{.*}}) + // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res4]] to <2 x i8> + %res4 = nvvm.convert.bf16x2.to.f8x2 %src {relu = true, rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xbf16> -> vector<2xi8> (f8E5M2) llvm.return } diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2_invalid.mlir new file mode 100644 index 000000000000..747706dfc341 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2_invalid.mlir @@ -0,0 +1,41 @@ +// RUN: mlir-translate -mlir-to-llvmir -verify-diagnostics %s + +// ----- + +llvm.func @convert_bf16x2_to_f8x2_invalid_type(%src : vector<2xbf16>) { + // expected-error @below {{attribute 'dstTy' failed to satisfy constraint: type attribute of f8E8M0FNU type or f8E4M3FN type or f8E5M2 type}} + %res = nvvm.convert.bf16x2.to.f8x2 %src : vector<2xbf16> -> vector<2xi8> (f6E2M3FN) + llvm.return +} + +// ----- + +llvm.func @convert_bf16x2_to_f8x2_invalid_rounding_1(%src : vector<2xbf16>) { + // expected-error @below {{Only RN rounding mode is supported for conversions from bf16x2 to 'f8E4M3FN' and 'f8E5M2' types}} + %res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> -> vector<2xi8> (f8E4M3FN) + llvm.return +} + +// ----- + +llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_rounding_2(%src : vector<2xbf16>) { + // expected-error @below {{Only RZ and RP rounding modes are supported for conversions from bf16x2 to 'f8E8M0FNU' type}} + %res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> -> vector<2xi8> (f8E8M0FNU) + llvm.return +} + +// ----- + +llvm.func @convert_bf16x2_to_f8x2_invalid_sat_mode(%src : vector<2xbf16>) { + // expected-error @below {{Only SATFINITE saturation mode is supported for conversions from bf16x2 to 'f8E4M3FN' and 'f8E5M2' types}} + %res = nvvm.convert.bf16x2.to.f8x2 %src {sat = #nvvm.sat_mode, rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> -> vector<2xi8> (f8E4M3FN) + llvm.return +} + +// ----- + +llvm.func @convert_bf16x2_to_f8x2_invalid_relu(%src : vector<2xbf16>) { + // expected-error @below {{relu not supported for conversions to 'f8E8M0FNU' type}} + %res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode, relu = true} : vector<2xbf16> -> vector<2xi8> (f8E8M0FNU) + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_s2f6x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_s2f6x2.mlir new file mode 100644 index 000000000000..7c1aa406a47a --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/convert_s2f6x2.mlir @@ -0,0 +1,189 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +llvm.func @convert_f32x2_to_s2f6x2(%srcA : f32, %srcB : f32) -> i16 { + // CHECK-LABEL: define i16 @convert_f32x2_to_s2f6x2(float %0, float %1) { + // CHECK-NEXT: %3 = call i16 @llvm.nvvm.ff.to.s2f6x2.rn.satfinite.scale.n2.ue8m0(float %0, float %1, i16 32639) + // CHECK-NEXT: %4 = call i16 @llvm.nvvm.ff.to.s2f6x2.rn.relu.satfinite.scale.n2.ue8m0(float %0, float %1, i16 32639) + // CHECK-NEXT: %5 = or i16 %3, %4 + // CHECK-NEXT: ret i16 %5 + // CHECK-NEXT: } + %res1 = nvvm.convert.f32x2.to.s2f6x2 %srcA, %srcB : i16 + %res2 = nvvm.convert.f32x2.to.s2f6x2 %srcA, %srcB {relu = true} : i16 + + // Combine results to avoid dead code elimination + %final_result = llvm.or %res1, %res2 : i16 + llvm.return %final_result : i16 +} + +llvm.func @convert_f32x2_to_s2f6x2_scale(%srcA : f32, %srcB : f32, %scale : i16) -> i16 { + // CHECK-LABEL: define i16 @convert_f32x2_to_s2f6x2_scale(float %0, float %1, i16 %2) { + // CHECK-NEXT: %4 = call i16 @llvm.nvvm.ff.to.s2f6x2.rn.satfinite.scale.n2.ue8m0(float %0, float %1, i16 %2) + // CHECK-NEXT: %5 = call i16 @llvm.nvvm.ff.to.s2f6x2.rn.relu.satfinite.scale.n2.ue8m0(float %0, float %1, i16 %2) + // CHECK-NEXT: %6 = or i16 %4, %5 + // CHECK-NEXT: ret i16 %6 + // CHECK-NEXT: } + %res1 = nvvm.convert.f32x2.to.s2f6x2 %srcA, %srcB, %scale : i16 + %res2 = nvvm.convert.f32x2.to.s2f6x2 %srcA, %srcB, %scale {relu = true} : i16 + + // Combine results to avoid dead code elimination + %final_result = llvm.or %res1, %res2 : i16 + llvm.return %final_result : i16 +} + +llvm.func @convert_f32x2_to_s2f6x2_vector(%srcA : f32, %srcB : f32) -> vector<2xi8> { + // CHECK-LABEL: define <2 x i8> @convert_f32x2_to_s2f6x2_vector(float %0, float %1) { + // CHECK-NEXT: %3 = call i16 @llvm.nvvm.ff.to.s2f6x2.rn.satfinite.scale.n2.ue8m0(float %0, float %1, i16 32639) + // CHECK-NEXT: %4 = bitcast i16 %3 to <2 x i8> + // CHECK-NEXT: ret <2 x i8> %4 + // CHECK-NEXT: } + %res1 = nvvm.convert.f32x2.to.s2f6x2 %srcA, %srcB : vector<2xi8> + llvm.return %res1 : vector<2xi8> +} + +llvm.func @convert_f32x2_to_s2f6x2_vector_scale(%srcA : f32, %srcB : f32, %scale : i16) -> vector<2xi8> { + // CHECK-LABEL: define <2 x i8> @convert_f32x2_to_s2f6x2_vector_scale(float %0, float %1, i16 %2) { + // CHECK-NEXT: %4 = call i16 @llvm.nvvm.ff.to.s2f6x2.rn.satfinite.scale.n2.ue8m0(float %0, float %1, i16 %2) + // CHECK-NEXT: %5 = bitcast i16 %4 to <2 x i8> + // CHECK-NEXT: ret <2 x i8> %5 + // CHECK-NEXT: } + %res1 = nvvm.convert.f32x2.to.s2f6x2 %srcA, %srcB, %scale : vector<2xi8> + llvm.return %res1 : vector<2xi8> +} + +llvm.func @convert_bf16x2_to_s2f6x2(%srcA : vector<2xbf16>) -> i16 { + // CHECK-LABEL: define i16 @convert_bf16x2_to_s2f6x2(<2 x bfloat> %0) { + // CHECK-NEXT: %2 = call i16 @llvm.nvvm.bf16x2.to.s2f6x2.rn.satfinite.scale.n2.ue8m0(<2 x bfloat> %0, i16 32639) + // CHECK-NEXT: %3 = call i16 @llvm.nvvm.bf16x2.to.s2f6x2.rn.relu.satfinite.scale.n2.ue8m0(<2 x bfloat> %0, i16 32639) + // CHECK-NEXT: %4 = or i16 %2, %3 + // CHECK-NEXT: ret i16 %4 + // CHECK-NEXT: } + %res1 = nvvm.convert.bf16x2.to.s2f6x2 %srcA : vector<2xbf16> -> i16 + %res2 = nvvm.convert.bf16x2.to.s2f6x2 %srcA {relu = true} : vector<2xbf16> -> i16 + + // Combine results to avoid dead code elimination + %final_result = llvm.or %res1, %res2 : i16 + llvm.return %final_result : i16 +} + +llvm.func @convert_bf16x2_to_s2f6x2_scale(%srcA : vector<2xbf16>, %scale : i16) -> i16 { + // CHECK-LABEL: define i16 @convert_bf16x2_to_s2f6x2_scale(<2 x bfloat> %0, i16 %1) { + // CHECK-NEXT: %3 = call i16 @llvm.nvvm.bf16x2.to.s2f6x2.rn.satfinite.scale.n2.ue8m0(<2 x bfloat> %0, i16 %1) + // CHECK-NEXT: %4 = call i16 @llvm.nvvm.bf16x2.to.s2f6x2.rn.relu.satfinite.scale.n2.ue8m0(<2 x bfloat> %0, i16 %1) + // CHECK-NEXT: %5 = or i16 %3, %4 + // CHECK-NEXT: ret i16 %5 + // CHECK-NEXT: } + %res1 = nvvm.convert.bf16x2.to.s2f6x2 %srcA, %scale : vector<2xbf16> -> i16 + %res2 = nvvm.convert.bf16x2.to.s2f6x2 %srcA, %scale {relu = true} : vector<2xbf16> -> i16 + + // Combine results to avoid dead code elimination + %final_result = llvm.or %res1, %res2 : i16 + llvm.return %final_result : i16 +} + +llvm.func @convert_bf16x2_to_s2f6x2_vector(%srcA : vector<2xbf16>) -> vector<2xi8> { + // CHECK-LABEL: define <2 x i8> @convert_bf16x2_to_s2f6x2_vector(<2 x bfloat> %0) { + // CHECK-NEXT: %2 = call i16 @llvm.nvvm.bf16x2.to.s2f6x2.rn.satfinite.scale.n2.ue8m0(<2 x bfloat> %0, i16 32639) + // CHECK-NEXT: %3 = bitcast i16 %2 to <2 x i8> + // CHECK-NEXT: ret <2 x i8> %3 + // CHECK-NEXT: } + %res1 = nvvm.convert.bf16x2.to.s2f6x2 %srcA : vector<2xbf16> -> vector<2xi8> + llvm.return %res1 : vector<2xi8> +} + +llvm.func @convert_bf16x2_to_s2f6x2_vector_scale(%srcA : vector<2xbf16>, %scale : i16) -> vector<2xi8> { + // CHECK-LABEL: define <2 x i8> @convert_bf16x2_to_s2f6x2_vector_scale(<2 x bfloat> %0, i16 %1) { + // CHECK-NEXT: %3 = call i16 @llvm.nvvm.bf16x2.to.s2f6x2.rn.satfinite.scale.n2.ue8m0(<2 x bfloat> %0, i16 %1) + // CHECK-NEXT: %4 = bitcast i16 %3 to <2 x i8> + // CHECK-NEXT: ret <2 x i8> %4 + // CHECK-NEXT: } + %res1 = nvvm.convert.bf16x2.to.s2f6x2 %srcA, %scale : vector<2xbf16> -> vector<2xi8> + llvm.return %res1 : vector<2xi8> +} + +// 1. no relu, no scale, no satfinite +llvm.func @convert_s2f6x2_to_bf16x2(%src : vector<2xi8>) -> vector<2xbf16> { + // CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2(<2 x i8> %0) { + // CHECK-NEXT: %2 = bitcast <2 x i8> %0 to i16 + // CHECK-NEXT: %3 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.scale.n2.ue8m0(i16 %2, i16 32639) + // CHECK-NEXT: ret <2 x bfloat> %3 + // CHECK-NEXT: } + %res = nvvm.convert.s2f6x2.to.bf16x2 %src : vector<2xi8> -> vector<2xbf16> + llvm.return %res : vector<2xbf16> +} + +// 2. relu, no scale, no satfinite +llvm.func @convert_s2f6x2_to_bf16x2_relu(%src : vector<2xi8>) -> vector<2xbf16> { + // CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_relu(<2 x i8> %0) { + // CHECK-NEXT: %2 = bitcast <2 x i8> %0 to i16 + // CHECK-NEXT: %3 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.relu.scale.n2.ue8m0(i16 %2, i16 32639) + // CHECK-NEXT: ret <2 x bfloat> %3 + // CHECK-NEXT: } + %res = nvvm.convert.s2f6x2.to.bf16x2 %src {relu = true} : vector<2xi8> -> vector<2xbf16> + llvm.return %res : vector<2xbf16> +} + +// 3. no relu, with scale, no satfinite +llvm.func @convert_s2f6x2_to_bf16x2_scale(%src : vector<2xi8>, %scale : i16) -> vector<2xbf16> { + // CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_scale(<2 x i8> %0, i16 %1) { + // CHECK-NEXT: %3 = bitcast <2 x i8> %0 to i16 + // CHECK-NEXT: %4 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.scale.n2.ue8m0(i16 %3, i16 %1) + // CHECK-NEXT: ret <2 x bfloat> %4 + // CHECK-NEXT: } + %res = nvvm.convert.s2f6x2.to.bf16x2 %src, %scale : vector<2xi8> -> vector<2xbf16> + llvm.return %res : vector<2xbf16> +} + +// 4. relu, with scale, no satfinite +llvm.func @convert_s2f6x2_to_bf16x2_scale_relu(%src : vector<2xi8>, %scale : i16) -> vector<2xbf16> { + // CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_scale_relu(<2 x i8> %0, i16 %1) { + // CHECK-NEXT: %3 = bitcast <2 x i8> %0 to i16 + // CHECK-NEXT: %4 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.relu.scale.n2.ue8m0(i16 %3, i16 %1) + // CHECK-NEXT: ret <2 x bfloat> %4 + // CHECK-NEXT: } + %res = nvvm.convert.s2f6x2.to.bf16x2 %src, %scale {relu = true} : vector<2xi8> -> vector<2xbf16> + llvm.return %res : vector<2xbf16> +} + +// 5. no relu, no scale, satfinite +llvm.func @convert_s2f6x2_to_bf16x2_satfinite(%src : vector<2xi8>) -> vector<2xbf16> { + // CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_satfinite(<2 x i8> %0) { + // CHECK-NEXT: %2 = bitcast <2 x i8> %0 to i16 + // CHECK-NEXT: %3 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.satfinite.scale.n2.ue8m0(i16 %2, i16 32639) + // CHECK-NEXT: ret <2 x bfloat> %3 + // CHECK-NEXT: } + %res = nvvm.convert.s2f6x2.to.bf16x2 %src {sat = #nvvm.sat_mode} : vector<2xi8> -> vector<2xbf16> + llvm.return %res : vector<2xbf16> +} + +// 6. relu, no scale, satfinite +llvm.func @convert_s2f6x2_to_bf16x2_relu_satfinite(%src : vector<2xi8>) -> vector<2xbf16> { + // CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_relu_satfinite(<2 x i8> %0) { + // CHECK-NEXT: %2 = bitcast <2 x i8> %0 to i16 + // CHECK-NEXT: %3 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.relu.satfinite.scale.n2.ue8m0(i16 %2, i16 32639) + // CHECK-NEXT: ret <2 x bfloat> %3 + // CHECK-NEXT: } + %res = nvvm.convert.s2f6x2.to.bf16x2 %src {relu = true, sat = #nvvm.sat_mode} : vector<2xi8> -> vector<2xbf16> + llvm.return %res : vector<2xbf16> +} + +// 7. no relu, with scale, satfinite +llvm.func @convert_s2f6x2_to_bf16x2_scale_satfinite(%src : vector<2xi8>, %scale : i16) -> vector<2xbf16> { + // CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_scale_satfinite(<2 x i8> %0, i16 %1) { + // CHECK-NEXT: %3 = bitcast <2 x i8> %0 to i16 + // CHECK-NEXT: %4 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.satfinite.scale.n2.ue8m0(i16 %3, i16 %1) + // CHECK-NEXT: ret <2 x bfloat> %4 + // CHECK-NEXT: } + %res = nvvm.convert.s2f6x2.to.bf16x2 %src, %scale {sat = #nvvm.sat_mode} : vector<2xi8> -> vector<2xbf16> + llvm.return %res : vector<2xbf16> +} + +// 8. relu, with scale, satfinite +llvm.func @convert_s2f6x2_to_bf16x2_scale_relu_satfinite(%src : vector<2xi8>, %scale : i16) -> vector<2xbf16> { + // CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_scale_relu_satfinite(<2 x i8> %0, i16 %1) { + // CHECK-NEXT: %3 = bitcast <2 x i8> %0 to i16 + // CHECK-NEXT: %4 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.relu.satfinite.scale.n2.ue8m0(i16 %3, i16 %1) + // CHECK-NEXT: ret <2 x bfloat> %4 + // CHECK-NEXT: } + %res = nvvm.convert.s2f6x2.to.bf16x2 %src, %scale {relu = true, sat = #nvvm.sat_mode} : vector<2xi8> -> vector<2xbf16> + llvm.return %res : vector<2xbf16> +} diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir index 785f59b22801..2726fc7a40ef 100644 --- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir @@ -209,22 +209,6 @@ llvm.func @nvvm_cvt_f16x2_to_f8x2_invalid_type(%src : vector<2xf16>) { // ----- -llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_type(%src : vector<2xbf16>) { - // expected-error @below {{Only 'f8E8M0FNU' type is supported for conversions from bf16x2 to f8x2.}} - %res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> -> i16 (f8E4M3FN) - llvm.return -} - -// ----- - -llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_rounding(%src : vector<2xbf16>) { - // expected-error @below {{Only RZ and RP rounding modes are supported for conversions from bf16x2 to f8x2.}} - %res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> -> i16 (f8E8M0FNU) - llvm.return -} - -// ----- - llvm.func @nvvm_cvt_f32x2_to_f6x2_invalid_type(%a : f32, %b : f32) { // expected-error @below {{Only 'f6E2M3FN' and 'f6E3M2FN' types are supported for conversions from f32x2 to f6x2.}} %res = nvvm.convert.f32x2.to.f6x2 %a, %b : i16 (f8E8M0FNU)