Have you ever tried to build a massive Lego set, but your table is just too small? You have all these pieces, but you can only work on one small section at a time before you run out of space. In a weird way, that’s exactly the problem AI researchers face when training the gigantic models we hear about every day.
These models have gotten so deep, with so many layers, that training them all at once is a memory nightmare. The bigger the model, the more VRAM you need on your GPUs, and that hardware gets expensive, fast. We've come up with some clever workarounds, but they often feel like patches, not solutions.
But what if you could just… train the model in pieces? Like building that Lego set one section at a time, completely independently, and then just snapping them all together at the end? That’s the beautifully simple idea behind a new paper from the folks at Sakana AI and the University of Tokyo. They call it DiffusionBlocks, and it’s a genuinely fascinating approach to tackling one of AI's biggest bottlenecks.
So, What's the Big Deal with Memory Anyway?
Let’s get into the weeds for a second, but I promise to keep it simple. The standard way we train neural networks is with something called backpropagation. Think of it like a game of telephone, but in reverse. The model makes a guess, checks how wrong it was, and then sends a correction signal backward through every single layer, telling each one how to adjust itself.
To do this, the computer has to keep a record of everything that happened on the forward pass—all the intermediate calculations, or "activations." For a deep network, that’s a ton of data to hold in memory. It scales directly with the number of layers. More layers = more memory.
On top of that, for every parameter in your model, you also need to store its gradient (the direction it needs to change) and, if you're using a common optimizer like Adam, two more values called momentum and variance. So for every single number that defines your model, you’re storing four numbers in memory during training.
It’s a memory hog, plain and simple. This is the fundamental reason why training a billion-parameter model requires a whole room full of ridiculously expensive GPUs.
The "Aha!" Moment: Residual Networks Are Secretly Diffusion Models
This is where the Sakana AI team had a brilliant insight. They looked at a super common type of network architecture called a Residual Network (or ResNet). You'll find these structures inside everything from Vision Transformers (ViT) to Diffusion Transformers (DiT).
The core idea of a ResNet is a "skip connection," where the input to a layer is added to its output. The formula looks like this: output = input + f(input). It basically says, "take what you started with, and just add a small change."
The researchers realized this looks a lot like a process from a completely different part of AI: score-based diffusion models, the same tech that powers image generators like Stable Diffusion. In diffusion, you take a clean image, slowly add noise step-by-step, and then train a model to reverse the process, removing the noise step-by-step.
They connected the dots: each residual block in a network can be seen as one single step in that denoising process. A whole stack of residual blocks isn't just a deep network; it's a multi-step journey from a very noisy state to a clean one.
This is the magic trick. Because in diffusion models, you can train the model to denoise at any noise level independently. You don't need to know about the other steps. This means you can take a big network, chop it into blocks, and train each block on its own little denoising job, completely separate from the others.
How to Turn Any Residual Network into DiffusionBlocks
Okay, so how do you actually do it? The team laid out a three-step recipe to convert a standard network into a DiffusionBlocks-powered one.
-
Chop it Up (Block Partitioning): First, you take your network with, say, 12 layers, and you split it into a few blocks. Let's say you choose 3 blocks (B=3). The first block gets layers 1-4, the second gets 5-8, and the third gets 9-12. Easy enough.
-
Assign Each Block a Job (Noise Range Assignment): Next, you tell each block what kind of "noise" it's responsible for. Block 3, the one closest to the output, might be responsible for removing the last little bit of fine-grained noise. Block 1, near the input, gets the messy job of making sense of a super noisy mess. Each block is assigned its own specific range of noise levels to work on.
-
Give it the Right Tools (Noise Conditioning): Finally, you modify each block so it knows what noise level it's supposed to be working on. You feed it the noisy data and a little hint about the noise level. This is often done with a technique called Adaptive Layer Normalization (AdaLN).
That's it. During training, you just pick a random block, give it a training example with the right amount of noise, and train only that block. The other blocks aren't even loaded into memory. Suddenly, your memory usage is slashed by a factor of B (in our example, 3x).
Not All Noise is Created Equal
Here’s another clever detail. The team found that just splitting the noise levels evenly didn't work very well. It turns out that getting the middle levels of noise right is way more important for the final quality of the output than getting the super-noisy or almost-clean ends right.
So, they developed "equi-probability partitioning." Instead of giving each block an equal slice of the noise range, they give each block a slice that represents an equal amount of importance or probability. This means the blocks working on those crucial middle-ground noise levels get a much narrower, more specialized range to focus on.
The results speak for themselves. On an image generation task, this smart partitioning method scored a 38.03 FID (a quality score, lower is better), while the simple uniform method got a much worse 43.53. It’s a small change that makes a big difference.
Okay, But Does It Actually Work?
This all sounds great in theory, but the real test is performance. The researchers tested DiffusionBlocks on five different architectures across three different types of tasks, comparing it to the exact same model trained the old-fashioned, memory-hungry way.
The results are pretty stunning:
| Architecture | Dataset | Metric | Baseline (End-to-End) | DiffusionBlocks | Memory Savings | | :--- | :--- | :--- | :--- | :--- | :--- | | ViT (12-layer) | CIFAR-100 | Accuracy (↑) | 60.25% | 59.30% | 3x | | DiT-S (12-layer) | CIFAR-10 | FID (↓) | 39.83 | 37.20 | 3x | | DiT-L (24-layer) | ImageNet | FID (↓) | 12.09 | 10.63 | 3x | | AR Transformer | LM1B | MAUVE (↑) | 0.50 | 0.71 | 4x | | Huginn (Recurrent) | LM1B | MAUVE (↑) | 0.49 | 0.70 | ~10x compute |
Look at those numbers. In most cases, the performance is nearly identical to the baseline, and in some cases—like with the DiT image models and the AR Transformer—it's actually better. They achieved this while using 3x or 4x less memory during training.
The Huginn result is particularly wild. Huginn is a recurrent model that normally requires a complex and slow training process called BPTT. With DiffusionBlocks, they replaced that with a single forward pass, resulting in a 10x total reduction in training compute. That’s huge.
The Good and The Not-So-Good
Like any new technology, DiffusionBlocks has its pros and cons. Let's break them down.
The Strengths:
- Massive Memory Savings: This is the headline feature. A Bx reduction in memory is a game-changer for anyone without a national budget for GPUs.
- Grounded in Theory: This isn't just a random hack. It's built on a solid mathematical connection between two major fields of AI.
- It Just Works: They showed it works across a wide range of models (vision, diffusion, autoregressive, recurrent) without needing a bunch of custom tweaks for each one.
- Faster Inference (Sometimes): For diffusion models, you also get a Bx speedup during inference, because each denoising step only needs to run one block.
- Perfect for Parallelism: You can literally train each block on a different machine with zero communication needed between them until the very end.
The Weaknesses:
- Architecture Limitation: Right now, it requires the input and output dimensions of each block to be the same. This means it can't be applied to popular U-Net style architectures just yet.
- Untested on Fine-Tuning: All the experiments were for training models from scratch. We don't know how well it works for fine-tuning an existing pre-trained model.
- No Magic Number for B: There isn't a clear rule for how many blocks you should split your model into for the best results. It requires a bit of trial and error.
- Mixed Results on Text: On one text generation benchmark (OpenWebText), the performance was slightly worse than the baseline.
Even with these limitations, the potential here is enormous. The ability to train models in these independent, memory-light chunks could democratize training for larger models, allowing more people to experiment and innovate. It’s a clever, elegant solution to a very real and expensive problem, and I, for one, can't wait to see where the Sakana AI team and others take this idea next.




