Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unexpected behavior of optax.apply_every #2

Open
vwxyzjn opened this issue Apr 6, 2023 · 6 comments
Open

Unexpected behavior of optax.apply_every #2

vwxyzjn opened this issue Apr 6, 2023 · 6 comments

Comments

@vwxyzjn
Copy link

vwxyzjn commented Apr 6, 2023

Hi @jenkspt, I have been learning nanoGPT and reproducing it in JAX from scratch. Your repo has been a very helpful reference.

I encountered an issue with optax.apply_every and thought you might want to know. It turns out optax.apply_every is not equivalent to how nanoGPT updates, which is to accumulate gradient and then clip by grad norm.

See this snippet

optimizer = optax.chain(
    optax.clip_by_global_norm(0.2),
    optax.sgd(1e-4),
)
params = net.init(jax.random.PRNGKey(0), EXAMPLES)

new_params_single_batch = fit(
    optimizer,
    params,
    batches=[
        MiniBatch(image=EXAMPLES, label=LABELS),
    ],
)

new_params_gradient_accumulation = fit(
    optax.MultiSteps(optimizer, every_k_schedule=3),
    params,
    batches=[
        MiniBatch(image=EXAMPLES[0:3], label=LABELS[0:3]),
        MiniBatch(image=EXAMPLES[3:6], label=LABELS[3:6]),
        MiniBatch(image=EXAMPLES[6:9], label=LABELS[6:9]),
    ],
)

chex.assert_trees_all_close(
    new_params_single_batch,
    new_params_gradient_accumulation,
    atol=1e-7,
)


new_params_gradient_accumulation2 = fit(
    optax.chain(
        optax.clip_by_global_norm(0.2),
        optax.sgd(1e-4),
        optax.apply_every(3),
    ),
    params,
    batches=[
        MiniBatch(image=EXAMPLES[0:3], label=LABELS[0:3]),
        MiniBatch(image=EXAMPLES[3:6], label=LABELS[3:6]),
        MiniBatch(image=EXAMPLES[6:9], label=LABELS[6:9]),
    ],
)

