AutoMM for Semantic Segmentation - Quick Start

Open In Colab Open In SageMaker Studio Lab

Semantic Segmentation is a computer vision task where the objective is to create a detailed pixel-wise segmentation map of an image, assigning each pixel to a specific class or object. This technology is crucial in various applications, such as in autonomous vehicles to identify vehicles, pedestrians, traffic signs, pavement, and other road features.

The Segment Anything Model (SAM) is a foundational model pretrained on a vast dataset with 1 billion masks and 11 million images. While SAM performs exceptionally well on generic scenes, it encounters challenges when applied to specialized domains like remote sensing, medical imagery, agriculture, and manufacturing. Fortunately, AutoMM comes to the rescue by facilitating the fine-tuning of SAM on domain-specific data.

In this easy-to-follow tutorial, we will guide you through the process of using AutoMM to fine-tune SAM. With just a single call to the fit() API, you can effortlessly train the model.

Prepare Data

For demonstration purposes, we use the Leaf Disease Segmentation from Kaggle. This dataset is a good example for automating disease detection in plants, especially for speeding up the plant pathology process. Segmenting specific regions on leaves or plants can be quite challenging, particularly when dealing with smaller diseased areas or various types of diseases.

To begin, download and prepare the dataset.

download_dir = './ag_automm_tutorial'
zip_file = 'https://automl-mm-bench.s3.amazonaws.com/semantic_segmentation/leaf_disease_segmentation.zip'
from autogluon.core.utils.loaders import load_zip
load_zip.unzip(zip_file, unzip_dir=download_dir)
Downloading ./ag_automm_tutorial/file.zip from https://automl-mm-bench.s3.amazonaws.com/semantic_segmentation/leaf_disease_segmentation.zip...
  0%|          | 0.00/53.3M [00:00<?, ?iB/s]
 12%|█▏        | 6.61M/53.3M [00:00<00:00, 58.1MiB/s]
 23%|██▎       | 12.4M/53.3M [00:00<00:01, 40.6MiB/s]
 31%|███▏      | 16.8M/53.3M [00:00<00:01, 31.6MiB/s]
 47%|████▋     | 25.2M/53.3M [00:00<00:00, 35.6MiB/s]
 63%|██████▎   | 33.5M/53.3M [00:00<00:00, 39.1MiB/s]
 79%|███████▊  | 41.9M/53.3M [00:01<00:00, 41.8MiB/s]
 92%|█████████▏| 48.9M/53.3M [00:01<00:00, 30.1MiB/s]
 98%|█████████▊| 52.4M/53.3M [00:01<00:00, 28.0MiB/s]
100%|██████████| 53.3M/53.3M [00:01<00:00, 30.2MiB/s]

Next, load the CSV files, ensuring that relative paths are expanded to facilitate correct data loading during both training and testing.

import pandas as pd
import os
dataset_path = os.path.join(download_dir, 'leaf_disease_segmentation')
train_data = pd.read_csv(f'{dataset_path}/train.csv', index_col=0)
val_data = pd.read_csv(f'{dataset_path}/val.csv', index_col=0)
test_data = pd.read_csv(f'{dataset_path}/test.csv', index_col=0)
image_col = 'image'
label_col = 'label'
def path_expander(path, base_folder):
    path_l = path.split(';')
    return ';'.join([os.path.abspath(os.path.join(base_folder, path)) for path in path_l])

for per_col in [image_col, label_col]:
    train_data[per_col] = train_data[per_col].apply(lambda ele: path_expander(ele, base_folder=dataset_path))
    val_data[per_col] = val_data[per_col].apply(lambda ele: path_expander(ele, base_folder=dataset_path))
    test_data[per_col] = test_data[per_col].apply(lambda ele: path_expander(ele, base_folder=dataset_path))
    

print(train_data[image_col].iloc[0])
print(train_data[label_col].iloc[0])
/home/ci/autogluon/docs/tutorials/multimodal/image_segmentation/ag_automm_tutorial/leaf_disease_segmentation/train_images/00002.jpg
/home/ci/autogluon/docs/tutorials/multimodal/image_segmentation/ag_automm_tutorial/leaf_disease_segmentation/train_masks/00002.png

