Text-to-Text Semantic Matching with AutoMM¶
Computing the similarity between two sentences/passages is a common task in NLP, with several practical applications such as web search, question answering, documents deduplication, plagiarism comparison, natural language inference, recommendation engines, etc. In general, text similarity models will take two sentences/passages as input and transform them into vectors, and then similarity scores calculated using cosine similarity, dot product, or Euclidean distances are used to measure how alike or different of the two text pieces.
Prepare your Data¶
In this tutorial, we will demonstrate how to use AutoMM for text-to-text semantic matching with the Stanford Natural Language Inference (SNLI) corpus. SNLI is a corpus contains around 570k human-written sentence pairs labeled with entailment, contradiction, and neutral. It is a widely used benchmark for evaluating the representation and inference capbility of machine learning methods. The following table contains three examples taken from this corpus.
Premise |
Hypothesis |
Label |
---|---|---|
A black race car starts up in front of a crowd of people. |
A man is driving down a lonely road. |
contradiction |
An older and younger man smiling. |
Two men are smiling and laughing at the cats playing on the floor. |
neutral |
A soccer game with multiple males playing. |
Some men are playing a sport. |
entailment |
Here, we consider sentence pairs with label entailment as positive pairs (labeled as 1) and those with label contradiction as negative pairs (labeled as 0). Sentence pairs with neural relationship are discarded. The following code downloads and loads the corpus into dataframes.
from autogluon.core.utils.loaders import load_pd
import pandas as pd
snli_train = load_pd.load('https://automl-mm-bench.s3.amazonaws.com/snli/snli_train.csv', delimiter="|")
snli_test = load_pd.load('https://automl-mm-bench.s3.amazonaws.com/snli/snli_test.csv', delimiter="|")
snli_train.head()
premise | hypothesis | label | |
---|---|---|---|
0 | A person on a horse jumps over a broken down a... | A person is at a diner , ordering an omelette . | 0 |
1 | A person on a horse jumps over a broken down a... | A person is outdoors , on a horse . | 1 |
2 | Children smiling and waving at camera | There are children present | 1 |
3 | Children smiling and waving at camera | The kids are frowning | 0 |
4 | A boy is jumping on skateboard in the middle o... | The boy skates down the sidewalk . | 0 |
Train your Model¶
Ideally, we want to obtain a model that can return high/low scores for positive/negative text pairs. Traditional text similarity methods only work on a lexical level without taking the semantic aspect into account, for example, using term frequency or tf-idf vectors. With AutoMM, we can easily train a model that captures the semantic relationship between sentences. Basically, it uses BERT to project each sentence into a high-dimensional vector and treat the matching problem as a classification problem following the design in sentence transformers.
With AutoMM, you just need to specify the query, response, and label column names and fit the model on the training dataset without worrying the implementation details. Note that the labels should be binary, and we need to specify the match_label
, which means two sentences have the same semantic meaning. In practice, your tasks may have different labels, e.g., duplicate or not duplicate. You may need to define the match_label
by considering your specific task contexts.
from autogluon.multimodal import MultiModalPredictor
# Initialize the model
predictor = MultiModalPredictor(
problem_type="text_similarity",
query="premise", # the column name of the first sentence
response="hypothesis", # the column name of the second sentence
label="label", # the label column name
match_label=1, # the label indicating that query and response have the same semantic meanings.
eval_metric='auc', # the evaluation metric
)
# Fit the model
predictor.fit(
train_data=snli_train,
time_limit=180,
)
No path specified. Models will be saved in: "AutogluonModels/ag-20250515_220300"
=================== System Info ===================
AutoGluon Version: 1.3.1b20250515
Python Version: 3.11.9
Operating System: Linux
Platform Machine: x86_64
Platform Version: #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count: 8
Pytorch Version: 2.6.0+cu124
CUDA Version: 12.4
Memory Avail: 28.17 GB / 30.95 GB (91.0%)
Disk Space Avail: 168.81 GB / 255.99 GB (65.9%)
===================================================
AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed).
2 unique label values: [np.int64(0), np.int64(1)]
If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during Predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression', 'quantile'])
/home/ci/autogluon/multimodal/src/autogluon/multimodal/optim/metrics/utils.py:185: UserWarning: Metric auc is not supported as the evaluation metric for binary in matching tasks.The evaluation metric is changed to roc_auc by default.
warnings.warn(
AutoMM starts to create your model. ✨✨✨
To track the learning progress, you can open a terminal and launch Tensorboard:
```shell
# Assume you have installed tensorboard
tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/semantic_matching/AutogluonModels/ag-20250515_220300
```
Seed set to 0
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
---------------------------------------------------------------------------
TimeoutError Traceback (most recent call last)
File ~/opt/venv/lib/python3.11/site-packages/urllib3/connectionpool.py:534, in HTTPConnectionPool._make_request(self, conn, method, url, body, headers, retries, timeout, chunked, response_conn, preload_content, decode_content, enforce_content_length)
533 try:
--> 534 response = conn.getresponse()
535 except (BaseSSLError, OSError) as e:
File ~/opt/venv/lib/python3.11/site-packages/urllib3/connection.py:516, in HTTPConnection.getresponse(self)
515 # Get the response from http.client.HTTPConnection
--> 516 httplib_response = super().getresponse()
518 try:
File /opt/conda/lib/python3.11/http/client.py:1395, in HTTPConnection.getresponse(self)
1394 try:
-> 1395 response.begin()
1396 except ConnectionError:
File /opt/conda/lib/python3.11/http/client.py:325, in HTTPResponse.begin(self)
324 while True:
--> 325 version, status, reason = self._read_status()
326 if status != CONTINUE:
File /opt/conda/lib/python3.11/http/client.py:286, in HTTPResponse._read_status(self)
285 def _read_status(self):
--> 286 line = str(self.fp.readline(_MAXLINE + 1), "iso-8859-1")
287 if len(line) > _MAXLINE:
File /opt/conda/lib/python3.11/socket.py:706, in SocketIO.readinto(self, b)
705 try:
--> 706 return self._sock.recv_into(b)
707 except timeout:
File /opt/conda/lib/python3.11/ssl.py:1314, in SSLSocket.recv_into(self, buffer, nbytes, flags)
1311 raise ValueError(
1312 "non-zero flags not allowed in calls to recv_into() on %s" %
1313 self.__class__)
-> 1314 return self.read(nbytes, buffer)
1315 else:
File /opt/conda/lib/python3.11/ssl.py:1166, in SSLSocket.read(self, len, buffer)
1165 if buffer is not None:
-> 1166 return self._sslobj.read(len, buffer)
1167 else:
TimeoutError: The read operation timed out
The above exception was the direct cause of the following exception:
ReadTimeoutError Traceback (most recent call last)
File ~/opt/venv/lib/python3.11/site-packages/requests/adapters.py:667, in HTTPAdapter.send(self, request, stream, timeout, verify, cert, proxies)
666 try:
--> 667 resp = conn.urlopen(
668 method=request.method,
669 url=url,
670 body=request.body,
671 headers=request.headers,
672 redirect=False,
673 assert_same_host=False,
674 preload_content=False,
675 decode_content=False,
676 retries=self.max_retries,
677 timeout=timeout,
678 chunked=chunked,
679 )
681 except (ProtocolError, OSError) as err:
File ~/opt/venv/lib/python3.11/site-packages/urllib3/connectionpool.py:841, in HTTPConnectionPool.urlopen(self, method, url, body, headers, retries, redirect, assert_same_host, timeout, pool_timeout, release_conn, chunked, body_pos, preload_content, decode_content, **response_kw)
839 new_e = ProtocolError("Connection aborted.", new_e)
--> 841 retries = retries.increment(
842 method, url, error=new_e, _pool=self, _stacktrace=sys.exc_info()[2]
843 )
844 retries.sleep()
File ~/opt/venv/lib/python3.11/site-packages/urllib3/util/retry.py:474, in Retry.increment(self, method, url, response, error, _pool, _stacktrace)
473 if read is False or method is None or not self._is_method_retryable(method):
--> 474 raise reraise(type(error), error, _stacktrace)
475 elif read is not None:
File ~/opt/venv/lib/python3.11/site-packages/urllib3/util/util.py:39, in reraise(tp, value, tb)
38 raise value.with_traceback(tb)
---> 39 raise value
40 finally:
File ~/opt/venv/lib/python3.11/site-packages/urllib3/connectionpool.py:787, in HTTPConnectionPool.urlopen(self, method, url, body, headers, retries, redirect, assert_same_host, timeout, pool_timeout, release_conn, chunked, body_pos, preload_content, decode_content, **response_kw)
786 # Make the request on the HTTPConnection object
--> 787 response = self._make_request(
788 conn,
789 method,
790 url,
791 timeout=timeout_obj,
792 body=body,
793 headers=headers,
794 chunked=chunked,
795 retries=retries,
796 response_conn=response_conn,
797 preload_content=preload_content,
798 decode_content=decode_content,
799 **response_kw,
800 )
802 # Everything went great!
File ~/opt/venv/lib/python3.11/site-packages/urllib3/connectionpool.py:536, in HTTPConnectionPool._make_request(self, conn, method, url, body, headers, retries, timeout, chunked, response_conn, preload_content, decode_content, enforce_content_length)
535 except (BaseSSLError, OSError) as e:
--> 536 self._raise_timeout(err=e, url=url, timeout_value=read_timeout)
537 raise
File ~/opt/venv/lib/python3.11/site-packages/urllib3/connectionpool.py:367, in HTTPConnectionPool._raise_timeout(self, err, url, timeout_value)
366 if isinstance(err, SocketTimeout):
--> 367 raise ReadTimeoutError(
368 self, url, f"Read timed out. (read timeout={timeout_value})"
369 ) from err
371 # See the above comment about EAGAIN in Python 3.
ReadTimeoutError: HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)
During handling of the above exception, another exception occurred:
ReadTimeout Traceback (most recent call last)
Cell In[3], line 14
4 predictor = MultiModalPredictor(
5 problem_type="text_similarity",
6 query="premise", # the column name of the first sentence
(...)
10 eval_metric='auc', # the evaluation metric
11 )
13 # Fit the model
---> 14 predictor.fit(
15 train_data=snli_train,
16 time_limit=180,
17 )
File ~/autogluon/multimodal/src/autogluon/multimodal/predictor.py:540, in MultiModalPredictor.fit(self, train_data, presets, tuning_data, max_num_tuning_data, id_mappings, time_limit, save_path, hyperparameters, column_types, holdout_frac, teacher_predictor, seed, standalone, hyperparameter_tune_kwargs, clean_ckpts, predictions, labels, predictors)
537 assert isinstance(predictors, list)
538 learners = [ele if isinstance(ele, str) else ele._learner for ele in predictors]
--> 540 self._learner.fit(
541 train_data=train_data,
542 presets=presets,
543 tuning_data=tuning_data,
544 max_num_tuning_data=max_num_tuning_data,
545 time_limit=time_limit,
546 save_path=save_path,
547 hyperparameters=hyperparameters,
548 column_types=column_types,
549 holdout_frac=holdout_frac,
550 teacher_learner=teacher_learner,
551 seed=seed,
552 standalone=standalone,
553 hyperparameter_tune_kwargs=hyperparameter_tune_kwargs,
554 clean_ckpts=clean_ckpts,
555 id_mappings=id_mappings,
556 predictions=predictions,
557 labels=labels,
558 learners=learners,
559 )
561 return self
File ~/autogluon/multimodal/src/autogluon/multimodal/learners/matching.py:569, in MatchingLearner.fit(self, train_data, id_mappings, presets, tuning_data, time_limit, save_path, hyperparameters, column_types, holdout_frac, hyperparameter_tune_kwargs, seed, standalone, clean_ckpts, **kwargs)
561 self.fit_sanity_check()
562 self.prepare_fit_args(
563 time_limit=time_limit,
564 seed=seed,
(...)
567 id_mappings=id_mappings,
568 )
--> 569 self.execute_fit()
570 self.on_fit_end(
571 training_start=training_start,
572 standalone=standalone,
573 clean_ckpts=clean_ckpts,
574 )
575 return self
File ~/autogluon/multimodal/src/autogluon/multimodal/learners/matching.py:438, in MatchingLearner.execute_fit(self)
436 return dict()
437 else:
--> 438 attributes = self.fit_per_run(**self._fit_args)
439 self.update_attributes(**attributes) # only update attributes for non-HPO mode
440 return attributes
File ~/autogluon/multimodal/src/autogluon/multimodal/learners/matching.py:784, in MatchingLearner.fit_per_run(self, id_mappings, max_time, save_path, ckpt_path, resume, enable_progress_bar, seed, config, hyperparameters, advanced_hyperparameters, standalone, clean_ckpts)
781 response_config = select_model(config=response_config, df_preprocessor=response_df_preprocessor, strict=False)
783 if self._query_model is None or self._response_model is None:
--> 784 query_model, response_model = create_siamese_model(
785 query_config=query_config,
786 response_config=response_config,
787 pretrained=self._pretrained,
788 )
789 else: # continuing training
790 query_model = self._query_model
File ~/autogluon/multimodal/src/autogluon/multimodal/utils/matcher.py:256, in create_siamese_model(query_config, response_config, query_model, response_model, pretrained)
237 """
238 Create the query and response models and make them share the same encoders for the same modalities.
239
(...)
253 The query and response models satisfying the siamese constraint.
254 """
255 if query_model is None:
--> 256 single_models, query_fusion_model = create_fusion_model_dict(
257 config=query_config.model,
258 pretrained=pretrained,
259 )
260 else:
261 single_models, query_fusion_model = get_fusion_model_dict(
262 model=query_model,
263 )
File ~/autogluon/multimodal/src/autogluon/multimodal/utils/matcher.py:102, in create_fusion_model_dict(config, single_models, pretrained)
100 if model_name in single_models:
101 continue
--> 102 model = create_model(
103 model_name=model_name,
104 model_config=model_config,
105 pretrained=pretrained,
106 is_matching=True, # clip needs to use this to init attributes for both image and text
107 )
108 if model_name.lower().startswith(FUSION):
109 fusion_model = model
File ~/autogluon/multimodal/src/autogluon/multimodal/models/utils.py:1343, in create_model(model_name, model_config, num_classes, classes, num_numerical_columns, num_categories, numerical_fill_values, pretrained, is_matching)
1340 elif model_name.lower().startswith(HF_TEXT):
1341 from .hf_text import HFAutoModelForTextPrediction
-> 1343 model = HFAutoModelForTextPrediction(
1344 prefix=model_name,
1345 checkpoint_name=model_config.checkpoint_name,
1346 num_classes=num_classes,
1347 pooling_mode=model_config.pooling_mode,
1348 gradient_checkpointing=model_config.gradient_checkpointing,
1349 low_cpu_mem_usage=model_config.low_cpu_mem_usage,
1350 pretrained=pretrained,
1351 tokenizer_name=model_config.tokenizer_name,
1352 max_text_len=model_config.max_text_len,
1353 text_segment_num=model_config.text_segment_num,
1354 use_fast=model_config.use_fast,
1355 )
1356 elif model_name.lower().startswith(T_FEW):
1357 from .t_few import TFewModel
File ~/autogluon/multimodal/src/autogluon/multimodal/models/hf_text.py:107, in HFAutoModelForTextPrediction.__init__(self, prefix, checkpoint_name, num_classes, pooling_mode, gradient_checkpointing, low_cpu_mem_usage, pretrained, tokenizer_name, max_text_len, text_segment_num, use_fast)
103 self.config, self.model = get_hf_config_and_model(
104 checkpoint_name=checkpoint_name, pretrained=pretrained, low_cpu_mem_usage=low_cpu_mem_usage
105 )
106 self.tokenizer_name = tokenizer_name
--> 107 self.tokenizer = get_pretrained_tokenizer(
108 tokenizer_name=self.tokenizer_name,
109 checkpoint_name=self.checkpoint_name,
110 use_fast=use_fast,
111 )
112 self.max_text_len = get_text_token_max_len(
113 provided_max_len=max_text_len,
114 config=self.config,
115 tokenizer=self.tokenizer,
116 checkpoint_name=self.checkpoint_name,
117 )
118 self.text_segment_num = get_text_segment_num(
119 config=self.config,
120 provided_segment_num=text_segment_num,
121 checkpoint_name=self.checkpoint_name,
122 )
File ~/autogluon/multimodal/src/autogluon/multimodal/models/utils.py:925, in get_pretrained_tokenizer(tokenizer_name, checkpoint_name, use_fast, add_prefix_space)
923 tokenizer_class = ALL_TOKENIZERS[tokenizer_name]
924 if add_prefix_space is None:
--> 925 return tokenizer_class.from_pretrained(checkpoint_name, use_fast=use_fast)
926 else:
927 return tokenizer_class.from_pretrained(
928 checkpoint_name, use_fast=use_fast, add_prefix_space=add_prefix_space
929 )
File ~/opt/venv/lib/python3.11/site-packages/transformers/models/auto/tokenization_auto.py:944, in AutoTokenizer.from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs)
940 if tokenizer_class is None:
941 raise ValueError(
942 f"Tokenizer class {tokenizer_class_candidate} does not exist or is not currently imported."
943 )
--> 944 return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
946 # Otherwise we have to be creative.
947 # if model is an encoder decoder, the encoder tokenizer class is used by default
948 if isinstance(config, EncoderDecoderConfig):
File ~/opt/venv/lib/python3.11/site-packages/transformers/tokenization_utils_base.py:2008, in PreTrainedTokenizerBase.from_pretrained(cls, pretrained_model_name_or_path, cache_dir, force_download, local_files_only, token, revision, trust_remote_code, *init_inputs, **kwargs)
2006 resolved_vocab_files[file_id] = download_url(file_path, proxies=proxies)
2007 else:
-> 2008 resolved_vocab_files[file_id] = cached_file(
2009 pretrained_model_name_or_path,
2010 file_path,
2011 cache_dir=cache_dir,
2012 force_download=force_download,
2013 proxies=proxies,
2014 resume_download=resume_download,
2015 local_files_only=local_files_only,
2016 token=token,
2017 user_agent=user_agent,
2018 revision=revision,
2019 subfolder=subfolder,
2020 _raise_exceptions_for_gated_repo=False,
2021 _raise_exceptions_for_missing_entries=False,
2022 _raise_exceptions_for_connection_errors=False,
2023 _commit_hash=commit_hash,
2024 )
2025 commit_hash = extract_commit_hash(resolved_vocab_files[file_id], commit_hash)
2027 if len(unresolved_files) > 0:
File ~/opt/venv/lib/python3.11/site-packages/transformers/utils/hub.py:342, in cached_file(path_or_repo_id, filename, cache_dir, force_download, resume_download, proxies, token, revision, local_files_only, subfolder, repo_type, user_agent, _raise_exceptions_for_gated_repo, _raise_exceptions_for_missing_entries, _raise_exceptions_for_connection_errors, _commit_hash, **deprecated_kwargs)
339 user_agent = http_user_agent(user_agent)
340 try:
341 # Load from URL or cache if already cached
--> 342 resolved_file = hf_hub_download(
343 path_or_repo_id,
344 filename,
345 subfolder=None if len(subfolder) == 0 else subfolder,
346 repo_type=repo_type,
347 revision=revision,
348 cache_dir=cache_dir,
349 user_agent=user_agent,
350 force_download=force_download,
351 proxies=proxies,
352 resume_download=resume_download,
353 token=token,
354 local_files_only=local_files_only,
355 )
356 except GatedRepoError as e:
357 resolved_file = _get_cache_file_to_return(path_or_repo_id, full_filename, cache_dir, revision)
File ~/opt/venv/lib/python3.11/site-packages/huggingface_hub/utils/_validators.py:114, in validate_hf_hub_args.<locals>._inner_fn(*args, **kwargs)
111 if check_use_auth_token:
112 kwargs = smoothly_deprecate_use_auth_token(fn_name=fn.__name__, has_token=has_token, kwargs=kwargs)
--> 114 return fn(*args, **kwargs)
File ~/opt/venv/lib/python3.11/site-packages/huggingface_hub/file_download.py:1008, in hf_hub_download(repo_id, filename, subfolder, repo_type, revision, library_name, library_version, cache_dir, local_dir, user_agent, force_download, proxies, etag_timeout, token, local_files_only, headers, endpoint, resume_download, force_filename, local_dir_use_symlinks)
988 return _hf_hub_download_to_local_dir(
989 # Destination
990 local_dir=local_dir,
(...)
1005 local_files_only=local_files_only,
1006 )
1007 else:
-> 1008 return _hf_hub_download_to_cache_dir(
1009 # Destination
1010 cache_dir=cache_dir,
1011 # File info
1012 repo_id=repo_id,
1013 filename=filename,
1014 repo_type=repo_type,
1015 revision=revision,
1016 # HTTP info
1017 endpoint=endpoint,
1018 etag_timeout=etag_timeout,
1019 headers=hf_headers,
1020 proxies=proxies,
1021 token=token,
1022 # Additional options
1023 local_files_only=local_files_only,
1024 force_download=force_download,
1025 )
File ~/opt/venv/lib/python3.11/site-packages/huggingface_hub/file_download.py:1159, in _hf_hub_download_to_cache_dir(cache_dir, repo_id, filename, repo_type, revision, endpoint, etag_timeout, headers, proxies, token, local_files_only, force_download)
1157 Path(lock_path).parent.mkdir(parents=True, exist_ok=True)
1158 with WeakFileLock(lock_path):
-> 1159 _download_to_tmp_and_move(
1160 incomplete_path=Path(blob_path + ".incomplete"),
1161 destination_path=Path(blob_path),
1162 url_to_download=url_to_download,
1163 proxies=proxies,
1164 headers=headers,
1165 expected_size=expected_size,
1166 filename=filename,
1167 force_download=force_download,
1168 etag=etag,
1169 xet_file_data=xet_file_data,
1170 )
1171 if not os.path.exists(pointer_path):
1172 _create_symlink(blob_path, pointer_path, new_blob=True)
File ~/opt/venv/lib/python3.11/site-packages/huggingface_hub/file_download.py:1723, in _download_to_tmp_and_move(incomplete_path, destination_path, url_to_download, proxies, headers, expected_size, filename, force_download, etag, xet_file_data)
1716 if xet_file_data is not None:
1717 logger.warning(
1718 "Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. "
1719 "Falling back to regular HTTP download. "
1720 "For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`"
1721 )
-> 1723 http_get(
1724 url_to_download,
1725 f,
1726 proxies=proxies,
1727 resume_size=resume_size,
1728 headers=headers,
1729 expected_size=expected_size,
1730 )
1732 logger.info(f"Download complete. Moving file to {destination_path}")
1733 _chmod_and_move(incomplete_path, destination_path)
File ~/opt/venv/lib/python3.11/site-packages/huggingface_hub/file_download.py:420, in http_get(url, temp_file, proxies, resume_size, headers, expected_size, displayed_filename, _nb_retries, _tqdm_bar)
414 else:
415 raise ValueError(
416 "The file is too large to be downloaded using the regular download method. Use `hf_transfer` or `hf_xet` instead."
417 " Try `pip install hf_transfer` or `pip install hf_xet`."
418 )
--> 420 r = _request_wrapper(
421 method="GET", url=url, stream=True, proxies=proxies, headers=headers, timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT
422 )
424 hf_raise_for_status(r)
425 content_length = _get_file_length_from_http_response(r)
File ~/opt/venv/lib/python3.11/site-packages/huggingface_hub/file_download.py:309, in _request_wrapper(method, url, follow_relative_redirects, **params)
306 return response
308 # Perform request and return if status_code is not in the retry list.
--> 309 response = http_backoff(method=method, url=url, **params, retry_on_exceptions=(), retry_on_status_codes=(429,))
310 hf_raise_for_status(response)
311 return response
File ~/opt/venv/lib/python3.11/site-packages/huggingface_hub/utils/_http.py:310, in http_backoff(method, url, max_retries, base_wait_time, max_wait_time, retry_on_exceptions, retry_on_status_codes, **kwargs)
307 kwargs["data"].seek(io_obj_initial_pos)
309 # Perform request and return if status_code is not in the retry list.
--> 310 response = session.request(method=method, url=url, **kwargs)
311 if response.status_code not in retry_on_status_codes:
312 return response
File ~/opt/venv/lib/python3.11/site-packages/requests/sessions.py:589, in Session.request(self, method, url, params, data, headers, cookies, files, auth, timeout, allow_redirects, proxies, hooks, stream, verify, cert, json)
584 send_kwargs = {
585 "timeout": timeout,
586 "allow_redirects": allow_redirects,
587 }
588 send_kwargs.update(settings)
--> 589 resp = self.send(prep, **send_kwargs)
591 return resp
File ~/opt/venv/lib/python3.11/site-packages/requests/sessions.py:703, in Session.send(self, request, **kwargs)
700 start = preferred_clock()
702 # Send the request
--> 703 r = adapter.send(request, **kwargs)
705 # Total elapsed time of the request (approximately)
706 elapsed = preferred_clock() - start
File ~/opt/venv/lib/python3.11/site-packages/huggingface_hub/utils/_http.py:96, in UniqueRequestIdAdapter.send(self, request, *args, **kwargs)
94 logger.debug(f"Send: {_curlify(request)}")
95 try:
---> 96 return super().send(request, *args, **kwargs)
97 except requests.RequestException as e:
98 request_id = request.headers.get(X_AMZN_TRACE_ID)
File ~/opt/venv/lib/python3.11/site-packages/requests/adapters.py:713, in HTTPAdapter.send(self, request, stream, timeout, verify, cert, proxies)
711 raise SSLError(e, request=request)
712 elif isinstance(e, ReadTimeoutError):
--> 713 raise ReadTimeout(e, request=request)
714 elif isinstance(e, _InvalidHeader):
715 raise InvalidHeader(e, request=request)
ReadTimeout: (ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: 82232ea3-d120-436f-beb0-53c82b36649a)')
Evaluate on Test Dataset¶
You can evaluate the macther on the test dataset to see how it performs with the roc_auc score:
score = predictor.evaluate(snli_test)
print("evaluation score: ", score)
Predict on a New Sentence Pair¶
We create a new sentence pair with similar meaning (expected to be predicted as \(1\)) and make predictions using the trained model.
pred_data = pd.DataFrame.from_dict({"premise":["The teacher gave his speech to an empty room."],
"hypothesis":["There was almost nobody when the professor was talking."]})
predictions = predictor.predict(pred_data)
print('Predicted entities:', predictions[0])
Predict Matching Probabilities¶
We can also compute the matching probabilities of sentence pairs.
probabilities = predictor.predict_proba(pred_data)
print(probabilities)
Extract Embeddings¶
Moreover, we support extracting embeddings separately for two sentence groups.
embeddings_1 = predictor.extract_embedding({"premise":["The teacher gave his speech to an empty room."]})
print(embeddings_1.shape)
embeddings_2 = predictor.extract_embedding({"hypothesis":["There was almost nobody when the professor was talking."]})
print(embeddings_2.shape)
Other Examples¶
You may go to AutoMM Examples to explore other examples about AutoMM.
Customization¶
To learn how to customize AutoMM, please refer to Customize AutoMM.