Text-to-Text Semantic Matching with AutoMM

Open In Colab Open In SageMaker Studio Lab

Computing the similarity between two sentences/passages is a common task in NLP, with several practical applications such as web search, question answering, documents deduplication, plagiarism comparison, natural language inference, recommendation engines, etc. In general, text similarity models will take two sentences/passages as input and transform them into vectors, and then similarity scores calculated using cosine similarity, dot product, or Euclidean distances are used to measure how alike or different of the two text pieces.

Prepare your Data

In this tutorial, we will demonstrate how to use AutoMM for text-to-text semantic matching with the Stanford Natural Language Inference (SNLI) corpus. SNLI is a corpus contains around 570k human-written sentence pairs labeled with entailment, contradiction, and neutral. It is a widely used benchmark for evaluating the representation and inference capbility of machine learning methods. The following table contains three examples taken from this corpus.

Premise

Hypothesis

Label

A black race car starts up in front of a crowd of people.

A man is driving down a lonely road.

contradiction

An older and younger man smiling.

Two men are smiling and laughing at the cats playing on the floor.

neutral

A soccer game with multiple males playing.

Some men are playing a sport.

entailment

Here, we consider sentence pairs with label entailment as positive pairs (labeled as 1) and those with label contradiction as negative pairs (labeled as 0). Sentence pairs with neural relationship are discarded. The following code downloads and loads the corpus into dataframes.

from autogluon.core.utils.loaders import load_pd
import pandas as pd

snli_train = load_pd.load('https://automl-mm-bench.s3.amazonaws.com/snli/snli_train.csv', delimiter="|")
snli_test = load_pd.load('https://automl-mm-bench.s3.amazonaws.com/snli/snli_test.csv', delimiter="|")
snli_train.head()
premise hypothesis label
0 A person on a horse jumps over a broken down a... A person is at a diner , ordering an omelette . 0
1 A person on a horse jumps over a broken down a... A person is outdoors , on a horse . 1
2 Children smiling and waving at camera There are children present 1
3 Children smiling and waving at camera The kids are frowning 0
4 A boy is jumping on skateboard in the middle o... The boy skates down the sidewalk . 0

Train your Model

Ideally, we want to obtain a model that can return high/low scores for positive/negative text pairs. Traditional text similarity methods only work on a lexical level without taking the semantic aspect into account, for example, using term frequency or tf-idf vectors. With AutoMM, we can easily train a model that captures the semantic relationship between sentences. Basically, it uses BERT to project each sentence into a high-dimensional vector and treat the matching problem as a classification problem following the design in sentence transformers. With AutoMM, you just need to specify the query, response, and label column names and fit the model on the training dataset without worrying the implementation details. Note that the labels should be binary, and we need to specify the match_label, which means two sentences have the same semantic meaning. In practice, your tasks may have different labels, e.g., duplicate or not duplicate. You may need to define the match_label by considering your specific task contexts.

from autogluon.multimodal import MultiModalPredictor

# Initialize the model
predictor = MultiModalPredictor(
        problem_type="text_similarity",
        query="premise", # the column name of the first sentence
        response="hypothesis", # the column name of the second sentence
        label="label", # the label column name
        match_label=1, # the label indicating that query and response have the same semantic meanings.
        eval_metric='auc', # the evaluation metric
    )

# Fit the model
predictor.fit(
    train_data=snli_train,
    time_limit=180,
)
No path specified. Models will be saved in: "AutogluonModels/ag-20250515_220300"
=================== System Info ===================
AutoGluon Version:  1.3.1b20250515
Python Version:     3.11.9
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count:          8
Pytorch Version:    2.6.0+cu124
CUDA Version:       12.4
Memory Avail:       28.17 GB / 30.95 GB (91.0%)
Disk Space Avail:   168.81 GB / 255.99 GB (65.9%)
===================================================
AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed).
	2 unique label values:  [np.int64(0), np.int64(1)]
	If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during Predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression', 'quantile'])
