Muon
Muon is introduced by Keller Jordan in this blog.
-
It’s amazingly fast: People have been breaking training records of NanoGPT and Cifar-10.
-
Muon saves space: it only stores momentum, while Adam stores momentum and variance.
Recap on AdamW
I learned Adam in 2018 through Andrey Karpathy’s CS231n lectures. When I re-learned DL in 2024, Adam(W) is still the default and Andrej is still offering lectures. I didn’t miss a lot :)
The original Adam was introduced by Kingma and Ba. Minimize $\theta\mapsto \ell(\theta;\mathbf{X})$ using mini-batces. $\theta$ is viewed as a vector.
$\alpha=10^{-3},\beta_1=0.9,\beta_2=0.999,\epsilon=10^{-8}.$
Initialize everything as $0$. At step $t$, compute the gradient \(g_t = \nabla_\theta\ell(\theta_{t-1}; \mathbf{X}_{t})\) using batch $\mathbf{X}_t$.
\[\begin{align*} m_t &= \beta_1m_{t-1}+(1-\beta_1)g_t,\quad v_t=\beta_2 v_{t-1}+(1-\beta_2)g_t\odot g_t,\\ \hat m_t &= \frac{m_t}{1-\beta_1^t},\quad \hat v_t = \frac{v_t}{1-\beta_2^t},\quad \theta_t = \theta_{t-1} - \alpha \frac{\hat m_t}{\sqrt{\hat v_t}+\epsilon}. \end{align*}\]-
$m_t$ is the biased moment estimate of $g_t$ or momentum. Each step we roughly keep the momentum ($\beta_1\approx 1$) but tweaking it towards the new gradient slightly.
-
$v_t$ is the biased second moment estimate of $g_t^{\odot 2}$.
-
Adam stores $m_t$ and $v_t$ as optimizer states, double of the total parameters.
-
$\hat m,\hat v$ are corrections.
Why corrections? If $g_t’s$ are iid,
\[\mathbf{E}[v_t] = (1-\beta_2)\sum_{s=1}^t \beta_2^{t-s} \mathbf{E}[g_s^{\odot 2}] = \mathbf{E}[g_t^{\odot 2}] (1-\beta_2^t).\]So we should correct by $/(1-\beta_2^t)$. $\hat m_t$ is similar.
AdamW adds weight decay that works well in practice and should be close to adding some normalizations on weights.
\[g_t = \nabla_\theta \ell(\theta_{t-1};\mathbf{X}_t) + \lambda \theta_{t-1}.\] \[\begin{align*} m_t &= \beta_1m_{t-1}+(1-\beta_1)g_t,\quad v_t=\beta_2 v_{t-1}+(1-\beta_2)g_t\odot g_t,\\ \hat m_t &= \frac{m_t}{1-\beta_1^t},\quad \hat v_t = \frac{v_t}{1-\beta_2^t},\quad \theta_t = (1-\lambda \eta_t)\theta_{t-1} - \alpha\eta_t \frac{\hat m_t}{\sqrt{\hat v_t}+\epsilon}. \end{align*}\]Here, $\eta_t$ is some scheduler.
Muon
Muon works on matrix parameters $W\in\mathbb{R}^{n\times d}.$ WLOG assume $d\le n$.
\[\begin{align*} M_t&= \beta M_{t-1} + (1-\beta)\nabla \ell_t(M_{t-1}),\\ W_t &= W_{t-1} - \eta_t(\pi(M_{t})+\lambda W_{t-1}). \end{align*}\]The idea is to replace the Adam normalization $\frac{\hat m_t}{\sqrt{\hat v_t}}$ by $\pi(M_t)$ using matrix structures.
Now we define $\pi$. Let
\[V_{n,d} :=\Big\lbrace V\in\mathbb{R}^{n\times d}:V^\top V = I\Big\rbrace\]denote the Stiefel manifold.
$\pi(M)$ denotes the projection of $M$ onto $V_{n,d}$, i.e.,
\[\|M-\pi(M)\|_{\rm F} = \min_{V\in V_{n,d}} \|M-V\|_{\rm F}\]where F denotes the usual Frobenius norm: \(\|A\|_{\rm F}^2=\sum_{i,j}A_{ij}^2=\mathrm{tr}(A^\top A).\)
If $d=1$, $V_{n,1}=S^{n-1}$, the unit sphere. Then $\pi(M)=M/|M|_2$, the normalization.
Let $M=U\Sigma V^\top$ be an SVD:
\[U\in V_{n,d},\quad V\in V_{d,d}, \quad \Sigma=\mathrm{diag}(\sigma_1\ge\cdots\ge\sigma_d).\]We claim that $UV^\top$ is a choice for $\pi(M)$.
Proof. For $\tilde O\in V_{n,d}$, let $O=\tilde O V\in V_{n,d}.$
\[\begin{align*} \|M-\tilde O\|^2 &= \|M\|^2+d- 2\mathrm{tr}(V\Sigma U^\top \tilde O)\\ &= \|M\|^2+d-2\mathrm{tr}(\Sigma U^\top O). \end{align*}\]To minimize $|M-\tilde O|$, we maximize $\mathrm{tr}(\Sigma U^\top O)$. As $\sigma_i\ge 0$, it suffices to maximize the diagonals of $U^\top O$.
If $U=[u_1\cdots u_d], O=[o_1\cdots o_d]$, then
\[(U^\top O)_{ii} = u_i^\top o_i \le \|u_i\|\|o_i\|=1,\]and the max is attained at $O=U$ or $\tilde O=UV^\top.$ Q.E.D.
A quick observation: recall that Fan’s nuclear norm is $|M|_*=\mathrm{tr}\Sigma.$ So,
\[\|M-\pi(M)\|^2 = \|M\|^2+d-2\|M\|_* = \|M-I\|_*^2 - \mathrm{tr}(\Sigma^2(I-UU^\top)).\]Back to Muon,
\[W_t = W_{t-1} - \eta(\pi(M_{t})+\lambda W_{t-1})\]updates by normalizing the momentum $M_t$. $\lambda W_{t-1}$ is a decay term, similar to the one in AdamW.
How to compute $\pi$?
The Newton-Schulz iteration: $X_0=X$,
\[X_{k+1} = \frac{3}{2}X_k - \frac{1}{2}X_k(X_k^\top X_k).\]Coming out of nowhere? In fact, we consider the more general case: if $X=U\Sigma V^\top$ is an SVD,
\[\begin{align*} & \alpha_1 X + \alpha_{3}X(X^\top X) +\cdots + \alpha_{2m+1} X(X^\top X)^m\\ =&\ U(\alpha_1 \Sigma + \alpha_3 \Sigma^3 + \cdots + \alpha_{2m+1} \Sigma^{2m+1}) V^\top = U p(\Sigma)V^\top \end{align*}\]where $p(x)=\alpha_1 x + \cdots+\alpha_{2m+1} x^{2m+1}$ is a polynomial of degree $2m+1$. Denote by $\tilde p$ its induced operation on $X$, the first line.
A more general iteration scheme:
\[X_0=X,\quad X_{k} = \tilde p_{k}(X_{k-1})\]with possibly different polynomials.
\[\|X_{k}-UV^\top\|^2 = \| p_k\circ \cdots \circ p_1(\Sigma) - I\|^2.\]The iteration converges iff $p_k\circ \cdots \circ p_1(\sigma_i)\to 1.$