CLIP in AutoMM - Zero-Shot Image Classification#
When you want to classify an image to different classes, it is standard to train an image classifier based on the class names. However, it is tedious to collect training data. And if the collected data is too few or too imbalanced, you may not get a decent image classifier. So you wonder, is there a strong enough model that can handle this situaton without the training efforts?
Actually there is! OpenAI has introduced a model named CLIP, which can be applied to any visual classification benchmark by simply providing the names of the visual categories to be recognized. And its accuracy is high, e.g., CLIP can achieve 76.2% top-1 accuracy on ImageNet without using any of the 1.28M training samples. This performance matches with original supervised ResNet50 on ImageNet, quite promising for a classification task with 1000 classes!
So in this tutorial, let’s dive deep into CLIP. We will show you how to use CLIP model to do zero-shot image classification in AutoGluon.
Simple Demo#
Here we provide a simple demo to classify what dog breed is in the picture below.
from IPython.display import Image, display
from autogluon.multimodal import download
url = "https://farm4.staticflickr.com/3445/3262471985_ed886bf61a_z.jpg"
dog_image = download(url)
pil_img = Image(filename=dog_image)
display(pil_img)
Downloading 3262471985_ed886bf61a_z.jpg from https://farm4.staticflickr.com/3445/3262471985_ed886bf61a_z.jpg...
Normally to solve this task, you need to collect some training data (e.g., the Stanford Dogs dataset) and train a dog breed classifier. But with CLIP, all you need to do is provide some potential visual categories. CLIP will handle the rest for you.
from autogluon.multimodal import MultiModalPredictor
predictor = MultiModalPredictor(problem_type="zero_shot_image_classification")
prob = predictor.predict_proba({"image": [dog_image]}, {"text": ['This is a Husky', 'This is a Golden Retriever', 'This is a German Sheperd', 'This is a Samoyed.']})
print("Label probs:", prob)
The model does not support using an image size that is different from the default size. Provided image size=224. Default size=336. Detailed model configuration=CLIPConfig {
"_commit_hash": "ce19dc912ca5cd21c8a653c79e251e808ccabcd1",
"_name_or_path": "openai/clip-vit-large-patch14-336",
"architectures": [
"CLIPModel"
],
"initializer_factor": 1.0,
"logit_scale_init_value": 2.6592,
"model_type": "clip",
"projection_dim": 768,
"text_config": {
"_name_or_path": "",
"add_cross_attention": false,
"architectures": null,
"attention_dropout": 0.0,
"bad_words_ids": null,
"begin_suppress_tokens": null,
"bos_token_id": 0,
"chunk_size_feed_forward": 0,
"cross_attention_hidden_size": null,
"decoder_start_token_id": null,
"diversity_penalty": 0.0,
"do_sample": false,
"dropout": 0.0,
"early_stopping": false,
"encoder_no_repeat_ngram_size": 0,
"eos_token_id": 2,
"exponential_decay_length_penalty": null,
"finetuning_task": null,
"forced_bos_token_id": null,
"forced_eos_token_id": null,
"hidden_act": "quick_gelu",
"hidden_size": 768,
"id2label": {
"0": "LABEL_0",
"1": "LABEL_1"
},
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 3072,
"is_decoder": false,
"is_encoder_decoder": false,
"label2id": {
"LABEL_0": 0,
"LABEL_1": 1
},
"layer_norm_eps": 1e-05,
"length_penalty": 1.0,
"max_length": 20,
"max_position_embeddings": 77,
"min_length": 0,
"model_type": "clip_text_model",
"no_repeat_ngram_size": 0,
"num_attention_heads": 12,
"num_beam_groups": 1,
"num_beams": 1,
"num_hidden_layers": 12,
"num_return_sequences": 1,
"output_attentions": false,
"output_hidden_states": false,
"output_scores": false,
"pad_token_id": 1,
"prefix": null,
"problem_type": null,
"projection_dim": 768,
"pruned_heads": {},
"remove_invalid_values": false,
"repetition_penalty": 1.0,
"return_dict": true,
"return_dict_in_generate": false,
"sep_token_id": null,
"suppress_tokens": null,
"task_specific_params": null,
"temperature": 1.0,
"tf_legacy_loss": false,
"tie_encoder_decoder": false,
"tie_word_embeddings": true,
"tokenizer_class": null,
"top_k": 50,
"top_p": 1.0,
"torch_dtype": null,
"torchscript": false,
"transformers_version": "4.26.1",
"typical_p": 1.0,
"use_bfloat16": false,
"vocab_size": 49408
},
"text_config_dict": {
"hidden_size": 768,
"intermediate_size": 3072,
"num_attention_heads": 12,
"num_hidden_layers": 12,
"projection_dim": 768
},
"torch_dtype": "float32",
"transformers_version": null,
"vision_config": {
"_name_or_path": "",
"add_cross_attention": false,
"architectures": null,
"attention_dropout": 0.0,
"bad_words_ids": null,
"begin_suppress_tokens": null,
"bos_token_id": null,
"chunk_size_feed_forward": 0,
"cross_attention_hidden_size": null,
"decoder_start_token_id": null,
"diversity_penalty": 0.0,
"do_sample": false,
"dropout": 0.0,
"early_stopping": false,
"encoder_no_repeat_ngram_size": 0,
"eos_token_id": null,
"exponential_decay_length_penalty": null,
"finetuning_task": null,
"forced_bos_token_id": null,
"forced_eos_token_id": null,
"hidden_act": "quick_gelu",
"hidden_size": 1024,
"id2label": {
"0": "LABEL_0",
"1": "LABEL_1"
},
"image_size": 336,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 4096,
"is_decoder": false,
"is_encoder_decoder": false,
"label2id": {
"LABEL_0": 0,
"LABEL_1": 1
},
"layer_norm_eps": 1e-05,
"length_penalty": 1.0,
"max_length": 20,
"min_length": 0,
"model_type": "clip_vision_model",
"no_repeat_ngram_size": 0,
"num_attention_heads": 16,
"num_beam_groups": 1,
"num_beams": 1,
"num_channels": 3,
"num_hidden_layers": 24,
"num_return_sequences": 1,
"output_attentions": false,
"output_hidden_states": false,
"output_scores": false,
"pad_token_id": null,
"patch_size": 14,
"prefix": null,
"problem_type": null,
"projection_dim": 768,
"pruned_heads": {},
"remove_invalid_values": false,
"repetition_penalty": 1.0,
"return_dict": true,
"return_dict_in_generate": false,
"sep_token_id": null,
"suppress_tokens": null,
"task_specific_params": null,
"temperature": 1.0,
"tf_legacy_loss": false,
"tie_encoder_decoder": false,
"tie_word_embeddings": true,
"tokenizer_class": null,
"top_k": 50,
"top_p": 1.0,
"torch_dtype": null,
"torchscript": false,
"transformers_version": "4.26.1",
"typical_p": 1.0,
"use_bfloat16": false
},
"vision_config_dict": {
"hidden_size": 1024,
"image_size": 336,
"intermediate_size": 4096,
"num_attention_heads": 16,
"num_hidden_layers": 24,
"patch_size": 14,
"projection_dim": 768
}
}
. We have ignored the provided image size.
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮ │ /home/ci/opt/venv/lib/python3.8/site-packages/urllib3/connectionpool.py:466 in _make_request │ │ │ │ 463 │ │ │ │ │ # Remove the TypeError from the exception chain in │ │ 464 │ │ │ │ │ # Python 3 (including for exceptions like SystemExit). │ │ 465 │ │ │ │ │ # Otherwise it looks like a bug in the code. │ │ ❱ 466 │ │ │ │ │ six.raise_from(e, None) │ │ 467 │ │ except (SocketTimeout, BaseSSLError, SocketError) as e: │ │ 468 │ │ │ self._raise_timeout(err=e, url=url, timeout_value=read_timeout) │ │ 469 │ │ │ raise │ │ in raise_from:3 │ │ │ │ /home/ci/opt/venv/lib/python3.8/site-packages/urllib3/connectionpool.py:461 in _make_request │ │ │ │ 458 │ │ │ except TypeError: │ │ 459 │ │ │ │ # Python 3 │ │ 460 │ │ │ │ try: │ │ ❱ 461 │ │ │ │ │ httplib_response = conn.getresponse() │ │ 462 │ │ │ │ except BaseException as e: │ │ 463 │ │ │ │ │ # Remove the TypeError from the exception chain in │ │ 464 │ │ │ │ │ # Python 3 (including for exceptions like SystemExit). │ │ │ │ /opt/conda/lib/python3.8/http/client.py:1348 in getresponse │ │ │ │ 1345 │ │ │ │ 1346 │ │ try: │ │ 1347 │ │ │ try: │ │ ❱ 1348 │ │ │ │ response.begin() │ │ 1349 │ │ │ except ConnectionError: │ │ 1350 │ │ │ │ self.close() │ │ 1351 │ │ │ │ raise │ │ │ │ /opt/conda/lib/python3.8/http/client.py:316 in begin │ │ │ │ 313 │ │ │ │ 314 │ │ # read until we get a non-100 response │ │ 315 │ │ while True: │ │ ❱ 316 │ │ │ version, status, reason = self._read_status() │ │ 317 │ │ │ if status != CONTINUE: │ │ 318 │ │ │ │ break │ │ 319 │ │ │ # skip the header from the 100 response │ │ │ │ /opt/conda/lib/python3.8/http/client.py:277 in _read_status │ │ │ │ 274 │ │ self.will_close = _UNKNOWN # conn will close at end of response │ │ 275 │ │ │ 276 │ def _read_status(self): │ │ ❱ 277 │ │ line = str(self.fp.readline(_MAXLINE + 1), "iso-8859-1") │ │ 278 │ │ if len(line) > _MAXLINE: │ │ 279 │ │ │ raise LineTooLong("status line") │ │ 280 │ │ if self.debuglevel > 0: │ │ │ │ /opt/conda/lib/python3.8/socket.py:669 in readinto │ │ │ │ 666 │ │ │ raise OSError("cannot read from timed out object") │ │ 667 │ │ while True: │ │ 668 │ │ │ try: │ │ ❱ 669 │ │ │ │ return self._sock.recv_into(b) │ │ 670 │ │ │ except timeout: │ │ 671 │ │ │ │ self._timeout_occurred = True │ │ 672 │ │ │ │ raise │ │ │ │ /opt/conda/lib/python3.8/ssl.py:1241 in recv_into │ │ │ │ 1238 │ │ │ │ raise ValueError( │ │ 1239 │ │ │ │ "non-zero flags not allowed in calls to recv_into() on %s" % │ │ 1240 │ │ │ │ self.__class__) │ │ ❱ 1241 │ │ │ return self.read(nbytes, buffer) │ │ 1242 │ │ else: │ │ 1243 │ │ │ return super().recv_into(buffer, nbytes, flags) │ │ 1244 │ │ │ │ /opt/conda/lib/python3.8/ssl.py:1099 in read │ │ │ │ 1096 │ │ │ raise ValueError("Read on closed or unwrapped SSL socket.") │ │ 1097 │ │ try: │ │ 1098 │ │ │ if buffer is not None: │ │ ❱ 1099 │ │ │ │ return self._sslobj.read(len, buffer) │ │ 1100 │ │ │ else: │ │ 1101 │ │ │ │ return self._sslobj.read(len) │ │ 1102 │ │ except SSLError as x: │ ╰──────────────────────────────────────────────────────────────────────────────────────────────────╯ timeout: The read operation timed out During handling of the above exception, another exception occurred: ╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮ │ /home/ci/opt/venv/lib/python3.8/site-packages/requests/adapters.py:486 in send │ │ │ │ 483 │ │ │ timeout = TimeoutSauce(connect=timeout, read=timeout) │ │ 484 │ │ │ │ 485 │ │ try: │ │ ❱ 486 │ │ │ resp = conn.urlopen( │ │ 487 │ │ │ │ method=request.method, │ │ 488 │ │ │ │ url=url, │ │ 489 │ │ │ │ body=request.body, │ │ │ │ /home/ci/opt/venv/lib/python3.8/site-packages/urllib3/connectionpool.py:798 in urlopen │ │ │ │ 795 │ │ │ elif isinstance(e, (SocketError, HTTPException)): │ │ 796 │ │ │ │ e = ProtocolError("Connection aborted.", e) │ │ 797 │ │ │ │ │ ❱ 798 │ │ │ retries = retries.increment( │ │ 799 │ │ │ │ method, url, error=e, _pool=self, _stacktrace=sys.exc_info()[2] │ │ 800 │ │ │ ) │ │ 801 │ │ │ retries.sleep() │ │ │ │ /home/ci/opt/venv/lib/python3.8/site-packages/urllib3/util/retry.py:550 in increment │ │ │ │ 547 │ │ elif error and self._is_read_error(error): │ │ 548 │ │ │ # Read retry? │ │ 549 │ │ │ if read is False or not self._is_method_retryable(method): │ │ ❱ 550 │ │ │ │ raise six.reraise(type(error), error, _stacktrace) │ │ 551 │ │ │ elif read is not None: │ │ 552 │ │ │ │ read -= 1 │ │ 553 │ │ │ │ /home/ci/opt/venv/lib/python3.8/site-packages/urllib3/packages/six.py:770 in reraise │ │ │ │ 767 │ │ │ │ value = tp() │ │ 768 │ │ │ if value.__traceback__ is not tb: │ │ 769 │ │ │ │ raise value.with_traceback(tb) │ │ ❱ 770 │ │ │ raise value │ │ 771 │ │ finally: │ │ 772 │ │ │ value = None │ │ 773 │ │ │ tb = None │ │ │ │ /home/ci/opt/venv/lib/python3.8/site-packages/urllib3/connectionpool.py:714 in urlopen │ │ │ │ 711 │ │ │ │ self._prepare_proxy(conn) │ │ 712 │ │ │ │ │ 713 │ │ │ # Make the request on the httplib connection object. │ │ ❱ 714 │ │ │ httplib_response = self._make_request( │ │ 715 │ │ │ │ conn, │ │ 716 │ │ │ │ method, │ │ 717 │ │ │ │ url, │ │ │ │ /home/ci/opt/venv/lib/python3.8/site-packages/urllib3/connectionpool.py:468 in _make_request │ │ │ │ 465 │ │ │ │ │ # Otherwise it looks like a bug in the code. │ │ 466 │ │ │ │ │ six.raise_from(e, None) │ │ 467 │ │ except (SocketTimeout, BaseSSLError, SocketError) as e: │ │ ❱ 468 │ │ │ self._raise_timeout(err=e, url=url, timeout_value=read_timeout) │ │ 469 │ │ │ raise │ │ 470 │ │ │ │ 471 │ │ # AppEngine doesn't have a version attr. │ │ │ │ /home/ci/opt/venv/lib/python3.8/site-packages/urllib3/connectionpool.py:357 in _raise_timeout │ │ │ │ 354 │ │ """Is the error actually a timeout? Will raise a ReadTimeout or pass""" │ │ 355 │ │ │ │ 356 │ │ if isinstance(err, SocketTimeout): │ │ ❱ 357 │ │ │ raise ReadTimeoutError( │ │ 358 │ │ │ │ self, url, "Read timed out. (read timeout=%s)" % timeout_value │ │ 359 │ │ │ ) │ │ 360 │ ╰──────────────────────────────────────────────────────────────────────────────────────────────────╯ ReadTimeoutError: HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10.0) During handling of the above exception, another exception occurred: ╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮ │ in <module>:4 │ │ │ │ 1 from autogluon.multimodal import MultiModalPredictor │ │ 2 │ │ 3 predictor = MultiModalPredictor(problem_type="zero_shot_image_classification") │ │ ❱ 4 prob = predictor.predict_proba({"image": [dog_image]}, {"text": ['This is a Husky', 'Thi │ │ 5 print("Label probs:", prob) │ │ 6 │ │ │ │ /home/ci/autogluon/multimodal/src/autogluon/multimodal/predictor.py:2309 in predict_proba │ │ │ │ 2306 │ │ │ │ realtime=realtime, │ │ 2307 │ │ │ ) │ │ 2308 │ │ │ │ ❱ 2309 │ │ self._ensure_inference_ready() │ │ 2310 │ │ assert self._problem_type not in [ │ │ 2311 │ │ │ REGRESSION, │ │ 2312 │ │ ], f"Problem {self._problem_type} has no probability output." │ │ │ │ /home/ci/autogluon/multimodal/src/autogluon/multimodal/predictor.py:950 in │ │ _ensure_inference_ready │ │ │ │ 947 │ │ │ │ │ f"running `predictor.predict()`, `predictor.evaluate()` or `predicto │ │ 948 │ │ │ │ ) │ │ 949 │ │ │ else: │ │ ❱ 950 │ │ │ │ self._init_pretrained() │ │ 951 │ │ │ 952 │ def _setup_distillation( │ │ 953 │ │ self, │ │ │ │ /home/ci/autogluon/multimodal/src/autogluon/multimodal/predictor.py:935 in _init_pretrained │ │ │ │ 932 │ │ │ │ config=self._config, pretrained=self._pretrained, num_classes=self._outp │ │ 933 │ │ │ ) │ │ 934 │ │ if self._data_processors is None: │ │ ❱ 935 │ │ │ self._data_processors = create_fusion_data_processors( │ │ 936 │ │ │ │ config=self._config, │ │ 937 │ │ │ │ model=self._model, │ │ 938 │ │ │ │ advanced_hyperparameters=advanced_hyperparameters, │ │ │ │ /home/ci/autogluon/multimodal/src/autogluon/multimodal/utils/data.py:312 in │ │ create_fusion_data_processors │ │ │ │ 309 │ │ │ │ 310 │ │ if requires_data and data_types and per_name != OVD: # currently OVD does not r │ │ 311 │ │ │ for data_type in data_types: │ │ ❱ 312 │ │ │ │ per_data_processor = create_data_processor( │ │ 313 │ │ │ │ │ data_type=data_type, │ │ 314 │ │ │ │ │ model=per_model, │ │ 315 │ │ │ │ │ config=config, │ │ │ │ /home/ci/autogluon/multimodal/src/autogluon/multimodal/utils/data.py:148 in │ │ create_data_processor │ │ │ │ 145 │ │ │ missing_value_strategy=config.data.image.missing_value_strategy, │ │ 146 │ │ ) │ │ 147 │ elif data_type == TEXT: │ │ ❱ 148 │ │ data_processor = TextProcessor( │ │ 149 │ │ │ model=model, │ │ 150 │ │ │ tokenizer_name=model_config.tokenizer_name, │ │ 151 │ │ │ max_len=model_config.max_text_len, │ │ │ │ /home/ci/autogluon/multimodal/src/autogluon/multimodal/data/process_text.py:136 in __init__ │ │ │ │ 133 │ │ if hasattr(model, "tokenizer"): │ │ 134 │ │ │ self.tokenizer = model.tokenizer │ │ 135 │ │ else: │ │ ❱ 136 │ │ │ self.tokenizer = self.get_pretrained_tokenizer( │ │ 137 │ │ │ │ tokenizer_name=tokenizer_name, │ │ 138 │ │ │ │ checkpoint_name=model.checkpoint_name, │ │ 139 │ │ │ ) │ │ │ │ /home/ci/autogluon/multimodal/src/autogluon/multimodal/data/process_text.py:430 in │ │ get_pretrained_tokenizer │ │ │ │ 427 │ │ """ │ │ 428 │ │ try: │ │ 429 │ │ │ tokenizer_class = ALL_TOKENIZERS[tokenizer_name] │ │ ❱ 430 │ │ │ return tokenizer_class.from_pretrained(checkpoint_name) │ │ 431 │ │ except TypeError as e: │ │ 432 │ │ │ try: │ │ 433 │ │ │ │ tokenizer_class = ALL_TOKENIZERS["bert"] │ │ │ │ /home/ci/opt/venv/lib/python3.8/site-packages/transformers/tokenization_utils_base.py:1763 in │ │ from_pretrained │ │ │ │ 1760 │ │ │ │ elif is_remote_url(file_path): │ │ 1761 │ │ │ │ │ resolved_vocab_files[file_id] = download_url(file_path, proxies=prox │ │ 1762 │ │ │ else: │ │ ❱ 1763 │ │ │ │ resolved_vocab_files[file_id] = cached_file( │ │ 1764 │ │ │ │ │ pretrained_model_name_or_path, │ │ 1765 │ │ │ │ │ file_path, │ │ 1766 │ │ │ │ │ cache_dir=cache_dir, │ │ │ │ /home/ci/opt/venv/lib/python3.8/site-packages/transformers/utils/hub.py:409 in cached_file │ │ │ │ 406 │ user_agent = http_user_agent(user_agent) │ │ 407 │ try: │ │ 408 │ │ # Load from URL or cache if already cached │ │ ❱ 409 │ │ resolved_file = hf_hub_download( │ │ 410 │ │ │ path_or_repo_id, │ │ 411 │ │ │ filename, │ │ 412 │ │ │ subfolder=None if len(subfolder) == 0 else subfolder, │ │ │ │ /home/ci/opt/venv/lib/python3.8/site-packages/huggingface_hub/utils/_validators.py:118 in │ │ _inner_fn │ │ │ │ 115 │ │ if check_use_auth_token: │ │ 116 │ │ │ kwargs = smoothly_deprecate_use_auth_token(fn_name=fn.__name__, has_token=ha │ │ 117 │ │ │ │ ❱ 118 │ │ return fn(*args, **kwargs) │ │ 119 │ │ │ 120 │ return _inner_fn # type: ignore │ │ 121 │ │ │ │ /home/ci/opt/venv/lib/python3.8/site-packages/huggingface_hub/file_download.py:1364 in │ │ hf_hub_download │ │ │ │ 1361 │ │ with temp_file_manager() as temp_file: │ │ 1362 │ │ │ logger.info("downloading %s to %s", url, temp_file.name) │ │ 1363 │ │ │ │ │ ❱ 1364 │ │ │ http_get( │ │ 1365 │ │ │ │ url_to_download, │ │ 1366 │ │ │ │ temp_file, │ │ 1367 │ │ │ │ proxies=proxies, │ │ │ │ /home/ci/opt/venv/lib/python3.8/site-packages/huggingface_hub/file_download.py:505 in http_get │ │ │ │ 502 │ if resume_size > 0: │ │ 503 │ │ headers["Range"] = "bytes=%d-" % (resume_size,) │ │ 504 │ │ │ ❱ 505 │ r = _request_wrapper( │ │ 506 │ │ method="GET", │ │ 507 │ │ url=url, │ │ 508 │ │ stream=True, │ │ │ │ /home/ci/opt/venv/lib/python3.8/site-packages/huggingface_hub/file_download.py:442 in │ │ _request_wrapper │ │ │ │ 439 │ │ return response │ │ 440 │ │ │ 441 │ # 3. Exponential backoff │ │ ❱ 442 │ return http_backoff( │ │ 443 │ │ method=method, │ │ 444 │ │ url=url, │ │ 445 │ │ max_retries=max_retries, │ │ │ │ /home/ci/opt/venv/lib/python3.8/site-packages/huggingface_hub/utils/_http.py:212 in http_backoff │ │ │ │ 209 │ │ │ │ kwargs["data"].seek(io_obj_initial_pos) │ │ 210 │ │ │ │ │ 211 │ │ │ # Perform request and return if status_code is not in the retry list. │ │ ❱ 212 │ │ │ response = session.request(method=method, url=url, **kwargs) │ │ 213 │ │ │ if response.status_code not in retry_on_status_codes: │ │ 214 │ │ │ │ return response │ │ 215 │ │ │ │ /home/ci/opt/venv/lib/python3.8/site-packages/requests/sessions.py:589 in request │ │ │ │ 586 │ │ │ "allow_redirects": allow_redirects, │ │ 587 │ │ } │ │ 588 │ │ send_kwargs.update(settings) │ │ ❱ 589 │ │ resp = self.send(prep, **send_kwargs) │ │ 590 │ │ │ │ 591 │ │ return resp │ │ 592 │ │ │ │ /home/ci/opt/venv/lib/python3.8/site-packages/requests/sessions.py:703 in send │ │ │ │ 700 │ │ start = preferred_clock() │ │ 701 │ │ │ │ 702 │ │ # Send the request │ │ ❱ 703 │ │ r = adapter.send(request, **kwargs) │ │ 704 │ │ │ │ 705 │ │ # Total elapsed time of the request (approximately) │ │ 706 │ │ elapsed = preferred_clock() - start │ │ │ │ /home/ci/opt/venv/lib/python3.8/site-packages/requests/adapters.py:532 in send │ │ │ │ 529 │ │ │ │ # This branch is for urllib3 versions earlier than v1.22 │ │ 530 │ │ │ │ raise SSLError(e, request=request) │ │ 531 │ │ │ elif isinstance(e, ReadTimeoutError): │ │ ❱ 532 │ │ │ │ raise ReadTimeout(e, request=request) │ │ 533 │ │ │ elif isinstance(e, _InvalidHeader): │ │ 534 │ │ │ │ raise InvalidHeader(e, request=request) │ │ 535 │ │ │ else: │ ╰──────────────────────────────────────────────────────────────────────────────────────────────────╯ ReadTimeout: HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10.0)
Clearly, according to the probabilities, we know there is a Husky in the photo (which I think is correct)!
Let’s try a harder example. Below is a photo of two Segways. This object class is not common in most existing vision datasets.
url = "https://live.staticflickr.com/7236/7114602897_9cf00b2820_b.jpg"
segway_image = download(url)
pil_img = Image(filename=segway_image)
display(pil_img)
Given several text queries, CLIP can still predict the segway class correctly with high confidence.
prob = predictor.predict_proba({"image": [segway_image]}, {"text": ['segway', 'bicycle', 'wheel', 'car']})
print("Label probs:", prob)
This is amazing, right? Now a bit knowledge on why and how CLIP works. CLIP is called Contrastive Language-Image Pre-training. It is trained on a massive number of data (400M image-text pairs). By using a simple loss objective, CLIP tries to predict which out of a set of randomly sampled text is actually paired with an given image in the training dataset. As a result, CLIP models can then be applied to nearly arbitrary visual classification tasks just like the examples we have shown above.
More about CLIP#
CLIP is powerful, and it was designed to mitigate a number of major problems in the standard deep learning approach to computer vision, such as costly datasets, closed set prediction and poor generalization performance. CLIP is a good solution to many problems, however, it is not the ultimate solution. CLIP has its own limitations. For example, CLIP is vulnerable to typographic attacks, i.e., if you add some text to an image, CLIP’s predictions will be easily affected by the text. Let’s see one example from OpenAI’s blog post on multimodal neurons.
Suppose we have a photo of a Granny Smith apple,
url = "https://cdn.openai.com/multimodal-neurons/assets/apple/apple-blank.jpg"
image_path = download(url)
pil_img = Image(filename=image_path)
display(pil_img)
We then try to classify this image to several classes, such as Granny Smith, iPod, library, pizza, toaster and dough.
prob = predictor.predict_proba({"image": [image_path]}, {"text": ['Granny Smith', 'iPod', 'library', 'pizza', 'toaster', 'dough']})
print("Label probs:", prob)
We can see that zero-shot classification works great, it predicts apple with almost 100% confidence. But if we add a text to the apple like this,
url = "https://cdn.openai.com/multimodal-neurons/assets/apple/apple-ipod.jpg"
image_path = download(url)
pil_img = Image(filename=image_path)
display(pil_img)
Then we use the same class names to perform zero-shot classification,
prob = predictor.predict_proba({"image": [image_path]}, {"text": ['Granny Smith', 'iPod', 'library', 'pizza', 'toaster', 'dough']})
print("Label probs:", prob)
Suddenly, the apple becomes iPod.
CLIP also has other limitations. If you are interested, you can read CLIP paper for more details. Or you can stay here, play with your own examples!
Other Examples#
You may go to AutoMM Examples to explore other examples about AutoMM.
Customization#
To learn how to customize AutoMM, please refer to Customize AutoMM.