/home/ci/autogluon/multimodal/src/autogluon/multimodal/optim/metrics/utils.py:185: UserWarning: Metric auc is not supported as the evaluation metric for binary in matching tasks.The evaluation metric is changed to roc_auc by default.
  warnings.warn(

AutoMM starts to create your model. ✨✨✨

To track the learning progress, you can open a terminal and launch Tensorboard:
    ```shell
    # Assume you have installed tensorboard
    tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/semantic_matching/AutogluonModels/ag-20250515_220300
    ```
Seed set to 0
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
---------------------------------------------------------------------------
TimeoutError                              Traceback (most recent call last)
File ~/opt/venv/lib/python3.11/site-packages/urllib3/connectionpool.py:534, in HTTPConnectionPool._make_request(self, conn, method, url, body, headers, retries, timeout, chunked, response_conn, preload_content, decode_content, enforce_content_length)
    533 try:
--> 534     response = conn.getresponse()
    535 except (BaseSSLError, OSError) as e:

File ~/opt/venv/lib/python3.11/site-packages/urllib3/connection.py:516, in HTTPConnection.getresponse(self)
    515 # Get the response from http.client.HTTPConnection
--> 516 httplib_response = super().getresponse()
    518 try:

File /opt/conda/lib/python3.11/http/client.py:1395, in HTTPConnection.getresponse(self)
   1394 try:
-> 1395     response.begin()
   1396 except ConnectionError:

File /opt/conda/lib/python3.11/http/client.py:325, in HTTPResponse.begin(self)
    324 while True:
--> 325     version, status, reason = self._read_status()
    326     if status != CONTINUE:

File /opt/conda/lib/python3.11/http/client.py:286, in HTTPResponse._read_status(self)
    285 def _read_status(self):
--> 286     line = str(self.fp.readline(_MAXLINE + 1), "iso-8859-1")
    287     if len(line) > _MAXLINE:

File /opt/conda/lib/python3.11/socket.py:706, in SocketIO.readinto(self, b)
    705 try:
--> 706     return self._sock.recv_into(b)
    707 except timeout:

File /opt/conda/lib/python3.11/ssl.py:1314, in SSLSocket.recv_into(self, buffer, nbytes, flags)
   1311         raise ValueError(
   1312           "non-zero flags not allowed in calls to recv_into() on %s" %
   1313           self.__class__)
-> 1314     return self.read(nbytes, buffer)
   1315 else:

File /opt/conda/lib/python3.11/ssl.py:1166, in SSLSocket.read(self, len, buffer)
   1165 if buffer is not None:
-> 1166     return self._sslobj.read(len, buffer)
   1167 else:

TimeoutError: The read operation timed out

The above exception was the direct cause of the following exception:

ReadTimeoutError                          Traceback (most recent call last)
File ~/opt/venv/lib/python3.11/site-packages/requests/adapters.py:667, in HTTPAdapter.send(self, request, stream, timeout, verify, cert, proxies)
    666 try:
--> 667     resp = conn.urlopen(
    668         method=request.method,
    669         url=url,
    670         body=request.body,
    671         headers=request.headers,
    672         redirect=False,
    673         assert_same_host=False,
    674         preload_content=False,
    675         decode_content=False,
    676         retries=self.max_retries,
    677         timeout=timeout,
    678         chunked=chunked,
    679     )
    681 except (ProtocolError, OSError) as err:

File ~/opt/venv/lib/python3.11/site-packages/urllib3/connectionpool.py:841, in HTTPConnectionPool.urlopen(self, method, url, body, headers, retries, redirect, assert_same_host, timeout, pool_timeout, release_conn, chunked, body_pos, preload_content, decode_content, **response_kw)
    839     new_e = ProtocolError("Connection aborted.", new_e)
--> 841 retries = retries.increment(
    842     method, url, error=new_e, _pool=self, _stacktrace=sys.exc_info()[2]
    843 )
    844 retries.sleep()

File ~/opt/venv/lib/python3.11/site-packages/urllib3/util/retry.py:474, in Retry.increment(self, method, url, response, error, _pool, _stacktrace)
    473 if read is False or method is None or not self._is_method_retryable(method):
--> 474     raise reraise(type(error), error, _stacktrace)
    475 elif read is not None:

File ~/opt/venv/lib/python3.11/site-packages/urllib3/util/util.py:39, in reraise(tp, value, tb)
     38         raise value.with_traceback(tb)
---> 39     raise value
     40 finally:

File ~/opt/venv/lib/python3.11/site-packages/urllib3/connectionpool.py:787, in HTTPConnectionPool.urlopen(self, method, url, body, headers, retries, redirect, assert_same_host, timeout, pool_timeout, release_conn, chunked, body_pos, preload_content, decode_content, **response_kw)
    786 # Make the request on the HTTPConnection object
--> 787 response = self._make_request(
    788     conn,
    789     method,
    790     url,
    791     timeout=timeout_obj,
    792     body=body,
    793     headers=headers,
    794     chunked=chunked,
    795     retries=retries,
    796     response_conn=response_conn,
    797     preload_content=preload_content,
    798     decode_content=decode_content,
    799     **response_kw,
    800 )
    802 # Everything went great!

File ~/opt/venv/lib/python3.11/site-packages/urllib3/connectionpool.py:536, in HTTPConnectionPool._make_request(self, conn, method, url, body, headers, retries, timeout, chunked, response_conn, preload_content, decode_content, enforce_content_length)
    535 except (BaseSSLError, OSError) as e:
--> 536     self._raise_timeout(err=e, url=url, timeout_value=read_timeout)
    537     raise

File ~/opt/venv/lib/python3.11/site-packages/urllib3/connectionpool.py:367, in HTTPConnectionPool._raise_timeout(self, err, url, timeout_value)
    366 if isinstance(err, SocketTimeout):
--> 367     raise ReadTimeoutError(
    368         self, url, f"Read timed out. (read timeout={timeout_value})"
    369     ) from err
    371 # See the above comment about EAGAIN in Python 3.

ReadTimeoutError: HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)

