Loss spikes are a familiar sight when training neural networks: the loss drops steadily, then suddenly jumps before recovering. This post explains why they happen. Starting from the simple case of a quadratic loss, we build up to the edge of stability and derive, via a Taylor expansion of the gradient, why spikes are self-correcting.

Gradient Descent

Given a differentiable loss function \(f: \mathbf{R}^d \to \mathbf{R}\), gradient descent iteratively updates parameters according to

\[x_{k+1} = x_k - \eta \nabla f(x_k)\]

where \(\eta > 0\) is the learning rate (or step size). When \(f\) is complicated, as in the case of a neural network loss landscape, it’s difficult to choose an appropriate value of \(\eta\). Choosing a value that’s too small leads to very slow convergence while choosing a value that’s too large leads to divergence.

To understand this, it helps to first study a simple case to understand the source of the instability.

1D Quadratic Case

Consider the simplest case of minimizing the 1D quadratic \(f(x) = \frac{S}{2}x^2\) where \(S > 0\) is the sharpness (curvature) of the function. Larger \(S\) means a steeper, narrower parabola as shown in Figure 1.

S = 1.0  
Figure 1: Larger S makes the parabola steeper and smaller S makes it flatter.

The gradient/derivate is \(\nabla f(x) = Sx\), so gradient descent becomes

\[x_{k+1} = x_k - \eta \cdot S x_k = (1 - S\eta) x_k\]

This is a simple geometric sequence! Starting from \(x_0\), we have \(x_k = (1-S\eta)^k x_0.\)

For convergence to zero, we need

\[\lvert 1 - S\eta\rvert < 1\]

Since \(S, \eta > 0\), this means

\[0 < S\eta < 2 \quad \Rightarrow \quad \eta < \frac{2}{S}.\]

This says that for a quadratic function, the maximum stable learning rate is inversely proportional to the sharpness. Intuitively, sharper quadratics require smaller learning rates as demonstrated in Figure 2.

f(x) x₀ Gradient descent
S = 1.0  |  η = 0.50  |  ηcrit = 2.00  |  Converging
Figure 2: Gradient descent for 15 steps. Left: iterates on the loss surface. Right: loss vs step on a log scale.

The Multi-Dimensional Quadratic Case

In higher dimensions, the picture becomes richer. Consider a \(d\)-dimensional quadratic

\[f(x) = \frac{1}{2}x^T A x = \sum_{i=1}^d \sum_{j=1}^d A_{ij}x_ix_j\]

where \(A\in \mathbf{S}_+^{d}\)1, the gradient of which is \(\nabla f = Ax\).2 We can form the eigenvalue decomposition of the quadratic form as

\[A = VDV^T, \quad D = \mathbf{diag}(\lambda_1, \ldots, \lambda_d)\]

where \(V^TV = I\).

The quadratic form can then be written as

\[f(x) = \frac{1}{2} \sum_{k=1}^d \lambda_k \cdot (x^Tv_k)^2\]

where \(v_k\) is the \(k^{th}\) column of \(V\). Comparing this decomposition to our 1D case, we can see that the eigenvalues are exactly the sharpnesses of the quadratic along the directions \(v_1, \ldots, v_d\) (called the principal axes).3

The level sets of a 2D quadratic are shown in Figure 3 which demonstrate how the eigenvalues affect the shape of the ellipses and thus the sharpness in the direction of the eigenvectors.

f(x) = ½(λ₁x₁² + λ₂x₂²)
λ₁ = 4.0  |  λ₂ = 1.0
Figure 3: Level sets (contours) of the 2D quadratic. When λ₁ = λ₂ the contours are circles; unequal eigenvalues produce ellipses elongated along the less-sharp direction.

In the 1D quadratic case, we derived the simple rule that \(\eta < 2/S\) for gradient descent to converge. For the \(d\)-dimensional quadratic case, how should we set the learning rate to ensure convergence?

Figure 4 shows the trajectory of gradient descent on a 2D quadratic.

Loss f(x) Step
Gradient descent Start
λmax = 4.00  |  ηcrit = 2/λmax = 0.50  |  Converging
Figure 4: Gradient descent trajectory on a quadratic. The left panel shows iterates on the loss contours; the right panel shows loss vs. step. Increasing η toward the critical value 2/λmax causes oscillations along the sharpest eigendirection, and crossing it leads to divergence.

From the figure, we can see that convergence occurs when

\[\eta < \frac{2}{\max\{\lambda_1, \lambda_2\}}.\]

For the general \(d\)-dimensional case

\[\eta < \frac{2}{\max\{\lambda_1, \ldots, \lambda_d\}}.\]

This says that, for a quadratic, the maximal learning rate with which gradient descent will converge is governed by the sharpness along each principal axis. More specifically, it is determined by the sharpest of these directions. Since the principal axis with the largest sharpness is the only one that governs convergence, we define \(S= \max\{\lambda_1, \ldots, \lambda_d\}\) as the sharpness of a quadratic in the general case. This generalizes the 1D definition, where the single eigenvalue \(S\) of \(f(x) = \frac{S}{2}x^2\) was the sharpness.

Non-convex Case

Let’s turn to the more complicated case of minimizing a general objective \(f(x)\). Writing the second order Taylor expansion about point \(a\in \mathbf{R}^d\)

\[f(x)\approx f(a) + \nabla f(a)^T(x-a) + \frac{1}{2}(x-a)^T\nabla^2 f(a) (x-a)\]

we define the sharpness at \(a\), denoted \(S(a)\in \mathbf{R}_+\), to be the maximum eigenvalue of the Hessian \(\nabla^2 f(a)\).4

In order to have a more concrete example to visualize, let’s consider the optimization problem

\[\begin{align*} \underset{C,\alpha}{\text{minimize}}&\quad \frac{1}{2N}\sum_{k=1}^N (y_k - Cx_k^\alpha)^2\\ \end{align*}\]

which comes from fitting a power law \(f(x) = Cx^\alpha\) to \(N\) data points.

Figure 5 shows the level sets of the objective along with the gradient descent trajectory

Start (K=0, α=−0.8) End
η = 1
Figure 5: Level sets of the power law non-linear least squares objective on word frequency data. Brighter regions indicate higher loss.

Notice how for \(\eta > 57\), the trajectory shows increasingly oscillatory behavior and eventually completely diverges. This indicates that our learning rate is somehow “too big” for our optimization landscape. We previously found the maximum learning rate we could tolerate for a quadratic was \(2/S\).

We can turn the question around and ask, “for a fixed learning rate \(\eta\), what’s the maximum sharpness \(S\) that can be tolerated?”. The answer is, a sharpness less than \(2/\eta\) which is referred to variously as the critical sharpness or the edge of stability.

Figure 6 plots both the loss and sharpness at each point along the gradient descent trajectory. The sharpness threshold for the chosen \(\eta\) is also displayed as a horizontal line on the sharpness plot.

Loss S = 2/η
η = 70  |  2/η ≈ 0.0286
Figure 6: Loss (log scale) and sharpness during gradient descent for 2500 steps on the non-linear least squares objective. The dashed orange line marks S = 2/η (the critical sharpness above which gradient descent would diverge on a quadratic).

Notice that for \(\eta\) below ~57, the loss decreases monotonically and the sharpness along the gradient descent trajectory never exceeds the critical sharpness \(2/\eta\).

After \(\eta> 57\), we see two things occur. The first is that loss spikes begin appearing and we no longer observe a monotone decreasing loss during gradient descent. The second is that the sharpness begins exceeding the critical sharpness and then rapidly plunges back down below it. This is very different than the quadratic case where once we exceeded the critical threshold, gradient descent diverged.

Also notice that as we increase \(\eta\), we see more oscillations around the critical sharpness with an exponentially decaying envelope. Simultaneously, we see larger and more frequent loss spikes. After around \(\eta=90\), we see oscillations in the sharpness display an exponentially growing envelope as well as the loss diverging.

To understand what causes the sharpness to suddently decrease after exceeding the critical sharpness, Figure 7 decomposes the gradient at each step into its components along the dominant eigenvector \(v_1\) (sharpest direction) and the non-dominant eigenvector \(v_2\) of the local Hessian \(\nabla^2 f\).

∇f·v₁ (dominant) ∇f·v₂ (non-dominant) S = 2/η
Figure 7: Sharpness (top) and gradient decomposed along the dominant eigenvector v₁ and non-dominant eigenvector v₂ of the local Hessian (bottom) for η = 70. The dominant component oscillates (sign-flipping) once sharpness exceeds 2/η.

Notice how just when the sharpness drops down precipitously below the edge of stability, we see large oscillations back and forth in the component of the gradient vector in the direction of maximum sharpness. It turns out that this is the crux to understanding the phenomenon of loss spikes as we’ll see in the next section.

Mathematics of Loss Spikes

To study the dynamics of the loss curve, let’s look at the Taylor expansion, not of the loss, but of the loss’s gradient. For ease of notation, we’ll denote the gradient as \(g(x)=\nabla f(x)\).

We’ll perform the second order Taylor expansion around the point \(a\in\mathbf{R}^d\). Since this would lead to a third order tensor, we’ll just look a single scalar component \(g_k\) of the gradient approximation centered at \(a\)

\[g_k(x) \approx g_k(a) + Dg_k(a)(x-a) + \frac{1}{2}(x-a)^T \nabla^2 g_k(a) (x-a), \quad k=1,\ldots, d\]

Evaluating this quadratic approximation at a perturbed point \(a+\delta\in\mathbf{R}^d\) we have

\[g_k(a+\delta) \approx g_k(a) + Dg_k(a)\delta + \frac{1}{2}\delta^T \nabla^2 g_k(a) \delta, \quad k=1,\ldots, d\]

From analyzing the quadratic case, we observed that when the sharpness exceeds the critical threshold \(2/\eta\), the perturbation was largely in the direction of maximal sharpness. For this reason we set \(\delta = \sigma u\) where \(u\in\mathbf{R}^d\) is a unit vector in the direction of the dominant eigenvector of \(\nabla^2f\) and \(\sigma \in \mathbf{R}\). This gives,

\[g_k(a+\sigma u) \approx g_k(a) + \sigma Dg_k(a)u + \frac{\sigma^2}{2}u^T \nabla^2 g_k(a) u.\]

The term \(Dg_k(a)\in \mathbf{R}^{1\times d}\) is the total derivative of the \(k^{th}\) component of the gradient at \(a\). This is exactly the \(k^{th}\) row of the hessian \(\nabla^2 f(a)\). This means \(Dg_k(a) u\) is simply, \([\nabla^2 f(a) u]_k\), the \(k^{th}\) component of the product of the hessian and its dominant eigenvector.

Since the eigenvalue associated with the dominant eigenvector is by definition the sharpness, we have \([\nabla^2 f(a) u]_k = S(a)u_k\). So the Taylor approximation of the gradient becomes

\[g_k(a+\sigma u) \approx g_k(a) + \sigma S(a)u_k + \frac{\sigma^2}{2}u^T \nabla^2 g_k(a) u.\]

We can also simplify the final term involving the hessian of the gradient (a third derivative!) using the definition \(g_k(a) = \partial f(a)/\partial x_k\) and the identity \(x^TAx = \sum_{i=1}^d\sum_{j=1}^d A_{ij}x_ix_j\)

\[\begin{align*} \frac{\sigma^2}{2}u^T \nabla^2 g_k(a) u &= \frac{\sigma^2}{2}\sum_{i=1}^d\sum_{j=1}^d \frac{\partial^2 g_k(a)}{ \partial x_i \partial x_j} u_i u_j \\ &= \frac{\sigma^2}{2}\sum_{i=1}^d\sum_{j=1}^d \frac{\partial^3 f(a)}{\partial x_k \partial x_i \partial x_j} u_i u_j \\ &= \frac{\sigma^2}{2}\left[\frac{\partial}{\partial x_k}\sum_{i=1}^d\sum_{j=1}^d \frac{\partial^2 f(x)}{\partial x_i \partial x_j} u_i u_j\right]_{x=a} \\ &= \frac{\sigma^2}{2}\left[\frac{\partial}{\partial x_k}\left(u^T \nabla^2 f(x) u\right)\right]_{x=a} \\ \end{align*}\]

Simplifying the last part is tricky but a sketch of the argument comes from looking at the limit definition of the partial derivative with respect to \(x_k\), where \(e_k\) is the \(k^{th}\) standard basis vector.

\[\left[\frac{\partial}{\partial x_k}\left(u^T \nabla^2 f(x) u\right)\right]_{x=a} = \lim_{h\rightarrow 0} \frac{u^T \nabla^2 f(a+he_k) u - u^T\nabla^2 f(a)u}{h}\]

Since \(u\) is the dominant eigenvector of \(\nabla^2 f(a)\), we have \(\nabla^2 f(a)u = S(a)u\). Combined with the fact that \(u^Tu=1\), we can simplify the above to

\[\left[\frac{\partial}{\partial x_k}\left(u^T \nabla^2 f(x) u\right)\right]_{x=a} = \lim_{h\rightarrow 0} \frac{u^T \nabla^2 f(a+he_k) u - S(a)}{h}\]

Since \(u\) is an eigenvector of \(\nabla^2 f(a)\) and not \(\nabla^2 f(a+he_k)\), this does not immediately simplify as before. However, for very small \(h\) we can treat \(u\) as constant and as an eigenvector of \(\nabla^2 f(a+he_k)\), in which case the corresponding eigenvalue would be \(S(a+he_k)\). This simplifies the partial derivative to

\[\begin{align*} \left[\frac{\partial}{\partial x_k}\left(u^T \nabla^2 f(x) u\right)\right]_{x=a} &= \lim_{h\rightarrow 0} \frac{S(a+he_k) - S(a)}{h}\\ &= \frac{\partial S(a)}{\partial x_k} \end{align*}\]

With this simplification, the second order gradient approximation for the \(k^{th}\) component at \(a+\sigma u\) is

\[g_k(a+\sigma u) \approx g_k(a) + \sigma S(a)u_k + \frac{\sigma^2}{2}\frac{\partial S(a)}{\partial x_k}.\]

The approximation for the entire gradient vector is then

\[g(a+\sigma u) \approx g(a) + \sigma S(a)u + \frac{\sigma^2}{2}\nabla S(a).\]

This says that when the perturbation is small, a step in the negative gradient direction is dominated by the term \(-\sigma S(a)u\) which pushes in the opposite direction of the original perturbation \(\sigma u\), causing the oscillatory behavior observed in Figure 7. When the perturbation is sufficiently large however, a step in the negative gradient direction is strongly in the direction of decreasing sharpness (i.e. \(-\frac{\sigma^2}{2}\nabla S(a)\)).

This effectively explains our observations! During gradient descent, our loss spikes occur when the sharpness exceeds the critical sharpness \(2/\eta\) for the given learning rate. The loss drops back down because the sharpness rapidly drops back below \(2/\eta\). But the reason the sharpness drops after having exceeded the edge of stability is because there is a built in negative feedback! When the perturbation along the dominant eigenvector of the hessian gets too large, the negative gradient of the loss also begins to point in the direction of the negative gradient of the sharpness, driving the sharpness back below the edge of stability.

Conclusion

Loss spikes are not noise. They are a predictable consequence of gradient descent briefly violating the edge of stability before snapping back. For quadratics, the stability condition \(\eta < 2/S\) is global, and crossing it sends the optimizer into immediate divergence. For general objectives, sharpness varies along the trajectory, so a learning rate that behaved stably early in training can locally violate the condition as the optimizer enters sharper regions. When it does, the quadratic term in the Taylor expansion of the gradient takes over, pointing in the direction of \(-\nabla S\). This drives the optimizer toward flatter regions, sharpness falls back below \(2/\eta\), and stable descent resumes. The spike is the mechanism, not a glitch. Without that restoring force, there would be no recovery.

References

  1. How does gradient descent work
  2. Boyd, S. & Vandenberghe, L. “Convex Optimization.” Cambridge University Press, 2004.
  1. \(\mathbf{S}_+^{d}\) denotes the set of \(d\times d\) symmetric positive semidefinite matrices. 

  2. More precisely, \(\nabla_x \tfrac{1}{2}x^T A x = \tfrac{1}{2}(A + A^T)x\). When \(A\) is symmetric this reduces to \(Ax\). 

  3. To see why, set \(x = t\,v_k\) for scalar \(t\). Because the columns of \(V\) are orthonormal, \(x^T v_j = t\,v_k^T v_j = t\,\delta_{kj}\), so every term in the sum vanishes except the \(k^{th}\) one: \(f(tv_k) = \tfrac{1}{2}\lambda_k t^2\). This is exactly the 1D quadratic with sharpness \(\lambda_k\), confirming that \(\lambda_k\) governs the curvature of \(f\) along \(v_k\). 

  4. For a quadratic \(f(x) = \tfrac{1}{2}x^TAx\), the Hessian is the constant matrix \(A\), so the sharpness is the same at every point (i.e. a global property). For a general nonlinear function the Hessian varies, making sharpness point-dependent. Note also that this definition is a direct generalization of the \(d\)-dimensional quadratic case: the second order Taylor expansion of a quadratic is just the quadratic itself, so the maximum eigenvalue of \(\nabla^2 f(a) = A\) recovers exactly \(S = \max\{\lambda_1,\ldots,\lambda_d\}\).