Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion transformer_engine/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,10 @@ struct QuantizationConfig {
NVTETensor rng_state = nullptr;
bool nvfp4_2d_quantization = false;
bool stochastic_rounding = false;
// Scale factor for estimating post-RHT amax from pre-RHT amax.
// When <= 0.0f, true post-RHT amax is used (default behavior).
// When > 0.0f, post-RHT amax is estimated as: pre_rht_amax * amax_estimation_scale
float amax_estimation_scale = 0.0f;

static constexpr size_t attr_sizes[] = {
sizeof(bool), // force_pow_2_scales
Expand All @@ -402,7 +406,8 @@ struct QuantizationConfig {
sizeof(Float8BlockScaleTensorFormat), // float8_block_scale_tensor_format
sizeof(NVTETensor), // rng_seed and offset
sizeof(bool), // nvfp4_2d_quantization
sizeof(bool) // stochastic_rounding
sizeof(bool), // stochastic_rounding
sizeof(float) // amax_estimation_scale
};
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile,
TSFC * SFC,
TiledMMA mma,
float const* global_amax,
const size_t* rng_state)
const size_t* rng_state,
float amax_scale)
{
using namespace cute;
using X = Underscore;
Expand Down Expand Up @@ -407,7 +408,8 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile,
accumulator_pipeline.producer_tail(accumulator_pipe_producer_state);
tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns);
} else if (is_epilogue_warp) {
const float global_amax_val = *global_amax;
// Apply amax estimation scale if provided (amax_scale > 0 means estimation is enabled)
const float global_amax_val = (*global_amax) * amax_scale;
static constexpr int FragmentSize = 256 / sizeof_bits_v<TC>;

tmem_allocation_result_barrier.arrive_and_wait();
Expand Down Expand Up @@ -543,7 +545,8 @@ rht_gemm_ntt_w_sfc(int m, int n,
const size_t* rng_state,
uint32_t sm_count,
cudaStream_t stream,
int k_tile_size = 2048)
int k_tile_size = 2048,
float amax_scale = 1.0f)
{
using namespace cute;

Expand Down Expand Up @@ -662,7 +665,8 @@ rht_gemm_ntt_w_sfc(int m, int n,
C, dC, sC,
SFC,
mma, global_amax,
rng_state);
rng_state,
amax_scale);
}

// this function is used to wrap the rht_gemm_ntt_w_sfc function
Expand All @@ -678,7 +682,8 @@ rht_gemm_ttt_wrapper(int m, int n,
const size_t* rng_state,
uint32_t sm_count,
cudaStream_t stream,
int k_tile_size = 1024)
int k_tile_size = 1024,
float amax_scale = 1.0f)
{
// in addition to transpose the input tensor A
// we also need to reshape m, n to at best
Expand All @@ -696,7 +701,8 @@ rht_gemm_ttt_wrapper(int m, int n,
SFC, global_amax,
rng_state,
sm_count, stream,
k_tile_size);
k_tile_size,
amax_scale);
}

} // namespace
Expand Down Expand Up @@ -734,6 +740,11 @@ void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &out
rng_state = reinterpret_cast<const size_t *>(rng_state_tensor.data.dptr);
}

// Amax estimation scale: when > 0, amax is scaled by this factor
// This allows estimating post-RHT amax from pre-RHT amax
const float amax_scale =
(quant_config.amax_estimation_scale > 0.0f) ? quant_config.amax_estimation_scale : 1.0f;

// Template arguments
using TA = cute::bfloat16_t;
using TB = cute::bfloat16_t;
Expand Down Expand Up @@ -813,7 +824,8 @@ void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &out
/*rng_state=*/rng_state,
/*sm_count=*/sm_count,
/*stream=*/stream,
/*k_tile_size=*/k_tile_size););
/*k_tile_size=*/k_tile_size,
/*amax_scale=*/amax_scale););
}

} // namespace transformer_engine
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,11 @@ enum NVTEQuantizationConfigAttribute {
kNVTEQuantizationConfigNVFP42DQuantization = 5,
/*! Whether to enable stochastic rounding */
kNVTEQuantizationConfigStochasticRounding = 6,
/*! Scale factor for estimating post-RHT amax from pre-RHT amax.
* When <= 0.0f, true post-RHT amax is used (default behavior).
* When > 0.0f, post-RHT amax is estimated as: pre_rht_amax * amax_estimation_scale
*/
kNVTEQuantizationConfigAmaxEstimationScale = 7,
kNVTEQuantizationConfigNumAttributes
};

