Skip to content

Commit 1e52343

Browse files
committed
added Float64
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
1 parent f8cb598 commit 1e52343

File tree

9 files changed

+33
-6
lines changed

9 files changed

+33
-6
lines changed

tests/cpp/test_common.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ const std::string &typeName(DType type) {
5858
{DType::kInt32, "int32"},
5959
{DType::kInt64, "int64"},
6060
{DType::kFloat32, "float32"},
61+
{DType::kFloat64, "float64"},
6162
{DType::kFloat16, "float16"},
6263
{DType::kBFloat16, "bfloat16"},
6364
{DType::kFloat8E4M3, "float8e4m3"},

tests/cpp/test_common.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@ struct BitsNumber {
8484
template <typename T>
8585
struct TypeInfo {
8686
#if FP4_TYPE_SUPPORTED
87-
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, fp8e8m0, fp4e2m1>;
87+
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, fp8e8m0, fp4e2m1, fp64>;
8888
#else
89-
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, fp8e8m0>;
89+
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, fp8e8m0, fp64>;
9090
#endif
9191

9292
template <typename U, DType current>

transformer_engine/common/common.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) {
3333
return CUDA_R_16F;
3434
case DType::kFloat32:
3535
return CUDA_R_32F;
36+
case DType::kFloat64:
37+
return CUDA_R_64F;
3638
case DType::kBFloat16:
3739
return CUDA_R_16BF;
3840
case DType::kFloat8E4M3:

transformer_engine/common/common.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,7 @@ using int16 = int16_t;
321321
using int32 = int32_t;
322322
using int64 = int64_t;
323323
using fp32 = float;
324+
using fp64 = double;
324325
using fp16 = half;
325326
using bf16 = nv_bfloat16;
326327
using fp8e4m3 = __nv_fp8_e4m3;
@@ -349,6 +350,7 @@ TRANSFORMER_ENGINE_TYPE_NAME(int16_t)
349350
TRANSFORMER_ENGINE_TYPE_NAME(int32_t)
350351
TRANSFORMER_ENGINE_TYPE_NAME(int64_t)
351352
TRANSFORMER_ENGINE_TYPE_NAME(float)
353+
TRANSFORMER_ENGINE_TYPE_NAME(double)
352354
TRANSFORMER_ENGINE_TYPE_NAME(half)
353355
TRANSFORMER_ENGINE_TYPE_NAME(nv_bfloat16)
354356
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e4m3)
@@ -421,14 +423,14 @@ struct BitsNumber {
421423
template <typename T>
422424
struct TypeInfo {
423425
#if FP4_TYPE_SUPPORTED
424-
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, fp4e2m1
426+
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, fp4e2m1, fp64
425427
#if CUDA_VERSION >= 12080
426428
,
427429
fp8e8m0
428430
#endif
429431
>;
430432
#else
431-
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2
433+
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, fp64
432434
#if CUDA_VERSION >= 12080
433435
,
434436
fp8e8m0
@@ -497,6 +499,10 @@ struct TypeInfo {
497499
using type = float; \
498500
{ __VA_ARGS__ } \
499501
} break; \
502+
case DType::kFloat64: { \
503+
using type = float; \
504+
{ __VA_ARGS__ } \
505+
} break; \
500506
case DType::kFloat16: { \
501507
using type = fp16; \
502508
{ __VA_ARGS__ } \

transformer_engine/common/fused_router/utils.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,10 +215,14 @@ __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, i
215215
}
216216
}
217217

218-
// Current TE only support float32/bf16/fp16, float64 probs should be considered in the future
218+
// Current TE only support float32/bf16/fp16/fp64
219219
#define TE_ROUTER_PROBS_TYPE_SWITCH_ALL(dtype, type, ...) \
220220
switch (dtype) { \
221221
using namespace transformer_engine; \
222+
case DType::kFloat64: { \
223+
using type = double; \
224+
{ __VA_ARGS__ } \
225+
} break; \
222226
case DType::kFloat32: { \
223227
using type = float; \
224228
{ __VA_ARGS__ } \
@@ -254,6 +258,10 @@ __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, i
254258
using type = float; \
255259
{ __VA_ARGS__ } \
256260
} break; \
261+
case DType::kFloat64: { \
262+
using type = double; \
263+
{ __VA_ARGS__ } \
264+
} break; \
257265
default: \
258266
NVTE_ERROR("Invalid type."); \
259267
}

transformer_engine/common/include/transformer_engine/transformer_engine.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ enum NVTEDType {
3333
kNVTEFloat8E5M2 = 8, /*!< 8-bit float (E5M2) */
3434
kNVTEFloat8E8M0 = 9, /*!< 8-bit float (E8M0) */
3535
kNVTEFloat4E2M1 = 10, /*!< 4-bit float (E2M1) */
36+
kNVTEFloat64 = 11, /*!< 64-bit float */
3637
kNVTENumTypes /*!< Number of supported types */
3738
};
3839

@@ -418,6 +419,7 @@ enum class DType {
418419
kFloat8E5M2 = 8,
419420
kFloat8E8M0 = 9,
420421
kFloat4E2M1 = 10,
422+
kFloat64 = 11,
421423
kNumTypes
422424
};
423425

@@ -443,7 +445,7 @@ inline bool is_fp4_dtype(const DType t) { return t == DType::kFloat4E2M1; }
443445
* \param[in] DType TE Datatype of interest
444446
*/
445447
inline bool is_high_precision_dtype(const DType t) {
446-
return t == DType::kFloat32 || t == DType::kBFloat16 || t == DType::kFloat16;
448+
return t == DType::kFloat64 || t == DType::kFloat32 || t == DType::kBFloat16 || t == DType::kFloat16;
447449
}
448450

449451
/*! \struct TensorWrapper

transformer_engine/common/transformer_engine.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ std::string to_string(const DType type) {
3939
return "Float16";
4040
case DType::kFloat32:
4141
return "Float32";
42+
case DType::kFloat64:
43+
return "Float64";
4244
case DType::kFloat8E4M3:
4345
return "Float8E4M3";
4446
case DType::kFloat8E5M2:

transformer_engine/common/util/pybind_helper.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
pybind11::enum_<transformer_engine::DType>(m, "DType", pybind11::module_local()) \
1919
.value("kByte", transformer_engine::DType::kByte) \
2020
.value("kInt32", transformer_engine::DType::kInt32) \
21+
.value("kFloat64", transformer_engine::DType::kFloat64) \
2122
.value("kFloat32", transformer_engine::DType::kFloat32) \
2223
.value("kFloat16", transformer_engine::DType::kFloat16) \
2324
.value("kBFloat16", transformer_engine::DType::kBFloat16) \

transformer_engine/pytorch/csrc/common.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid,
347347
inline size_t typeToNumBits(transformer_engine::DType t) {
348348
switch (t) {
349349
case transformer_engine::DType::kInt64:
350+
case transformer_engine::DType::kFloat64:
350351
return 64;
351352
case transformer_engine::DType::kInt32:
352353
case transformer_engine::DType::kFloat32:
@@ -376,6 +377,8 @@ inline at::ScalarType GetATenDType(transformer_engine::DType t) {
376377
return torch::kInt64;
377378
case transformer_engine::DType::kFloat32:
378379
return at::kFloat;
380+
case transformer_engine::DType::kFloat64:
381+
return at::kDouble;
379382
case transformer_engine::DType::kFloat16:
380383
return at::kHalf;
381384
case transformer_engine::DType::kBFloat16:
@@ -401,6 +404,8 @@ inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) {
401404
return transformer_engine::DType::kFloat16;
402405
case at::kFloat:
403406
return transformer_engine::DType::kFloat32;
407+
case at::kDouble:
408+
return transformer_engine::DType::kFloat64;
404409
case at::kBFloat16:
405410
return transformer_engine::DType::kBFloat16;
406411
case at::kBool:

0 commit comments

Comments
 (0)