Neural networks with first-order optimisers such as SGD and Adam are the go-to when it comes to training LLMs, forming evaluations, and interpreting models in AI Safety. Meanwhile, optimisation is a hard problem that has been tackled in machine learning in many ways. In this blog, we aim to look at the intersection of interpretability and optimisation, and what it means for the AI safety space. As a brief overview, we’ll consider:
We’ll look deeply into these things to better understand models and optimisation procedures, and see if we can tie these things together to get a better idea of how things work for AI safety.
<aside> 💡
The code for this can be found on my at my Github - it involves a general feature visualisation generator that works on a variety of models, datasets, and optimisers and also supports training and checkpointing. Enjoy!
</aside>
During the mechanistic interpretability week, we discussed the A1 dictionary learning features from one of Anthropic’s model. Dictionary learning, a technique that involves learning a sparse representation of the data, is known to have **saddle points.**
Saddle points are a commonly re-occuring theme in the loss landscapes of models - in which first order optimisers such as SGD and Adam have trouble navigating through. One of the contributing factors to a standard loss curve plateuing a few iteration in is likely because the model optimisation process is stuck in a saddle point. Here are a two toy examples of saddle functions that I’ve visualised.
Left: Monkey Saddle defined as $z = x^3 - 3xy^2$. Right: Classic Saddle defined as $z = x^2 − y^2$. Both functions were optimised with Momentum SGD for 200 iterations (LR = 0.001, Momentum = 0.90)
But these are just toy examples, what does an actual landscape look like. Let’s find out - the following is a low dimensional projection of the landscape of a ResNet-110.
Left: Side view of ResNet-110 projection. Right: Bird’s eye view of ResNet-110 projection. Visualised using https://github.com/tomgoldstein/loss-landscape.
We can see that while it’s not an exact map to a saddle point, the features that exhibit one still exist, with parts of the landscape being hard to navigate due to plateaus, ragged edges, and pitfalls. So while saddle points exist, even if they’re not exact maps - they’re known to be riddled exponentially throughout the loss landscapes of our models. This is bad because they can: