Scaling Laws for Optimal LR and Batch Size - Part VIII
Trying out the Gemma 3 architecture
Introduction
In the Stability Across Time post, I discussed how the Gemma 3 architecture may approximately avoid evolving the relevant features at different rates w.r.t. time.
In this post, I will see if this hypothesis is helpful in practice for avoiding the pitfalls associated with learning rate forecasting failures.
Scaling Laws Setup
Open codebase. I use the research codebase at https://github.com/lucaslingle/babel with branch name ‘gemma3_scaling’.
Architecture. I use an architecture similar to Gemma 3, with parametric RMSNorm applied to the residual block inputs and outputs, and QK-LayerNorm applied to queries and keys. I use SwiGLU nonlinearities in the MLP sublayers, with d_ff = 3 * d_model. I use d_head = 128, and n_head = d_model / d_head. I use a locked aspect ratio of d_model/n_layer = 128.
Optimization. I use AdamW with beta1 = 0.9, beta2 = 0.95, eps = 1e^-8. The learning rate schedule uses 2% warmup, followed by cosine decay to a 0.1x factor of the peak.
(N, D) grid. For efficiency, I use model sizes N = 46M, 109M, 368M and dataset sizes D = 128M, 256M, 512M, 1024M.
(B, \eta) grid. I sweep over batch sizes from 32768 to 2097152 in powers of 2, sweep over learning rates from 0.00012207031 to 0.0078125 in powers of 2^0.5.
Training data. I use the 350B token sample of FineWeb, and the LLaMA-2 tokenizer, which has approximately 32000 vocabulary items including special tokens.
Evaluation. I use 100 held-out batches of size 2097152 tokens for loss evaluation. Evaluation occurs after training has concluded.
Scaling Laws Result
The obtained scaling law, fitted using Akima spline interpolation as per Part VI, is
Test Point
At the test point (N, D) = (4.7 * 10^9, 100 * 10^9), this yields a forecasted optimal batch size of 429,110 tokens and a forecasted optimal learning rate of 0.00101400484.
The empirical optima for the (N, D) pairs used in the fit were not significantly different than previously. If we assume this is also true for our test point, it means the true optimum is still about ~0.0001, same as Part VII. This implies forecast is off by a factor of 10x, which is far worse than the previous 4x.
If anything, this suggests we are actually going the wrong way by adding additional RMSNorms to the architecture—if we want to improve scaling law forecasting accuracy, we may need to use nonparametric RMSNorm instead of adding more parametric RMSNorms.
I could also try using one of the other methods mentioned in Stability Across Time, but I think I better try nonparametric RMSNorm first.
Conclusion
In this post, I tried using the Gemma 3 architecture and independent weight decay.
My belief was that the time evolution of the features was different depending on the layers, as detailed in Stability Across Time, and I hoped that the Gemma 3 architecture could rectify this and thereby standardize the effective learning rate for all layers of the model, perhaps improving hyperparameter scaling law forecasting as a consequence.
In fact, the forecasting quality was not improved, and was in fact much worse. This suggests we are going the wrong way by adding additional parametric RMSNorms and should perhaps try nonparametric RMSNorm next.