Each Pandas DataFrame contains two columns: one for image paths and the other for corresponding groundtruth masks. Let’s take a closer look at the training data DataFrame.

train_data.head()
image label
0 /home/ci/autogluon/docs/tutorials/multimodal/i... /home/ci/autogluon/docs/tutorials/multimodal/i...
1 /home/ci/autogluon/docs/tutorials/multimodal/i... /home/ci/autogluon/docs/tutorials/multimodal/i...
2 /home/ci/autogluon/docs/tutorials/multimodal/i... /home/ci/autogluon/docs/tutorials/multimodal/i...
3 /home/ci/autogluon/docs/tutorials/multimodal/i... /home/ci/autogluon/docs/tutorials/multimodal/i...
4 /home/ci/autogluon/docs/tutorials/multimodal/i... /home/ci/autogluon/docs/tutorials/multimodal/i...

We can also visualize one image and its groundtruth mask.

from autogluon.multimodal.utils import SemanticSegmentationVisualizer
visualizer = SemanticSegmentationVisualizer()
visualizer.plot_image(test_data.iloc[0]['image'])
/home/ci/autogluon/multimodal/src/autogluon/multimodal/data/templates.py:16: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  import pkg_resources
../../../_images/fe216d3b1ebdb5e9d170581c7865657759146bf00eb97ce753b6886966a80d80.png
visualizer.plot_image(test_data.iloc[0]['label'])
../../../_images/186e83cdfd2e474902dfff1e30dca3f527b7b36bf109e30dd23c2907c2482efd.png

Zero Shot Evaluation

Now, let’s see how well the pretrained SAM can segment the images. For this demonstration, we’ll use the base SAM model.

from autogluon.multimodal import MultiModalPredictor
predictor_zero_shot = MultiModalPredictor(
    problem_type="semantic_segmentation", 
    label=label_col,
     hyperparameters={
            "model.sam.checkpoint_name": "facebook/sam-vit-base",
        },
    num_classes=1, # forground-background segmentation
)

After initializing the predictor, you can perform inference directly.

pred_zero_shot = predictor_zero_shot.predict({'image': [test_data.iloc[0]['image']]})
INFO: 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
visualizer.plot_mask(pred_zero_shot)
../../../_images/a34814264f16b4483c7cf578ff744d16c8eb39bcdf7e8bb57378282125e67976.png

It’s worth noting that SAM without prompts outputs a rough leaf mask instead of disease masks due to its lack of context about the domain task. While SAM can perform better with proper click prompts, it might not be an ideal end-to-end solution for some applications that require a standalone model for deployment.

You can also conduct a zero-shot evaluation on the test data.

scores = predictor_zero_shot.evaluate(test_data, metrics=["iou"])
print(scores)
{'iou': 0.14082679152488708}
INFO: 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.

As expected, the test score of the zero-shot SAM is relatively low. Next, let’s explore how to fine-tune SAM for enhanced performance.

Finetune SAM

Initialize a new predictor and fit it with the training and validation data.

from autogluon.multimodal import MultiModalPredictor
import uuid
save_path = f"./tmp/{uuid.uuid4().hex}-automm_semantic_seg"
predictor = MultiModalPredictor(
    problem_type="semantic_segmentation", 
    label="label",
     hyperparameters={
            "model.sam.checkpoint_name": "facebook/sam-vit-base",
        },
    path=save_path,
)
predictor.fit(
    train_data=train_data,
    tuning_data=val_data,
    time_limit=180, # seconds
)
=================== System Info ===================
AutoGluon Version:  1.4.1b20250917
Python Version:     3.12.10
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count:          8
Pytorch Version:    2.7.1+cu126
CUDA Version:       12.6
GPU Count:          1
Memory Avail:       26.55 GB / 30.95 GB (85.8%)
Disk Space Avail:   180.34 GB / 255.99 GB (70.4%)
===================================================

AutoMM starts to create your model. ✨✨✨

To track the learning progress, you can open a terminal and launch Tensorboard:
    ```shell
    # Assume you have installed tensorboard
    tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/image_segmentation/tmp/a8fb6daa218842f28bbfc738aeb86ded-automm_semantic_seg
    ```