During handling of the above exception, another exception occurred:

ReadTimeout                               Traceback (most recent call last)
Cell In[3], line 14
      4 predictor = MultiModalPredictor(
      5         problem_type="text_similarity",
      6         query="premise", # the column name of the first sentence
   (...)
     10         eval_metric='auc', # the evaluation metric
     11     )
     13 # Fit the model
---> 14 predictor.fit(
     15     train_data=snli_train,
     16     time_limit=180,
     17 )

File ~/autogluon/multimodal/src/autogluon/multimodal/predictor.py:540, in MultiModalPredictor.fit(self, train_data, presets, tuning_data, max_num_tuning_data, id_mappings, time_limit, save_path, hyperparameters, column_types, holdout_frac, teacher_predictor, seed, standalone, hyperparameter_tune_kwargs, clean_ckpts, predictions, labels, predictors)
    537     assert isinstance(predictors, list)
    538     learners = [ele if isinstance(ele, str) else ele._learner for ele in predictors]
--> 540 self._learner.fit(
    541     train_data=train_data,
    542     presets=presets,
    543     tuning_data=tuning_data,
    544     max_num_tuning_data=max_num_tuning_data,
    545     time_limit=time_limit,
    546     save_path=save_path,
    547     hyperparameters=hyperparameters,
    548     column_types=column_types,
    549     holdout_frac=holdout_frac,
    550     teacher_learner=teacher_learner,
    551     seed=seed,
    552     standalone=standalone,
    553     hyperparameter_tune_kwargs=hyperparameter_tune_kwargs,
    554     clean_ckpts=clean_ckpts,
    555     id_mappings=id_mappings,
    556     predictions=predictions,
    557     labels=labels,
    558     learners=learners,
    559 )
    561 return self

File ~/autogluon/multimodal/src/autogluon/multimodal/learners/matching.py:569, in MatchingLearner.fit(self, train_data, id_mappings, presets, tuning_data, time_limit, save_path, hyperparameters, column_types, holdout_frac, hyperparameter_tune_kwargs, seed, standalone, clean_ckpts, **kwargs)
    561 self.fit_sanity_check()
    562 self.prepare_fit_args(
    563     time_limit=time_limit,
    564     seed=seed,
   (...)
    567     id_mappings=id_mappings,
    568 )
