Handling Class Imbalance with AutoMM - Focal Loss

Open In Colab Open In SageMaker Studio Lab

In this tutorial, we introduce how to use focal loss with the AutoMM package for balanced training. Focal loss is first introduced in this Paper and can be used for balancing hard/easy samples as well as un-even sample distribution among classes. This tutorial demonstrates how to use focal loss.

Create Dataset

We use the shopee dataset for demonstration in this tutorial. Shopee dataset contains 4 classes and has 200 samples each in the training set.

from autogluon.multimodal.utils.misc import shopee_dataset

download_dir = "./ag_automm_tutorial_imgcls_focalloss"
train_data, test_data = shopee_dataset(download_dir)
Downloading ./ag_automm_tutorial_imgcls_focalloss/file.zip from https://automl-mm-bench.s3.amazonaws.com/vision_datasets/shopee.zip...
/home/ci/autogluon/multimodal/src/autogluon/multimodal/data/templates.py:16: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  import pkg_resources
  0%|          | 0.00/84.0M [00:00<?, ?iB/s]
  8%|▊         | 6.61M/84.0M [00:00<00:03, 23.4MiB/s]
 11%|█▏        | 9.62M/84.0M [00:00<00:02, 25.7MiB/s]
 20%|█▉        | 16.8M/84.0M [00:00<00:01, 35.6MiB/s]
 30%|██▉       | 25.2M/84.0M [00:00<00:01, 45.0MiB/s]
 40%|███▉      | 33.5M/84.0M [00:00<00:01, 49.4MiB/s]
 52%|█████▏    | 43.6M/84.0M [00:00<00:00, 62.5MiB/s]
 60%|█████▉    | 50.4M/84.0M [00:01<00:00, 46.4MiB/s]
 71%|███████   | 59.3M/84.0M [00:01<00:00, 55.7MiB/s]
 80%|███████▉  | 67.1M/84.0M [00:01<00:00, 56.6MiB/s]
 90%|████████▉ | 75.5M/84.0M [00:01<00:00, 60.7MiB/s]
100%|██████████| 84.0M/84.0M [00:01<00:00, 53.4MiB/s]

For the purpose of demonstrating the effectiveness of Focal Loss on imbalanced training data, we artificially downsampled the shopee training data to form an imbalanced distribution.

import numpy as np
import pandas as pd

ds = 1

imbalanced_train_data = []
for lb in range(4):
    class_data = train_data[train_data.label == lb]
    sample_index = np.random.choice(np.arange(len(class_data)), size=int(len(class_data) * ds), replace=False)
    ds /= 3  # downsample 1/3 each time for each class
    imbalanced_train_data.append(class_data.iloc[sample_index])
imbalanced_train_data = pd.concat(imbalanced_train_data)
print(imbalanced_train_data)

weights = []
for lb in range(4):
    class_data = imbalanced_train_data[imbalanced_train_data.label == lb]
    weights.append(1 / (class_data.shape[0] / imbalanced_train_data.shape[0]))
    print(f"class {lb}: num samples {len(class_data)}")
weights = list(np.array(weights) / np.sum(weights))
print(weights)
                                                 image  label
184  /home/ci/autogluon/docs/tutorials/multimodal/a...      0
22   /home/ci/autogluon/docs/tutorials/multimodal/a...      0
134  /home/ci/autogluon/docs/tutorials/multimodal/a...      0
66   /home/ci/autogluon/docs/tutorials/multimodal/a...      0
10   /home/ci/autogluon/docs/tutorials/multimodal/a...      0
..                                                 ...    ...
677  /home/ci/autogluon/docs/tutorials/multimodal/a...      3
738  /home/ci/autogluon/docs/tutorials/multimodal/a...      3
796  /home/ci/autogluon/docs/tutorials/multimodal/a...      3
728  /home/ci/autogluon/docs/tutorials/multimodal/a...      3
645  /home/ci/autogluon/docs/tutorials/multimodal/a...      3

[295 rows x 2 columns]
class 0: num samples 200
class 1: num samples 66
class 2: num samples 22
class 3: num samples 7
[np.float64(0.0239850482815907), np.float64(0.07268196448966878), np.float64(0.21804589346900635), np.float64(0.6852870937597342)]

Create and train MultiModalPredictor

Train with Focal Loss

We specify the model to use focal loss by setting the "optim.loss_func" to "focal_loss". There are also three other optional parameters you can set.

optim.focal_loss.alpha - a list of floats which is the per-class loss weight that can be used to balance un-even sample distribution across classes. Note that the len of the list must match the total number of classes in the training dataset. A good way to compute alpha for each class is to use the inverse of its percentage number of samples.

optim.focal_loss.gamma - float which controls how much to focus on the hard samples. Larger value means more focus on the hard samples.

optim.focal_loss.reduction - how to aggregate the loss value. Can only take "mean" or "sum" for now.

import uuid
from autogluon.multimodal import MultiModalPredictor

