Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 76 additions & 85 deletions docs/examples/quickstart_jax.ipynb

Large diffs are not rendered by default.

35 changes: 27 additions & 8 deletions docs/examples/quickstart_jax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ def speedometer(
variables: Any,
input: jnp.ndarray,
output_grad: jnp.ndarray,
dropout_key: jax.random.PRNGKey,
model_init_fn: Callable = None,
forward_kwargs: dict = {},
autocast_kwargs: Optional[dict] = None,
timing_iters: int = 50,
warmup_iters: int = 50,
rngs: Dict[str, jax.random.PRNGKey] = None,
) -> None:
"""Measure average runtime for a JAX module
Perform forward and backward passes .
Expand All @@ -33,19 +33,21 @@ def speedometer(
autocast_kwargs = {"enabled": False}
model_init_fn = None

if rngs is None:
rngs = {}

train_step_fn = create_train_step_fn(model_apply_fn, autocast_kwargs, forward_kwargs)

# Warm up runs
key = dropout_key
for _ in range(warmup_iters):
key, step_key = jax.random.split(key)
loss, (param_grads, other_grads) = train_step_fn(variables, input, output_grad, step_key)
rngs, step_rngs = _split_step_rngs(rngs)
loss, (param_grads, other_grads) = train_step_fn(variables, input, output_grad, step_rngs)

# Timing runs
start = time.time()
for _ in range(timing_iters):
key, step_key = jax.random.split(key)
loss, (param_grads, other_grads) = train_step_fn(variables, input, output_grad, step_key)
rngs, step_rngs = _split_step_rngs(rngs)
loss, (param_grads, other_grads) = train_step_fn(variables, input, output_grad, step_rngs)
end = time.time()

print(f"Mean time: {(end - start) * 1000 / timing_iters} ms")
Expand All @@ -63,8 +65,12 @@ def create_train_step_fn(
if forward_kwargs is None:
forward_kwargs = {}

def loss_fn(variables: Any, inp: jnp.ndarray, grad_target: jnp.ndarray, dropout_key):
rngs = {"dropout": dropout_key}
def loss_fn(
variables: Any,
inp: jnp.ndarray,
grad_target: jnp.ndarray,
rngs: Dict[str, jax.random.PRNGKey],
):
with te.autocast(**autocast_kwargs):
# Forward Pass: Apply the model using current parameters and variables
call_kwargs = {**forward_kwargs, "rngs": rngs}
Expand All @@ -84,3 +90,16 @@ def fwd_bwd_fn(*args, **kwargs):

# JIT-compile the fwd_bwd_fn
return jax.jit(fwd_bwd_fn)


def _split_step_rngs(
rngs: Dict[str, jax.random.PRNGKey],
) -> Tuple[Dict[str, jax.random.PRNGKey], Dict[str, jax.random.PRNGKey]]:
"""Splits each RNG in the rngs dictionary for a new step."""
step_rngs = {}
new_rngs = {}
for name, key in rngs.items():
new_key, step_key = jax.random.split(key)
new_rngs[name] = new_key
step_rngs[name] = step_key
return new_rngs, step_rngs
Loading