--> 569 self.execute_fit()
    570 self.on_fit_end(
    571     training_start=training_start,
    572     standalone=standalone,
    573     clean_ckpts=clean_ckpts,
    574 )
    575 return self

File ~/autogluon/multimodal/src/autogluon/multimodal/learners/matching.py:438, in MatchingLearner.execute_fit(self)
    436     return dict()
    437 else:
--> 438     attributes = self.fit_per_run(**self._fit_args)
    439     self.update_attributes(**attributes)  # only update attributes for non-HPO mode
    440     return attributes

File ~/autogluon/multimodal/src/autogluon/multimodal/learners/matching.py:784, in MatchingLearner.fit_per_run(self, id_mappings, max_time, save_path, ckpt_path, resume, enable_progress_bar, seed, config, hyperparameters, advanced_hyperparameters, standalone, clean_ckpts)
    781 response_config = select_model(config=response_config, df_preprocessor=response_df_preprocessor, strict=False)
    783 if self._query_model is None or self._response_model is None:
--> 784     query_model, response_model = create_siamese_model(
    785         query_config=query_config,
    786         response_config=response_config,
    787         pretrained=self._pretrained,
    788     )
    789 else:  # continuing training
    790     query_model = self._query_model

File ~/autogluon/multimodal/src/autogluon/multimodal/utils/matcher.py:256, in create_siamese_model(query_config, response_config, query_model, response_model, pretrained)
    237 """
    238 Create the query and response models and make them share the same encoders for the same modalities.
    239 
   (...)
    253 The query and response models satisfying the siamese constraint.
    254 """
    255 if query_model is None:
--> 256     single_models, query_fusion_model = create_fusion_model_dict(
    257         config=query_config.model,
    258         pretrained=pretrained,
    259     )
    260 else:
    261     single_models, query_fusion_model = get_fusion_model_dict(
    262         model=query_model,
    263     )

File ~/autogluon/multimodal/src/autogluon/multimodal/utils/matcher.py:102, in create_fusion_model_dict(config, single_models, pretrained)
    100     if model_name in single_models:
    101         continue
--> 102 model = create_model(
    103     model_name=model_name,
    104     model_config=model_config,
    105     pretrained=pretrained,
    106     is_matching=True,  # clip needs to use this to init attributes for both image and text
    107 )
    108 if model_name.lower().startswith(FUSION):
    109     fusion_model = model

File ~/autogluon/multimodal/src/autogluon/multimodal/models/utils.py:1343, in create_model(model_name, model_config, num_classes, classes, num_numerical_columns, num_categories, numerical_fill_values, pretrained, is_matching)
   1340 elif model_name.lower().startswith(HF_TEXT):
   1341     from .hf_text import HFAutoModelForTextPrediction
-> 1343     model = HFAutoModelForTextPrediction(
   1344         prefix=model_name,
   1345         checkpoint_name=model_config.checkpoint_name,
   1346         num_classes=num_classes,
   1347         pooling_mode=model_config.pooling_mode,
   1348         gradient_checkpointing=model_config.gradient_checkpointing,
   1349         low_cpu_mem_usage=model_config.low_cpu_mem_usage,
   1350         pretrained=pretrained,
   1351         tokenizer_name=model_config.tokenizer_name,
   1352         max_text_len=model_config.max_text_len,
   1353         text_segment_num=model_config.text_segment_num,
   1354         use_fast=model_config.use_fast,
   1355     )
   1356 elif model_name.lower().startswith(T_FEW):
   1357     from .t_few import TFewModel

File ~/autogluon/multimodal/src/autogluon/multimodal/models/hf_text.py:107, in HFAutoModelForTextPrediction.__init__(self, prefix, checkpoint_name, num_classes, pooling_mode, gradient_checkpointing, low_cpu_mem_usage, pretrained, tokenizer_name, max_text_len, text_segment_num, use_fast)
    103 self.config, self.model = get_hf_config_and_model(
    104     checkpoint_name=checkpoint_name, pretrained=pretrained, low_cpu_mem_usage=low_cpu_mem_usage
    105 )
    106 self.tokenizer_name = tokenizer_name
