Skip to content

Commit 3550b96

Browse files
Fix lint
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
1 parent 7d1f2fb commit 3550b96

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

transformer_engine/jax/flax/module.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1377,6 +1377,7 @@ def wrap_function_in_te_state_module(f, quantization_recipe, name: Optional[str]
13771377
import transformer_engine.jax as te
13781378

13791379
class TEWrapper(te.flax.module.TransformerEngineBase):
1380+
""" Wrapper Flax module for TransformerEngine quantization support. """
13801381
def generate_quantizer_set(self, postfix: str = ""):
13811382
OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient"
13821383
return super().generate_quantizer_set(
@@ -1416,6 +1417,8 @@ def make_dot_general_cls(quantization_recipe):
14161417
from transformer_engine.common.recipe import NVFP4BlockScaling
14171418

14181419
def te_dot_general(generate_quantizer_set, x, kernel, dims, **kwargs):
1420+
""" Performs a dot_general operation using TransformerEngine with quantization. """
1421+
del kwargs # Unused
14191422
contracting_dims, batch_dims = dims
14201423
assert batch_dims == ((), ()), "Batch dimensions must be empty for TransformerEngine dot."
14211424

0 commit comments

Comments
 (0)