Expand Down Expand Up @@ -997,6 +1002,16 @@ class QuantizationConfigWrapper {
&stochastic_rounding, sizeof(bool));
}

/*! \brief Set amax estimation scale for post-RHT amax estimation
*
* When <= 0.0f, true post-RHT amax is used (default behavior).
* When > 0.0f, post-RHT amax is estimated as: pre_rht_amax * amax_estimation_scale
*/
void set_amax_estimation_scale(float amax_estimation_scale) {
nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigAmaxEstimationScale,
&amax_estimation_scale, sizeof(float));
}

private:
/*! \brief Wrapped NVTEQuantizationConfig. */
NVTEQuantizationConfig config_ = nullptr;
Expand Down
43 changes: 42 additions & 1 deletion transformer_engine/common/recipe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,21 +65,26 @@ class QParams:
amax_epsilon: optional minimum value of abs max
random_hadamard_transform: whether to use random hadamard transform
stochastic_rounding: whether to use stocastic rounding
amax_estimation_scale: scale factor for estimating post-RHT amax from pre-RHT amax.
When None, true post-RHT amax is computed (default behavior).
When set to a float, post-RHT amax is estimated as: pre_rht_amax * amax_estimation_scale
"""

power_2_scale: bool = False
amax_epsilon: float = 0.0
random_hadamard_transform: bool = False
stochastic_rounding: bool = False
fp4_2d_quantization: bool = False
amax_estimation_scale: float | None = None

def __repr__(self) -> str:
return (
f"Qparams(\npower_2_scale={self.power_2_scale},\n"
f"amax_epsilon={self.amax_epsilon},\n"
f"random_hadamard_transform={self.random_hadamard_transform},\n"
f"stochastic_rounding={self.stochastic_rounding},\n"
f"fp4_2d_quantization={self.fp4_2d_quantization}\n)"
f"fp4_2d_quantization={self.fp4_2d_quantization},\n"
f"amax_estimation_scale={self.amax_estimation_scale}\n)"
)


Expand Down Expand Up @@ -428,6 +433,16 @@ class NVFP4BlockScaling(Recipe):
If set to `True`, stochastic rounding is disabled during quantization for all tensors.
disable_2d_quantization : bool, default = False
If set to `True`, 1D block scaling with block size 16 is used for all tensors.
use_post_rht_amax_estimation : bool, default = False
**EXPERIMENTAL**: If set to `True`, post-RHT amax is estimated from pre-RHT amax
instead of being computed by a separate RHT+amax kernel. This can reduce the
number of kernel launches but may affect numerical accuracy.
post_rht_amax_estimation_scale_fwd_inp : float, default = 2.0
Scale factor for estimating post-RHT amax for forward input activations.
Only used when `use_post_rht_amax_estimation=True`.
post_rht_amax_estimation_scale_bwd_grad : float, default = 1.0
Scale factor for estimating post-RHT amax for backward gradients.
Only used when `use_post_rht_amax_estimation=True`.
"""

# Configuration envvars
Expand All @@ -444,17 +459,41 @@ class NVFP4BlockScaling(Recipe):
fp8_dpa: bool = False
fp8_mha: bool = False

# Experimental: Post-RHT amax estimation
use_post_rht_amax_estimation: bool = (
os.getenv("NVTE_NVFP4_POST_RHT_AMAX_ESTIMATION", "0") == "1"
)
post_rht_amax_estimation_scale_fwd_inp = float(
os.getenv("NVTE_NVFP4_POST_RHT_AMAX_ESTIMATION_X_SCALE", "2.0")
)
post_rht_amax_estimation_scale_bwd_grad = float(
os.getenv("NVTE_NVFP4_POST_RHT_AMAX_ESTIMATION_G_SCALE", "1.0")
)

