Image-to-Image Semantic Matching with AutoMM¶
Computing the similarity between two images is a common task in computer vision, with several practical applications such as detecting same or different product, etc. In general, image similarity models will take two images as input and transform them into vectors, and then similarity scores calculated using cosine similarity, dot product, or Euclidean distances are used to measure how alike or different of the two images.
import os
import pandas as pd
import warnings
from IPython.display import Image, display
warnings.filterwarnings('ignore')
Prepare your Data¶
In this tutorial, we will demonstrate how to use AutoMM for image-to-image semantic matching with the simplified Stanford Online Products dataset (SOP).
Stanford Online Products dataset is introduced for metric learning. There are 12 categories of products in this dataset: bicycle, cabinet, chair, coffee maker, fan, kettle, lamp, mug, sofa, stapler, table and toaster. Each category has some products, and each product has several images captured from different views. Here, we consider different views of the same product as positive pairs (labeled as 1) and images from different products as negative pairs (labeled as 0).
The following code downloads the dataset and unzip the images and annotation files.
download_dir = './ag_automm_tutorial_img2img'
zip_file = 'https://automl-mm-bench.s3.amazonaws.com/Stanford_Online_Products.zip'
from autogluon.core.utils.loaders import load_zip
load_zip.unzip(zip_file, unzip_dir=download_dir)
Downloading ./ag_automm_tutorial_img2img/file.zip from https://automl-mm-bench.s3.amazonaws.com/Stanford_Online_Products.zip...
  0%|          | 0.00/3.08G [00:00<?, ?iB/s]
  0%|          | 8.38M/3.08G [00:00<01:05, 46.6MiB/s]
  1%|          | 16.8M/3.08G [00:00<00:52, 58.9MiB/s]
  1%|          | 23.8M/3.08G [00:00<01:01, 50.0MiB/s]
  1%|          | 29.0M/3.08G [00:00<01:07, 45.4MiB/s]
  1%|          | 33.7M/3.08G [00:00<01:30, 33.8MiB/s]
  1%|▏         | 41.9M/3.08G [00:00<01:11, 42.3MiB/s]
  2%|▏         | 50.8M/3.08G [00:01<00:56, 53.2MiB/s]
  2%|▏         | 58.7M/3.08G [00:01<00:53, 56.6MiB/s]
  2%|▏         | 67.1M/3.08G [00:01<00:51, 58.2MiB/s]
  2%|▏         | 73.7M/3.08G [00:01<01:03, 47.7MiB/s]
  3%|▎         | 79.0M/3.08G [00:01<01:06, 45.3MiB/s]
  3%|▎         | 83.9M/3.08G [00:01<01:16, 39.4MiB/s]
  3%|▎         | 90.5M/3.08G [00:01<01:12, 41.5MiB/s]
  3%|▎         | 94.9M/3.08G [00:02<01:15, 39.8MiB/s]
  3%|▎         | 99.3M/3.08G [00:02<01:14, 40.2MiB/s]
  3%|▎         | 103M/3.08G [00:02<01:18, 37.8MiB/s]
  4%|▎         | 109M/3.08G [00:02<01:18, 37.8MiB/s]
  4%|▍         | 116M/3.08G [00:02<01:10, 42.3MiB/s]
  4%|▍         | 120M/3.08G [00:02<01:18, 37.8MiB/s]
  4%|▍         | 126M/3.08G [00:02<01:11, 41.5MiB/s]
  4%|▍         | 134M/3.08G [00:02<01:01, 48.3MiB/s]
  5%|▍         | 143M/3.08G [00:03<01:02, 47.1MiB/s]
  5%|▍         | 151M/3.08G [00:03<01:01, 47.4MiB/s]
  5%|▌         | 158M/3.08G [00:03<00:58, 49.7MiB/s]
  5%|▌         | 163M/3.08G [00:03<01:03, 45.9MiB/s]
  5%|▌         | 168M/3.08G [00:03<01:18, 37.1MiB/s]
  6%|▌         | 176M/3.08G [00:03<01:08, 42.3MiB/s]
  6%|▌         | 185M/3.08G [00:04<01:07, 42.7MiB/s]
  6%|▌         | 191M/3.08G [00:04<01:07, 42.8MiB/s]
  6%|▋         | 196M/3.08G [00:04<01:13, 39.2MiB/s]
  7%|▋         | 201M/3.08G [00:04<01:17, 37.1MiB/s]
  7%|▋         | 208M/3.08G [00:04<01:16, 37.5MiB/s]
  7%|▋         | 212M/3.08G [00:04<01:21, 35.4MiB/s]
  7%|▋         | 215M/3.08G [00:05<01:38, 29.2MiB/s]
  7%|▋         | 218M/3.08G [00:05<01:49, 26.1MiB/s]
  7%|▋         | 226M/3.08G [00:05<01:37, 29.2MiB/s]
  8%|▊         | 234M/3.08G [00:05<01:15, 37.9MiB/s]
  8%|▊         | 238M/3.08G [00:05<01:18, 36.2MiB/s]
  8%|▊         | 243M/3.08G [00:05<01:17, 36.9MiB/s]
  8%|▊         | 252M/3.08G [00:06<01:05, 43.4MiB/s]
  8%|▊         | 260M/3.08G [00:06<01:09, 40.6MiB/s]
  9%|▊         | 267M/3.08G [00:06<01:08, 41.1MiB/s]
  9%|▉         | 271M/3.08G [00:06<01:14, 37.5MiB/s]
  9%|▉         | 275M/3.08G [00:06<01:20, 34.9MiB/s]
  9%|▉         | 279M/3.08G [00:06<01:26, 32.5MiB/s]
  9%|▉         | 283M/3.08G [00:07<01:29, 31.2MiB/s]
  9%|▉         | 287M/3.08G [00:07<01:30, 31.1MiB/s]
 10%|▉         | 294M/3.08G [00:07<01:24, 33.2MiB/s]
 10%|▉         | 300M/3.08G [00:07<01:12, 38.7MiB/s]
 10%|▉         | 305M/3.08G [00:07<01:07, 41.1MiB/s]
 10%|█         | 310M/3.08G [00:07<01:16, 36.5MiB/s]
 10%|█         | 314M/3.08G [00:07<01:22, 33.7MiB/s]
 10%|█         | 319M/3.08G [00:08<01:46, 26.0MiB/s]
 11%|█         | 327M/3.08G [00:08<01:17, 35.6MiB/s]
 11%|█         | 336M/3.08G [00:08<01:04, 42.4MiB/s]
 11%|█         | 342M/3.08G [00:08<00:59, 46.3MiB/s]
 11%|█▏        | 347M/3.08G [00:08<01:07, 40.3MiB/s]
 11%|█▏        | 352M/3.08G [00:08<01:06, 40.9MiB/s]
 12%|█▏        | 361M/3.08G [00:08<00:56, 48.0MiB/s]
 12%|█▏        | 369M/3.08G [00:09<00:55, 49.2MiB/s]
 12%|█▏        | 374M/3.08G [00:09<01:16, 35.2MiB/s]
 12%|█▏        | 378M/3.08G [00:09<01:22, 32.9MiB/s]
 13%|█▎        | 386M/3.08G [00:09<01:16, 35.1MiB/s]
 13%|█▎        | 394M/3.08G [00:09<01:10, 38.0MiB/s]
 13%|█▎        | 403M/3.08G [00:10<01:09, 38.4MiB/s]
 13%|█▎        | 409M/3.08G [00:10<01:01, 43.4MiB/s]
 13%|█▎        | 414M/3.08G [00:10<01:03, 42.2MiB/s]
 14%|█▎        | 418M/3.08G [00:10<01:08, 39.0MiB/s]
 14%|█▎        | 423M/3.08G [00:10<01:13, 36.0MiB/s]
 14%|█▍        | 426M/3.08G [00:10<01:14, 35.7MiB/s]
 14%|█▍        | 430M/3.08G [00:10<01:22, 32.3MiB/s]
 14%|█▍        | 436M/3.08G [00:11<01:19, 33.4MiB/s]
 14%|█▍        | 445M/3.08G [00:11<01:06, 39.9MiB/s]
 15%|█▍        | 449M/3.08G [00:11<01:09, 38.1MiB/s]
 15%|█▍        | 453M/3.08G [00:11<01:11, 36.8MiB/s]
 15%|█▍        | 460M/3.08G [00:11<01:08, 38.1MiB/s]
 15%|█▌        | 463M/3.08G [00:11<01:13, 35.8MiB/s]
 15%|█▌        | 470M/3.08G [00:11<01:04, 40.5MiB/s]
 16%|█▌        | 478M/3.08G [00:12<00:56, 46.1MiB/s]
 16%|█▌        | 485M/3.08G [00:12<01:01, 42.4MiB/s]
 16%|█▌        | 489M/3.08G [00:12<01:06, 39.2MiB/s]
 16%|█▌        | 495M/3.08G [00:12<01:03, 40.6MiB/s]
 16%|█▋        | 503M/3.08G [00:12<00:56, 45.9MiB/s]
 17%|█▋        | 510M/3.08G [00:12<00:51, 49.7MiB/s]
 17%|█▋        | 515M/3.08G [00:12<01:02, 41.2MiB/s]
 17%|█▋        | 519M/3.08G [00:13<01:16, 33.5MiB/s]
 17%|█▋        | 523M/3.08G [00:13<01:28, 29.0MiB/s]
 17%|█▋        | 528M/3.08G [00:13<01:22, 31.1MiB/s]
 17%|█▋        | 535M/3.08G [00:13<01:34, 27.1MiB/s]
 17%|█▋        | 538M/3.08G [00:13<01:39, 25.7MiB/s]
 18%|█▊        | 546M/3.08G [00:14<01:11, 35.4MiB/s]
 18%|█▊        | 552M/3.08G [00:14<01:11, 35.3MiB/s]
 18%|█▊        | 556M/3.08G [00:14<01:09, 36.4MiB/s]
 18%|█▊        | 560M/3.08G [00:14<01:07, 37.2MiB/s]
 18%|█▊        | 564M/3.08G [00:14<01:22, 30.6MiB/s]
 18%|█▊        | 570M/3.08G [00:14<01:19, 31.6MiB/s]
 19%|█▉        | 579M/3.08G [00:14<01:04, 38.9MiB/s]
 19%|█▉        | 587M/3.08G [00:15<00:55, 44.8MiB/s]
 19%|█▉        | 596M/3.08G [00:15<00:51, 48.5MiB/s]
 20%|█▉        | 602M/3.08G [00:15<00:48, 51.7MiB/s]
 20%|█▉        | 608M/3.08G [00:15<00:58, 42.2MiB/s]
 20%|█▉        | 612M/3.08G [00:15<01:02, 39.3MiB/s]
 20%|█▉        | 616M/3.08G [00:15<01:05, 37.5MiB/s]
 20%|██        | 620M/3.08G [00:15<01:15, 32.6MiB/s]
 20%|██        | 624M/3.08G [00:16<01:22, 30.0MiB/s]
 20%|██        | 629M/3.08G [00:16<01:17, 31.5MiB/s]
 21%|██        | 638M/3.08G [00:16<01:01, 39.5MiB/s]
 21%|██        | 645M/3.08G [00:16<00:54, 44.8MiB/s]
 21%|██        | 649M/3.08G [00:16<01:05, 37.3MiB/s]
 21%|██        | 655M/3.08G [00:16<00:58, 41.6MiB/s]
 21%|██▏       | 663M/3.08G [00:16<00:52, 46.2MiB/s]
 22%|██▏       | 671M/3.08G [00:17<00:55, 43.8MiB/s]
 22%|██▏       | 678M/3.08G [00:17<01:07, 35.9MiB/s]
 22%|██▏       | 682M/3.08G [00:17<01:07, 35.8MiB/s]
 22%|██▏       | 688M/3.08G [00:17<01:09, 34.5MiB/s]
 23%|██▎       | 696M/3.08G [00:17<01:01, 38.6MiB/s]
 23%|██▎       | 705M/3.08G [00:18<00:59, 40.2MiB/s]
 23%|██▎       | 713M/3.08G [00:18<00:52, 45.5MiB/s]
 23%|██▎       | 721M/3.08G [00:18<00:46, 50.8MiB/s]
 24%|██▎       | 728M/3.08G [00:18<00:56, 42.0MiB/s]
 24%|██▍       | 733M/3.08G [00:18<01:02, 37.9MiB/s]
 24%|██▍       | 738M/3.08G [00:18<01:02, 37.8MiB/s]
 24%|██▍       | 745M/3.08G [00:19<01:06, 35.2MiB/s]
 24%|██▍       | 749M/3.08G [00:19<01:08, 34.2MiB/s]
 24%|██▍       | 754M/3.08G [00:19<01:06, 35.1MiB/s]
 25%|██▍       | 757M/3.08G [00:19<01:07, 34.4MiB/s]
 25%|██▍       | 766M/3.08G [00:19<00:49, 47.0MiB/s]
 25%|██▌       | 772M/3.08G [00:19<00:49, 46.4MiB/s]
 25%|██▌       | 780M/3.08G [00:19<00:46, 49.9MiB/s]
 26%|██▌       | 787M/3.08G [00:20<00:50, 45.2MiB/s]
 26%|██▌       | 791M/3.08G [00:20<00:50, 45.2MiB/s]
 26%|██▌       | 797M/3.08G [00:20<00:50, 45.6MiB/s]
 26%|██▌       | 805M/3.08G [00:20<00:44, 51.3MiB/s]
 26%|██▋       | 814M/3.08G [00:20<00:41, 54.6MiB/s]
 27%|██▋       | 822M/3.08G [00:20<00:43, 52.0MiB/s]
 27%|██▋       | 830M/3.08G [00:20<00:42, 52.6MiB/s]
 27%|██▋       | 837M/3.08G [00:20<00:40, 55.4MiB/s]
 27%|██▋       | 843M/3.08G [00:21<00:46, 48.0MiB/s]
 27%|██▋       | 848M/3.08G [00:21<00:48, 46.6MiB/s]
 28%|██▊       | 856M/3.08G [00:21<00:46, 48.2MiB/s]
 28%|██▊       | 862M/3.08G [00:21<00:52, 42.7MiB/s]
 28%|██▊       | 867M/3.08G [00:21<00:54, 40.6MiB/s]
 28%|██▊       | 872M/3.08G [00:21<01:00, 36.6MiB/s]
 29%|██▊       | 879M/3.08G [00:22<01:31, 24.1MiB/s]
 29%|██▊       | 882M/3.08G [00:22<01:32, 23.7MiB/s]
 29%|██▉       | 887M/3.08G [00:22<01:21, 26.8MiB/s]
 29%|██▉       | 891M/3.08G [00:22<01:22, 26.7MiB/s]
 29%|██▉       | 899M/3.08G [00:22<00:58, 37.5MiB/s]
 29%|██▉       | 906M/3.08G [00:23<00:50, 43.5MiB/s]
 30%|██▉       | 914M/3.08G [00:23<00:47, 46.0MiB/s]
 30%|██▉       | 923M/3.08G [00:23<00:51, 42.2MiB/s]
 30%|███       | 930M/3.08G [00:23<00:47, 45.0MiB/s]
 30%|███       | 935M/3.08G [00:23<00:52, 41.2MiB/s]
 30%|███       | 940M/3.08G [00:23<00:53, 40.2MiB/s]
 31%|███       | 948M/3.08G [00:23<00:46, 46.2MiB/s]
 31%|███       | 956M/3.08G [00:24<00:42, 49.9MiB/s]
 31%|███▏      | 965M/3.08G [00:24<00:45, 46.5MiB/s]
 31%|███▏      | 971M/3.08G [00:24<00:47, 44.2MiB/s]
 32%|███▏      | 976M/3.08G [00:24<00:49, 42.7MiB/s]
 32%|███▏      | 981M/3.08G [00:24<00:51, 40.9MiB/s]
 32%|███▏      | 990M/3.08G [00:24<00:41, 49.9MiB/s]
 32%|███▏      | 997M/3.08G [00:24<00:38, 53.9MiB/s]
 32%|███▏      | 1.00G/3.08G [00:25<00:41, 50.4MiB/s]
 33%|███▎      | 1.01G/3.08G [00:25<00:51, 40.7MiB/s]
 33%|███▎      | 1.01G/3.08G [00:25<00:51, 40.2MiB/s]
 33%|███▎      | 1.02G/3.08G [00:25<00:51, 40.1MiB/s]
 33%|███▎      | 1.02G/3.08G [00:25<00:57, 36.2MiB/s]
 33%|███▎      | 1.03G/3.08G [00:25<01:06, 31.1MiB/s]
 33%|███▎      | 1.03G/3.08G [00:25<00:50, 40.2MiB/s]
 34%|███▎      | 1.04G/3.08G [00:26<00:52, 38.8MiB/s]
 34%|███▍      | 1.04G/3.08G [00:26<01:12, 28.0MiB/s]
 34%|███▍      | 1.05G/3.08G [00:26<01:01, 32.8MiB/s]
 34%|███▍      | 1.05G/3.08G [00:26<01:05, 31.2MiB/s]
 34%|███▍      | 1.06G/3.08G [00:26<00:49, 41.1MiB/s]
 35%|███▍      | 1.06G/3.08G [00:26<00:50, 39.9MiB/s]
 35%|███▍      | 1.07G/3.08G [00:27<00:53, 37.3MiB/s]
 35%|███▍      | 1.07G/3.08G [00:27<00:56, 35.7MiB/s]
 35%|███▍      | 1.08G/3.08G [00:27<00:56, 35.6MiB/s]
 35%|███▌      | 1.08G/3.08G [00:27<00:46, 43.2MiB/s]
 35%|███▌      | 1.09G/3.08G [00:27<00:48, 41.0MiB/s]
 36%|███▌      | 1.10G/3.08G [00:27<00:46, 42.7MiB/s]
 36%|███▌      | 1.10G/3.08G [00:27<00:50, 39.0MiB/s]
 36%|███▌      | 1.11G/3.08G [00:27<00:48, 40.4MiB/s]
 36%|███▌      | 1.12G/3.08G [00:28<00:40, 49.1MiB/s]
 36%|███▋      | 1.12G/3.08G [00:28<00:37, 52.7MiB/s]
 37%|███▋      | 1.13G/3.08G [00:28<00:41, 46.8MiB/s]
 37%|███▋      | 1.13G/3.08G [00:28<00:55, 35.1MiB/s]
 37%|███▋      | 1.14G/3.08G [00:28<00:47, 40.7MiB/s]
 37%|███▋      | 1.15G/3.08G [00:28<00:43, 44.2MiB/s]
 38%|███▊      | 1.16G/3.08G [00:29<00:42, 45.4MiB/s]
 38%|███▊      | 1.16G/3.08G [00:29<00:37, 50.9MiB/s]
 38%|███▊      | 1.17G/3.08G [00:29<00:41, 46.6MiB/s]
 38%|███▊      | 1.18G/3.08G [00:29<00:44, 43.1MiB/s]
 38%|███▊      | 1.18G/3.08G [00:29<00:39, 48.5MiB/s]
 39%|███▊      | 1.19G/3.08G [00:29<00:37, 50.1MiB/s]
 39%|███▉      | 1.20G/3.08G [00:29<00:36, 52.0MiB/s]
 39%|███▉      | 1.21G/3.08G [00:30<00:35, 52.3MiB/s]
 39%|███▉      | 1.21G/3.08G [00:30<00:37, 50.3MiB/s]
 40%|███▉      | 1.22G/3.08G [00:30<00:59, 31.4MiB/s]
 40%|███▉      | 1.22G/3.08G [00:30<00:49, 37.5MiB/s]
 40%|███▉      | 1.23G/3.08G [00:30<00:44, 41.2MiB/s]
 40%|████      | 1.24G/3.08G [00:30<00:43, 42.8MiB/s]
 40%|████      | 1.25G/3.08G [00:31<00:46, 39.4MiB/s]
 41%|████      | 1.25G/3.08G [00:31<01:00, 30.2MiB/s]
 41%|████      | 1.26G/3.08G [00:31<00:57, 31.5MiB/s]
 41%|████      | 1.26G/3.08G [00:31<00:58, 31.1MiB/s]
 41%|████      | 1.27G/3.08G [00:31<00:52, 34.5MiB/s]
 41%|████▏     | 1.28G/3.08G [00:31<00:46, 39.0MiB/s]
 42%|████▏     | 1.28G/3.08G [00:32<00:39, 45.6MiB/s]
 42%|████▏     | 1.29G/3.08G [00:32<00:40, 43.7MiB/s]
 42%|████▏     | 1.30G/3.08G [00:32<00:39, 44.9MiB/s]
 42%|████▏     | 1.31G/3.08G [00:32<00:37, 46.9MiB/s]
 43%|████▎     | 1.31G/3.08G [00:32<00:44, 39.4MiB/s]
 43%|████▎     | 1.32G/3.08G [00:32<00:45, 38.7MiB/s]
 43%|████▎     | 1.33G/3.08G [00:33<00:40, 43.3MiB/s]
 43%|████▎     | 1.33G/3.08G [00:33<00:45, 38.1MiB/s]
 43%|████▎     | 1.34G/3.08G [00:33<00:51, 34.2MiB/s]
 44%|████▎     | 1.34G/3.08G [00:33<00:41, 42.1MiB/s]
 44%|████▍     | 1.35G/3.08G [00:33<00:37, 45.9MiB/s]
 44%|████▍     | 1.35G/3.08G [00:33<00:42, 41.1MiB/s]
 44%|████▍     | 1.36G/3.08G [00:33<00:43, 40.1MiB/s]
 44%|████▍     | 1.37G/3.08G [00:34<00:40, 42.4MiB/s]
 44%|████▍     | 1.37G/3.08G [00:34<00:44, 38.4MiB/s]
 45%|████▍     | 1.37G/3.08G [00:34<00:42, 40.0MiB/s]
 45%|████▍     | 1.38G/3.08G [00:34<00:43, 39.6MiB/s]
 45%|████▍     | 1.38G/3.08G [00:34<00:44, 37.9MiB/s]
 45%|████▌     | 1.39G/3.08G [00:34<00:35, 47.8MiB/s]
 45%|████▌     | 1.40G/3.08G [00:34<00:36, 46.3MiB/s]
 46%|████▌     | 1.41G/3.08G [00:35<00:33, 50.5MiB/s]
 46%|████▌     | 1.41G/3.08G [00:35<00:34, 48.4MiB/s]
 46%|████▌     | 1.42G/3.08G [00:35<00:53, 31.4MiB/s]
 46%|████▌     | 1.43G/3.08G [00:35<00:43, 38.1MiB/s]
 47%|████▋     | 1.43G/3.08G [00:35<00:37, 44.5MiB/s]
 47%|████▋     | 1.44G/3.08G [00:35<00:34, 47.3MiB/s]
 47%|████▋     | 1.45G/3.08G [00:36<00:36, 45.0MiB/s]
 47%|████▋     | 1.46G/3.08G [00:36<00:35, 46.3MiB/s]
 48%|████▊     | 1.47G/3.08G [00:36<00:34, 46.6MiB/s]
 48%|████▊     | 1.48G/3.08G [00:36<00:34, 46.2MiB/s]
 48%|████▊     | 1.48G/3.08G [00:36<00:32, 49.5MiB/s]
 48%|████▊     | 1.49G/3.08G [00:36<00:29, 53.2MiB/s]
 49%|████▊     | 1.50G/3.08G [00:37<00:31, 49.7MiB/s]
 49%|████▊     | 1.50G/3.08G [00:37<00:40, 39.2MiB/s]
 49%|████▉     | 1.51G/3.08G [00:37<00:39, 40.0MiB/s]
 49%|████▉     | 1.52G/3.08G [00:37<00:32, 48.1MiB/s]
 49%|████▉     | 1.52G/3.08G [00:37<00:35, 44.4MiB/s]
 50%|████▉     | 1.53G/3.08G [00:37<00:37, 41.4MiB/s]
 50%|████▉     | 1.53G/3.08G [00:37<00:39, 39.5MiB/s]
 50%|████▉     | 1.54G/3.08G [00:38<00:49, 31.3MiB/s]
 50%|████▉     | 1.54G/3.08G [00:38<00:48, 31.7MiB/s]
 50%|█████     | 1.54G/3.08G [00:38<00:50, 30.3MiB/s]
 50%|█████     | 1.55G/3.08G [00:38<00:40, 37.5MiB/s]
 51%|█████     | 1.56G/3.08G [00:38<00:37, 41.0MiB/s]
 51%|█████     | 1.57G/3.08G [00:38<00:29, 50.6MiB/s]
 51%|█████     | 1.58G/3.08G [00:38<00:29, 51.9MiB/s]
 51%|█████▏    | 1.59G/3.08G [00:39<00:25, 57.8MiB/s]
 52%|█████▏    | 1.59G/3.08G [00:39<00:25, 59.5MiB/s]
 52%|█████▏    | 1.60G/3.08G [00:39<00:36, 40.6MiB/s]
 52%|█████▏    | 1.60G/3.08G [00:39<00:36, 40.3MiB/s]
 52%|█████▏    | 1.61G/3.08G [00:39<00:31, 46.6MiB/s]
 52%|█████▏    | 1.62G/3.08G [00:39<00:36, 40.0MiB/s]
 53%|█████▎    | 1.62G/3.08G [00:40<00:37, 39.5MiB/s]
 53%|█████▎    | 1.63G/3.08G [00:40<00:37, 38.8MiB/s]
 53%|█████▎    | 1.64G/3.08G [00:40<00:32, 44.7MiB/s]
 53%|█████▎    | 1.64G/3.08G [00:40<00:29, 49.5MiB/s]
 53%|█████▎    | 1.65G/3.08G [00:40<00:30, 46.5MiB/s]
 54%|█████▎    | 1.65G/3.08G [00:40<00:32, 43.8MiB/s]
 54%|█████▎    | 1.66G/3.08G [00:40<00:32, 43.5MiB/s]
 54%|█████▍    | 1.66G/3.08G [00:40<00:36, 38.9MiB/s]
 54%|█████▍    | 1.67G/3.08G [00:41<00:31, 44.6MiB/s]
 54%|█████▍    | 1.67G/3.08G [00:41<00:32, 43.1MiB/s]
 54%|█████▍    | 1.68G/3.08G [00:41<00:40, 34.8MiB/s]
 55%|█████▍    | 1.69G/3.08G [00:41<00:34, 40.9MiB/s]
 55%|█████▍    | 1.69G/3.08G [00:41<00:30, 45.3MiB/s]
 55%|█████▌    | 1.70G/3.08G [00:41<00:27, 50.7MiB/s]
 55%|█████▌    | 1.71G/3.08G [00:41<00:28, 47.9MiB/s]
 56%|█████▌    | 1.71G/3.08G [00:42<00:38, 35.5MiB/s]
 56%|█████▌    | 1.72G/3.08G [00:42<00:42, 32.4MiB/s]
 56%|█████▌    | 1.72G/3.08G [00:42<00:47, 28.8MiB/s]
 56%|█████▌    | 1.73G/3.08G [00:42<00:42, 31.6MiB/s]
 56%|█████▌    | 1.73G/3.08G [00:42<00:41, 32.1MiB/s]
 56%|█████▋    | 1.74G/3.08G [00:43<00:45, 29.3MiB/s]
 57%|█████▋    | 1.74G/3.08G [00:43<00:42, 31.4MiB/s]
 57%|█████▋    | 1.75G/3.08G [00:43<00:44, 30.3MiB/s]
 57%|█████▋    | 1.75G/3.08G [00:43<00:38, 34.9MiB/s]
 57%|█████▋    | 1.76G/3.08G [00:43<00:41, 31.8MiB/s]
 57%|█████▋    | 1.76G/3.08G [00:43<00:30, 42.9MiB/s]
 57%|█████▋    | 1.77G/3.08G [00:43<00:31, 41.4MiB/s]
 58%|█████▊    | 1.78G/3.08G [00:44<00:25, 50.4MiB/s]
 58%|█████▊    | 1.79G/3.08G [00:44<00:29, 43.3MiB/s]
 58%|█████▊    | 1.79G/3.08G [00:44<00:40, 32.3MiB/s]
 58%|█████▊    | 1.80G/3.08G [00:44<00:41, 31.3MiB/s]
 59%|█████▊    | 1.80G/3.08G [00:44<00:29, 42.9MiB/s]
 59%|█████▉    | 1.81G/3.08G [00:44<00:26, 47.6MiB/s]
 59%|█████▉    | 1.82G/3.08G [00:45<00:25, 49.6MiB/s]
 59%|█████▉    | 1.83G/3.08G [00:45<00:27, 45.9MiB/s]
 60%|█████▉    | 1.84G/3.08G [00:45<00:26, 47.1MiB/s]
 60%|█████▉    | 1.85G/3.08G [00:45<00:25, 49.3MiB/s]
 60%|██████    | 1.85G/3.08G [00:45<00:24, 49.9MiB/s]
 60%|██████    | 1.86G/3.08G [00:45<00:21, 57.0MiB/s]
 61%|██████    | 1.87G/3.08G [00:45<00:21, 57.6MiB/s]
 61%|██████    | 1.88G/3.08G [00:46<00:22, 52.8MiB/s]
 61%|██████    | 1.88G/3.08G [00:46<00:24, 48.5MiB/s]
 61%|██████▏   | 1.89G/3.08G [00:46<00:22, 53.3MiB/s]
 61%|██████▏   | 1.90G/3.08G [00:46<00:22, 51.9MiB/s]
 62%|██████▏   | 1.90G/3.08G [00:46<00:25, 47.2MiB/s]
 62%|██████▏   | 1.91G/3.08G [00:46<00:25, 45.8MiB/s]
 62%|██████▏   | 1.91G/3.08G [00:47<00:31, 36.7MiB/s]
 62%|██████▏   | 1.92G/3.08G [00:47<00:28, 40.3MiB/s]
 62%|██████▏   | 1.92G/3.08G [00:47<00:30, 37.8MiB/s]
 63%|██████▎   | 1.93G/3.08G [00:47<00:29, 39.3MiB/s]
 63%|██████▎   | 1.93G/3.08G [00:47<00:27, 41.2MiB/s]
 63%|██████▎   | 1.94G/3.08G [00:47<00:26, 43.0MiB/s]
 63%|██████▎   | 1.94G/3.08G [00:47<00:29, 39.1MiB/s]
 63%|██████▎   | 1.95G/3.08G [00:47<00:28, 40.4MiB/s]
 63%|██████▎   | 1.95G/3.08G [00:47<00:26, 43.4MiB/s]
 64%|██████▎   | 1.96G/3.08G [00:48<00:26, 42.7MiB/s]
 64%|██████▍   | 1.97G/3.08G [00:48<00:28, 39.8MiB/s]
 64%|██████▍   | 1.97G/3.08G [00:48<00:29, 37.5MiB/s]
 64%|██████▍   | 1.98G/3.08G [00:48<00:30, 36.4MiB/s]
 64%|██████▍   | 1.98G/3.08G [00:48<00:30, 36.0MiB/s]
 64%|██████▍   | 1.99G/3.08G [00:48<00:26, 40.7MiB/s]
 65%|██████▍   | 1.99G/3.08G [00:48<00:23, 46.3MiB/s]
 65%|██████▍   | 2.00G/3.08G [00:49<00:25, 41.8MiB/s]
 65%|██████▌   | 2.00G/3.08G [00:49<00:31, 34.8MiB/s]
 65%|██████▌   | 2.01G/3.08G [00:49<00:26, 40.7MiB/s]
 65%|██████▌   | 2.02G/3.08G [00:49<00:23, 45.1MiB/s]
 66%|██████▌   | 2.02G/3.08G [00:49<00:30, 34.7MiB/s]
 66%|██████▌   | 2.03G/3.08G [00:50<00:30, 34.9MiB/s]
 66%|██████▌   | 2.04G/3.08G [00:50<00:26, 39.6MiB/s]
 66%|██████▌   | 2.04G/3.08G [00:50<00:27, 37.4MiB/s]
 66%|██████▋   | 2.05G/3.08G [00:50<00:31, 33.0MiB/s]
 67%|██████▋   | 2.05G/3.08G [00:50<00:27, 37.7MiB/s]
 67%|██████▋   | 2.06G/3.08G [00:50<00:28, 35.5MiB/s]
 67%|██████▋   | 2.06G/3.08G [00:50<00:27, 36.9MiB/s]
 67%|██████▋   | 2.07G/3.08G [00:51<00:28, 36.0MiB/s]
 67%|██████▋   | 2.08G/3.08G [00:51<00:24, 40.5MiB/s]
 67%|██████▋   | 2.08G/3.08G [00:51<00:23, 42.8MiB/s]
 68%|██████▊   | 2.09G/3.08G [00:51<00:20, 48.1MiB/s]
 68%|██████▊   | 2.09G/3.08G [00:51<00:22, 43.5MiB/s]
 68%|██████▊   | 2.10G/3.08G [00:51<00:20, 47.1MiB/s]
 68%|██████▊   | 2.11G/3.08G [00:51<00:22, 42.6MiB/s]
 68%|██████▊   | 2.11G/3.08G [00:51<00:22, 43.9MiB/s]
 69%|██████▊   | 2.12G/3.08G [00:52<00:22, 43.1MiB/s]
 69%|██████▉   | 2.12G/3.08G [00:52<00:23, 41.5MiB/s]
 69%|██████▉   | 2.13G/3.08G [00:52<00:20, 47.0MiB/s]
 69%|██████▉   | 2.13G/3.08G [00:52<00:25, 37.3MiB/s]
 69%|██████▉   | 2.14G/3.08G [00:52<00:34, 27.4MiB/s]
 69%|██████▉   | 2.14G/3.08G [00:52<00:29, 31.9MiB/s]
 70%|██████▉   | 2.15G/3.08G [00:53<00:32, 28.5MiB/s]
 70%|██████▉   | 2.16G/3.08G [00:53<00:24, 37.2MiB/s]
 70%|███████   | 2.16G/3.08G [00:53<00:20, 43.8MiB/s]
 70%|███████   | 2.17G/3.08G [00:53<00:17, 51.1MiB/s]
 71%|███████   | 2.18G/3.08G [00:53<00:18, 47.8MiB/s]
 71%|███████   | 2.19G/3.08G [00:53<00:16, 53.2MiB/s]
 71%|███████   | 2.19G/3.08G [00:53<00:16, 52.4MiB/s]
 71%|███████▏  | 2.20G/3.08G [00:54<00:18, 48.6MiB/s]
 72%|███████▏  | 2.21G/3.08G [00:54<00:16, 53.3MiB/s]
 72%|███████▏  | 2.21G/3.08G [00:54<00:15, 54.8MiB/s]
 72%|███████▏  | 2.22G/3.08G [00:54<00:17, 48.5MiB/s]
 72%|███████▏  | 2.23G/3.08G [00:54<00:19, 43.1MiB/s]
 72%|███████▏  | 2.23G/3.08G [00:54<00:20, 42.0MiB/s]
 72%|███████▏  | 2.24G/3.08G [00:54<00:20, 40.5MiB/s]
 73%|███████▎  | 2.24G/3.08G [00:55<00:22, 37.6MiB/s]
 73%|███████▎  | 2.25G/3.08G [00:55<00:20, 41.7MiB/s]
 73%|███████▎  | 2.25G/3.08G [00:55<00:20, 40.7MiB/s]
 73%|███████▎  | 2.25G/3.08G [00:55<00:21, 39.1MiB/s]
 73%|███████▎  | 2.26G/3.08G [00:55<00:20, 39.5MiB/s]
 73%|███████▎  | 2.26G/3.08G [00:55<00:19, 41.6MiB/s]
 74%|███████▎  | 2.27G/3.08G [00:55<00:16, 48.0MiB/s]
 74%|███████▍  | 2.28G/3.08G [00:55<00:18, 43.1MiB/s]
 74%|███████▍  | 2.28G/3.08G [00:55<00:18, 43.4MiB/s]
 74%|███████▍  | 2.29G/3.08G [00:56<00:19, 42.0MiB/s]
 74%|███████▍  | 2.29G/3.08G [00:56<00:20, 38.2MiB/s]
 74%|███████▍  | 2.30G/3.08G [00:56<00:18, 41.5MiB/s]
 75%|███████▍  | 2.30G/3.08G [00:56<00:22, 35.6MiB/s]
 75%|███████▍  | 2.31G/3.08G [00:56<00:18, 42.3MiB/s]
 75%|███████▌  | 2.31G/3.08G [00:56<00:17, 44.0MiB/s]
 75%|███████▌  | 2.32G/3.08G [00:56<00:18, 40.6MiB/s]
 75%|███████▌  | 2.32G/3.08G [00:57<00:18, 41.8MiB/s]
 76%|███████▌  | 2.33G/3.08G [00:57<00:16, 46.0MiB/s]
 76%|███████▌  | 2.33G/3.08G [00:57<00:17, 43.0MiB/s]
 76%|███████▌  | 2.34G/3.08G [00:57<00:17, 41.5MiB/s]
 76%|███████▌  | 2.35G/3.08G [00:57<00:14, 50.2MiB/s]
 76%|███████▋  | 2.36G/3.08G [00:57<00:14, 50.3MiB/s]
 77%|███████▋  | 2.36G/3.08G [00:57<00:17, 41.2MiB/s]
 77%|███████▋  | 2.37G/3.08G [00:57<00:17, 41.0MiB/s]
 77%|███████▋  | 2.37G/3.08G [00:58<00:14, 48.6MiB/s]
 77%|███████▋  | 2.38G/3.08G [00:58<00:13, 52.9MiB/s]
 77%|███████▋  | 2.39G/3.08G [00:58<00:15, 46.0MiB/s]
 78%|███████▊  | 2.39G/3.08G [00:58<00:16, 41.5MiB/s]
 78%|███████▊  | 2.40G/3.08G [00:58<00:14, 46.3MiB/s]
 78%|███████▊  | 2.41G/3.08G [00:58<00:14, 45.8MiB/s]
 78%|███████▊  | 2.41G/3.08G [00:59<00:17, 38.6MiB/s]
 78%|███████▊  | 2.42G/3.08G [00:59<00:18, 36.4MiB/s]
 79%|███████▊  | 2.42G/3.08G [00:59<00:18, 36.2MiB/s]
 79%|███████▉  | 2.43G/3.08G [00:59<00:16, 40.5MiB/s]
 79%|███████▉  | 2.44G/3.08G [00:59<00:16, 38.6MiB/s]
 79%|███████▉  | 2.44G/3.08G [00:59<00:19, 32.9MiB/s]
 79%|███████▉  | 2.45G/3.08G [01:00<00:16, 38.8MiB/s]
 80%|███████▉  | 2.46G/3.08G [01:00<00:15, 39.8MiB/s]
 80%|███████▉  | 2.46G/3.08G [01:00<00:19, 32.5MiB/s]
 80%|███████▉  | 2.46G/3.08G [01:00<00:20, 29.9MiB/s]
 80%|████████  | 2.47G/3.08G [01:00<00:23, 26.6MiB/s]
 80%|████████  | 2.47G/3.08G [01:00<00:18, 33.3MiB/s]
 81%|████████  | 2.48G/3.08G [01:00<00:14, 41.6MiB/s]
 81%|████████  | 2.49G/3.08G [01:01<00:13, 45.0MiB/s]
 81%|████████  | 2.49G/3.08G [01:01<00:14, 41.0MiB/s]
 81%|████████  | 2.50G/3.08G [01:01<00:13, 43.6MiB/s]
 81%|████████  | 2.50G/3.08G [01:01<00:13, 43.4MiB/s]
 81%|████████▏ | 2.51G/3.08G [01:01<00:17, 32.2MiB/s]
 82%|████████▏ | 2.52G/3.08G [01:01<00:15, 37.7MiB/s]
 82%|████████▏ | 2.52G/3.08G [01:01<00:12, 43.5MiB/s]
 82%|████████▏ | 2.53G/3.08G [01:02<00:12, 43.4MiB/s]
 82%|████████▏ | 2.54G/3.08G [01:02<00:10, 49.9MiB/s]
 83%|████████▎ | 2.55G/3.08G [01:02<00:10, 52.1MiB/s]
 83%|████████▎ | 2.56G/3.08G [01:02<00:12, 43.0MiB/s]
 83%|████████▎ | 2.56G/3.08G [01:02<00:15, 32.9MiB/s]
 83%|████████▎ | 2.57G/3.08G [01:03<00:16, 31.2MiB/s]
 83%|████████▎ | 2.57G/3.08G [01:03<00:17, 29.2MiB/s]
 84%|████████▎ | 2.58G/3.08G [01:03<00:14, 36.0MiB/s]
 84%|████████▎ | 2.58G/3.08G [01:03<00:13, 36.7MiB/s]
 84%|████████▍ | 2.58G/3.08G [01:03<00:17, 29.3MiB/s]
 84%|████████▍ | 2.59G/3.08G [01:03<00:14, 32.9MiB/s]
 84%|████████▍ | 2.59G/3.08G [01:03<00:15, 31.2MiB/s]
 84%|████████▍ | 2.60G/3.08G [01:04<00:13, 35.6MiB/s]
 85%|████████▍ | 2.61G/3.08G [01:04<00:11, 41.4MiB/s]
 85%|████████▍ | 2.61G/3.08G [01:04<00:12, 38.2MiB/s]
 85%|████████▍ | 2.62G/3.08G [01:04<00:16, 28.2MiB/s]
 85%|████████▍ | 2.62G/3.08G [01:04<00:16, 27.6MiB/s]
 85%|████████▌ | 2.62G/3.08G [01:04<00:15, 29.2MiB/s]
 85%|████████▌ | 2.63G/3.08G [01:05<00:15, 28.9MiB/s]
 85%|████████▌ | 2.63G/3.08G [01:05<00:13, 34.3MiB/s]
 86%|████████▌ | 2.64G/3.08G [01:05<00:10, 41.3MiB/s]
 86%|████████▌ | 2.64G/3.08G [01:05<00:11, 37.5MiB/s]
 86%|████████▌ | 2.65G/3.08G [01:05<00:09, 45.1MiB/s]
 86%|████████▌ | 2.65G/3.08G [01:05<00:11, 36.4MiB/s]
 86%|████████▌ | 2.66G/3.08G [01:05<00:12, 32.8MiB/s]
 86%|████████▋ | 2.66G/3.08G [01:05<00:13, 31.2MiB/s]
 86%|████████▋ | 2.67G/3.08G [01:06<00:13, 30.1MiB/s]
 87%|████████▋ | 2.67G/3.08G [01:06<00:14, 29.2MiB/s]
 87%|████████▋ | 2.68G/3.08G [01:06<00:12, 32.9MiB/s]
 87%|████████▋ | 2.68G/3.08G [01:06<00:11, 35.6MiB/s]
 87%|████████▋ | 2.69G/3.08G [01:06<00:11, 34.6MiB/s]
 87%|████████▋ | 2.69G/3.08G [01:06<00:14, 26.6MiB/s]
 87%|████████▋ | 2.69G/3.08G [01:07<00:14, 26.6MiB/s]
 88%|████████▊ | 2.70G/3.08G [01:07<00:11, 33.3MiB/s]
 88%|████████▊ | 2.71G/3.08G [01:07<00:09, 38.0MiB/s]
 88%|████████▊ | 2.71G/3.08G [01:07<00:11, 31.5MiB/s]
 88%|████████▊ | 2.72G/3.08G [01:07<00:10, 34.3MiB/s]
 88%|████████▊ | 2.73G/3.08G [01:07<00:07, 44.9MiB/s]
 89%|████████▊ | 2.73G/3.08G [01:07<00:07, 45.2MiB/s]
 89%|████████▉ | 2.74G/3.08G [01:08<00:06, 49.3MiB/s]
 89%|████████▉ | 2.75G/3.08G [01:08<00:06, 52.2MiB/s]
 89%|████████▉ | 2.76G/3.08G [01:08<00:08, 39.2MiB/s]
 89%|████████▉ | 2.76G/3.08G [01:08<00:08, 39.5MiB/s]
 90%|████████▉ | 2.77G/3.08G [01:08<00:06, 47.8MiB/s]
 90%|████████▉ | 2.77G/3.08G [01:08<00:09, 33.8MiB/s]
 90%|█████████ | 2.78G/3.08G [01:09<00:08, 35.1MiB/s]
 90%|█████████ | 2.79G/3.08G [01:09<00:08, 37.3MiB/s]
 91%|█████████ | 2.79G/3.08G [01:09<00:06, 46.0MiB/s]
 91%|█████████ | 2.80G/3.08G [01:09<00:05, 51.9MiB/s]
 91%|█████████ | 2.81G/3.08G [01:09<00:06, 42.4MiB/s]
 91%|█████████ | 2.81G/3.08G [01:09<00:07, 35.2MiB/s]
 91%|█████████▏| 2.82G/3.08G [01:09<00:06, 40.8MiB/s]
 92%|█████████▏| 2.82G/3.08G [01:10<00:06, 40.4MiB/s]
 92%|█████████▏| 2.83G/3.08G [01:10<00:07, 33.9MiB/s]
 92%|█████████▏| 2.84G/3.08G [01:10<00:05, 41.8MiB/s]
 92%|█████████▏| 2.84G/3.08G [01:10<00:05, 46.4MiB/s]
 92%|█████████▏| 2.85G/3.08G [01:10<00:04, 52.0MiB/s]
 93%|█████████▎| 2.86G/3.08G [01:10<00:05, 45.1MiB/s]
 93%|█████████▎| 2.86G/3.08G [01:11<00:05, 37.2MiB/s]
 93%|█████████▎| 2.87G/3.08G [01:11<00:05, 42.8MiB/s]
 93%|█████████▎| 2.88G/3.08G [01:11<00:04, 42.0MiB/s]
 94%|█████████▎| 2.89G/3.08G [01:11<00:03, 50.7MiB/s]
 94%|█████████▍| 2.89G/3.08G [01:11<00:03, 54.0MiB/s]
 94%|█████████▍| 2.90G/3.08G [01:11<00:03, 51.0MiB/s]
 94%|█████████▍| 2.91G/3.08G [01:11<00:03, 51.8MiB/s]
 95%|█████████▍| 2.92G/3.08G [01:12<00:02, 56.0MiB/s]
 95%|█████████▍| 2.92G/3.08G [01:12<00:03, 48.2MiB/s]
 95%|█████████▍| 2.93G/3.08G [01:12<00:03, 40.7MiB/s]
 95%|█████████▌| 2.94G/3.08G [01:12<00:02, 49.1MiB/s]
 95%|█████████▌| 2.94G/3.08G [01:12<00:03, 45.4MiB/s]
 96%|█████████▌| 2.95G/3.08G [01:12<00:02, 45.2MiB/s]
 96%|█████████▌| 2.96G/3.08G [01:12<00:02, 43.0MiB/s]
 96%|█████████▌| 2.96G/3.08G [01:13<00:02, 43.6MiB/s]
 96%|█████████▋| 2.97G/3.08G [01:13<00:02, 43.0MiB/s]
 97%|█████████▋| 2.98G/3.08G [01:13<00:02, 50.6MiB/s]
 97%|█████████▋| 2.99G/3.08G [01:13<00:01, 56.8MiB/s]
 97%|█████████▋| 2.99G/3.08G [01:13<00:01, 46.9MiB/s]
 97%|█████████▋| 3.00G/3.08G [01:13<00:01, 44.1MiB/s]
 97%|█████████▋| 3.00G/3.08G [01:13<00:01, 42.7MiB/s]
 98%|█████████▊| 3.01G/3.08G [01:14<00:01, 45.5MiB/s]
 98%|█████████▊| 3.01G/3.08G [01:14<00:01, 44.1MiB/s]
 98%|█████████▊| 3.02G/3.08G [01:14<00:01, 43.3MiB/s]
 98%|█████████▊| 3.03G/3.08G [01:14<00:01, 47.9MiB/s]
 98%|█████████▊| 3.04G/3.08G [01:14<00:00, 51.0MiB/s]
 99%|█████████▊| 3.05G/3.08G [01:14<00:00, 46.2MiB/s]
 99%|█████████▉| 3.05G/3.08G [01:15<00:00, 45.9MiB/s]
 99%|█████████▉| 3.06G/3.08G [01:15<00:00, 41.9MiB/s]
 99%|█████████▉| 3.06G/3.08G [01:15<00:00, 38.7MiB/s]
100%|█████████▉| 3.07G/3.08G [01:15<00:00, 47.2MiB/s]
100%|█████████▉| 3.08G/3.08G [01:15<00:00, 48.3MiB/s]
100%|██████████| 3.08G/3.08G [01:15<00:00, 40.8MiB/s]
Then we can load the annotations into dataframes.
dataset_path = os.path.join(download_dir, 'Stanford_Online_Products')
train_data = pd.read_csv(f'{dataset_path}/train.csv', index_col=0)
test_data = pd.read_csv(f'{dataset_path}/test.csv', index_col=0)
image_col_1 = "Image1"
image_col_2 = "Image2"
label_col = "Label"
match_label = 1
Here you need to specify the match_label, the label class representing that a pair semantically match. In this demo dataset, we use 1 since we assigned 1 to image pairs from the same product. You may consider your task context to specify match_label.
Next, we expand the image paths since the original paths are relative.
def path_expander(path, base_folder):
    path_l = path.split(';')
    return ';'.join([os.path.abspath(os.path.join(base_folder, path)) for path in path_l])
for image_col in [image_col_1, image_col_2]:
    train_data[image_col] = train_data[image_col].apply(lambda ele: path_expander(ele, base_folder=dataset_path))
    test_data[image_col] = test_data[image_col].apply(lambda ele: path_expander(ele, base_folder=dataset_path))
The annotations are only image path pairs and their binary labels (1 and 0 mean the image pair matching or not, respectively).
train_data.head()
| Image1 | Image2 | Label | |
|---|---|---|---|
| 0 | /home/ci/autogluon/docs/tutorials/multimodal/s... | /home/ci/autogluon/docs/tutorials/multimodal/s... | 0 | 
| 1 | /home/ci/autogluon/docs/tutorials/multimodal/s... | /home/ci/autogluon/docs/tutorials/multimodal/s... | 1 | 
| 2 | /home/ci/autogluon/docs/tutorials/multimodal/s... | /home/ci/autogluon/docs/tutorials/multimodal/s... | 0 | 
| 3 | /home/ci/autogluon/docs/tutorials/multimodal/s... | /home/ci/autogluon/docs/tutorials/multimodal/s... | 1 | 
| 4 | /home/ci/autogluon/docs/tutorials/multimodal/s... | /home/ci/autogluon/docs/tutorials/multimodal/s... | 1 | 
Let’s visualize a matching image pair.
pil_img = Image(filename=train_data[image_col_1][5])
display(pil_img)
 
pil_img = Image(filename=train_data[image_col_2][5])
display(pil_img)
 
Here are two images that do not match.
pil_img = Image(filename=train_data[image_col_1][0])
display(pil_img)
 
pil_img = Image(filename=train_data[image_col_2][0])
display(pil_img)
 
Train your Model¶
Ideally, we want to obtain a model that can return high/low scores for positive/negative image pairs. With AutoMM, we can easily train a model that captures the semantic relationship between images. Basically, it uses Swin Transformer to project each image into a high-dimensional vector and compute the cosine similarity of feature vectors.
With AutoMM, you just need to specify the query, response, and label column names and fit the model on the training dataset without worrying the implementation details.
from autogluon.multimodal import MultiModalPredictor
predictor = MultiModalPredictor(
        problem_type="image_similarity",
        query=image_col_1, # the column name of the first image
        response=image_col_2, # the column name of the second image
        label=label_col, # the label column name
        match_label=match_label, # the label indicating that query and response have the same semantic meanings.
        eval_metric='auc', # the evaluation metric
    )
    
# Fit the model
predictor.fit(
    train_data=train_data,
    time_limit=180,
)
No path specified. Models will be saved in: "AutogluonModels/ag-20250729_000626"
=================== System Info ===================
AutoGluon Version:  1.4.0b20250728
Python Version:     3.12.10
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count:          8
Pytorch Version:    2.7.1+cu126
CUDA Version:       12.6
GPU Count:          1
Memory Avail:       28.19 GB / 30.95 GB (91.1%)
Disk Space Avail:   173.84 GB / 255.99 GB (67.9%)
===================================================
AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed).
	2 unique label values:  [np.int64(0), np.int64(1)]
	If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during Predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression', 'quantile'])
AutoMM starts to create your model. ✨✨✨
To track the learning progress, you can open a terminal and launch Tensorboard:
    ```shell
    # Assume you have installed tensorboard
    tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/semantic_matching/AutogluonModels/ag-20250729_000626
    ```
INFO: Seed set to 0
WARNING:timm.models._builder:Unexpected keys (head.fc.fc1.bias, head.fc.fc1.weight, head.fc.norm.bias, head.fc.norm.weight) found while loading pretrained weights. This may be expected if model is being adapted.
GPU Count: 1
GPU Count to be Used: 1
INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO: GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name              | Type                            | Params | Mode 
------------------------------------------------------------------------------
0 | query_model       | TimmAutoModelForImagePrediction | 93.3 M | train
1 | response_model    | TimmAutoModelForImagePrediction | 93.3 M | train
2 | validation_metric | BinaryAUROC                     | 0      | train
3 | loss_func         | ContrastiveLoss                 | 0      | train
4 | miner_func        | PairMarginMiner                 | 0      | train
------------------------------------------------------------------------------
93.3 M    Trainable params
0         Non-trainable params
93.3 M    Total params
373.248   Total estimated model params size (MB)
866       Modules in train mode
0         Modules in eval mode
INFO: Epoch 0, global step 15: 'val_roc_auc' reached 0.82797 (best 0.82797), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/semantic_matching/AutogluonModels/ag-20250729_000626/epoch=0-step=15.ckpt' as top 3
INFO: Time limit reached. Elapsed time is 0:03:00. Signaling Trainer to stop.
INFO: Epoch 0, global step 22: 'val_roc_auc' reached 0.88910 (best 0.88910), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/semantic_matching/AutogluonModels/ag-20250729_000626/epoch=0-step=22.ckpt' as top 3
Start to fuse 2 checkpoints via the greedy soup algorithm.
INFO: 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
INFO: 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
AutoMM has created your model. 🎉🎉🎉
To load the model, use the code below:
    ```python
    from autogluon.multimodal import MultiModalPredictor
    predictor = MultiModalPredictor.load("/home/ci/autogluon/docs/tutorials/multimodal/semantic_matching/AutogluonModels/ag-20250729_000626")
    ```
If you are not satisfied with the model, try to increase the training time, 
adjust the hyperparameters (https://auto.gluon.ai/stable/tutorials/multimodal/advanced_topics/customization.html),
or post issues on GitHub (https://github.com/autogluon/autogluon/issues).
<autogluon.multimodal.predictor.MultiModalPredictor at 0x7f85e82e84d0>
Evaluate on Test Dataset¶
You can evaluate the predictor on the test dataset to see how it performs with the roc_auc score:
score = predictor.evaluate(test_data)
print("evaluation score: ", score)
evaluation score:  {'roc_auc': np.float64(0.8890753046961162)}
INFO: 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
Predict on Image Pairs¶
Given new image pairs, we can predict whether they match or not.
pred = predictor.predict(test_data.head(3))
print(pred)
0    1
1    1
2    1
Name: Label, dtype: int64
INFO: 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
The predictions use a naive probability threshold 0.5. That is, we choose the label with the probability larger than 0.5.
Predict Matching Probabilities¶
However, you can do more customized thresholding by getting probabilities.
proba = predictor.predict_proba(test_data.head(3))
print(proba)
          0         1
0  0.301796  0.698204
1  0.048948  0.951052
2  0.073070  0.926930
INFO: 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
Extract Embeddings¶
You can also extract embeddings for each image of a pair.
embeddings_1 = predictor.extract_embedding({image_col_1: test_data[image_col_1][:5].tolist()})
print(embeddings_1.shape)
embeddings_2 = predictor.extract_embedding({image_col_2: test_data[image_col_2][:5].tolist()})
print(embeddings_2.shape)
(5, 768)
(5, 768)
INFO: 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
INFO: 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
Other Examples¶
You may go to AutoMM Examples to explore other examples about AutoMM.
Customization¶
To learn how to customize AutoMM, please refer to Customize AutoMM.