--> 107 self.tokenizer = get_pretrained_tokenizer(
    108     tokenizer_name=self.tokenizer_name,
    109     checkpoint_name=self.checkpoint_name,
    110     use_fast=use_fast,
    111 )
    112 self.max_text_len = get_text_token_max_len(
    113     provided_max_len=max_text_len,
    114     config=self.config,
    115     tokenizer=self.tokenizer,
    116     checkpoint_name=self.checkpoint_name,
    117 )
    118 self.text_segment_num = get_text_segment_num(
    119     config=self.config,
    120     provided_segment_num=text_segment_num,
    121     checkpoint_name=self.checkpoint_name,
    122 )

File ~/autogluon/multimodal/src/autogluon/multimodal/models/utils.py:925, in get_pretrained_tokenizer(tokenizer_name, checkpoint_name, use_fast, add_prefix_space)
    923 tokenizer_class = ALL_TOKENIZERS[tokenizer_name]
    924 if add_prefix_space is None:
--> 925     return tokenizer_class.from_pretrained(checkpoint_name, use_fast=use_fast)
    926 else:
    927     return tokenizer_class.from_pretrained(
    928         checkpoint_name, use_fast=use_fast, add_prefix_space=add_prefix_space
    929     )

File ~/opt/venv/lib/python3.11/site-packages/transformers/models/auto/tokenization_auto.py:944, in AutoTokenizer.from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs)
    940     if tokenizer_class is None:
    941         raise ValueError(
    942             f"Tokenizer class {tokenizer_class_candidate} does not exist or is not currently imported."
    943         )
--> 944     return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
    946 # Otherwise we have to be creative.
    947 # if model is an encoder decoder, the encoder tokenizer class is used by default
    948 if isinstance(config, EncoderDecoderConfig):

File ~/opt/venv/lib/python3.11/site-packages/transformers/tokenization_utils_base.py:2008, in PreTrainedTokenizerBase.from_pretrained(cls, pretrained_model_name_or_path, cache_dir, force_download, local_files_only, token, revision, trust_remote_code, *init_inputs, **kwargs)
   2006             resolved_vocab_files[file_id] = download_url(file_path, proxies=proxies)
   2007     else:
-> 2008         resolved_vocab_files[file_id] = cached_file(
   2009             pretrained_model_name_or_path,
   2010             file_path,
   2011             cache_dir=cache_dir,
   2012             force_download=force_download,
   2013             proxies=proxies,
   2014             resume_download=resume_download,
   2015             local_files_only=local_files_only,
   2016             token=token,
   2017             user_agent=user_agent,
   2018             revision=revision,
   2019             subfolder=subfolder,
   2020             _raise_exceptions_for_gated_repo=False,
   2021             _raise_exceptions_for_missing_entries=False,
   2022             _raise_exceptions_for_connection_errors=False,
   2023             _commit_hash=commit_hash,
   2024         )
   2025         commit_hash = extract_commit_hash(resolved_vocab_files[file_id], commit_hash)
   2027 if len(unresolved_files) > 0:

File ~/opt/venv/lib/python3.11/site-packages/transformers/utils/hub.py:342, in cached_file(path_or_repo_id, filename, cache_dir, force_download, resume_download, proxies, token, revision, local_files_only, subfolder, repo_type, user_agent, _raise_exceptions_for_gated_repo, _raise_exceptions_for_missing_entries, _raise_exceptions_for_connection_errors, _commit_hash, **deprecated_kwargs)
    339 user_agent = http_user_agent(user_agent)
    340 try:
    341     # Load from URL or cache if already cached
--> 342     resolved_file = hf_hub_download(
    343         path_or_repo_id,
    344         filename,
    345         subfolder=None if len(subfolder) == 0 else subfolder,
    346         repo_type=repo_type,
    347         revision=revision,
    348         cache_dir=cache_dir,
    349         user_agent=user_agent,
    350         force_download=force_download,
    351         proxies=proxies,
    352         resume_download=resume_download,
    353         token=token,
    354         local_files_only=local_files_only,
    355     )
    356 except GatedRepoError as e:
    357     resolved_file = _get_cache_file_to_return(path_or_repo_id, full_filename, cache_dir, revision)

