CLIP in AutoMM - Zero-Shot Image Classification#

Open In Colab Open In SageMaker Studio Lab

When you want to classify an image to different classes, it is standard to train an image classifier based on the class names. However, it is tedious to collect training data. And if the collected data is too few or too imbalanced, you may not get a decent image classifier. So you wonder, is there a strong enough model that can handle this situaton without the training efforts?

Actually there is! OpenAI has introduced a model named CLIP, which can be applied to any visual classification benchmark by simply providing the names of the visual categories to be recognized. And its accuracy is high, e.g., CLIP can achieve 76.2% top-1 accuracy on ImageNet without using any of the 1.28M training samples. This performance matches with original supervised ResNet50 on ImageNet, quite promising for a classification task with 1000 classes!

So in this tutorial, let’s dive deep into CLIP. We will show you how to use CLIP model to do zero-shot image classification in AutoGluon.

Simple Demo#

Here we provide a simple demo to classify what dog breed is in the picture below.

from IPython.display import Image, display
from autogluon.multimodal import download

url = "https://farm4.staticflickr.com/3445/3262471985_ed886bf61a_z.jpg"
dog_image = download(url)

pil_img = Image(filename=dog_image)
display(pil_img)
Downloading 3262471985_ed886bf61a_z.jpg from https://farm4.staticflickr.com/3445/3262471985_ed886bf61a_z.jpg...
                     
../../../_images/430b063e465770dc7edf6986e78c25ebec12897e4547065dac86c62cd6a9765e.jpg

Normally to solve this task, you need to collect some training data (e.g., the Stanford Dogs dataset) and train a dog breed classifier. But with CLIP, all you need to do is provide some potential visual categories. CLIP will handle the rest for you.

from autogluon.multimodal import MultiModalPredictor

