Ever have one of those moments where your deep learning model is training beautifully, the loss is going down, and then… BAM. Everything just explodes?
Suddenly, your loss shoots up to inf, your gradients become NaN (Not a Number), and your entire training run comes to a screeching halt. It’s one of the most frustrating things to debug. You start questioning everything: the learning rate, the data, the model architecture. But often, the culprit is something much smaller and sneakier, hidden inside a function you use every day: Softmax.
On the surface, Softmax seems so simple. Its job is to take the raw, messy scores your model spits out (we call these "logits") and turn them into clean, understandable probabilities. You know, "there's a 70% chance this is a cat, 25% a dog, and 5% a toaster." It’s the cornerstone of any model that needs to classify things into multiple categories.
But here’s the thing. There’s a huge gap between the neat mathematical formula for Softmax and what actually happens when you write the code. And if you’re not careful, you’ll fall right into a numerical stability trap that can completely derail your work. Let’s walk through this together, because understanding this little detail will save you a world of pain down the road.
The "Simple" Way That's Actually a Trap
Let's start by looking at the most straightforward way to write a Softmax function. If you look up the formula, it’s basically this: take the exponential of each logit, then divide it by the sum of all the exponentiated logits.
In PyTorch, that code looks clean and simple:
import torch
def softmax_naive(logits):
exp_logits = torch.exp(logits)
return exp_logits / exp_logits.sum(dim=1, keepdim=True)
Looks perfectly fine, right? It’s a direct translation of the math. And for nice, small numbers, it works. But we're not always dealing with nice, small numbers. This is where the trouble begins.
This implementation is what we call "numerically unstable." It’s like building a bridge out of toothpicks. It might hold a feather, but as soon as you put any real weight on it, the whole thing collapses.
Let’s Break It on Purpose
To see what I mean, let's create a tiny batch of data. We'll have three samples. Two of them will have perfectly normal, reasonable logit values. But the second one? We're going to give it some extreme numbers to see how our naive function handles the pressure.
# A batch of 3 samples, each with 3 possible classes
logits = torch.tensor([
[2.0, 1.0, 0.1], # A normal, well-behaved sample
[1000.0, 1.0, -1000.0], # The troublemaker
[3.0, 2.0, 1.0] # Another normal one
], requires_grad=True)
targets = torch.tensor([0, 2, 1]) # The correct answers for our samples
Now, let's push these logits through our softmax_naive function and see what comes out the other side.
# Let's see the probabilities
probs = softmax_naive(logits)
print(probs)
Here's what you'll get:
tensor([[0.6590, 0.2424, 0.0986],
[ nan, 0.0000, nan], # Uh oh...
[0.6652, 0.2447, 0.0900]], grad_fn=<DivBackward0>)
Look at that second row. It’s a complete disaster. We got nan and 0.0000. That’s not a probability distribution; it's garbage. The first and third samples look fine, but that one bad apple has spoiled the batch.
So, What Just Happened? The Math vs. The Machine
The problem isn't with the math. The problem is with the limitations of computers.
Think of it like this: your computer can only store numbers up to a certain size. When you calculate torch.exp(1000), you're asking it to compute e¹⁰⁰⁰. That number is astronomically huge—way bigger than what can be stored in a standard floating-point number.
- Overflow: When the number is too big, the computer just gives up and calls it "infinity" (
inf). That's what happened withexp(1000). - Underflow: When the number is too small (like
exp(-1000)), it's so close to zero that the computer just rounds it down to0.0.
So for our troublemaker sample, the calculation became [inf, 2.718, 0.0]. When you try to normalize this by dividing by the sum (inf + 2.718 + 0.0 is still inf), you end up with inf / inf, which is undefined. And that’s where NaN comes from.
The Problem Gets Worse: The Loss and The Gradients
Okay, so we have some NaNs. Big deal? It’s a huge deal. Because the next step in training is to calculate the loss. For classification, we typically use cross-entropy loss, which involves taking the logarithm of the predicted probability for the correct class.
Let's see what happens when we try to calculate the loss for our broken sample. The correct class was index 2, and our model predicted a probability of 0.0 for it (due to underflow).
The loss calculation is -log(probability). What's log(0)? It’s negative infinity. When you average that with the other finite losses, the total loss for the batch becomes inf.
And once your loss is inf, it’s game over for backpropagation. When you try to calculate the gradients, you get a sea of NaNs.
# This is what happens when you call loss.backward()
print(logits.grad)
# Output:
tensor([[ 0.1137, -0.0808, -0.0329],
[ nan, nan, nan], # Total corruption
[ 0.1117, -0.2518, 0.1400]])
Those NaN gradients are like a virus. They’ll propagate backward through your network, corrupting all the model's weights. Your model hasn't just learned nothing; it has effectively destroyed itself.
The Fix: A Clever Trick to Keep Numbers in Check
So, how do the big frameworks like PyTorch and TensorFlow avoid this? They don't calculate Softmax and Cross-Entropy Loss separately. They use a combined, numerically stable function that relies on a neat mathematical trick.
The core idea is simple: we can subtract any number from all the logits in a sample, and the final Softmax output will be exactly the same. It's a neat property of the function.
So, what number should we subtract? The smartest choice is to subtract the largest logit value from all the others in that same sample.
Let's look at our troublemaker again: [1000.0, 1.0, -1000.0]. The max value is 1000.0.
If we subtract 1000.0 from each, we get:
[1000.0 - 1000.0, 1.0 - 1000.0, -1000.0 - 1000.0]
Which simplifies to: [0.0, -999.0, -2000.0]
Now, look at these numbers. The largest one is 0. exp(0) is just 1. The other numbers are large negative values, so their exponentials will be very, very small (but not quite zero!). We’ve completely avoided the overflow problem! All our numbers are now in a safe, manageable range.
This technique is often part of what’s known as the "LogSumExp" trick, which is a stable way to compute log(sum(exp(x))).
The Stable Implementation in Action
Here’s what a stable function to calculate cross-entropy loss directly from logits looks like. It does all the work in one go, avoiding the unstable intermediate step of creating probability values.
def stable_cross_entropy(logits, targets):
# Subtract the max logit for stability
max_logits, _ = torch.max(logits, dim=1, keepdim=True)
shifted_logits = logits - max_logits
# Use the LogSumExp trick
log_sum_exp = torch.log(torch.sum(torch.exp(shifted_logits), dim=1))
# Calculate the final loss
# Note: we use the ORIGINAL logits for the target class
loss = log_sum_exp - logits[torch.arange(len(targets)), targets] + max_logits.squeeze(1)
return loss.mean()
Let's run our same problematic logits through this new function:
loss = stable_cross_entropy(logits, targets)
print("Stable loss:", loss)
loss.backward()
print("\nGradients:")
print(logits.grad)
The output is like a breath of fresh air:
Stable loss: tensor(2000., grad_fn=<MeanBackward0>)
Gradients:
tensor([[ 0.1137, -0.0808, -0.0329],
[ 0.3333, 0.3333, -0.6667],
[ 0.1117, -0.2518, 0.1400]])
Look at that! A clean, finite loss. And more importantly, perfectly reasonable gradients for every single sample, including our troublemaker. The training can continue, and the model can learn. No explosions, no NaNs, no headaches.
The Takeaway: Don't Roll Your Own Crypto (or Softmax)
The moral of the story is pretty clear. While it’s fantastic to understand how these things work under the hood, you should almost never implement a naive Softmax function in your production code.
Frameworks like PyTorch (torch.nn.CrossEntropyLoss) and TensorFlow (tf.keras.losses.CategoricalCrossentropy) have this stability built-in. Their loss functions take the raw logits directly as input, precisely to avoid this numerical trap. They perform the Softmax and the log loss calculation together in a fused, stable way.
So next time your model blows up, before you start tearing your hair out over hyperparameters, take a quick look to see if you’re accidentally calculating Softmax manually somewhere. It’s a simple mistake to make, but now you know exactly why it’s so dangerous and, more importantly, how to fix it.




