Muon is introduced by Keller Jordan in this blog.

  • It’s amazingly fast: People have been breaking training records of NanoGPT and Cifar-10.

  • Muon saves space: it only stores momentum, while Adam stores momentum and variance.

Recap on AdamW

I learned Adam in 2018 through Andrey Karpathy’s CS231n lectures. When I re-learned DL in 2024, Adam(W) is still the default and Andrej is still offering lectures. I didn’t miss a lot :)

The original Adam was introduced by Kingma and Ba. Minimize $\theta\mapsto \ell(\theta;\mathbf{X})$ using mini-batces. $\theta$ is viewed as a vector.

$\alpha=10^{-3},\beta_1=0.9,\beta_2=0.999,\epsilon=10^{-8}.$

Initialize everything as $0$. At step $t$, compute the gradient \(g_t = \nabla_\theta\ell(\theta_{t-1}; \mathbf{X}_{t})\) using batch $\mathbf{X}_t$.

\[\begin{align*} m_t &= \beta_1m_{t-1}+(1-\beta_1)g_t,\quad v_t=\beta_2 v_{t-1}+(1-\beta_2)g_t\odot g_t,\\ \hat m_t &= \frac{m_t}{1-\beta_1^t},\quad \hat v_t = \frac{v_t}{1-\beta_2^t},\quad \theta_t = \theta_{t-1} - \alpha \frac{\hat m_t}{\sqrt{\hat v_t}+\epsilon}. \end{align*}\]
  • $m_t$ is the biased moment estimate of $g_t$ or momentum. Each step we roughly keep the momentum ($\beta_1\approx 1$) but tweaking it towards the new gradient slightly.

  • $v_t$ is the biased second moment estimate of $g_t^{\odot 2}$.

  • Adam stores $m_t$ and $v_t$ as optimizer states, double of the total parameters.

  • $\hat m,\hat v$ are corrections.

Why corrections? If $g_t’s$ are iid,

\[\mathbf{E}[v_t] = (1-\beta_2)\sum_{s=1}^t \beta_2^{t-s} \mathbf{E}[g_s^{\odot 2}] = \mathbf{E}[g_t^{\odot 2}] (1-\beta_2^t).\]

So we should correct by $/(1-\beta_2^t)$. $\hat m_t$ is similar.

AdamW adds weight decay that works well in practice and should be close to adding some normalizations on weights.

\[g_t = \nabla_\theta \ell(\theta_{t-1};\mathbf{X}_t) + \lambda \theta_{t-1}.\] \[\begin{align*} m_t &= \beta_1m_{t-1}+(1-\beta_1)g_t,\quad v_t=\beta_2 v_{t-1}+(1-\beta_2)g_t\odot g_t,\\ \hat m_t &= \frac{m_t}{1-\beta_1^t},\quad \hat v_t = \frac{v_t}{1-\beta_2^t},\quad \theta_t = (1-\lambda \eta_t)\theta_{t-1} - \alpha\eta_t \frac{\hat m_t}{\sqrt{\hat v_t}+\epsilon}. \end{align*}\]

Here, $\eta_t$ is some scheduler.

Muon

Muon works on matrix parameters $W\in\mathbb{R}^{n\times d}.$ WLOG assume $d\le n$.

\[\begin{align*} M_t&= \beta M_{t-1} + (1-\beta)\nabla \ell_t(M_{t-1}),\\ W_t &= W_{t-1} - \eta_t(\pi(M_{t})+\lambda W_{t-1}). \end{align*}\]

The idea is to replace the Adam normalization $\frac{\hat m_t}{\sqrt{\hat v_t}}$ by $\pi(M_t)$ using matrix structures.

Now we define $\pi$. Let

\[V_{n,d} :=\Big\lbrace V\in\mathbb{R}^{n\times d}:V^\top V = I\Big\rbrace\]

denote the Stiefel manifold.

$\pi(M)$ denotes the projection of $M$ onto $V_{n,d}$, i.e.,

\[\|M-\pi(M)\|_{\rm F} = \min_{V\in V_{n,d}} \|M-V\|_{\rm F}\]

where F denotes the usual Frobenius norm: \(\|A\|_{\rm F}^2=\sum_{i,j}A_{ij}^2=\mathrm{tr}(A^\top A).\)

If $d=1$, $V_{n,1}=S^{n-1}$, the unit sphere. Then $\pi(M)=M/|M|_2$, the normalization.

Let $M=U\Sigma V^\top$ be an SVD:

\[U\in V_{n,d},\quad V\in V_{d,d}, \quad \Sigma=\mathrm{diag}(\sigma_1\ge\cdots\ge\sigma_d).\]

We claim that $UV^\top$ is a choice for $\pi(M)$.

Proof. For $\tilde O\in V_{n,d}$, let $O=\tilde O V\in V_{n,d}.$

\[\begin{align*} \|M-\tilde O\|^2 &= \|M\|^2+d- 2\mathrm{tr}(V\Sigma U^\top \tilde O)\\ &= \|M\|^2+d-2\mathrm{tr}(\Sigma U^\top O). \end{align*}\]

To minimize $|M-\tilde O|$, we maximize $\mathrm{tr}(\Sigma U^\top O)$. As $\sigma_i\ge 0$, it suffices to maximize the diagonals of $U^\top O$.

If $U=[u_1\cdots u_d], O=[o_1\cdots o_d]$, then

\[(U^\top O)_{ii} = u_i^\top o_i \le \|u_i\|\|o_i\|=1,\]

and the max is attained at $O=U$ or $\tilde O=UV^\top.$ Q.E.D.

A quick observation: recall that Fan’s nuclear norm is $|M|_*=\mathrm{tr}\Sigma.$ So,

\[\|M-\pi(M)\|^2 = \|M\|^2+d-2\|M\|_* = \|M-I\|_*^2 - \mathrm{tr}(\Sigma^2(I-UU^\top)).\]

Back to Muon,

\[W_t = W_{t-1} - \eta(\pi(M_{t})+\lambda W_{t-1})\]

updates by normalizing the momentum $M_t$. $\lambda W_{t-1}$ is a decay term, similar to the one in AdamW.

How to compute $\pi$?

The Newton-Schulz iteration: $X_0=X$,

\[X_{k+1} = \frac{3}{2}X_k - \frac{1}{2}X_k(X_k^\top X_k).\]

Coming out of nowhere? In fact, we consider the more general case: if $X=U\Sigma V^\top$ is an SVD,

\[\begin{align*} & \alpha_1 X + \alpha_{3}X(X^\top X) +\cdots + \alpha_{2m+1} X(X^\top X)^m\\ =&\ U(\alpha_1 \Sigma + \alpha_3 \Sigma^3 + \cdots + \alpha_{2m+1} \Sigma^{2m+1}) V^\top = U p(\Sigma)V^\top \end{align*}\]

where $p(x)=\alpha_1 x + \cdots+\alpha_{2m+1} x^{2m+1}$ is a polynomial of degree $2m+1$. Denote by $\tilde p$ its induced operation on $X$, the first line.

A more general iteration scheme:

\[X_0=X,\quad X_{k} = \tilde p_{k}(X_{k-1})\]

with possibly different polynomials.

\[\|X_{k}-UV^\top\|^2 = \| p_k\circ \cdots \circ p_1(\Sigma) - I\|^2.\]

The iteration converges iff $p_k\circ \cdots \circ p_1(\sigma_i)\to 1.$