FedAvg vs. FedProx: A Hands-On Guide to Taming Unruly Data in Federated Learning

Akram Chauhan
Akram Chauhan
7 min read7 views
FedAvg vs. FedProx: A Hands-On Guide to Taming Unruly Data in Federated Learning

Federated learning has this beautiful, almost utopian promise, doesn't it? Train a powerful AI model across multiple devices or silos without ever having to centralize the data. It’s a huge win for privacy and a big deal for industries like healthcare and finance.

But here’s the thing they don't always tell you in the intro slides: real-world data is a mess.

Imagine you're training a model to identify different types of vehicles. In a perfect world, every one of your clients—let's say, different city traffic cameras—would have a nice, balanced mix of cars, trucks, and buses. But reality is different. The camera on the highway sees mostly trucks, the one downtown sees mostly cars, and the one near the depot sees a ton of buses.

This is what we call "non-IID" (non-identically and independently distributed) data, and it's the bane of standard federated learning. The classic algorithm, FedAvg, can really struggle with this. It’s like trying to get a group of specialists to agree on a general topic—each one pulls the model in their own biased direction.

So, what can we do? Well, there's a clever modification called FedProx that offers a surprisingly simple solution. Today, we’re going to roll up our sleeves and put these two head-to-head in a real experiment. We'll use NVIDIA FLARE to orchestrate everything and see for ourselves how FedProx helps tame that unruly, real-world data.

Setting the Stage: Our Federated Showdown

Before we dive into the code, let's talk about what we're building. Our goal is to simulate a realistic federated learning environment where the data is intentionally unbalanced.

Here are our key ingredients:

  • The Framework: We'll be using NVIDIA FLARE. It's a fantastic open-source framework that makes setting up and running these kinds of complex federated experiments much, much easier. It handles all the communication between the server and clients for us.
  • The Dataset: We’re using CIFAR-10. It’s a classic image dataset with 10 classes (airplanes, cars, birds, etc.). It's simple enough that we can focus on the federated learning mechanics without needing a supercomputer.
  • The Twist (Non-IID Data): This is the most important part. We’re not just going to split the CIFAR-10 data evenly. Instead, we'll use something called a Dirichlet distribution to give each of our simulated clients a skewed, specialized slice of the data. For this experiment, we'll simulate 3 clients, and the ALPHA parameter of 0.3 ensures the data is quite imbalanced.

Think of it this way: one client might get a dataset that's 70% dogs and 10% cats, while another gets 60% cats and only 5% dogs. This is the kind of challenge that FedAvg often fumbles.

We'll set up some basic parameters like the number of training rounds, learning rate, and batch size. Nothing too fancy, just the standard stuff.

The Heart of the Operation: What Each Client Does

Now, let's get into what happens on each client device. We'll write a single Python script that every client will run. It’s their set of instructions for participating in the training.

The Model and the Data Split

First, each client needs a model. We're using a simple Convolutional Neural Network (CNN). It’s a pretty standard choice for image classification and perfect for CIFAR-10.

Next, and this is crucial, each client needs to figure out which data belongs to them. Remember that non-IID split we talked about? We have a function that uses the client's ID (e.g., "site-1", "site-2") to deterministically grab its unique, skewed chunk of the training data. This ensures every client knows its role without any confusion.

The Training Loop: Where the Magic Happens

This is where the real action is. The client's life is a simple loop that repeats for every communication round:

  1. Receive the Global Model: The client first gets the latest version of the global model from the central server. Think of this as the "master blueprint."
  2. Test It Out: Before training, the client evaluates this global model on a shared test set. This gives us a baseline accuracy score for the current round.
  3. Local Training: The client then trains this model on its own local, skewed data for a few epochs.

This is the step where FedAvg and FedProx diverge.

With standard FedAvg, the client just tries its best to minimize the loss on its local data. If its data is mostly "dogs," it's going to get really good at identifying dogs, potentially at the expense of other classes.

