The Muon optimizer was developed inside of the Modded NanoGPT speedrun, which has the expressed goal of training a GPT-style model as fast as possible 1 2. Since then, Muon has become widely used, and the Modded NanoGPT record has been pushed from 45 minutes down to just over 2 minutes. In this post, I’ll showcase some improvements to Muon from recent record runs and motivate why they improve standard Muon.


(1) Adaptive Muon

The Modded NanoGPT speedrun uses an adaptive variant of Muon called NorMuon 3. Several recent papers have proposed similar adaptive extensions 4 5 6. In this section, I want to explain what “adaptive Muon” means and why it helps.

Let’s consider how Muon differs from a much simpler optimizer, Stochastic Gradient Descent with Momentum (SGDM). SGDM updates parameters using an exponential moving average (EMA) of gradients and, at each step, nudges the weights a small distance in the opposite direction of that averaged gradient. Muon (MomentUm Orthogonalized by Newton-Schulz) takes this same gradient EMA but orthogonalizes it via a matrix-sign function before applying the update. Orthogonalization is a rich concept; for now, you can think of it as a special form of normalization that keeps the gradient’s directions but not their overall magnitude (see footnote for more info) 7.

Adam (ADaptive Moment Estimation) also builds on SGDM, but in a different way. It keeps an EMA of the squared gradients for each parameter, and then forms a variance-corrected update by dividing the gradient EMA by this squared-gradient EMA.

Variance correction is hypothesized to be useful in two distinct ways:

  1. Stochastic noise: At a fixed batch size, some features may be very noisy while other features will have high signal. Adaptivity allows noisier gradient estimates to get dampened, effectively giving each parameter its own adaptive learning rate.
  2. Curvature: how the gradient changes as we move through parameter space. Various papers examine how second order statistics approximate the hessian (the second derivative) of the loss landscape—that is, the per-parameter curvature 8 9. Normalizing by this estimate makes the optimizer take smaller steps into high-curvature regions, preventing overshooting and descent into “sharp” minima.

Despite not being a variance-adaptive method, Muon is “curvature-aware” 101112, which addresses point (2). Intuitively, this comes from orthogonalization: after orthogonalization, Muon’s update is “well-rounded” in parameter space, avoiding directions of steep change. Formally, each update is perfectly-conditioned (all spectral values are ), which keeps the weights themselves well-conditioned (their singular values remain small and near each other)13. Muon is effectively performing gradient descent down a constrained submanifold of the full parameter space. Along these lines, Jeremy Bernstein has proposed Manifold Muon, which tweaks Muon such that the weights remain perfectly-conditioned 14.

To the best of my knowledge, there is no argument that Muon solves (1)—e.g. that orthogonalization corrects for noise in the gradient estimation. To account for this, we want to estimate the noise along each spectral direction and correct for it. After orthogonalization, the columns of the update matrix correspond to the spectral directions, so we can estimate spectral variance by calculating column-wise RMS. For this reason, variance adaptation methods for Muon only require variance estimates along a single dimension6.

def adaptive_muon_update(grad, momentum1, momentum2, beta1, beta2):
	momentum1.lerp_(grad, 1 - beta1)
	update = grad.lerp(momentum1, beta1)
	update = orthogonalize(update)
	momentum2.lerp_((update ** 2).mean(dim=-1, keepdim=True), 1 - beta2)
	update /= momentum2.sqrt()
	update *= max(1, grad.size(-2) / grad.size(-1)) ** 0.5
	return update

Algorithm 1: Typical-ish Adaptive Muon update (differences from standard Muon highlighted).

The code above conveys the general idea for adaptive Muon variants, though the variant used in Modded NanoGPT, NorMuon is a bit longer. It renormalizes the update matrix so that it has the same magnitude (via the Frobenius norm) after dividing by variance. This method yields a decrease in training time when combined with learning rate tuning.

(2) Batch size scheduling

One advantage of Muon over Adam is a higher critical batch size 15. To explain what this means, we need to consider two factors:

  1. Token efficiency. Smaller batch sizes are more token-efficient. For a fixed token budget, once the batch size is above some threshold, increasing it further tends to hurt final model performance. Beyond this point, the gradient signal from a single batch is saturated, so larger batches just waste tokens.
  2. Token speed: Training at higher batch sizes is faster per-token than training with a lower batch size for many steps. This is because GPUs parallelize over the batch dimension, DDP is easy, and every step incurs overhead (optimizers, comms, etc).

A critical batch size balances both of these considerations: low enough to be token-efficient, but high enough to be speed-efficient.

Figure 1: Critical batch sizes for Adam at various token budgets from Allen AI 16. , chosen as the greatest batch size exceeding below 1% of the lowest loss, is marked in red. Batch size here is in units of 4096 Tokens.

To accurately determine the critical batch size, we need the optimal learning rate at each batch size. Larger batches average over more samples, which reduces gradient variance. For SGD, we can directly model the relationship between the optimal learning rate and batch size 17:

where is the signal-to-noise ratio of the gradient estimate 18. When , this is approximately , which is why the “linear scaling law” is a good heuristic.