def __post_init__(self) -> None:
assert self.fp4_format == Format.E2M1, "Only E2M1 is supported for NVFP4 scaling"
assert self.fp8_format == Format.E4M3, "Only E4M3 is supported for NVFP4 scaling"

# Determine amax estimation scales (None = use true post-RHT amax)
amax_scale_fwd_inp = (
self.post_rht_amax_estimation_scale_fwd_inp
if self.use_post_rht_amax_estimation
else None
)
amax_scale_bwd_grad = (
self.post_rht_amax_estimation_scale_bwd_grad
if self.use_post_rht_amax_estimation
else None
)

# Quantization params
# Note: RHT is currently only applied to column-wise usage so that
# it can be used for wgrad GEMM.
self.fp4_quant_fwd_inp = QParams(
random_hadamard_transform=not self.disable_rht,
stochastic_rounding=False,
fp4_2d_quantization=False,
amax_estimation_scale=amax_scale_fwd_inp,
)
self.fp4_quant_fwd_weight = QParams(
random_hadamard_transform=False,
Expand All @@ -465,6 +504,7 @@ def __post_init__(self) -> None:
random_hadamard_transform=not self.disable_rht,
stochastic_rounding=not self.disable_stochastic_rounding,
fp4_2d_quantization=False,
amax_estimation_scale=amax_scale_bwd_grad,
)

def __repr__(self) -> str:
Expand All @@ -474,6 +514,7 @@ def __repr__(self) -> str:
f"fp8_format={str(self.fp8_format).split('.')[1]}, "
f"fp8_dpa={self.fp8_dpa}, "
f"fp8_mha={self.fp8_mha}, "
f"use_post_rht_amax_estimation={self.use_post_rht_amax_estimation}, "
f"fp4_quant_fwd_inp={self.fp4_quant_fwd_inp}, "
f"fp4_quant_fwd_weight={self.fp4_quant_fwd_weight}, "
f"fp4_quant_bwd_grad={self.fp4_quant_bwd_grad}, "
Expand Down
3 changes: 3 additions & 0 deletions transformer_engine/common/transformer_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,9 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
case kNVTEQuantizationConfigStochasticRounding:
std::memcpy(&config_.stochastic_rounding, buf, attr_size);
break;
case kNVTEQuantizationConfigAmaxEstimationScale:
std::memcpy(&config_.amax_estimation_scale, buf, attr_size);
break;
default:
NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
}
Expand Down
4 changes: 4 additions & 0 deletions transformer_engine/pytorch/csrc/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,10 @@ class NVFP4Quantizer : public Quantizer {
// 2D block scaling
bool with_2d_quantization;
bool stochastic_rounding;
// Scale factor for estimating post-RHT amax from pre-RHT amax.
// When <= 0.0f, true post-RHT amax is used (default behavior).
// When > 0.0f, post-RHT amax is estimated as: pre_rht_amax * amax_estimation_scale
float amax_estimation_scale;

int rht_matrix_random_sign_mask_t;
at::Tensor rht_matrix;
Expand Down
18 changes: 14 additions & 4 deletions transformer_engine/pytorch/csrc/extensions/activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,15 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int
} else if (detail::IsNVFP4Quantizers(quantizer.ptr())) {
auto nvfp4_quantizer_cpp = dynamic_cast<NVFP4Quantizer*>(quantizer_cpp.get());
NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer");
if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) {
// Post-RHT amax is handled within NVFP4 quantizer
// Check if amax estimation is enabled (scale > 0 means we can use pre-RHT amax)
const bool use_amax_estimation = nvfp4_quantizer_cpp->amax_estimation_scale > 0.0f;
if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax &&
!use_amax_estimation) {
// Post-RHT amax is handled within NVFP4 quantizer (need true post-RHT amax)
impl = Impl::UNFUSED;
} else {
// When use_amax_estimation is true, activation kernel computes pre-RHT amax,
// and the quantizer will scale it to estimate post-RHT amax.
impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4;
}
}
Expand Down Expand Up @@ -154,10 +159,15 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i
} else if (detail::IsNVFP4Quantizers(quantizer.ptr())) {
auto nvfp4_quantizer_cpp = dynamic_cast<NVFP4Quantizer*>(quantizer_cpp.get());
NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer");
if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) {
// Post-RHT amax is handled within NVFP4 quantizer
// Check if amax estimation is enabled (scale > 0 means we can use pre-RHT amax)
const bool use_amax_estimation = nvfp4_quantizer_cpp->amax_estimation_scale > 0.0f;
if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax &&
!use_amax_estimation) {
// Post-RHT amax is handled within NVFP4 quantizer (need true post-RHT amax)
impl = Impl::UNFUSED;
} else {
// When use_amax_estimation is true, activation kernel computes pre-RHT amax,
// and the quantizer will scale it to estimate post-RHT amax.
impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4;
}
}
Expand Down
9 changes: 7 additions & 2 deletions transformer_engine/pytorch/csrc/extensions/bias.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,15 @@ std::vector<py::object> dact_dbias(
} else if (detail::IsNVFP4Quantizers(quantizer_py.ptr())) {
auto nvfp4_quantizer_cpp = dynamic_cast<NVFP4Quantizer *>(quantizer_cpp.get());
NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer");
if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) {
// Post-RHT amax is handled within NVFP4 quantizer
// Check if amax estimation is enabled (scale > 0 means we can use pre-RHT amax)
const bool use_amax_estimation = nvfp4_quantizer_cpp->amax_estimation_scale > 0.0f;
if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax &&
!use_amax_estimation) {
// Post-RHT amax is handled within NVFP4 quantizer (need true post-RHT amax)
impl = Impl::UNFUSED;
} else {
// When use_amax_estimation is true, dact kernel computes pre-RHT amax,
// and the quantizer will scale it to estimate post-RHT amax.
impl = Impl::FUSED_DACT_AMAX_NVFP4;
}
}
Expand Down
22 changes: 16 additions & 6 deletions transformer_engine/pytorch/csrc/extensions/normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,16 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
} else if (detail::IsNVFP4Quantizers(quantizer.ptr())) {
auto nvfp4_quantizer_cpp = dynamic_cast<NVFP4Quantizer *>(quantizer_cpp.get());
NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer");
if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) {
// Post-RHT amax is handled within NVFP4 quantizer
// Check if amax estimation is enabled (scale > 0 means we can use pre-RHT amax)
const bool use_amax_estimation = nvfp4_quantizer_cpp->amax_estimation_scale > 0.0f;
if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax &&
!use_amax_estimation) {
// Post-RHT amax is handled within NVFP4 quantizer (need true post-RHT amax)
impl = Impl::UNFUSED;
} else if (!transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN")) {
// TE kernel supports amax output
// TE kernel supports amax output.
// When use_amax_estimation is true, LayerNorm computes pre-RHT amax,
// and the quantizer will scale it to estimate post-RHT amax.
impl = Impl::FUSED_NORM_AMAX_NVFP4;
}
}
Expand Down Expand Up @@ -355,11 +360,16 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
} else if (detail::IsNVFP4Quantizers(quantizer.ptr())) {
auto nvfp4_quantizer_cpp = dynamic_cast<NVFP4Quantizer *>(quantizer_cpp.get());
NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer");
if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) {
// Post-RHT amax is handled within NVFP4 quantizer
// Check if amax estimation is enabled (scale > 0 means we can use pre-RHT amax)
const bool use_amax_estimation = nvfp4_quantizer_cpp->amax_estimation_scale > 0.0f;
if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax &&
!use_amax_estimation) {
// Post-RHT amax is handled within NVFP4 quantizer (need true post-RHT amax)
impl = Impl::UNFUSED;
} else if (!transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN")) {
// TE kernel supports amax output
// TE kernel supports amax output.
// When use_amax_estimation is true, LayerNorm computes pre-RHT amax,
// and the quantizer will scale it to estimate post-RHT amax.
impl = Impl::FUSED_NORM_AMAX_NVFP4;
}
}
Expand Down
Loading
Loading