File ~/opt/venv/lib/python3.11/site-packages/huggingface_hub/utils/_validators.py:114, in validate_hf_hub_args.<locals>._inner_fn(*args, **kwargs)
    111 if check_use_auth_token:
    112     kwargs = smoothly_deprecate_use_auth_token(fn_name=fn.__name__, has_token=has_token, kwargs=kwargs)
--> 114 return fn(*args, **kwargs)

File ~/opt/venv/lib/python3.11/site-packages/huggingface_hub/file_download.py:1008, in hf_hub_download(repo_id, filename, subfolder, repo_type, revision, library_name, library_version, cache_dir, local_dir, user_agent, force_download, proxies, etag_timeout, token, local_files_only, headers, endpoint, resume_download, force_filename, local_dir_use_symlinks)
    988     return _hf_hub_download_to_local_dir(
    989         # Destination
    990         local_dir=local_dir,
   (...)
   1005         local_files_only=local_files_only,
   1006     )
   1007 else:
-> 1008     return _hf_hub_download_to_cache_dir(
   1009         # Destination
   1010         cache_dir=cache_dir,
   1011         # File info
   1012         repo_id=repo_id,
   1013         filename=filename,
   1014         repo_type=repo_type,
   1015         revision=revision,
   1016         # HTTP info
   1017         endpoint=endpoint,
   1018         etag_timeout=etag_timeout,
   1019         headers=hf_headers,
   1020         proxies=proxies,
   1021         token=token,
   1022         # Additional options
   1023         local_files_only=local_files_only,
   1024         force_download=force_download,
   1025     )

File ~/opt/venv/lib/python3.11/site-packages/huggingface_hub/file_download.py:1159, in _hf_hub_download_to_cache_dir(cache_dir, repo_id, filename, repo_type, revision, endpoint, etag_timeout, headers, proxies, token, local_files_only, force_download)
   1157 Path(lock_path).parent.mkdir(parents=True, exist_ok=True)
   1158 with WeakFileLock(lock_path):
-> 1159     _download_to_tmp_and_move(
   1160         incomplete_path=Path(blob_path + ".incomplete"),
   1161         destination_path=Path(blob_path),
   1162         url_to_download=url_to_download,
   1163         proxies=proxies,
   1164         headers=headers,
   1165         expected_size=expected_size,
   1166         filename=filename,
   1167         force_download=force_download,
   1168         etag=etag,
   1169         xet_file_data=xet_file_data,
   1170     )
   1171     if not os.path.exists(pointer_path):
   1172         _create_symlink(blob_path, pointer_path, new_blob=True)

File ~/opt/venv/lib/python3.11/site-packages/huggingface_hub/file_download.py:1723, in _download_to_tmp_and_move(incomplete_path, destination_path, url_to_download, proxies, headers, expected_size, filename, force_download, etag, xet_file_data)
   1716         if xet_file_data is not None:
   1717             logger.warning(
   1718                 "Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. "
   1719                 "Falling back to regular HTTP download. "
   1720                 "For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`"
   1721             )
-> 1723         http_get(
   1724             url_to_download,
   1725             f,
   1726             proxies=proxies,
   1727             resume_size=resume_size,
   1728             headers=headers,
   1729             expected_size=expected_size,
   1730         )
   1732 logger.info(f"Download complete. Moving file to {destination_path}")
   1733 _chmod_and_move(incomplete_path, destination_path)

File ~/opt/venv/lib/python3.11/site-packages/huggingface_hub/file_download.py:420, in http_get(url, temp_file, proxies, resume_size, headers, expected_size, displayed_filename, _nb_retries, _tqdm_bar)
    414     else:
    415         raise ValueError(
    416             "The file is too large to be downloaded using the regular download method. Use `hf_transfer` or `hf_xet` instead."
    417             " Try `pip install hf_transfer` or `pip install hf_xet`."
    418         )