INFO: Seed set to 0
GPU Count: 1
GPU Count to be Used: 1
INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO: GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO: `Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/ci/opt/venv/lib/python3.12/site-packages/lightning/pytorch/utilities/model_summary/model_summary.py:231: Precision 16-mixed is not supported by the model summary.  Estimated model size in MB will not be accurate. Using 32 bits instead.
INFO: 
  | Name              | Type                       | Params | Mode 
-------------------------------------------------------------------------
0 | model             | SAMForSemanticSegmentation | 93.4 M | train
1 | validation_metric | Binary_IoU                 | 0      | train
2 | loss_func         | StructureLoss              | 0      | train
-------------------------------------------------------------------------
3.6 M     Trainable params
89.8 M    Non-trainable params
93.4 M    Total params
373.703   Total estimated model params size (MB)
17        Modules in train mode
208       Modules in eval mode
/home/ci/opt/venv/lib/python3.12/site-packages/torch/nn/_reduction.py:51: UserWarning: size_average and reduce args will be deprecated, please use reduction='mean' instead.
  warnings.warn(warning.format(ret))
/home/ci/opt/venv/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py:527: Found 208 module(s) in eval mode at the start of training. This may lead to unexpected behavior during training. If this is intentional, you can ignore this warning.
INFO: Time limit reached. Elapsed time is 0:03:00. Signaling Trainer to stop.
INFO: Epoch 0, global step 96: 'val_iou' reached 0.59721 (best 0.59721), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/image_segmentation/tmp/a8fb6daa218842f28bbfc738aeb86ded-automm_semantic_seg/epoch=0-step=96.ckpt' as top 3
AutoMM has created your model. 🎉🎉🎉

To load the model, use the code below:
    ```python
    from autogluon.multimodal import MultiModalPredictor
    predictor = MultiModalPredictor.load("/home/ci/autogluon/docs/tutorials/multimodal/image_segmentation/tmp/a8fb6daa218842f28bbfc738aeb86ded-automm_semantic_seg")
    ```

If you are not satisfied with the model, try to increase the training time, 
adjust the hyperparameters (https://auto.gluon.ai/stable/tutorials/multimodal/advanced_topics/customization.html),
or post issues on GitHub (https://github.com/autogluon/autogluon/issues).
<autogluon.multimodal.predictor.MultiModalPredictor at 0x7f655358e150>

Under the hood, we use LoRA for efficient fine-tuning. Note that, without hyperparameter customization, the huge SAM serves as the default model, which requires efficient fine-tuning in many cases.

After fine-tuning, evaluate SAM on the test data.

scores = predictor.evaluate(test_data, metrics=["iou"])
print(scores)
{'iou': 0.575222909450531}
INFO: 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.

Thanks to the fine-tuning process, the test score has significantly improved.

To visualize the impact, let’s examine the predicted mask after fine-tuning.

pred = predictor.predict({'image': [test_data.iloc[0]['image']]})
INFO: 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
visualizer.plot_mask(pred)
../../../_images/e4dbad5a648ea20d1ff3a86b2b8b707019b63ca994ad8300bf6b0d5da6c0964a.png

As evident from the results, the predicted mask is now much closer to the groundtruth. This demonstrates the effectiveness of using AutoMM to fine-tune SAM for domain-specific applications, enhancing its performance in tasks like leaf disease segmentation.

Save and Load

The trained predictor is automatically saved at the end of fit(), and you can easily reload it.

Warning

MultiModalPredictor.load() uses pickle module implicitly, which is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Never load data that could have come from an untrusted source, or that could have been tampered with. Only load data you trust.

loaded_predictor = MultiModalPredictor.load(save_path)
scores = loaded_predictor.evaluate(test_data, metrics=["iou"])
print(scores)
Load pretrained checkpoint: /home/ci/autogluon/docs/tutorials/multimodal/image_segmentation/tmp/a8fb6daa218842f28bbfc738aeb86ded-automm_semantic_seg/model.ckpt
INFO: 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
{'iou': 0.575222909450531}

We can see the evaluation score is still the same as above, which means same model!

Other Examples

You may go to AutoMM Examples to explore other examples about AutoMM.

Customization

To learn how to customize AutoMM, please refer to Customize AutoMM.