model_path = f"./tmp/{uuid.uuid4().hex}-automm_shopee_focal"

predictor = MultiModalPredictor(label="label", problem_type="multiclass", path=model_path)

predictor.fit(
    hyperparameters={
        "model.mmdet_image.checkpoint_name": "swin_tiny_patch4_window7_224",
        "env.num_gpus": 1,
        "optim.loss_func": "focal_loss",
        "optim.focal_loss.alpha": weights,  # shopee dataset has 4 classes.
        "optim.focal_loss.gamma": 1.0,
        "optim.focal_loss.reduction": "sum",
        "optim.max_epochs": 10,
    },
    train_data=imbalanced_train_data,
) 

predictor.evaluate(test_data, metrics=["acc"])
=================== System Info ===================
AutoGluon Version:  1.4.1b20250926
Python Version:     3.12.10
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count:          8
Pytorch Version:    2.7.1+cu126
CUDA Version:       12.6
GPU Count:          1
Memory Avail:       28.48 GB / 30.95 GB (92.0%)
Disk Space Avail:   160.06 GB / 255.99 GB (62.5%)
===================================================

AutoMM starts to create your model. ✨✨✨

To track the learning progress, you can open a terminal and launch Tensorboard:
    ```shell
    # Assume you have installed tensorboard
    tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/f5a8e8a46a424f83b84dd5b011251416-automm_shopee_focal
    ```
Seed set to 0
self._config={'image': {'missing_value_strategy': 'zero'}, 'text': {'normalize_text': False}, 'categorical': {'minimum_cat_count': 100, 'maximum_num_cat': 20, 'convert_to_text': False, 'convert_to_text_template': 'latex'}, 'numerical': {'convert_to_text': False, 'scaler_with_mean': True, 'scaler_with_std': True}, 'document': {'missing_value_strategy': 'zero'}, 'label': {'numerical_preprocessing': 'standardscaler'}, 'pos_label': None, 'ignore_label': None, 'column_features_pooling_mode': 'concat', 'mixup': {'turn_on': False, 'mixup_alpha': 0.8, 'cutmix_alpha': 1.0, 'cutmix_minmax': None, 'prob': 1.0, 'switch_prob': 0.5, 'mode': 'batch', 'turn_off_epoch': 5, 'label_smoothing': 0.1}, 'modality_dropout': 0, 'templates': {'turn_on': False, 'num_templates': 30, 'template_length': 2048, 'preset_templates': ['super_glue', 'rte'], 'custom_templates': None}}
metric_name: self._validation_metric_name='accuracy', num_classes: self._output_shape=4, problem_type: self._problem_type='multiclass'
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[4], line 8
      4 model_path = f"./tmp/{uuid.uuid4().hex}-automm_shopee_focal"
      6 predictor = MultiModalPredictor(label="label", problem_type="multiclass", path=model_path)
----> 8 predictor.fit(
      9     hyperparameters={
     10         "model.mmdet_image.checkpoint_name": "swin_tiny_patch4_window7_224",
     11         "env.num_gpus": 1,
     12         "optim.loss_func": "focal_loss",
     13         "optim.focal_loss.alpha": weights,  # shopee dataset has 4 classes.
     14         "optim.focal_loss.gamma": 1.0,
     15         "optim.focal_loss.reduction": "sum",
     16         "optim.max_epochs": 10,
     17     },
     18     train_data=imbalanced_train_data,
     19 ) 
     21 predictor.evaluate(test_data, metrics=["acc"])

File ~/autogluon/multimodal/src/autogluon/multimodal/predictor.py:543, in MultiModalPredictor.fit(self, train_data, presets, tuning_data, max_num_tuning_data, id_mappings, time_limit, save_path, hyperparameters, column_types, holdout_frac, teacher_predictor, seed, standalone, hyperparameter_tune_kwargs, clean_ckpts, predictions, labels, predictors)
    540     assert isinstance(predictors, list)
    541     learners = [ele if isinstance(ele, str) else ele._learner for ele in predictors]
--> 543 self._learner.fit(
    544     train_data=train_data,
    545     presets=presets,
    546     tuning_data=tuning_data,
    547     max_num_tuning_data=max_num_tuning_data,
    548     time_limit=time_limit,
    549     save_path=save_path,
    550     hyperparameters=hyperparameters,
    551     column_types=column_types,
    552     holdout_frac=holdout_frac,
    553     teacher_learner=teacher_learner,
    554     seed=seed,
    555     standalone=standalone,
    556     hyperparameter_tune_kwargs=hyperparameter_tune_kwargs,
    557     clean_ckpts=clean_ckpts,
    558     id_mappings=id_mappings,
    559     predictions=predictions,
    560     labels=labels,
    561     learners=learners,
    562 )
    564 return self

File ~/autogluon/multimodal/src/autogluon/multimodal/learners/base.py:683, in BaseLearner.fit(self, train_data, presets, tuning_data, time_limit, save_path, hyperparameters, column_types, holdout_frac, teacher_learner, seed, standalone, hyperparameter_tune_kwargs, clean_ckpts, **kwargs)
    676 self.fit_sanity_check()
    677 self.prepare_fit_args(
    678     time_limit=time_limit,
    679     seed=seed,
    680     standalone=standalone,
    681     clean_ckpts=clean_ckpts,
    682 )
--> 683 fit_returns = self.execute_fit()
    684 self.on_fit_end(
    685     training_start=training_start,
    686     strategy=fit_returns.get("strategy", None),
   (...)
    689     clean_ckpts=clean_ckpts,
    690 )
    692 return self

File ~/autogluon/multimodal/src/autogluon/multimodal/learners/base.py:595, in BaseLearner.execute_fit(self)
    593     return dict()
    594 else:
--> 595     attributes = self.fit_per_run(**self._fit_args)
    596     self.update_attributes(**attributes)  # only update attributes for non-HPO mode
    597     return attributes

File ~/autogluon/multimodal/src/autogluon/multimodal/learners/base.py:1321, in BaseLearner.fit_per_run(self, max_time, save_path, ckpt_path, resume, enable_progress_bar, seed, hyperparameters, advanced_hyperparameters, config, df_preprocessor, data_processors, model, standalone, clean_ckpts)
   1319 validation_metric, custom_metric_func = self.get_validation_metric_per_run()
   1320 mixup_active, mixup_func = self.get_mixup_func_per_run(config=config)
-> 1321 loss_func, aug_loss_func = self.get_loss_func_per_run(config=config, mixup_active=mixup_active)
   1322 model_postprocess_fn = self.get_model_postprocess_fn_per_run(loss_func=loss_func)
   1323 num_gpus, strategy = self.get_num_gpus_and_strategy_per_run(config=config)

File ~/autogluon/multimodal/src/autogluon/multimodal/learners/base.py:875, in BaseLearner.get_loss_func_per_run(self, config, mixup_active)
    874 def get_loss_func_per_run(self, config, mixup_active=None):
--> 875     loss_func = get_loss_func(
    876         problem_type=self._problem_type,
    877         mixup_active=mixup_active,
    878         loss_func_name=config.optim.loss_func,
    879         config=config.optim,
    880     )
    881     aug_loss_func = get_aug_loss_func(
    882         config=config.optim,
    883         problem_type=self._problem_type,
    884     )
    885     return loss_func, aug_loss_func

File ~/autogluon/multimodal/src/autogluon/multimodal/optim/losses/utils.py:64, in get_loss_func(problem_type, mixup_active, loss_func_name, config, **kwargs)
     62 else:
     63     if loss_func_name is not None and loss_func_name.lower() == "focal_loss":
---> 64         loss_func = FocalLoss(
     65             alpha=config.focal_loss.alpha,
     66             gamma=config.focal_loss.gamma,
     67             reduction=config.focal_loss.reduction,
     68         )
     69     else:
     70         loss_func = nn.CrossEntropyLoss(label_smoothing=config.label_smoothing)

File ~/autogluon/multimodal/src/autogluon/multimodal/optim/losses/focal_loss.py:49, in FocalLoss.__init__(self, alpha, gamma, reduction, eps)
     47         except:
     48             raise ValueError(f"{type(alpha)} {alpha} is not in a supported format.")
---> 49     alpha = torch.tensor(alpha)
     50 self.nll_loss = nn.NLLLoss(weight=alpha, reduction="none")

ValueError: too many dimensions 'str'

Train without Focal Loss

import uuid
from autogluon.multimodal import MultiModalPredictor

model_path = f"./tmp/{uuid.uuid4().hex}-automm_shopee_non_focal"

predictor2 = MultiModalPredictor(label="label", problem_type="multiclass", path=model_path)

predictor2.fit(
    hyperparameters={
        "model.mmdet_image.checkpoint_name": "swin_tiny_patch4_window7_224",
        "env.num_gpus": 1,
        "optim.max_epochs": 10,
    },
    train_data=imbalanced_train_data,
)

predictor2.evaluate(test_data, metrics=["acc"])

As we can see that the model with focal loss is able to achieve a much better performance compared to the model without focal loss. When your data is imbalanced, try out focal loss to see if it brings improvements to the performance!

Citations

@misc{https://doi.org/10.48550/arxiv.1708.02002,
  doi = {10.48550/ARXIV.1708.02002},
  
  url = {https://arxiv.org/abs/1708.02002},
  
  author = {Lin, Tsung-Yi and Goyal, Priya and Girshick, Ross and He, Kaiming and Dollár, Piotr},
  
  keywords = {Computer Vision and Pattern Recognition (cs.CV), FOS: Computer and information sciences, FOS: Computer and information sciences},
  
  title = {Focal Loss for Dense Object Detection},
  
  publisher = {arXiv},
  
  year = {2017},
  
  copyright = {arXiv.org perpetual, non-exclusive license}
}