Adam dampens update variance, so its optimal scaling is closer to 19 before also becoming asymptotic20.

A power-law relationship () is often assumed to hold for Muon as well2122, though in practice it will asymptote quickly:

Figure 2: Sweeping Adaptive Muon at a learning rate x log(batch size) grid at three different token budgets. The lowest loss per batch size is selected in red.

At token budgets of , , tokens, the optimal learning rate seems to converge around , , and , respectively. Note that does not appear to converge at the highest token budget for any of the swept learning rates.

Below the asymptote/saturated batch size, we can approximate the value in via the log-slope of the red boxes. At this granularity, seems approximately correct. Convergence to the optimal learning rate appears to be very fast for this variant of Muon.

Figure 3: Results from the previous sweep (Fig2) while choosing the optimal learning rate for each batch size. For a fixed token budget, increasing batch size will tend to decrease final validation loss. is the highest acceptable batch size given some tolerance for loss.

The above graphs demonstrate an incredible stability in the critical batch size, especially at higher token budget. Consider the flatness in the curve for the highest token budget ( parameters): steps at a A batch size of tokens converges to nearly the same loss at steps as a batch size of tokens in steps.

Batch size can often be safely increased over the course of training 16. Conceptually, this may be because early training focuses on common patterns, so even small batches provide strong gradient signals. Later training involves learning rarer patterns. These sparse signals remain noisy even at high batch sizes, so the batch size can be increased without saturating the gradient. Notably, two strong models trained with Muon, Kimi K2 and GLM 4.5, both increase batch size mid-training 23 24.

Modded NanoGPT’s recent record 25 uses batch size scheduling in order to maximize token efficiency throughout training.

(3) Faster orthogonalization: Polar Express and beyond

The Newton-Schulz iterative algorithm approximates orthogonalization via a quintic polynomial iteration. This iterative approach is agnostic to the conditioning of the underlying matrix. However, if we notice that the conditioning of the matrix improves in each iteration, we can find an optimal polynomial for that iteration. Ansel et al provide optimal coefficients at each iteration step via their algorithm Polar Express 26.

Last month, a paper from Shulgin et al demonstrated that a more precise orthogonalization improves Muon convergence, especially when accompanied with appropriate learning rate tuning 27:

Figure 4: Validation Loss at two levels of convergence for orthogonalization. Caption from Shulgin: “the optimal learning rate couples with approximation quality […] higher precision → higher optimal LR + wider stability,”

It follows that using Polar Express over Newton Schulz represents an improvement in convergence, and using it led to a new record in Modded NanoGPT.

On ongoing effort in the Modded NanoGPT speedrun is being made for even faster orthogonalization via Almost-Orthogonal Layers, though this has not been incorporated yet 28.

(4) Cautious weight decay

Decoupled weight decay has been noted as a crucial technique for training with Muon. The Kimi team writes: “While vanilla Muon initially converges faster, we observed that some model weights grew too large over time, potentially limiting the model’s long-term performances. Adding weight decay addressed this issue - the results demonstrate that Muon with weight decay outperforms both vanilla Muon and AdamW” 23.

Inside Modded NanoGPT, weight decay on Muon was doing more harm than good. In October, a variant of decoupled weight decay known as Cautious Weight Decay was discovered 29, which only decays parameters that will increase in magnitude in the update step:

def apply_update(param, update, learning_rate, weight_decay):
	mask = (update * param) >= 0
	update += weight_decay * param * mask
	return param - learning_rate * update

Algorithm 2: Cautious weight decay (difference from decoupled weight decay highlighted).

This technique proved highly effective in Modded NanoGPT when paired with a schedule that decreases weight decay to by the end of training.

(5) Distributed and efficient computation

The implementation of Muon has been optimized in order to distribute the implementation over 8 devices. First, there are several tricks used to speed up orthogonalization over the basic Newton-Schulz algorithm:

  • Parameters of the same shape are stacked together so that orthogonalization is vectorized.
  • Additionally, the attention weights qkvo are concatenated so that they are the same shape as MLPs. This allows attention and MLP weights to be stacked.
  • The Newton-Schulz iteration involves the manipulation of symmetric matrices. Using this fact can cut down the number of computations in half in some cases. Custom triton kernels have been written for these steps. 30

Second, there are a few tricks to distribute the Muon step over all the GPUs:

  • Each GPU receives an equal subset of the parameter gradients (via a reduce-scatter).
  • Each GPU processes gradients in groups. Each group has a “nice” number of parameters, e.g. or a power of , which is important for underlying kernels.
  • The groups are handled concurrently so that gradients are being communicated across GPUs while other gradients are being processed inside the GPUs.

For additional information I direct you to Larry Dial’s blog post 31 and the Modded NanoGPT repo 2.

(6) Implementation Notes

“The reason I didn’t write a proper arxiv paper for Muon is because I simply don’t think there’s any relationship between the ability to publish a paper with lots of good-looking results about a new optimizer, and whether that optimizer actually works. I only trust speedruns.” —Keller Jordan 32

I hope this post is broadly useful for pretraining with Muon. I want to highlight some important considerations for the Modded NanoGPT recipe:

  • The model is very small (<124M active params) and has a unique architecture.
  • Adam is used instead of Muon on a few parameters: the linear output head, the embedding layer, and a few scalar constants, though this is “standard” for Muon.
  • Adam is stepped with twice the number of gradient accumulation steps as Muon. Effectively, it has a 2x batch size than Muon. In the experiment in Section 2, I’ve scaled the Adam learning rate according to the square-root law.

This post summarizes the work of many people on the Modded NanoGPT speedrun. Section 1 primarily corresponds to the work of an author of NorMuon, Zichong Li. Sections 2, 3, and 4 correspond to records added by myself. Section 5 is the result of many people over many iterations, though in the last few months Larry Dial especially.



Thank you to Prime Intellect, who sponsors my research with GPU credits. If this blog post was useful for you, you can cite:
@misc{
	srivastava2025,
	author = {Varun Srivastava},
	title = {Muon in Modded NanoGPT},
	year = {2025},
	url = {https://varunneal.github.io/essays/muon}
}

Footnotes

  1. Keller Jordan et al 2024b Muon: An optimizer for hidden layers in neural networks ^

  2. Keller Jordan et al 2024a Modded NanoGPT. The competition is for “the fastest algorithm to use 8 NVIDIA H100 GPUs to train a language model that attains 3.28 cross-entropy loss on the FineWeb validation set.” ^ ^

  3. Li et al 2025 NorMuon ^

  4. Si et al 2025 AdaMuon ^

  5. Zhang et al 2025 AdaGrad Meets Muon ^

  6. Frans et al 2025 What Really Matters in Matrix-Whitening Optimizers? ^ ^

  7. The spectral directions of a matrix corresponds to how its transformation stretches the input/output space. Formally, spectral directions are the left and right singular vectors in the Singular Value Decomposition (SVD) of a matrix. The amount each direction is stretched corresponds to singular values in the SVD. Orthogonalization finds a matrix with identical spectral directions but with all the singular value equal to . Why is this isometry useful? Bernstein and Newhouse 2024 relate a powerful intuition: implicitly, gradient descent uses the Euclidean/Frobenius norm. SGD will update the weights away from the gradient with magnitude corresponding to the distance in the Euclidean metric. The correct metric for the gradient should consider how much it transforms the input space, which corresponds to the Spectral metric. Orthogonalization is acting as a map from transformation space to weight space so that the update and the weights are in the same geometry. ^

  8. Cohen et al 2024 Understanding Optimization in Deep Learning with Central Flows, with a shorter accompanying blogpost here. ^

  9. Kustner et al 2024 Heavy-Tailed Class Imbalance and Why Adam Outperforms Gradient Descent on Language Models ^

  10. Kovalev 2025 Understanding Gradient Orthogonalization for Deep Learning via Non-Euclidean Trust-Region Optimization ^

  11. Anonymous ICLR Conference Submission 2025 Long-tailed Learning with Muon Optimizer ^

  12. Su 2025 Isotropic Curvature Model for Understanding Deep Learning Optimization ^

  13. Boreiko et al 2025 Towards Understanding Orthogonalization in Muon ^

  14. Jeremy Bernstein 2025 Modular Manifolds ^

  15. Essential AI, Shah 2025 Practical Efficiency of Muon for Pretraining. ^

  16. Ai2, Merrill et al 2025 Critical Batch Size Revisited ^ ^

  17. McCandlish et al 2018 An Empirical Model of Large-Batch Training ^

  18. I’m being somewhat imprecise with how we define the signal to noise ratio. McCandlish et al 17 defines a critical batch size for their analysis of SGD. I’m using its inverse as the SNR. ^

  19. Granziol et al 2020 Learning Rates as a Function of Batch Size contains a proof. ^

  20. Li et al 2024 Surge Phenomenon in Optimal Learning Rate and Batch Size Scaling ^

  21. Sato et al 2025 Convergence Bound and Critical Batch Size of Muon Optimizer. ^

  22. Simo Ryu 2025 Adam vs Shampoo vs Muon on MNIST. all follow the lr ~ sqrt(BS) law. ^

  23. Moonshot AI, Liu et al 2025 Muon is Scalable for LLM Training ^ ^

  24. Zeng et al GLM-4.5 ^

  25. Srivastava 2025 Modded NanoGPT PR#163 ^

  26. Ansel et al 2025 The Polar Express ^

  27. Shulgin et al 2025 Beyond the Ideal: Analyzing the Inexact Muon Update. Corresponding tweet thread here. ^

  28. Thibaut Boissin 2025 Modded NanoGPT PR#155 ^

  29. Chen et al 2025 Cautious Weight Decay ^

  30. Xu 2025 Modded NanoGPT PR#109 (Triton kernels for symmetric matmul). Also part of the Dion repository. ^

  31. Larry Dial 2025 How the NanoGPT Speedrun WR dropped by 20% in 3 months ^

  32. Keller Jordan 2025 The reason I didn’t write a proper arxiv paper for Muon is because I simply don’t think there’s any relationship between the ability to publish a paper with lots of good-looking results about a new optimizer, and whether that optimizer actually works. I only trust speedruns. ^