chex.assert_trees_all_close(
    new_params_single_batch,
    new_params_gradient_accumulation2,
    atol=1e-7,
)
checking equivalence of single batch and optax.MultiSteps
checking equivalence of single batch and optax.apply_every
Traceback (most recent call last):
  File "/home/costa/Documents/go/src/github.com/vwxyzjn/envpool-xla-cleanrl/optax_grad_accu_clip.py", line 109, in <module>
    chex.assert_trees_all_close(
  File "/home/costa/.cache/pypoetry/virtualenvs/envpool-xla-cleanrl-xwPMbtrF-py3.9/lib/python3.9/site-packages/chex/_src/asserts_internal.py", line 197, in _chex_assert_fn
    host_assertion(*args, **kwargs)
  File "/home/costa/.cache/pypoetry/virtualenvs/envpool-xla-cleanrl-xwPMbtrF-py3.9/lib/python3.9/site-packages/chex/_src/asserts_internal.py", line 157, in _static_assert
    raise exception_type(error_msg)
AssertionError: [Chex] Assertion assert_trees_all_close failed:  Trees 0 and 1 differ in leaves 'mlp/~/linear_0/b': 
Not equal to tolerance rtol=1e-06, atol=1e-07
Error in value equality check: Values not approximately equal
Mismatched elements: 12 / 32 (37.5%)
Max absolute difference: 3.6430913e-07
Max relative difference: 1.4173905
 x: array([-7.326746e-08,  0.000000e+00,  1.287373e-07,  0.000000e+00,
        0.000000e+00, -1.868993e-07, -1.627599e-07, -7.771037e-08,
       -6.023089e-07,  1.824948e-08, -9.969744e-08,  0.000000e+00,...
 y: array([-1.193194e-07,  0.000000e+00,  1.980013e-07,  0.000000e+00,
        0.000000e+00, -3.954974e-07, -2.911101e-07, -1.588895e-07,
       -9.666180e-07, -4.372281e-08, -2.566443e-07,  0.000000e+00,... 
Original dtypes: float32, float32.

Empirically, it could have a significant impact on training as well. I am following nanoGPT's setting in a single GPU, which is accumulate gradient 40 times. As shown below, optax.apply_every is significantly more unstable than optax.MultiStep.

image

While the training might be more stable if the gradient accumulation steps are fewer, it still feels like an issue...

@jenkspt
Copy link
Owner

jenkspt commented Apr 12, 2023

Hi @vwxyzjn I'm glad you've found this repo useful. And good catch! But can you confirm that optax.MultiStep does in-fact match the nano-GPT implementation? I can also check this maybe in the next week.

@vwxyzjn
Copy link
Author

vwxyzjn commented Apr 12, 2023

Hi @jenkspt, thanks for your response. I can confirm w/ optax.MultiStep the learning curve matches closely against nanoGPT's learning curve in shakespeare_char. Haven't tested with openwebtext though.

image

Interestingly I found nanoGPT with torch.amp.autocast(device_type=device_type, dtype='torch.bfloat16') to be faster than my jax implementation with bfloat16 in my 3060 TI. In the following weeks, I can probably prepare a more detailed ablation study with A100s though.

@jenkspt
Copy link
Owner

jenkspt commented Apr 12, 2023

Are you setting the dtype=jnp.bfloat16 attribute of your flax Modules?

@vwxyzjn
Copy link
Author

vwxyzjn commented Apr 12, 2023

That is correct. See https://wandb.ai/costa-huang/cleanrlhf/runs/3g43kqom/logs. Without jnp.bfloat16 it's twice as slow.

GPT Summary                                       
┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
┃�[1m path           �[22m┃�[1m module         �[22m┃�[1m inputs          �[22m┃�[1m outputs        �[22m┃�[1m params          �[22m┃
┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
│                │ GPT            │ -               │ bfloat16[64,2… │                 │
│                │                │ uint16[64,256]  │                │                 │
│                │                │ -               │                │                 │
│                │                │ deterministic:  │                │                 │
│                │                │ True            │                │                 │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ wte            │ Embed          │ uint16[64,256]  │ bfloat16[64,2… │ embedding:      │
│                │                │                 │                │ float32[65,384] │
│                │                │                 │                │                 │
│                │                │                 │                │ �[1m24,960 (99.8 �[22m   │
│                │                │                 │                │ �[1mKB)�[22m             │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ wpe            │ Embed          │ int32[1,256]    │ bfloat16[1,25… │ embedding:      │
│                │                │                 │                │ float32[256,38… │
│                │                │                 │                │                 │
│                │                │                 │                │ �[1m98,304 (393.2 �[22m  │
│                │                │                 │                │ �[1mKB)�[22m             │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ Dropout_0      │ Dropout        │ -               │ bfloat16[64,2… │                 │
│                │                │ bfloat16[64,25… │                │                 │
│                │                │ - True          │                │                 │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 0              │ Block          │ -               │ bfloat16[64,2… │                 │
│                │                │ bfloat16[64,25… │                │                 │
│                │                │ -               │                │                 │
│                │                │ deterministic:  │                │                 │
│                │                │ True            │                │                 │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 0/ln_1         │ LayerNorm      │ bfloat16[64,25… │ bfloat16[64,2… │ scale:          │
│                │                │                 │                │ float32[384]    │
│                │                │                 │                │                 │
│                │                │                 │                │ �[1m384 (1.5 KB)�[22m    │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 0/attn         │ CausalSelfAtt… │ -               │ bfloat16[64,2… │                 │
│                │                │ bfloat16[64,25… │                │                 │
│                │                │ - True          │                │                 │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 0/attn/Dense_0 │ Dense          │ bfloat16[64,25… │ bfloat16[64,2… │ kernel:         │
│                │                │                 │                │ float32[384,11… │
│                │                │                 │                │                 │
│                │                │                 │                │ �[1m442,368 (1.8 �[22m   │
│                │                │                 │                │ �[1mMB)�[22m             │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 0/attn/Dropou… │ Dropout        │ -               │ bfloat16[64,6… │                 │
│                │                │ bfloat16[64,6,… │                │                 │
│                │                │ -               │                │                 │
│                │                │ deterministic:  │                │                 │
│                │                │ True            │                │                 │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 0/attn/Dense_1 │ Dense          │ bfloat16[64,25… │ bfloat16[64,2… │ kernel:         │
│                │                │                 │                │ float32[384,38… │
│                │                │                 │                │                 │
│                │                │                 │                │ �[1m147,456 (589.8 �[22m │
│                │                │                 │                │ �[1mKB)�[22m             │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 0/attn/Dropou… │ Dropout        │ -               │ bfloat16[64,2… │                 │
│                │                │ bfloat16[64,25… │                │                 │
│                │                │ -               │                │                 │
│                │                │ deterministic:  │                │                 │
│                │                │ True            │                │                 │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 0/ln_2         │ LayerNorm      │ bfloat16[64,25… │ bfloat16[64,2… │ scale:          │
│                │                │                 │                │ float32[384]    │
│                │                │                 │                │                 │
│                │                │                 │                │ �[1m384 (1.5 KB)�[22m    │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 0/mlp          │ MLP            │ -               │ bfloat16[64,2… │                 │
│                │                │ bfloat16[64,25… │                │                 │
│                │                │ - True          │                │                 │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 0/mlp/c_fc     │ Dense          │ bfloat16[64,25… │ bfloat16[64,2… │ kernel:         │
│                │                │                 │                │ float32[384,15… │
│                │                │                 │                │                 │
│                │                │                 │                │ �[1m589,824 (2.4 �[22m   │
│                │                │                 │                │ �[1mMB)�[22m             │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 0/mlp/c_proj   │ Dense          │ bfloat16[64,25… │ bfloat16[64,2… │ kernel:         │
│                │                │                 │                │ float32[1536,3… │
│                │                │                 │                │                 │
│                │                │                 │                │ �[1m589,824 (2.4 �[22m   │
│                │                │                 │                │ �[1mMB)�[22m             │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 0/mlp/Dropout… │ Dropout        │ -               │ bfloat16[64,2… │                 │
│                │                │ bfloat16[64,25… │                │                 │
│                │                │ - True          │                │                 │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 1              │ Block          │ -               │ bfloat16[64,2… │                 │
│                │                │ bfloat16[64,25… │                │                 │
│                │                │ -               │                │                 │
│                │                │ deterministic:  │                │                 │
│                │                │ True            │                │                 │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 1/ln_1         │ LayerNorm      │ bfloat16[64,25… │ bfloat16[64,2… │ scale:          │
│                │                │                 │                │ float32[384]    │
│                │                │                 │                │                 │
│                │                │                 │                │ �[1m384 (1.5 KB)�[22m    │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 1/attn         │ CausalSelfAtt… │ -               │ bfloat16[64,2… │                 │
│                │                │ bfloat16[64,25… │                │                 │
│                │                │ - True          │                │                 │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 1/attn/Dense_0 │ Dense          │ bfloat16[64,25… │ bfloat16[64,2… │ kernel:         │
│                │                │                 │                │ float32[384,11… │
│                │                │                 │                │                 │
│                │                │                 │                │ �[1m442,368 (1.8 �[22m   │
│                │                │                 │                │ �[1mMB)�[22m             │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 1/attn/Dropou… │ Dropout        │ -               │ bfloat16[64,6… │                 │
│                │                │ bfloat16[64,6,… │                │                 │
│                │                │ -               │                │                 │
│                │                │ deterministic:  │                │                 │
│                │                │ True            │                │                 │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 1/attn/Dense_1 │ Dense          │ bfloat16[64,25… │ bfloat16[64,2… │ kernel:         │
│                │                │                 │                │ float32[384,38… │
│                │                │                 │                │                 │
│                │                │                 │                │ �[1m147,456 (589.8 �[22m │
│                │                │                 │                │ �[1mKB)�[22m             │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 1/attn/Dropou… │ Dropout        │ -               │ bfloat16[64,2… │                 │
│                │                │ bfloat16[64,25… │                │                 │
│                │                │ -               │                │                 │
│                │                │ deterministic:  │                │                 │
│                │                │ True            │                │                 │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 1/ln_2         │ LayerNorm      │ bfloat16[64,25… │ bfloat16[64,2… │ scale:          │
│                │                │                 │                │ float32[384]    │
│                │                │                 │                │                 │
│                │                │                 │                │ �[1m384 (1.5 KB)�[22m    │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 1/mlp          │ MLP            │ -               │ bfloat16[64,2… │                 │
│                │                │ bfloat16[64,25… │                │                 │
│                │                │ - True          │                │                 │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 1/mlp/c_fc     │ Dense          │ bfloat16[64,25… │ bfloat16[64,2… │ kernel:         │
│                │                │                 │                │ float32[384,15… │
│                │                │                 │                │                 │
│                │                │                 │                │ �[1m589,824 (2.4 �[22m   │
│                │                │                 │                │ �[1mMB)�[22m             │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 1/mlp/c_proj   │ Dense          │ bfloat16[64,25… │ bfloat16[64,2… │ kernel:         │
│                │                │                 │                │ float32[1536,3… │
│                │                │                 │                │                 │
│                │                │                 │                │ �[1m589,824 (2.4 �[22m   │
│                │                │                 │                │ �[1mMB)�[22m             │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 1/mlp/Dropout… │ Dropout        │ -               │ bfloat16[64,2… │                 │
│                │                │ bfloat16[64,25… │                │                 │
│                │                │ - True          │                │                 │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 2              │ Block          │ -               │ bfloat16[64,2… │                 │
│                │                │ bfloat16[64,25… │                │                 │
│                │                │ -               │                │                 │
│                │                │ deterministic:  │                │                 │
│                │                │ True            │                │                 │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 2/ln_1         │ LayerNorm      │ bfloat16[64,25… │ bfloat16[64,2… │ scale:          │
│                │                │                 │                │ float32[384]    │
│                │                │                 │                │                 │
│                │                │                 │                │ �[1m384 (1.5 KB)�[22m    │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 2/attn         │ CausalSelfAtt… │ -               │ bfloat16[64,2… │                 │
│                │                │ bfloat16[64,25… │                │                 │
│                │                │ - True          │                │                 │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 2/attn/Dense_0 │ Dense          │ bfloat16[64,25… │ bfloat16[64,2… │ kernel:         │
│                │                │                 │                │ float32[384,11… │
│                │                │                 │                │                 │
│                │                │                 │                │ �[1m442,368 (1.8 �[22m   │
│                │                │                 │                │ �[1mMB)�[22m             │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 2/attn/Dropou… │ Dropout        │ -               │ bfloat16[64,6… │                 │
│                │                │ bfloat16[64,6,… │                │                 │
│                │                │ -               │                │                 │
│                │                │ deterministic:  │                │                 │
│                │                │ True            │                │                 │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 2/attn/Dense_1 │ Dense          │ bfloat16[64,25… │ bfloat16[64,2… │ kernel:         │
│                │                │                 │                │ float32[384,38… │
│                │                │                 │                │                 │
│                │                │                 │                │ �[1m147,456 (589.8 �[22m │
│                │                │                 │                │ �[1mKB)�[22m             │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 2/attn/Dropou… │ Dropout        │ -               │ bfloat16[64,2… │                 │
│                │                │ bfloat16[64,25… │                │                 │
│                │                │ -               │                │                 │
│                │                │ deterministic:  │                │                 │
│                │                │ True            │                │                 │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 2/ln_2         │ LayerNorm      │ bfloat16[64,25… │ bfloat16[64,2… │ scale:          │
│                │                │                 │                │ float32[384]    │
│                │                │                 │                │                 │
│                │                │                 │                │ �[1m384 (1.5 KB)�[22m    │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 2/mlp          │ MLP            │ -               │ bfloat16[64,2… │                 │
│                │                │ bfloat16[64,25… │                │                 │
│                │                │ - True          │                │                 │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 2/mlp/c_fc     │ Dense          │ bfloat16[64,25… │ bfloat16[64,2… │ kernel:         │
│                │                │                 │                │ float32[384,15… │
│                │                │                 │                │                 │
│                │                │                 │                │ �[1m589,824 (2.4 �[22m   │
│                │                │                 │                │ �[1mMB)�[22m             │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 2/mlp/c_proj   │ Dense          │ bfloat16[64,25… │ bfloat16[64,2… │ kernel:         │
│                │                │                 │                │ float32[1536,3… │
│                │                │                 │                │                 │
│                │                │                 │                │ �[1m589,824 (2.4 �[22m   │
│                │                │                 │                │ �[1mMB)�[22m             │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 2/mlp/Dropout… │ Dropout        │ -               │ bfloat16[64,2… │                 │
│                │                │ bfloat16[64,25… │                │                 │
│                │                │ - True          │                │                 │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 3              │ Block          │ -               │ bfloat16[64,2… │                 │
│                │                │ bfloat16[64,25… │                │                 │
│                │                │ -               │                │                 │
│                │                │ deterministic:  │                │                 │
│                │                │ True            │                │                 │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 3/ln_1         │ LayerNorm      │ bfloat16[64,25… │ bfloat16[64,2… │ scale:          │
│                │                │                 │                │ float32[384]    │
│                │                │                 │                │                 │
│                │                │                 │                │ �[1m384 (1.5 KB)�[22m    │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 3/attn         │ CausalSelfAtt… │ -               │ bfloat16[64,2… │                 │
│                │                │ bfloat16[64,25… │                │                 │
│                │                │ - True          │                │                 │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 3/attn/Dense_0 │ Dense          │ bfloat16[64,25… │ bfloat16[64,2… │ kernel:         │
│                │                │                 │                │ float32[384,11… │
│                │                │                 │                │                 │
│                │                │                 │                │ �[1m442,368 (1.8 �[22m   │
│                │                │                 │                │ �[1mMB)�[22m             │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 3/attn/Dropou… │ Dropout        │ -               │ bfloat16[64,6… │                 │
│                │                │ bfloat16[64,6,… │                │                 │
│                │                │ -               │                │                 │
│                │                │ deterministic:  │                │                 │
│                │                │ True            │                │                 │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 3/attn/Dense_1 │ Dense          │ bfloat16[64,25… │ bfloat16[64,2… │ kernel:         │
│                │                │                 │                │ float32[384,38… │
│                │                │                 │                │                 │
│                │                │                 │                │ �[1m147,456 (589.8 �[22m │
│                │                │                 │                │ �[1mKB)�[22m             │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 3/attn/Dropou… │ Dropout        │ -               │ bfloat16[64,2… │                 │
│                │                │ bfloat16[64,25… │                │                 │
│                │                │ -               │                │                 │
│                │                │ deterministic:  │                │                 │
│                │                │ True            │                │                 │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 3/ln_2         │ LayerNorm      │ bfloat16[64,25… │ bfloat16[64,2… │ scale:          │
│                │                │                 │                │ float32[384]    │
│                │                │                 │                │                 │
│                │                │                 │                │ �[1m384 (1.5 KB)�[22m    │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 3/mlp          │ MLP            │ -               │ bfloat16[64,2… │                 │
│                │                │ bfloat16[64,25… │                │                 │
│                │                │ - True          │                │                 │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 3/mlp/c_fc     │ Dense          │ bfloat16[64,25… │ bfloat16[64,2… │ kernel:         │
│                │                │                 │                │ float32[384,15… │
│                │                │                 │                │                 │
│                │                │                 │                │ �[1m589,824 (2.4 �[22m   │
│                │                │                 │                │ �[1mMB)�[22m             │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 3/mlp/c_proj   │ Dense          │ bfloat16[64,25… │ bfloat16[64,2… │ kernel:         │
│                │                │                 │                │ float32[1536,3… │
│                │                │                 │                │                 │
│                │                │                 │                │ �[1m589,824 (2.4 �[22m   │
│                │                │                 │                │ �[1mMB)�[22m             │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 3/mlp/Dropout… │ Dropout        │ -               │ bfloat16[64,2… │                 │
│                │                │ bfloat16[64,25… │                │                 │
│                │                │ - True          │                │                 │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 4              │ Block          │ -               │ bfloat16[64,2… │                 │
│                │                │ bfloat16[64,25… │                │                 │
│                │                │ -               │                │                 │
│                │                │ deterministic:  │                │                 │
│                │                │ True            │                │                 │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 4/ln_1         │ LayerNorm      │ bfloat16[64,25… │ bfloat16[64,2… │ scale:          │
│                │                │                 │                │ float32[384]    │
│                │                │                 │                │                 │
│                │                │                 │                │ �[1m384 (1.5 KB)�[22m    │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 4/attn         │ CausalSelfAtt… │ -               │ bfloat16[64,2… │                 │
│                │                │ bfloat16[64,25… │                │                 │
│                │                │ - True          │                │                 │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 4/attn/Dense_0 │ Dense          │ bfloat16[64,25… │ bfloat16[64,2… │ kernel:         │
│                │                │                 │                │ float32[384,11… │
│                │                │                 │                │                 │
│                │                │                 │                │ �[1m442,368 (1.8 �[22m   │
│                │                │                 │                │ �[1mMB)�[22m             │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 4/attn/Dropou… │ Dropout        │ -               │ bfloat16[64,6… │                 │
│                │                │ bfloat16[64,6,… │                │                 │
│                │                │ -               │                │                 │
│                │                │ deterministic:  │                │                 │
│                │                │ True            │                │                 │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 4/attn/Dense_1 │ Dense          │ bfloat16[64,25… │ bfloat16[64,2… │ kernel:         │
│                │                │                 │                │ float32[384,38… │
│                │                │                 │                │                 │
│                │                │                 │                │ �[1m147,456 (589.8 �[22m │
│                │                │                 │                │ �[1mKB)�[22m             │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 4/attn/Dropou… │ Dropout        │ -               │ bfloat16[64,2… │                 │
│                │                │ bfloat16[64,25… │                │                 │
│                │                │ -               │                │                 │
│                │                │ deterministic:  │                │                 │
│                │                │ True            │                │                 │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 4/ln_2         │ LayerNorm      │ bfloat16[64,25… │ bfloat16[64,2… │ scale:          │
│                │                │                 │                │ float32[384]    │
│                │                │                 │                │                 │
│                │                │                 │                │ �[1m384 (1.5 KB)�[22m    │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 4/mlp          │ MLP            │ -               │ bfloat16[64,2… │                 │
│                │                │ bfloat16[64,25… │                │                 │
│                │                │ - True          │                │                 │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 4/mlp/c_fc     │ Dense          │ bfloat16[64,25… │ bfloat16[64,2… │ kernel:         │
│                │                │                 │                │ float32[384,15… │
│                │                │                 │                │                 │
│                │                │                 │                │ �[1m589,824 (2.4 �[22m   │
│                │                │                 │                │ �[1mMB)�[22m             │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 4/mlp/c_proj   │ Dense          │ bfloat16[64,25… │ bfloat16[64,2… │ kernel:         │
│                │                │                 │                │ float32[1536,3… │
│                │                │                 │                │                 │
│                │                │                 │                │ �[1m589,824 (2.4 �[22m   │
│                │                │                 │                │ �[1mMB)�[22m             │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 4/mlp/Dropout… │ Dropout        │ -               │ bfloat16[64,2… │                 │
│                │                │ bfloat16[64,25… │                │                 │
│                │                │ - True          │                │                 │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 5              │ Block          │ -               │ bfloat16[64,2… │                 │
│                │                │ bfloat16[64,25… │                │                 │
│                │                │ -               │                │                 │
│                │                │ deterministic:  │                │                 │
│                │                │ True            │                │                 │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 5/ln_1         │ LayerNorm      │ bfloat16[64,25… │ bfloat16[64,2… │ scale:          │
│                │                │                 │                │ float32[384]    │
│                │                │                 │                │                 │
│                │                │                 │                │ �[1m384 (1.5 KB)�[22m    │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 5/attn         │ CausalSelfAtt… │ -               │ bfloat16[64,2… │                 │
│                │                │ bfloat16[64,25… │                │                 │
│                │                │ - True          │                │                 │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 5/attn/Dense_0 │ Dense          │ bfloat16[64,25… │ bfloat16[64,2… │ kernel:         │
│                │                │                 │                │ float32[384,11… │
│                │                │                 │                │                 │
│                │                │                 │                │ �[1m442,368 (1.8 �[22m   │
│                │                │                 │                │ �[1mMB)�[22m             │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 5/attn/Dropou… │ Dropout        │ -               │ bfloat16[64,6… │                 │
│                │                │ bfloat16[64,6,… │                │                 │
│                │                │ -               │                │                 │
│                │                │ deterministic:  │                │                 │
│                │                │ True            │                │                 │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 5/attn/Dense_1 │ Dense          │ bfloat16[64,25… │ bfloat16[64,2… │ kernel:         │
│                │                │                 │                │ float32[384,38… │
│                │                │                 │                │                 │
│                │                │                 │                │ �[1m147,456 (589.8 �[22m │
│                │                │                 │                │ �[1mKB)�[22m             │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 5/attn/Dropou… │ Dropout        │ -               │ bfloat16[64,2… │                 │
│                │                │ bfloat16[64,25… │                │                 │
│                │                │ -               │                │                 │
│                │                │ deterministic:  │                │                 │
│                │                │ True            │                │                 │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 5/ln_2         │ LayerNorm      │ bfloat16[64,25… │ bfloat16[64,2… │ scale:          │
│                │                │                 │                │ float32[384]    │
│                │                │                 │                │                 │
│                │                │                 │                │ �[1m384 (1.5 KB)�[22m    │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 5/mlp          │ MLP            │ -               │ bfloat16[64,2… │                 │
│                │                │ bfloat16[64,25… │                │                 │
│                │                │ - True          │                │                 │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 5/mlp/c_fc     │ Dense          │ bfloat16[64,25… │ bfloat16[64,2… │ kernel:         │
│                │                │                 │                │ float32[384,15… │
│                │                │                 │                │                 │
│                │                │                 │                │ �[1m589,824 (2.4 �[22m   │
│                │                │                 │                │ �[1mMB)�[22m             │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 5/mlp/c_proj   │ Dense          │ bfloat16[64,25… │ bfloat16[64,2… │ kernel:         │
│                │                │                 │                │ float32[1536,3… │
│                │                │                 │                │                 │
│                │                │                 │                │ �[1m589,824 (2.4 �[22m   │
│                │                │                 │                │ �[1mMB)�[22m             │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ 5/mlp/Dropout… │ Dropout        │ -               │ bfloat16[64,2… │                 │
│                │                │ bfloat16[64,25… │                │                 │
│                │                │ - True          │                │                 │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│ ln_f           │ LayerNorm      │ bfloat16[64,25… │ bfloat16[64,2… │ scale:          │
│                │                │                 │                │ float32[384]    │
│                │                │                 │                │                 │
│                │                │                 │                │ �[1m384 (1.5 KB)�[22m    │
├────────────────┼────────────────┼─────────────────┼────────────────┼─────────────────┤
│�[1m                �[22m│�[1m                �[22m│�[1m                 �[22m│�[1m          Total �[22m│�[1m 10,745,088      �[22m│
│�[1m                �[22m│�[1m                �[22m│�[1m                 �[22m│�[1m                �[22m│�[1m (43.0 MB)       �[22m│
└────────────────┴────────────────┴─────────────────┴────────────────┴─────────────────┘
�[1m                                                                                        

@vwxyzjn
Copy link
Author

vwxyzjn commented Apr 12, 2023

I also just did a more comprehensive profile in A100. See report
image

How it compares w/ my JAX implementation, which is heavily influenced by your repo :)

image

@vwxyzjn
Copy link
Author

vwxyzjn commented Apr 19, 2023

FYI I did an ablation study with openwebtext:

image

I also found a library https://github.com/google/maxtext lately that is well-optimized and could be a useful reference.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants