MiniML 0.7.0: the Grokking Update
A few months ago I released the first version of MiniML, a small machine learning framework powered by Jax. The main idea behind it was to have a very lean toolset to build models that would be as simple to train as a Scikit-learn one, while offering the same level of flexibility as PyTorch. I wrote all about it here.
A few versions later I’ve expanded on that base with a few much needed basic modules for machine learning - radial basis function networks, multi-head self attention, a testing system that uses gold standard data from Backblaze, support for non-Scipy optimizers and more. In this release though I want to focus on the addition of two features that support research in one particular phenomenon: “grokking”. Let’s see what is it!
The origin of “grokking”#
The term “grokking” originates in Robert A. Heinlein’s novel Stranger in a Strange Land, where it’s used by an alien to mean achieving a depp, intimate level of understanding of something. In 2022 it was adopted in a paper by OpenAI researchers to describe a peculiar phenomenon they observed1. The phenomenon was apparently discovered serendipitously: some researchers who were experimenting on training a small model to learn basic modular arithmetic left the training run overnight far longer than they planned. While the model had quickly learned how to reproduce its training data, it had overfit - it matched the training data very well and did very poorly on the test data. However, after being left running for a million or so of optimization steps, something else happened. All of a sudden, the loss function on the test set dropped too, and the model started performing beautifully on every data point, even those it had never seen! Beyond rote memorization, the model had “grokked” modular addition: it had found a way to reproduce the exact correct algorithm that produced the data with its weights, and was now able to correctly extrapolate outside of its training set. Subsequent inspection of the model’s weights revealed a lot of regular, periodic structures - the model adopted some tricks involving sines and cosines to produce exactly the desired result of an essentially exact implementation of modular arithmetic. That’s the gold standard for any machine learning model - to thread the line between fidelity and simplicity which eventually lands us on some true regularity, some law underlying all of the data, rather than just a lot of unrelated anecdotal observation.
Grokking and numerical stability#
Flash forward to early 2025, and a different paper came out, this time from researches of the Imperial College, London2. This paper inspected the grokking phenomenon a bit closer and managed to achieve grokking on the same problems from the original study in as little as 300 training epochs! They did so by identifying two ways to dramatically improve the properties of the model and training process that addressed some problems that would otherwise undermine grokking. The two changes they introduced (and, spoilers, that you will find implemented in MiniML 0.7.0) are StableMax and OrthoGrad.
StableMax: when SoftMax isn’t soft enough#
SoftMax is a common loss function used to turn an array of unbounded floats into an array of probabilities by treating them as ’logits’ of a distribution over discrete outcomes. It means that if you get an array of outputs $z_i$ for your possible bins out of a classifier, you then derive probabilities:
$$ \hat{y}_i = \text{SoftMax}(z)_i = \frac{\exp({z_i})}{\sum_j \exp({z_j})} $$
These are used for cross-entropy loss. Since there will only be one correct class for which $y_{true} = 1$, the loss simplifies to:
$$ \mathcal{L} = -z_{true}+\log\left[\sum_j \exp({z_j})\right] $$
The paper points out a problem with this, one which the researchers observed in their experiments. If a model converges to a good answer, one of the $z_i$ is going to be very large compared to the others. Once these are exponentiated, it can mean several orders of magnitude of difference between them. Now, when computing the gradient of the loss, the only way for gradients to “flow through” any outputs other than $z_{true}$ is through the logarithm of the denominator, where all the exponentials are summed. But if the differences are large enough, floating point arithmetic will make it so that these smaller terms completely vanish. The smallest terms are small enough, they don’t even appear in the bigger term’s significant digits! This effect sort of gets optimization stuck into a rut.
To avoid this effect, the researchers propose replacing SoftMax with a slower changing function with fatter tails, which they dub StableMax. First they define a function $s(x)$ as follows:
$$
s(x) = \begin{cases}
1+x & x > 0 \\
\frac{1}{1-x} & x \le 0
\end{cases}
$$
Then they define StableMax as:
$$ \text{StableMax}(z)_i = \frac{s(z_i)}{\sum_j s(z_j)} $$
You can see here how the two functions compare (on the left a comparison of exponential vs $s(x)$, on the right a comparison of the resulting sigmoid probabilities for a varying $x$ referenced against zero):

The slower decay in StableMax allows the loss gradient to be more sensitive even to weights that participate more in the low-probability bins.
OrthoGrad: fixing the root of the issue#
While StableMax can improve the numerical stability problem, it does not fix the root cause of the issue. That cause can be understood simply if you consider this toy example:
imagine you have logistic regressor. A simple linear layer multiplied by your input vector, whose outputs are treated as logits for the $N$ classes of your problem. You’ve reached a stage in training where the results are already good enough (if you take the highest logits you have accuracy of 100% on your training set). How can the loss improve any further? One easy answer is: by scaling upwards every weight in the layer! If the weights are scaled up, the logits are also scaled. The positive ones will become more positive, and the negative more negative; the model will become more “certain” of the result and thus the loss will go down.
This does not improve the actual predictive power of the model one bit. In fact, this actively worsens its ability to generalize and overfits more and more the specific pattern that was found in the training data. Normally, this trend is contained by regularization, but when trying to achieve “grokking” it’s actually usually better to apply no regularization at all - to avoid damping the model’s ability to converge to the correct algorithm, and given the fact that we know that our training data is entirely noiseless.
If you consider the weights of the model as a vector $\theta$, then this kind of scaling means performing a step that’s essentially parallel to the original vector. In practice, the gradient of the loss, $g = \nabla_\theta \mathcal{L}$, is going to have a parallel and a perpendicular component to $\theta$, but the parallel component is going to be much more important if that’s an easy way for the model to reduce its loss. The OrthoGrad approach means simply that the update step is not computed any more using the full gradient $g$, but instead, using only the perpendicular component to the weight vector, $g_\perp$. This way, the updates keep exploring the space of possible solutions in more interesting directions rather than doubling down on the one that was already found. The paper uses AdamW but in theory this approach should be usable with some other optimizers too - not the ones relying on Hessian approximations though, I expect, unless the Hessian is similarly adapted.
Grokking in MiniML#
In MiniML 0.7.0 I implemented two things to support “grokking” experiments:
- I’ve added several utility functions to support StableMax: in particular
stablemax,log_stablemaxandCrossEntropyStableMaxLogLoss; - I’ve added an
ortho_gradoption to the Adam optimizers, which enables orthogonalized gradients for it; the same option also exists inOptimizationMethods.Configso that it can be easily used in future optimizers as well; - I’ve added a new example Marimo notebook that puts these to use in the same test as the 2025 paper.
You can see the results of the notebook here:

early on, there’s a very fast drop in the train loss while the test loss remains high. Later, the test loss has itself a drop, albeit not as dramatic; to this drop corresponds a jump in accuracy that hits 100%, with perfect generalization of the algorithm outside of the training domain!
Find MiniML here on Github; here are its docs, and here is its PyPI page.