How AlphaFold 3 Is Like DALLE 2 and Other Learnings

Diffusion (literally) from Unsplash

Understanding AI applications in bio for machine learning engineers

In our last article, we explored how AlphaFold 2 and BERT were connected through transformer architecture. In this piece, we’ll learn how the most recent update, AlphaFold 3 (hereafter AlphaFold) is more similar to DALLE 2 (hereafter DALLE) and then dive into other changes to its architecture and training.

What’s the connection?

AlphaFold and DALLE are another example of how vastly different use cases can benefit from architectural learning across domains. DALLE is a text-to-image model that generates images from text prompts. AlphaFold 3 is a model for predicting biomolecular interactions. The applications of these two models sound like they couldn’t be any more different but both rely on diffusion model architecture.

Because reasoning about images and text is more intuitive than biomolecular interactions, we’ll first explore DALLE’s application. Then we’ll learn about how the same concepts are applied by AlphaFold.

Diffusion Models

A metaphor for understanding the diffusion model: consider tracing the origin of a drop of dye in a glass of water. As the dye disperses, it moves randomly through the liquid until it is evenly spread. To backtrack to the initial drop’s location, you must reconstruct its path step by step since each movement depends on the one before. If you repeat this experiment over and over, you’ll be able to build a model to predict the dye movement.

More concretely, diffusion models are trained to predict and remove noise from a dataset. Then upon inference, the model generates a new sample using random noise. The architecture comprises three core components: the forward process, the reverse process, and the sampling procedure. The forward process takes the training data and adds noise at each time step. As you might expect, the reverse process removes noise at each step. The sampling procedure (or inference) executes the reverse process using the trained model and a noise schedule, transforming an initial random noise input back into a structured data sample.

Simplified illustration of the forward and reverse processes where a pixelated heart has noise added and removed back to its original shape. (Created by author)

DALLE and Diffusion

DALLE incorporates diffusion model architecture in two major components, the prior and the decoder, and removes its predecessor’s autoregressive module. The prior model takes text embeddings generated by CLIP (a model trained on a dataset of images and captions known as Contrastive Language-Image Pre-training) and creates an image embedding. During training, the prior is given a text embedding and a noised version of the image embedding. The prior learns to denoise the image embedding step by step. This process allows the model to learn a distribution over image embeddings representing the variability in possible images for a given text prompt.

The decoder generates an image from the resulting image embedding starting with a random noise image. The reverse diffusion process iteratively removes noise from the image using the image embedding (from the prior) at each timestep according to the noise schedule. Time step embeddings tell the model about the current stage of the denoising process, helping it adjust the noise level removed based on how close it is to the final step.

While DALLE 2 utilizes a diffusion model for generating images, its predecessor, DALLE 1, relied on an autoregressive approach, sequentially predicting image tokens based on the text prompt. This approach was much less computationally efficient, required a more complex training and inference process, struggled to produce high-resolution images, and often resulted in artifacts.

Its predecessor also did not make use of CLIP. Instead, DALLE 1 learned the text-image representations directly. The introduction of CLIP embeddings unified these representations making more robust text-to-image representations.

High-level overview of DALLE 2 architecture from Hierarchical Text-Conditional Image Generation with CLIP Latents by the OpenAI team from Hierarchical Text-Conditional Image Generation with CLIP Latents

How AlphaFold uses Diffusion

While DALLE’s use of diffusion helps generate detailed visual content, AlphaFold leverages similar principles in biomolecular structure prediction (no longer just protein folding!).

AlphaFold 2 was not a generative model, as it predicted structures directly from given input sequences. Due to the introduction of the diffusion module, AlphaFold 3 IS a generative model. Just like with DALLE, noise is sampled and then recurrently denoised to produce a final structure.

The diffusion module is incorporated by replacing the structure module. This architecture change greatly simplifies the model because the structure module predicts amino-acid-specific frames and side-chain torsion angles whereas the diffusion module predicts the raw atom coordinates. This eliminates several intermediate steps in the inference process.

AF3 architecture for inference showing where the diffusion module residues. From Accurate structure prediction of biomolecular interactions with AlphaFold 3

The impetus behind removing these intermediate steps was that the scope of training data for this iteration of the model grew substantially. AlphaFold 2 was only trained on protein structures, whereas AlphaFold 3 is a “multi-modal” model capable of predicting the joint structure of complexes including proteins, nucleic acids, small molecules, ions and modified residues. If the model still used the structure module, it would have required an excessive number of complex rules about chemical bonds and stereochemistry to create valid structures.

The reason why diffusion did not require these rules is because it can be applied at coarse and fine-grained levels. For high noise levels, the model is focused on capturing the global structure, while at low noise levels, it fine-tunes the details. When the noise is minimal, the model refines the local details of the structure, such as the precise positions of atoms and their orientations, which are crucial for accurate molecular modeling. This means the model can easily work with different types of chemical components, not just standard amino acids or protein structures.

The benefit of working with different types of chemical components appears to be that the model can learn more about protein structures from other types of structures such as protein-ligand interfaces. It appears that integrating diverse data types helps models generalize better across different tasks. This improvement is similar to how Gemini’s text comprehension abilities became better the model became multi-modal with the incorporation of image and video data.

Other Important Changes to AlphaFold

The role of MSA (Multiple Sequence Alignment) was significantly downgraded. The AF2 evoformer is replaced with the simpler pairformer module (a reduction of 48 blocks to 4 blocks). As you may recall from my previous article, the MSA was thought to help the model learn what parts of the amino acid sequence were important evolutionarily. Experimental changes showed that reducing the importance of the MSA had a limited impact on model accuracy.

Hallucination had to be countered. Generative models are very exciting but they come with the baggage of hallucination. Researchers found the model would invent plausible-looking structures in unstructured regions. To overcome this, a cross-distillation method was used to augment the training data with predicted structures AlphaFold-Multimer (v.2.3). The cross-distillation approach teaches the model to differentiate between structured and unstructured regions better. This helps the model to understand when to avoid adding artificial details.

Some interactions were easier to predict than others. Sampling probabilities were adjusted for each class of interaction i.e. fewer samples from simple types that can be learned in relatively few training steps and visa versa for complex ones. This helps avoid under and overfitting across types.

Training curves for initial training and fine-tuning stages illustrate how different classes reached their best performance at varying training steps. For this reason, the training data was subsampled to prevent under and overfitting by class. From Accurate structure prediction of biomolecular interactions with AlphaFold 3

High-Level Learnings

DALLE 2 and AlphaFold 3 made improvements to their predecessors by using diffusion modules which simultaneously simplified their architectures.Training on a wider range of data types makes generative models more robust. Diversifying the types of structures in the AlphaFold training dataset allowed the model to improve protein folding predictions and generalize to other biomolecular interactions. Similarly, the diversity of text-image pairs used to train CLIP improved DALLE.The noise schedule is an important knob when training diffusion models. Turning it up or down affects the model’s ability to learn both coarse and fine details. Doing so allowed for a significant simplification of AlphaFold because it eliminated the need to make intermediate predictions about side-chain torsion angles etc.

Thank you again for reading, and stay tuned for the next installment. Until then, keep learning and keep exploring.

How AlphaFold 3 Is Like DALLE 2 and Other Learnings was originally published in Towards Data Science on Medium, where people are continuing the conversation by highlighting and responding to this story.

Author:

Leave a Comment

You must be logged in to post a comment.