predictor = MultiModalPredictor(problem_type="zero_shot_image_classification")
prob = predictor.predict_proba({"image": [dog_image]}, {"text": ['This is a Husky', 'This is a Golden Retriever', 'This is a German Sheperd', 'This is a Samoyed.']})
print("Label probs:", prob)
The model does not support using an image size that is different from the default size. Provided image size=224. Default size=336. Detailed model configuration=CLIPConfig {
  "_commit_hash": "ce19dc912ca5cd21c8a653c79e251e808ccabcd1",
  "_name_or_path": "openai/clip-vit-large-patch14-336",
  "architectures": [
    "CLIPModel"
  ],
  "initializer_factor": 1.0,
  "logit_scale_init_value": 2.6592,
  "model_type": "clip",
  "projection_dim": 768,
  "text_config": {
    "_name_or_path": "",
    "add_cross_attention": false,
    "architectures": null,
    "attention_dropout": 0.0,
    "bad_words_ids": null,
    "begin_suppress_tokens": null,
    "bos_token_id": 0,
    "chunk_size_feed_forward": 0,
    "cross_attention_hidden_size": null,
    "decoder_start_token_id": null,
    "diversity_penalty": 0.0,
    "do_sample": false,
    "dropout": 0.0,
    "early_stopping": false,
    "encoder_no_repeat_ngram_size": 0,
    "eos_token_id": 2,
    "exponential_decay_length_penalty": null,
    "finetuning_task": null,
    "forced_bos_token_id": null,
    "forced_eos_token_id": null,
    "hidden_act": "quick_gelu",
    "hidden_size": 768,
    "id2label": {
      "0": "LABEL_0",
      "1": "LABEL_1"
    },
    "initializer_factor": 1.0,
    "initializer_range": 0.02,
    "intermediate_size": 3072,
    "is_decoder": false,
    "is_encoder_decoder": false,
    "label2id": {
      "LABEL_0": 0,
      "LABEL_1": 1
    },
    "layer_norm_eps": 1e-05,
    "length_penalty": 1.0,
    "max_length": 20,
    "max_position_embeddings": 77,
    "min_length": 0,
    "model_type": "clip_text_model",
    "no_repeat_ngram_size": 0,
    "num_attention_heads": 12,
    "num_beam_groups": 1,
    "num_beams": 1,
    "num_hidden_layers": 12,
    "num_return_sequences": 1,
    "output_attentions": false,
    "output_hidden_states": false,
    "output_scores": false,
    "pad_token_id": 1,
    "prefix": null,
    "problem_type": null,
    "projection_dim": 768,
    "pruned_heads": {},
    "remove_invalid_values": false,
    "repetition_penalty": 1.0,
    "return_dict": true,
    "return_dict_in_generate": false,
    "sep_token_id": null,
    "suppress_tokens": null,
    "task_specific_params": null,
    "temperature": 1.0,
    "tf_legacy_loss": false,
    "tie_encoder_decoder": false,
    "tie_word_embeddings": true,
    "tokenizer_class": null,
    "top_k": 50,
    "top_p": 1.0,
    "torch_dtype": null,
    "torchscript": false,
    "transformers_version": "4.26.1",
    "typical_p": 1.0,
    "use_bfloat16": false,
    "vocab_size": 49408
  },
  "text_config_dict": {
    "hidden_size": 768,
    "intermediate_size": 3072,
    "num_attention_heads": 12,
    "num_hidden_layers": 12,
    "projection_dim": 768
  },
  "torch_dtype": "float32",
  "transformers_version": null,
  "vision_config": {
    "_name_or_path": "",
    "add_cross_attention": false,
    "architectures": null,
    "attention_dropout": 0.0,
    "bad_words_ids": null,
    "begin_suppress_tokens": null,
    "bos_token_id": null,
    "chunk_size_feed_forward": 0,
    "cross_attention_hidden_size": null,
    "decoder_start_token_id": null,
    "diversity_penalty": 0.0,
    "do_sample": false,
    "dropout": 0.0,
    "early_stopping": false,
    "encoder_no_repeat_ngram_size": 0,
    "eos_token_id": null,
    "exponential_decay_length_penalty": null,
    "finetuning_task": null,
    "forced_bos_token_id": null,
    "forced_eos_token_id": null,
    "hidden_act": "quick_gelu",
    "hidden_size": 1024,
    "id2label": {
      "0": "LABEL_0",
      "1": "LABEL_1"
    },
    "image_size": 336,
    "initializer_factor": 1.0,
    "initializer_range": 0.02,
    "intermediate_size": 4096,
    "is_decoder": false,
    "is_encoder_decoder": false,
    "label2id": {
      "LABEL_0": 0,
      "LABEL_1": 1
    },
    "layer_norm_eps": 1e-05,
    "length_penalty": 1.0,
    "max_length": 20,
    "min_length": 0,
    "model_type": "clip_vision_model",
    "no_repeat_ngram_size": 0,
    "num_attention_heads": 16,
    "num_beam_groups": 1,
    "num_beams": 1,
    "num_channels": 3,
    "num_hidden_layers": 24,
    "num_return_sequences": 1,
    "output_attentions": false,
    "output_hidden_states": false,
    "output_scores": false,
    "pad_token_id": null,
    "patch_size": 14,
    "prefix": null,
    "problem_type": null,
    "projection_dim": 768,
    "pruned_heads": {},
    "remove_invalid_values": false,
    "repetition_penalty": 1.0,
    "return_dict": true,
    "return_dict_in_generate": false,
    "sep_token_id": null,
    "suppress_tokens": null,
    "task_specific_params": null,
    "temperature": 1.0,
    "tf_legacy_loss": false,
    "tie_encoder_decoder": false,
    "tie_word_embeddings": true,
    "tokenizer_class": null,
    "top_k": 50,
    "top_p": 1.0,
    "torch_dtype": null,
    "torchscript": false,
    "transformers_version": "4.26.1",
    "typical_p": 1.0,
    "use_bfloat16": false
  },
  "vision_config_dict": {
    "hidden_size": 1024,
    "image_size": 336,
    "intermediate_size": 4096,
    "num_attention_heads": 16,
    "num_hidden_layers": 24,
    "patch_size": 14,
    "projection_dim": 768
  }
}
. We have ignored the provided image size.
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
 /home/ci/opt/venv/lib/python3.8/site-packages/urllib3/connectionpool.py:466 in _make_request     
                                                                                                  
    463 │   │   │   │   │   # Remove the TypeError from the exception chain in                    
    464 │   │   │   │   │   # Python 3 (including for exceptions like SystemExit).                
    465 │   │   │   │   │   # Otherwise it looks like a bug in the code.                          
  466 │   │   │   │   │   six.raise_from(e, None)                                               
    467 │   │   except (SocketTimeout, BaseSSLError, SocketError) as e:                           
    468 │   │   │   self._raise_timeout(err=e, url=url, timeout_value=read_timeout)               
    469 │   │   │   raise                                                                         
 in raise_from:3                                                                                  
                                                                                                  
 /home/ci/opt/venv/lib/python3.8/site-packages/urllib3/connectionpool.py:461 in _make_request     
                                                                                                  
    458 │   │   │   except TypeError:                                                             
    459 │   │   │   │   # Python 3                                                                
    460 │   │   │   │   try:                                                                      
  461 │   │   │   │   │   httplib_response = conn.getresponse()                                 
    462 │   │   │   │   except BaseException as e:                                                
    463 │   │   │   │   │   # Remove the TypeError from the exception chain in                    
    464 │   │   │   │   │   # Python 3 (including for exceptions like SystemExit).                
                                                                                                  
 /opt/conda/lib/python3.8/http/client.py:1348 in getresponse                                      
                                                                                                  
   1345 │   │                                                                                     
   1346 │   │   try:                                                                              
   1347 │   │   │   try:                                                                          
 1348 │   │   │   │   response.begin()                                                          
   1349 │   │   │   except ConnectionError:                                                       
   1350 │   │   │   │   self.close()                                                              
   1351 │   │   │   │   raise                                                                     
                                                                                                  
 /opt/conda/lib/python3.8/http/client.py:316 in begin                                             
                                                                                                  
    313 │   │                                                                                     
    314 │   │   # read until we get a non-100 response                                            
    315 │   │   while True:                                                                       
  316 │   │   │   version, status, reason = self._read_status()                                 
    317 │   │   │   if status != CONTINUE:                                                        
    318 │   │   │   │   break                                                                     
    319 │   │   │   # skip the header from the 100 response                                       
                                                                                                  
 /opt/conda/lib/python3.8/http/client.py:277 in _read_status                                      
                                                                                                  
    274 │   │   self.will_close = _UNKNOWN      # conn will close at end of response              
    275                                                                                       
    276 def _read_status(self):                                                               
  277 │   │   line = str(self.fp.readline(_MAXLINE + 1), "iso-8859-1")                          
    278 │   │   if len(line) > _MAXLINE:                                                          
    279 │   │   │   raise LineTooLong("status line")                                              
    280 │   │   if self.debuglevel > 0:                                                           
                                                                                                  
 /opt/conda/lib/python3.8/socket.py:669 in readinto                                               
                                                                                                  
   666 │   │   │   raise OSError("cannot read from timed out object")                             
   667 │   │   while True:                                                                        
   668 │   │   │   try:                                                                           
 669 │   │   │   │   return self._sock.recv_into(b)                                             
   670 │   │   │   except timeout:                                                                
   671 │   │   │   │   self._timeout_occurred = True                                              
   672 │   │   │   │   raise                                                                      
                                                                                                  
 /opt/conda/lib/python3.8/ssl.py:1241 in recv_into                                                
                                                                                                  
   1238 │   │   │   │   raise ValueError(                                                         
   1239 │   │   │   │     "non-zero flags not allowed in calls to recv_into() on %s" %            
   1240 │   │   │   │     self.__class__)                                                         
 1241 │   │   │   return self.read(nbytes, buffer)                                              
   1242 │   │   else:                                                                             
   1243 │   │   │   return super().recv_into(buffer, nbytes, flags)                               
   1244                                                                                           
                                                                                                  
 /opt/conda/lib/python3.8/ssl.py:1099 in read                                                     
                                                                                                  
   1096 │   │   │   raise ValueError("Read on closed or unwrapped SSL socket.")                   
   1097 │   │   try:                                                                              
   1098 │   │   │   if buffer is not None:                                                        
 1099 │   │   │   │   return self._sslobj.read(len, buffer)                                     
   1100 │   │   │   else:                                                                         
   1101 │   │   │   │   return self._sslobj.read(len)                                             
   1102 │   │   except SSLError as x:                                                             
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
timeout: The read operation timed out

During handling of the above exception, another exception occurred:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
 /home/ci/opt/venv/lib/python3.8/site-packages/requests/adapters.py:486 in send                   
                                                                                                  
   483 │   │   │   timeout = TimeoutSauce(connect=timeout, read=timeout)                          
   484 │   │                                                                                      
   485 │   │   try:                                                                               
 486 │   │   │   resp = conn.urlopen(                                                           
   487 │   │   │   │   method=request.method,                                                     
   488 │   │   │   │   url=url,                                                                   
   489 │   │   │   │   body=request.body,                                                         
                                                                                                  
 /home/ci/opt/venv/lib/python3.8/site-packages/urllib3/connectionpool.py:798 in urlopen           
                                                                                                  
    795 │   │   │   elif isinstance(e, (SocketError, HTTPException)):                             
    796 │   │   │   │   e = ProtocolError("Connection aborted.", e)                               
    797 │   │   │                                                                                 
  798 │   │   │   retries = retries.increment(                                                  
    799 │   │   │   │   method, url, error=e, _pool=self, _stacktrace=sys.exc_info()[2]           
    800 │   │   │   )                                                                             
    801 │   │   │   retries.sleep()                                                               
                                                                                                  
 /home/ci/opt/venv/lib/python3.8/site-packages/urllib3/util/retry.py:550 in increment             
                                                                                                  
   547 │   │   elif error and self._is_read_error(error):                                         
   548 │   │   │   # Read retry?                                                                  
   549 │   │   │   if read is False or not self._is_method_retryable(method):                     
 550 │   │   │   │   raise six.reraise(type(error), error, _stacktrace)                         
   551 │   │   │   elif read is not None:                                                         
   552 │   │   │   │   read -= 1                                                                  
   553                                                                                            
                                                                                                  
 /home/ci/opt/venv/lib/python3.8/site-packages/urllib3/packages/six.py:770 in reraise             
                                                                                                  
    767 │   │   │   │   value = tp()                                                              
    768 │   │   │   if value.__traceback__ is not tb:                                             
    769 │   │   │   │   raise value.with_traceback(tb)                                            
  770 │   │   │   raise value                                                                   
    771 │   │   finally:                                                                          
    772 │   │   │   value = None                                                                  
    773 │   │   │   tb = None                                                                     
                                                                                                  
 /home/ci/opt/venv/lib/python3.8/site-packages/urllib3/connectionpool.py:714 in urlopen           
                                                                                                  
    711 │   │   │   │   self._prepare_proxy(conn)                                                 
    712 │   │   │                                                                                 
    713 │   │   │   # Make the request on the httplib connection object.                          
  714 │   │   │   httplib_response = self._make_request(                                        
    715 │   │   │   │   conn,                                                                     
    716 │   │   │   │   method,                                                                   
    717 │   │   │   │   url,                                                                      
                                                                                                  
 /home/ci/opt/venv/lib/python3.8/site-packages/urllib3/connectionpool.py:468 in _make_request     
                                                                                                  
    465 │   │   │   │   │   # Otherwise it looks like a bug in the code.                          
    466 │   │   │   │   │   six.raise_from(e, None)                                               
    467 │   │   except (SocketTimeout, BaseSSLError, SocketError) as e:                           
  468 │   │   │   self._raise_timeout(err=e, url=url, timeout_value=read_timeout)               
    469 │   │   │   raise                                                                         
    470 │   │                                                                                     
    471 │   │   # AppEngine doesn't have a version attr.                                          
                                                                                                  
 /home/ci/opt/venv/lib/python3.8/site-packages/urllib3/connectionpool.py:357 in _raise_timeout    
                                                                                                  
    354 │   │   """Is the error actually a timeout? Will raise a ReadTimeout or pass"""           
    355 │   │                                                                                     
    356 │   │   if isinstance(err, SocketTimeout):                                                
  357 │   │   │   raise ReadTimeoutError(                                                       
    358 │   │   │   │   self, url, "Read timed out. (read timeout=%s)" % timeout_value            
    359 │   │   │   )                                                                             
    360                                                                                           
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ReadTimeoutError: HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10.0)

During handling of the above exception, another exception occurred:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
 in <module>:4                                                                                    
                                                                                                  
   1 from autogluon.multimodal import MultiModalPredictor                                         
   2                                                                                              
   3 predictor = MultiModalPredictor(problem_type="zero_shot_image_classification")               
 4 prob = predictor.predict_proba({"image": [dog_image]}, {"text": ['This is a Husky', 'Thi     
   5 print("Label probs:", prob)                                                                  
   6                                                                                              
                                                                                                  
 /home/ci/autogluon/multimodal/src/autogluon/multimodal/predictor.py:2309 in predict_proba        
                                                                                                  
   2306 │   │   │   │   realtime=realtime,                                                        
   2307 │   │   │   )                                                                             
   2308 │   │                                                                                     
 2309 │   │   self._ensure_inference_ready()                                                    
   2310 │   │   assert self._problem_type not in [                                                
   2311 │   │   │   REGRESSION,                                                                   
   2312 │   │   ], f"Problem {self._problem_type} has no probability output."                     
                                                                                                  
 /home/ci/autogluon/multimodal/src/autogluon/multimodal/predictor.py:950 in                       
 _ensure_inference_ready                                                                          
                                                                                                  
    947 │   │   │   │   │   f"running `predictor.predict()`, `predictor.evaluate()` or `predicto  
    948 │   │   │   │   )                                                                         
    949 │   │   │   else:                                                                         
  950 │   │   │   │   self._init_pretrained()                                                   
    951                                                                                       
    952 def _setup_distillation(                                                              
    953 │   │   self,                                                                             
                                                                                                  
 /home/ci/autogluon/multimodal/src/autogluon/multimodal/predictor.py:935 in _init_pretrained      
                                                                                                  
    932 │   │   │   │   config=self._config, pretrained=self._pretrained, num_classes=self._outp  
    933 │   │   │   )                                                                             
    934 │   │   if self._data_processors is None:                                                 
  935 │   │   │   self._data_processors = create_fusion_data_processors(                        
    936 │   │   │   │   config=self._config,                                                      
    937 │   │   │   │   model=self._model,                                                        
    938 │   │   │   │   advanced_hyperparameters=advanced_hyperparameters,                        
                                                                                                  
 /home/ci/autogluon/multimodal/src/autogluon/multimodal/utils/data.py:312 in                      
 create_fusion_data_processors                                                                    
                                                                                                  
   309 │   │                                                                                      
   310 │   │   if requires_data and data_types and per_name != OVD:  # currently OVD does not r   
   311 │   │   │   for data_type in data_types:                                                   
 312 │   │   │   │   per_data_processor = create_data_processor(                                
   313 │   │   │   │   │   data_type=data_type,                                                   
   314 │   │   │   │   │   model=per_model,                                                       
   315 │   │   │   │   │   config=config,                                                         
                                                                                                  
 /home/ci/autogluon/multimodal/src/autogluon/multimodal/utils/data.py:148 in                      
 create_data_processor                                                                            
                                                                                                  
   145 │   │   │   missing_value_strategy=config.data.image.missing_value_strategy,               
   146 │   │   )                                                                                  
   147 elif data_type == TEXT:                                                                
 148 │   │   data_processor = TextProcessor(                                                    
   149 │   │   │   model=model,                                                                   
   150 │   │   │   tokenizer_name=model_config.tokenizer_name,                                    
   151 │   │   │   max_len=model_config.max_text_len,                                             
                                                                                                  
 /home/ci/autogluon/multimodal/src/autogluon/multimodal/data/process_text.py:136 in __init__      
                                                                                                  
   133 │   │   if hasattr(model, "tokenizer"):                                                    
   134 │   │   │   self.tokenizer = model.tokenizer                                               
   135 │   │   else:                                                                              
 136 │   │   │   self.tokenizer = self.get_pretrained_tokenizer(                                
   137 │   │   │   │   tokenizer_name=tokenizer_name,                                             
   138 │   │   │   │   checkpoint_name=model.checkpoint_name,                                     
   139 │   │   │   )                                                                              
                                                                                                  
 /home/ci/autogluon/multimodal/src/autogluon/multimodal/data/process_text.py:430 in               
 get_pretrained_tokenizer                                                                         
                                                                                                  
   427 │   │   """                                                                                
   428 │   │   try:                                                                               
   429 │   │   │   tokenizer_class = ALL_TOKENIZERS[tokenizer_name]                               
 430 │   │   │   return tokenizer_class.from_pretrained(checkpoint_name)                        
   431 │   │   except TypeError as e:                                                             
   432 │   │   │   try:                                                                           
   433 │   │   │   │   tokenizer_class = ALL_TOKENIZERS["bert"]                                   
                                                                                                  
 /home/ci/opt/venv/lib/python3.8/site-packages/transformers/tokenization_utils_base.py:1763 in    
 from_pretrained                                                                                  
                                                                                                  
   1760 │   │   │   │   elif is_remote_url(file_path):                                            
   1761 │   │   │   │   │   resolved_vocab_files[file_id] = download_url(file_path, proxies=prox  
   1762 │   │   │   else:                                                                         
 1763 │   │   │   │   resolved_vocab_files[file_id] = cached_file(                              
   1764 │   │   │   │   │   pretrained_model_name_or_path,                                        
   1765 │   │   │   │   │   file_path,                                                            
   1766 │   │   │   │   │   cache_dir=cache_dir,                                                  
                                                                                                  
 /home/ci/opt/venv/lib/python3.8/site-packages/transformers/utils/hub.py:409 in cached_file       
                                                                                                  
    406 user_agent = http_user_agent(user_agent)                                              
    407 try:                                                                                  
    408 │   │   # Load from URL or cache if already cached                                        
  409 │   │   resolved_file = hf_hub_download(                                                  
    410 │   │   │   path_or_repo_id,                                                              
    411 │   │   │   filename,                                                                     
    412 │   │   │   subfolder=None if len(subfolder) == 0 else subfolder,                         
                                                                                                  
 /home/ci/opt/venv/lib/python3.8/site-packages/huggingface_hub/utils/_validators.py:118 in        
 _inner_fn                                                                                        
                                                                                                  
   115 │   │   if check_use_auth_token:                                                           
   116 │   │   │   kwargs = smoothly_deprecate_use_auth_token(fn_name=fn.__name__, has_token=ha   
   117 │   │                                                                                      
 118 │   │   return fn(*args, **kwargs)                                                         
   119                                                                                        
   120 return _inner_fn  # type: ignore                                                       
   121                                                                                            
                                                                                                  
 /home/ci/opt/venv/lib/python3.8/site-packages/huggingface_hub/file_download.py:1364 in           
 hf_hub_download                                                                                  
                                                                                                  
   1361 │   │   with temp_file_manager() as temp_file:                                            
   1362 │   │   │   logger.info("downloading %s to %s", url, temp_file.name)                      
   1363 │   │   │                                                                                 
 1364 │   │   │   http_get(                                                                     
   1365 │   │   │   │   url_to_download,                                                          
   1366 │   │   │   │   temp_file,                                                                
   1367 │   │   │   │   proxies=proxies,                                                          
                                                                                                  
 /home/ci/opt/venv/lib/python3.8/site-packages/huggingface_hub/file_download.py:505 in http_get   
                                                                                                  
    502 if resume_size > 0:                                                                   
    503 │   │   headers["Range"] = "bytes=%d-" % (resume_size,)                                   
    504                                                                                       
  505 r = _request_wrapper(                                                                 
    506 │   │   method="GET",                                                                     
    507 │   │   url=url,                                                                          
    508 │   │   stream=True,                                                                      
                                                                                                  
 /home/ci/opt/venv/lib/python3.8/site-packages/huggingface_hub/file_download.py:442 in            
 _request_wrapper                                                                                 
                                                                                                  
    439 │   │   return response                                                                   
    440                                                                                       
    441 # 3. Exponential backoff                                                              
  442 return http_backoff(                                                                  
    443 │   │   method=method,                                                                    
    444 │   │   url=url,                                                                          
    445 │   │   max_retries=max_retries,                                                          
                                                                                                  
 /home/ci/opt/venv/lib/python3.8/site-packages/huggingface_hub/utils/_http.py:212 in http_backoff 
                                                                                                  
   209 │   │   │   │   kwargs["data"].seek(io_obj_initial_pos)                                    
   210 │   │   │                                                                                  
   211 │   │   │   # Perform request and return if status_code is not in the retry list.          
 212 │   │   │   response = session.request(method=method, url=url, **kwargs)                   
   213 │   │   │   if response.status_code not in retry_on_status_codes:                          
   214 │   │   │   │   return response                                                            
   215                                                                                            
                                                                                                  
 /home/ci/opt/venv/lib/python3.8/site-packages/requests/sessions.py:589 in request                
                                                                                                  
   586 │   │   │   "allow_redirects": allow_redirects,                                            
   587 │   │   }                                                                                  
   588 │   │   send_kwargs.update(settings)                                                       
 589 │   │   resp = self.send(prep, **send_kwargs)                                              
   590 │   │                                                                                      
   591 │   │   return resp                                                                        
   592                                                                                            
                                                                                                  
 /home/ci/opt/venv/lib/python3.8/site-packages/requests/sessions.py:703 in send                   
                                                                                                  
   700 │   │   start = preferred_clock()                                                          
   701 │   │                                                                                      
   702 │   │   # Send the request                                                                 
 703 │   │   r = adapter.send(request, **kwargs)                                                
   704 │   │                                                                                      
   705 │   │   # Total elapsed time of the request (approximately)                                
   706 │   │   elapsed = preferred_clock() - start                                                
                                                                                                  
 /home/ci/opt/venv/lib/python3.8/site-packages/requests/adapters.py:532 in send                   
                                                                                                  
   529 │   │   │   │   # This branch is for urllib3 versions earlier than v1.22                   
   530 │   │   │   │   raise SSLError(e, request=request)                                         
   531 │   │   │   elif isinstance(e, ReadTimeoutError):                                          
 532 │   │   │   │   raise ReadTimeout(e, request=request)                                      
   533 │   │   │   elif isinstance(e, _InvalidHeader):                                            
   534 │   │   │   │   raise InvalidHeader(e, request=request)                                    
   535 │   │   │   else:                                                                          
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ReadTimeout: HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10.0)

Clearly, according to the probabilities, we know there is a Husky in the photo (which I think is correct)!

Let’s try a harder example. Below is a photo of two Segways. This object class is not common in most existing vision datasets.

url = "https://live.staticflickr.com/7236/7114602897_9cf00b2820_b.jpg"
segway_image = download(url)

pil_img = Image(filename=segway_image)
display(pil_img)

Given several text queries, CLIP can still predict the segway class correctly with high confidence.

prob = predictor.predict_proba({"image": [segway_image]}, {"text": ['segway', 'bicycle', 'wheel', 'car']})
print("Label probs:", prob)

This is amazing, right? Now a bit knowledge on why and how CLIP works. CLIP is called Contrastive Language-Image Pre-training. It is trained on a massive number of data (400M image-text pairs). By using a simple loss objective, CLIP tries to predict which out of a set of randomly sampled text is actually paired with an given image in the training dataset. As a result, CLIP models can then be applied to nearly arbitrary visual classification tasks just like the examples we have shown above.

More about CLIP#

CLIP is powerful, and it was designed to mitigate a number of major problems in the standard deep learning approach to computer vision, such as costly datasets, closed set prediction and poor generalization performance. CLIP is a good solution to many problems, however, it is not the ultimate solution. CLIP has its own limitations. For example, CLIP is vulnerable to typographic attacks, i.e., if you add some text to an image, CLIP’s predictions will be easily affected by the text. Let’s see one example from OpenAI’s blog post on multimodal neurons.

Suppose we have a photo of a Granny Smith apple,

url = "https://cdn.openai.com/multimodal-neurons/assets/apple/apple-blank.jpg"
image_path = download(url)

pil_img = Image(filename=image_path)
display(pil_img)

We then try to classify this image to several classes, such as Granny Smith, iPod, library, pizza, toaster and dough.

prob = predictor.predict_proba({"image": [image_path]}, {"text": ['Granny Smith', 'iPod', 'library', 'pizza', 'toaster', 'dough']})
print("Label probs:", prob)

We can see that zero-shot classification works great, it predicts apple with almost 100% confidence. But if we add a text to the apple like this,

url = "https://cdn.openai.com/multimodal-neurons/assets/apple/apple-ipod.jpg"
image_path = download(url)

pil_img = Image(filename=image_path)
display(pil_img)

Then we use the same class names to perform zero-shot classification,

prob = predictor.predict_proba({"image": [image_path]}, {"text": ['Granny Smith', 'iPod', 'library', 'pizza', 'toaster', 'dough']})
print("Label probs:", prob)

Suddenly, the apple becomes iPod.

CLIP also has other limitations. If you are interested, you can read CLIP paper for more details. Or you can stay here, play with your own examples!

Other Examples#

You may go to AutoMM Examples to explore other examples about AutoMM.

Customization#

To learn how to customize AutoMM, please refer to Customize AutoMM.