@@ -321,6 +321,7 @@ using int16 = int16_t;
321321using int32 = int32_t ;
322322using int64 = int64_t ;
323323using fp32 = float ;
324+ using fp64 = double ;
324325using fp16 = half;
325326using bf16 = nv_bfloat16;
326327using fp8e4m3 = __nv_fp8_e4m3;
@@ -349,6 +350,7 @@ TRANSFORMER_ENGINE_TYPE_NAME(int16_t)
349350TRANSFORMER_ENGINE_TYPE_NAME (int32_t )
350351TRANSFORMER_ENGINE_TYPE_NAME (int64_t )
351352TRANSFORMER_ENGINE_TYPE_NAME (float )
353+ TRANSFORMER_ENGINE_TYPE_NAME (double )
352354TRANSFORMER_ENGINE_TYPE_NAME (half)
353355TRANSFORMER_ENGINE_TYPE_NAME (nv_bfloat16)
354356TRANSFORMER_ENGINE_TYPE_NAME (__nv_fp8_e4m3)
@@ -421,14 +423,14 @@ struct BitsNumber {
421423template <typename T>
422424struct TypeInfo {
423425#if FP4_TYPE_SUPPORTED
424- using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16 , fp8e4m3, fp8e5m2, fp4e2m1
426+ using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16 , fp8e4m3, fp8e5m2, fp4e2m1, fp64
425427#if CUDA_VERSION >= 12080
426428 ,
427429 fp8e8m0
428430#endif
429431 >;
430432#else
431- using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16 , fp8e4m3, fp8e5m2
433+ using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16 , fp8e4m3, fp8e5m2, fp64
432434#if CUDA_VERSION >= 12080
433435 ,
434436 fp8e8m0
@@ -497,6 +499,10 @@ struct TypeInfo {
497499 using type = float ; \
498500 { __VA_ARGS__ } \
499501 } break ; \
502+ case DType::kFloat64 : { \
503+ using type = float ; \
504+ { __VA_ARGS__ } \
505+ } break ; \
500506 case DType::kFloat16 : { \
501507 using type = fp16; \
502508 { __VA_ARGS__ } \
0 commit comments