With FedProx, we add a little something extra to the loss function. It’s called a proximal term. You can think of it like a rubber band or a leash. While the model is learning from the local data, this term gently pulls it back, reminding it not to stray too far from the global model it started the round with.

This "leash" (controlled by a parameter we call mu) is the secret sauce. It encourages the client to learn from its local data without completely forgetting the general knowledge contained in the global model. It’s a balancing act.

  1. Send It Back: After local training, the client packages up its updated model weights and sends them back to the server. The server will then average the updates from all clients to create the next-generation global model, and the whole cycle begins again.

The Director's Chair: Running the Show with NVFlare

So we have our client script, but how do we manage the whole orchestra? That's where the NVIDIA FLARE Job API comes in. It’s our server-side control panel.

We'll define a simple function to set up and run an experiment. Here’s what it does:

  • It creates a FedAvgJob, which tells NVFlare the basic structure of our project (e.g., 3 clients, 5 rounds).
  • It uses a ScriptRunner to tell each client ("site-1", "site-2", etc.) to execute our client_train.py script.
  • It passes all the important settings as command-line arguments, like the learning rate, the non-IID alpha, and, most importantly, the FedProx mu value.

To run our comparison, we simply call this function twice:

  1. First Run (FedAvg): We set mu=0.0. This effectively turns off the proximal term, making it a standard FedAvg run. No leash!
  2. Second Run (FedProx): We set mu=0.1. This turns on the FedProx leash with a bit of tension, encouraging the local models to stay closer to the global one.

NVFlare's simulator then spins everything up, managing the clients and server, passing models back and forth, and letting our experiment run automatically. It’s incredibly handy.

The Results Are In: Who Won the Fight?

After both experiments finish, we get to the best part: looking at the results. We saved the global model's test accuracy after every round for both FedAvg and FedProx. Now, let's plot them and see what happened.

You can almost immediately see the difference.

The FedAvg curve is often more erratic. It might jump up, then dip down, struggling to find a stable path to convergence. This is the classic symptom of non-IID data. Each client is pulling the model in a different direction, and the simple averaging process struggles to reconcile these conflicting updates.

Now look at the FedProx curve. It's generally much smoother and trends upward more consistently. In many cases, it reaches a higher final accuracy than FedAvg.

That's the leash doing its job. By preventing the local models from drifting too far, FedProx ensures that the updates sent back to the server are more consistent with each other. This leads to a more stable aggregation process and a better-performing final global model. It effectively dampens the noise caused by the skewed data distribution.

So, What's the Takeaway?

If you're stepping into the world of federated learning, it's easy to get started with the vanilla FedAvg algorithm. But as we just saw, the moment you encounter the kind of messy, unbalanced data that defines the real world, you might run into trouble.

The beauty of FedProx is its simplicity and effectiveness. It's not a radical overhaul of the process; it's just one small, intelligent tweak to the local training objective. By adding that simple proximal term, you can significantly improve the stability and performance of your federated model on non-IID data.

So next time you're building a federated learning system and your accuracy charts look more like a seismograph than a learning curve, remember the leash. A little bit of regularization with FedProx might be all you need to tame your data and get your project back on track.

Tags

Machine Learning Deep Learning Computer Vision AI Engineering Python AI Model Optimization Federated Learning Data Privacy Distributed AI Real-world AI Privacy-Preserving AI CIFAR-10 Non-IID Data NVIDIA FLARE FedAvg FedProx Federated Learning Tutorial AI Model Comparison Data Heterogeneity AI Implementation Guide

Stay Updated

Get the latest articles and insights delivered straight to your inbox.

We respect your privacy. Unsubscribe at any time.

Aicosoft

AI & Technology News, Insights & Innovation

AICOSOFT delivers cutting-edge AI news, technology breakthroughs, and innovation insights. Stay informed about artificial intelligence, machine learning, robotics, and the latest tech trends shaping tomorrow.

Connect With Us

© 2026 Aicosoft. All rights reserved.