--> 420 r = _request_wrapper(
    421     method="GET", url=url, stream=True, proxies=proxies, headers=headers, timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT
    422 )
    424 hf_raise_for_status(r)
    425 content_length = _get_file_length_from_http_response(r)

File ~/opt/venv/lib/python3.11/site-packages/huggingface_hub/file_download.py:309, in _request_wrapper(method, url, follow_relative_redirects, **params)
    306     return response
    308 # Perform request and return if status_code is not in the retry list.
--> 309 response = http_backoff(method=method, url=url, **params, retry_on_exceptions=(), retry_on_status_codes=(429,))
    310 hf_raise_for_status(response)
    311 return response

File ~/opt/venv/lib/python3.11/site-packages/huggingface_hub/utils/_http.py:310, in http_backoff(method, url, max_retries, base_wait_time, max_wait_time, retry_on_exceptions, retry_on_status_codes, **kwargs)
    307     kwargs["data"].seek(io_obj_initial_pos)
    309 # Perform request and return if status_code is not in the retry list.
--> 310 response = session.request(method=method, url=url, **kwargs)
    311 if response.status_code not in retry_on_status_codes:
    312     return response

File ~/opt/venv/lib/python3.11/site-packages/requests/sessions.py:589, in Session.request(self, method, url, params, data, headers, cookies, files, auth, timeout, allow_redirects, proxies, hooks, stream, verify, cert, json)
    584 send_kwargs = {
    585     "timeout": timeout,
    586     "allow_redirects": allow_redirects,
    587 }
    588 send_kwargs.update(settings)
--> 589 resp = self.send(prep, **send_kwargs)
    591 return resp

File ~/opt/venv/lib/python3.11/site-packages/requests/sessions.py:703, in Session.send(self, request, **kwargs)
    700 start = preferred_clock()
    702 # Send the request
--> 703 r = adapter.send(request, **kwargs)
    705 # Total elapsed time of the request (approximately)
    706 elapsed = preferred_clock() - start

File ~/opt/venv/lib/python3.11/site-packages/huggingface_hub/utils/_http.py:96, in UniqueRequestIdAdapter.send(self, request, *args, **kwargs)
     94     logger.debug(f"Send: {_curlify(request)}")
     95 try:
---> 96     return super().send(request, *args, **kwargs)
     97 except requests.RequestException as e:
     98     request_id = request.headers.get(X_AMZN_TRACE_ID)

File ~/opt/venv/lib/python3.11/site-packages/requests/adapters.py:713, in HTTPAdapter.send(self, request, stream, timeout, verify, cert, proxies)
    711     raise SSLError(e, request=request)
    712 elif isinstance(e, ReadTimeoutError):
--> 713     raise ReadTimeout(e, request=request)
    714 elif isinstance(e, _InvalidHeader):
    715     raise InvalidHeader(e, request=request)

ReadTimeout: (ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: 82232ea3-d120-436f-beb0-53c82b36649a)')

Evaluate on Test Dataset

You can evaluate the macther on the test dataset to see how it performs with the roc_auc score:

score = predictor.evaluate(snli_test)
print("evaluation score: ", score)

Predict on a New Sentence Pair

We create a new sentence pair with similar meaning (expected to be predicted as \(1\)) and make predictions using the trained model.

pred_data = pd.DataFrame.from_dict({"premise":["The teacher gave his speech to an empty room."], 
                                    "hypothesis":["There was almost nobody when the professor was talking."]})

predictions = predictor.predict(pred_data)
print('Predicted entities:', predictions[0])

Predict Matching Probabilities

We can also compute the matching probabilities of sentence pairs.

probabilities = predictor.predict_proba(pred_data)
print(probabilities)

Extract Embeddings

Moreover, we support extracting embeddings separately for two sentence groups.

embeddings_1 = predictor.extract_embedding({"premise":["The teacher gave his speech to an empty room."]})
print(embeddings_1.shape)
embeddings_2 = predictor.extract_embedding({"hypothesis":["There was almost nobody when the professor was talking."]})
print(embeddings_2.shape)

Other Examples

You may go to AutoMM Examples to explore other examples about AutoMM.

Customization

To learn how to customize AutoMM, please refer to Customize AutoMM.