"""Implementation of the multimodal predictor"""
from __future__ import annotations
import json
import logging
import os
import warnings
from typing import Dict, List, Optional, Union
import pandas as pd
import transformers
from autogluon.common.utils.log_utils import set_logger_verbosity, verbosity2loglevel
from autogluon.core.metrics import Scorer
from .constants import AUTOMM_TUTORIAL_MODE, NER, OBJECT_DETECTION, SEMANTIC_SEGMENTATION
from .learners import BaseLearner, MultiModalMatcher, NERLearner, ObjectDetectionLearner, SemanticSegmentationLearner
from .problem_types import PROBLEM_TYPES_REG
from .utils import get_dir_ckpt_paths, handle_deprecated_args
pl_logger = logging.getLogger("lightning")
pl_logger.propagate = False # https://github.com/Lightning-AI/lightning/issues/4621
logger = logging.getLogger(__name__)
[docs]class MultiModalPredictor:
"""
MultiModalPredictor is a deep learning "model zoo" of model zoos. It can automatically build deep learning models that
are suitable for multimodal datasets. You will only need to preprocess the data in the multimodal dataframe format
and the MultiModalPredictor can predict the values of one column conditioned on the features from the other columns.
The prediction can be either classification or regression. The feature columns can contain
image paths, text, numerical, and categorical values.
"""
[docs] def __init__(
self,
label: Optional[str] = None,
problem_type: Optional[str] = None,
query: Optional[Union[str, List[str]]] = None,
response: Optional[Union[str, List[str]]] = None,
match_label: Optional[Union[int, str]] = None,
pipeline: Optional[str] = None,
presets: Optional[str] = None,
eval_metric: Optional[Union[str, Scorer]] = None,
hyperparameters: Optional[dict] = None,
path: Optional[str] = None,
verbosity: Optional[int] = 2,
num_classes: Optional[int] = None, # TODO: can we infer this from data?
classes: Optional[list] = None,
warn_if_exist: Optional[bool] = True,
enable_progress_bar: Optional[bool] = None,
init_scratch: Optional[bool] = False,
pretrained: Optional[bool] = True,
validation_metric: Optional[str] = None,
sample_data_path: Optional[str] = None,
):
"""
Parameters
----------
label
Name of the column that contains the target variable to predict.
problem_type
Type of the prediction problem. We support standard problems like
- 'binary': Binary classification
- 'multiclass': Multi-class classification
- 'regression': Regression
- 'classification': Classification problems include 'binary' and 'multiclass' classification.
In addition, we support advanced problems such as
- 'object_detection': Object detection
- 'open_vocabulry_object_detection': Zero-shot object detection (only support inference for now, finetuning TBC)
- 'ner' or 'named_entity_recognition': Named entity extraction
- 'text_similarity': Text-text similarity problem
- 'image_similarity': Image-image similarity problem
- 'image_text_similarity': Text-image similarity problem
- 'feature_extraction': Extracting feature (only support inference)
- 'zero_shot_image_classification': Zero-shot image classification (only support inference)
- 'few_shot_text_classification': (experimental) Few-shot text classification
For certain problem types, the default behavior is to load a pretrained model based on
the presets / hyperparameters and the predictor will support zero-shot inference
(running inference without .fit()). This includes the following
problem types:
- 'object_detection'
- 'open_vocabulry_object_detection'
- 'text_similarity'
- 'image_similarity'
- 'image_text_similarity'
- 'feature_extraction'
- 'zero_shot_image_classification'
- 'few_shot_text_classification' (experimental)
query
Column names of query data (used for matching).
response
Column names of response data (used for matching). If no label column is provided,
query and response columns form positive pairs.
match_label
The label class that indicates the <query, response> pair is counted as "match".
This is used when the problem_type is one of the matching problem types, and when the labels are binary.
For example, the label column can contain ["duplicate", "not duplicate"]. And match_label can be "duplicate".
If match_label is not provided, every sample is assumed to have a unique label.
pipeline
Pipeline has been deprecated and merged in problem_type.
presets
Presets regarding model quality, e.g., best_quality, high_quality, and medium_quality.
eval_metric
Evaluation metric name. If `eval_metric = None`, it is automatically chosen based on `problem_type`.
Defaults to 'accuracy' for multiclass classification, `roc_auc` for binary classification, and 'root_mean_squared_error' for regression.
hyperparameters
This is to override some default configurations.
For example, changing the text and image backbones can be done by formatting:
a string
hyperparameters = "model.hf_text.checkpoint_name=google/electra-small-discriminator model.timm_image.checkpoint_name=swin_small_patch4_window7_224"
or a list of strings
hyperparameters = ["model.hf_text.checkpoint_name=google/electra-small-discriminator", "model.timm_image.checkpoint_name=swin_small_patch4_window7_224"]
or a dictionary
hyperparameters = {
"model.hf_text.checkpoint_name": "google/electra-small-discriminator",
"model.timm_image.checkpoint_name": "swin_small_patch4_window7_224",
}
path
Path to directory where models and intermediate outputs should be saved.
If unspecified, a time-stamped folder called "AutogluonAutoMM/ag-[TIMESTAMP]"
will be created in the working directory to store all models.
Note: To call `fit()` twice and save all results of each fit,
you must specify different `path` locations or don't specify `path` at all.
Otherwise files from first `fit()` will be overwritten by second `fit()`.
verbosity
Verbosity levels range from 0 to 4 and control how much information is printed.
Higher levels correspond to more detailed print statements (you can set verbosity = 0 to suppress warnings).
If using logging, you can alternatively control amount of information printed via `logger.setLevel(L)`,
where `L` ranges from 0 to 50
(Note: higher values of `L` correspond to fewer print statements, opposite of verbosity levels)
num_classes
Number of classes. Used in classification task.
If this is specified and is different from the pretrained model's output,
the model's head will be changed to have <num_classes> output.
classes
All classes in this dataset.
warn_if_exist
Whether to raise warning if the specified path already exists.
enable_progress_bar
Whether to show progress bar. It will be True by default and will also be
disabled if the environment variable os.environ["AUTOMM_DISABLE_PROGRESS_BAR"] is set.
pretrained
Whether to init model with pretrained weights. If False, it creates a model with random initialization.
validation_metric
Validation metric name. If `validation_metric = None`, it is automatically chosen based on `problem_type`.
Defaults to 'accuracy' for multiclass classification, `roc_auc` for binary classification, and 'root_mean_squared_error' for regression.
sample_data_path
This is used for automatically inference num_classes, classes, or label.
"""
problem_type, pretrained = handle_deprecated_args(
init_scratch=init_scratch,
pipeline=pipeline,
problem_type=problem_type,
pretrained=pretrained,
)
if problem_type is not None:
assert problem_type in PROBLEM_TYPES_REG, (
f"problem_type='{problem_type}' is not supported yet. You may pick a problem type from"
f" {PROBLEM_TYPES_REG.list_keys()}."
)
problem_property = PROBLEM_TYPES_REG.get(problem_type)
if problem_property.experimental:
warnings.warn(
f"problem_type='{problem_type}' is currently experimental.",
UserWarning,
)
problem_type = problem_property.name
else:
problem_property = None
if os.environ.get(AUTOMM_TUTORIAL_MODE):
enable_progress_bar = False
# Also disable progress bar of transformers package
transformers.logging.disable_progress_bar()
if verbosity is not None:
set_logger_verbosity(verbosity)
self._verbosity = verbosity
if problem_property and problem_property.is_matching:
learner_class = MultiModalMatcher
elif problem_type == OBJECT_DETECTION:
learner_class = ObjectDetectionLearner
elif problem_type == NER:
learner_class = NERLearner
elif problem_type == SEMANTIC_SEGMENTATION:
learner_class = SemanticSegmentationLearner
else:
learner_class = BaseLearner
self._learner = learner_class(
label=label,
problem_type=problem_type,
presets=presets,
eval_metric=eval_metric,
hyperparameters=hyperparameters,
path=path,
verbosity=verbosity,
num_classes=num_classes,
classes=classes,
warn_if_exist=warn_if_exist,
enable_progress_bar=enable_progress_bar,
pretrained=pretrained,
sample_data_path=sample_data_path,
validation_metric=validation_metric,
query=query,
response=response,
match_label=match_label,
)
@property
def path(self):
return self._learner.path
@property
def label(self):
return self._learner.label
@property
def query(self):
return self._learner.query
@property
def response(self):
return self._learner.response
@property
def match_label(self):
return self._learner.match_label
@property
def problem_type(self):
return self._learner.problem_type
@property
def problem_property(self):
return self._learner.problem_property
@property
def column_types(self):
return self._learner.column_types
@property
def eval_metric(self):
return self._learner.eval_metric
@property
def validation_metric(self):
return self._learner.validation_metric
@property
def verbosity(self):
return self._verbosity
@property
def total_parameters(self) -> int:
return self._learner.total_parameters
@property
def trainable_parameters(self) -> int:
return self._learner.trainable_parameters
@property
def model_size(self) -> float:
return self._learner.model_size
@property
def classes(self):
"""
Return the classes of object detection.
"""
return self._learner.classes
@property
def class_labels(self):
"""
The original name of the class labels.
For example, the tabular data may contain classes equal to
"entailment", "contradiction", "neutral". Internally, these will be converted to
0, 1, 2, ...
This function returns the original names of these raw labels.
Returns
-------
List that contain the class names. It will be None if it's not a classification problem.
"""
return self._learner.class_labels
@property
def positive_class(self):
"""
Name of the class label that will be mapped to 1.
This is only meaningful for binary classification problems.
It is useful for computing metrics such as F1 which require a positive and negative class.
You may refer to https://en.wikipedia.org/wiki/F-score for more details.
In binary classification, :class:`MultiModalPredictor.predict_proba(as_multiclass=False)`
returns the estimated probability that each row belongs to the positive class.
Will print a warning and return None if called when `predictor.problem_type != 'binary'`.
Returns
-------
The positive class name in binary classification or None if the problem is not binary classification.
"""
return self._learner.positive_class
# This func is required by the abstract trainer of TabularPredictor.
[docs] def set_verbosity(self, verbosity: int):
"""Set the verbosity level of the log.
Parameters
----------
verbosity
The verbosity level.
0 --> only errors
1 --> only warnings and critical print statements
2 --> key print statements which should be shown by default
3 --> more-detailed printing
4 --> everything
"""
self._verbosity = verbosity
set_logger_verbosity(verbosity)
# TODO: align verbosity2loglevel with https://huggingface.co/docs/transformers/main_classes/logging#transformers.utils.logging.get_verbosity
[docs] def set_num_gpus(self, num_gpus):
self._learner.set_num_gpus(num_gpus)
[docs] def get_num_gpus(self):
self._learner.get_num_gpus()
[docs] def fit(
self,
train_data: Union[pd.DataFrame, str],
presets: Optional[str] = None,
config: Optional[dict] = None,
tuning_data: Optional[Union[pd.DataFrame, str]] = None,
max_num_tuning_data: Optional[int] = None,
id_mappings: Optional[Union[Dict[str, Dict], Dict[str, pd.Series]]] = None,
time_limit: Optional[int] = None,
save_path: Optional[str] = None,
hyperparameters: Optional[Union[str, Dict, List[str]]] = None,
column_types: Optional[dict] = None,
holdout_frac: Optional[float] = None,
teacher_predictor: Union[str, MultiModalPredictor] = None,
seed: Optional[int] = 0,
standalone: Optional[bool] = True,
hyperparameter_tune_kwargs: Optional[dict] = None,
clean_ckpts: Optional[bool] = True,
):
"""
Fit MultiModalPredictor predict label column of a dataframe based on the other columns,
which may contain image path, text, numeric, or categorical features.
Parameters
----------
train_data
A dataframe containing training data.
presets
Presets regarding model quality, e.g., best_quality, high_quality, and medium_quality.
config
A dictionary with four keys "model", "data", "optimization", and "environment".
Each key's value can be a string, yaml file path, or OmegaConf's DictConfig.
Strings should be the file names (DO NOT include the postfix ".yaml") in
automm/configs/model, automm/configs/data, automm/configs/optimization, and automm/configs/environment.
For example, you can configure a late-fusion model for the image, text, and tabular data as follows:
config = {
"model": "default",
"data": "default",
"optimization": "default",
"environment": "default",
}
or
config = {
"model": "/path/to/model/config.yaml",
"data": "/path/to/data/config.yaml",
"optimization": "/path/to/optimization/config.yaml",
"environment": "/path/to/environment/config.yaml",
}
or
config = {
"model": OmegaConf.load("/path/to/model/config.yaml"),
"data": OmegaConf.load("/path/to/data/config.yaml"),
"optimization": OmegaConf.load("/path/to/optimization/config.yaml"),
"environment": OmegaConf.load("/path/to/environment/config.yaml"),
}
tuning_data
A dataframe containing validation data, which should have the same columns as the train_data.
If `tuning_data = None`, `fit()` will automatically
hold out some random validation examples from `train_data`.
max_num_tuning_data
The maximum number of tuning samples, which is only used in object detection.
id_mappings
Id-to-content mappings. The contents can be text, image, etc.
This is used when the dataframe contains the query/response identifiers instead of their contents.
time_limit
How long `fit()` should run for (wall clock time in seconds).
If not specified, `fit()` will run until the model has completed training.
save_path
Path to directory where models and intermediate outputs should be saved.
hyperparameters
This is to override some default configurations.
For example, changing the text and image backbones can be done by formatting:
a string
hyperparameters = "model.hf_text.checkpoint_name=google/electra-small-discriminator model.timm_image.checkpoint_name=swin_small_patch4_window7_224"
or a list of strings
hyperparameters = ["model.hf_text.checkpoint_name=google/electra-small-discriminator", "model.timm_image.checkpoint_name=swin_small_patch4_window7_224"]
or a dictionary
hyperparameters = {
"model.hf_text.checkpoint_name": "google/electra-small-discriminator",
"model.timm_image.checkpoint_name": "swin_small_patch4_window7_224",
}
column_types
A dictionary that maps column names to their data types.
For example: `column_types = {"item_name": "text", "image": "image_path",
"product_description": "text", "height": "numerical"}`
may be used for a table with columns: "item_name", "brand", "product_description", and "height".
If None, column_types will be automatically inferred from the data.
The current supported types are:
- "image_path": each row in this column is one image path.
- "text": each row in this column contains text (sentence, paragraph, etc.).
- "numerical": each row in this column contains a number.
- "categorical": each row in this column belongs to one of K categories.
holdout_frac
Fraction of train_data to holdout as tuning_data for optimizing hyper-parameters or
early stopping (ignored unless `tuning_data = None`).
Default value (if None) is selected based on the number of rows in the training data
and whether hyper-parameter-tuning is utilized.
teacher_predictor
The pre-trained teacher predictor or its saved path. If provided, `fit()` can distill its
knowledge to a student predictor, i.e., the current predictor.
seed
The random seed to use for this training run.
Defaults to 0
standalone
Whether to save the enire model for offline deployment or only trained parameters of parameter-efficient fine-tuning strategy.
hyperparameter_tune_kwargs
Hyperparameter tuning strategy and kwargs (for example, how many HPO trials to run).
If None, then hyperparameter tuning will not be performed.
num_trials: int
How many HPO trials to run. Either `num_trials` or `time_limit` to `fit` needs to be specified.
scheduler: Union[str, ray.tune.schedulers.TrialScheduler]
If str is passed, AutoGluon will create the scheduler for you with some default parameters.
If ray.tune.schedulers.TrialScheduler object is passed, you are responsible for initializing the object.
scheduler_init_args: Optional[dict] = None
If provided str to `scheduler`, you can optionally provide custom init_args to the scheduler
searcher: Union[str, ray.tune.search.SearchAlgorithm, ray.tune.search.Searcher]
If str is passed, AutoGluon will create the searcher for you with some default parameters.
If ray.tune.schedulers.TrialScheduler object is passed, you are responsible for initializing the object.
You don't need to worry about `metric` and `mode` of the searcher object. AutoGluon will figure it out by itself.
scheduler_init_args: Optional[dict] = None
If provided str to `searcher`, you can optionally provide custom init_args to the searcher
You don't need to worry about `metric` and `mode`. AutoGluon will figure it out by itself.
clean_ckpts
Whether to clean the checkpoints of each validation step after training.
Returns
-------
An "MultiModalPredictor" object (itself).
"""
if teacher_predictor is None:
teacher_learner = None
elif isinstance(teacher_predictor, str):
teacher_learner = teacher_predictor
else:
teacher_learner = teacher_predictor._learner
self._learner.fit(
train_data=train_data,
presets=presets,
config=config,
tuning_data=tuning_data,
max_num_tuning_data=max_num_tuning_data,
time_limit=time_limit,
save_path=save_path,
hyperparameters=hyperparameters,
column_types=column_types,
holdout_frac=holdout_frac,
teacher_learner=teacher_learner,
seed=seed,
standalone=standalone,
hyperparameter_tune_kwargs=hyperparameter_tune_kwargs,
clean_ckpts=clean_ckpts,
id_mappings=id_mappings,
)
return self
[docs] def evaluate(
self,
data: Union[pd.DataFrame, dict, list, str],
query_data: Optional[list] = None,
response_data: Optional[list] = None,
id_mappings: Optional[Union[Dict[str, Dict], Dict[str, pd.Series]]] = None,
metrics: Optional[Union[str, List[str]]] = None,
chunk_size: Optional[int] = 1024,
similarity_type: Optional[str] = "cosine",
cutoffs: Optional[List[int]] = [1, 5, 10],
label: Optional[str] = None,
return_pred: Optional[bool] = False,
realtime: Optional[bool] = None,
eval_tool: Optional[str] = None,
):
"""
Evaluate model on a test dataset.
Parameters
----------
data
A dataframe, containing the same columns as the training data.
Or a str, that is a path of the annotation file for detection.
query_data
Query data used for ranking.
response_data
Response data used for ranking.
id_mappings
Id-to-content mappings. The contents can be text, image, etc.
This is used when data/query_data/response_data contain the query/response identifiers instead of their contents.
metrics
A list of metric names to report.
If None, we only return the score for the stored `_eval_metric_name`.
chunk_size
Scan the response data by chunk_size each time. Increasing the value increases the speed, but requires more memory.
similarity_type
Use what function (cosine/dot_prod) to score the similarity (default: cosine).
cutoffs
A list of cutoff values to evaluate ranking.
label
The label column name in data. Some tasks, e.g., image<-->text matching, have no label column in training data,
but the label column may be still required in evaluation.
return_pred
Whether to return the prediction result of each row.
realtime
Whether to do realtime inference, which is efficient for small data (default None).
If not specified, we would infer it on based on the data modalities
and sample number.
eval_tool
The eval_tool for object detection. Could be "pycocotools" or "torchmetrics".
Returns
-------
A dictionary with the metric names and their corresponding scores.
Optionally return a dataframe of prediction results.
"""
return self._learner.evaluate(
data=data,
metrics=metrics,
return_pred=return_pred,
realtime=realtime,
eval_tool=eval_tool,
query_data=query_data,
response_data=response_data,
id_mappings=id_mappings,
chunk_size=chunk_size,
similarity_type=similarity_type,
cutoffs=cutoffs,
label=label,
)
[docs] def predict(
self,
data: Union[pd.DataFrame, dict, list, str],
candidate_data: Optional[Union[pd.DataFrame, dict, list]] = None,
id_mappings: Optional[Union[Dict[str, Dict], Dict[str, pd.Series]]] = None,
as_pandas: Optional[bool] = None,
realtime: Optional[bool] = None,
save_results: Optional[bool] = None,
):
"""
Predict values for the label column of new data.
Parameters
----------
data
The data to make predictions for. Should contain same column names as training data and
follow same format (except for the `label` column).
candidate_data
The candidate data from which to search the query data's matches.
id_mappings
Id-to-content mappings. The contents can be text, image, etc.
This is used when data contain the query/response identifiers instead of their contents.
as_pandas
Whether to return the output as a pandas DataFrame(Series) (True) or numpy array (False).
realtime
Whether to do realtime inference, which is efficient for small data (default None).
If not specified, we would infer it on based on the data modalities
and sample number.
save_results
Whether to save the prediction results (only works for detection now)
Returns
-------
Array of predictions, one corresponding to each row in given dataset.
"""
return self._learner.predict(
data=data,
candidate_data=candidate_data,
as_pandas=as_pandas,
realtime=realtime,
save_results=save_results,
id_mappings=id_mappings,
)
[docs] def predict_proba(
self,
data: Union[pd.DataFrame, dict, list],
candidate_data: Optional[Union[pd.DataFrame, dict, list]] = None,
id_mappings: Optional[Union[Dict[str, Dict], Dict[str, pd.Series]]] = None,
as_pandas: Optional[bool] = None,
as_multiclass: Optional[bool] = True,
realtime: Optional[bool] = None,
):
"""
Predict probabilities class probabilities rather than class labels.
This is only for the classification tasks. Calling it for a regression task will throw an exception.
Parameters
----------
data
The data to make predictions for. Should contain same column names as training data and
follow same format (except for the `label` column).
candidate_data
The candidate data from which to search the query data's matches.
id_mappings
Id-to-content mappings. The contents can be text, image, etc.
This is used when data contain the query/response identifiers instead of their contents.
as_pandas
Whether to return the output as a pandas DataFrame(Series) (True) or numpy array (False).
as_multiclass
Whether to return the probability of all labels or
just return the probability of the positive class for binary classification problems.
realtime
Whether to do realtime inference, which is efficient for small data (default None).
If not specified, we would infer it on based on the data modalities
and sample number.
Returns
-------
Array of predicted class-probabilities, corresponding to each row in the given data.
When as_multiclass is True, the output will always have shape (#samples, #classes).
Otherwise, the output will have shape (#samples,)
"""
return self._learner.predict_proba(
data=data,
candidate_data=candidate_data,
as_pandas=as_pandas,
as_multiclass=as_multiclass,
realtime=realtime,
id_mappings=id_mappings,
)
[docs] def save(self, path: str, standalone: Optional[bool] = True):
"""
Save this predictor to file in directory specified by `path`.
Parameters
----------
path
The directory to save this predictor.
standalone
Whether to save the downloaded model for offline deployment.
When standalone = True, save the transformers.CLIPModel and transformers.AutoModel to os.path.join(path,model_name),
and reset the associate model.model_name.checkpoint_name start with `local://` in config.yaml.
When standalone = False, the saved artifact may require an online environment to process in load().
"""
self._learner.save(path=path, standalone=standalone)
[docs] @classmethod
def load(
cls,
path: str,
resume: Optional[bool] = False,
verbosity: Optional[int] = 3,
):
"""
Load a predictor object from a directory specified by `path`. The to-be-loaded predictor
can be completely or partially trained by .fit(). If a previous training has completed,
it will load the checkpoint `model.ckpt`. Otherwise, if a previous training accidentally
collapses in the middle, it can load the `last.ckpt` checkpoint by setting `resume=True`.
It also supports loading one specific checkpoint given its path.
Parameters
----------
path
The directory to load the predictor object.
resume
Whether to resume training from `last.ckpt`. This is useful when a training was accidentally
broken during the middle, and we want to resume the training from the last saved checkpoint.
verbosity
Verbosity levels range from 0 to 4 and control how much information is printed.
Higher levels correspond to more detailed print statements (you can set verbosity = 0 to suppress warnings).
Returns
-------
The loaded predictor object.
"""
dir_path, ckpt_path = get_dir_ckpt_paths(path=path)
assert os.path.isdir(dir_path), f"'{dir_path}' must be an existing directory."
predictor = cls(label="dummy_label")
with open(os.path.join(dir_path, "assets.json"), "r") as fp:
assets = json.load(fp)
if "class_name" in assets and assets["class_name"] == "MultiModalMatcher":
learner_class = MultiModalMatcher
elif assets["problem_type"] == OBJECT_DETECTION:
learner_class = ObjectDetectionLearner
elif assets["problem_type"] == NER:
learner_class = NERLearner
elif assets["problem_type"] == SEMANTIC_SEGMENTATION:
learner_class = SemanticSegmentationLearner
else:
learner_class = BaseLearner
predictor._learner = learner_class.load(path=path, resume=resume, verbosity=verbosity)
return predictor
[docs] def dump_model(self, save_path: Optional[str] = None):
"""
Save model weights and config to local directory.
Model weights are saved in file `pytorch_model.bin` (timm, hf) or '<ckpt_name>.pth' (mmdet);
Configs are saved in file `config.json` (timm, hf) or '<ckpt_name>.py' (mmdet).
Parameters
----------
save_path : str
Path to directory where models and configs should be saved.
"""
return self._learner.dump_model(save_path=save_path)
[docs] def export_onnx(
self,
data: Union[dict, pd.DataFrame],
path: Optional[str] = None,
batch_size: Optional[int] = None,
verbose: Optional[bool] = False,
opset_version: Optional[int] = 16,
truncate_long_and_double: Optional[bool] = False,
):
"""
Export this predictor's model to ONNX file.
When `path` argument is not provided, the method would not save the model into disk.
Instead, it would export the onnx model into BytesIO and return its binary as bytes.
Parameters
----------
data
Raw data used to trace and export the model.
If this is None, will check if a processed batch is provided.
path : str, default=None
The export path of onnx model. If path is not provided, the method would export model to memory.
batch_size
The batch_size of export model's input.
Normally the batch_size is a dynamic axis, so we could use a small value for faster export.
verbose
verbose flag in torch.onnx.export.
opset_version
opset_version flag in torch.onnx.export.
truncate_long_and_double: bool, default False
Truncate weights provided in int64 or double (float64) to int32 and float32
Returns
-------
onnx_path : str or bytes
A string that indicates location of the exported onnx model, if `path` argument is provided.
Otherwise, would return the onnx model as bytes.
"""
return self._learner.export_onnx(
data=data,
path=path,
batch_size=batch_size,
verbose=verbose,
opset_version=opset_version,
truncate_long_and_double=truncate_long_and_double,
)
[docs] def optimize_for_inference(
self,
providers: Optional[Union[dict, List[str]]] = None,
):
"""
Optimize the predictor's model for inference.
Under the hood, the implementation would convert the PyTorch module into an ONNX module, so that
we can leverage efficient execution providers in onnxruntime for faster inference.
Parameters
----------
data
Raw data used to trace and export the model.
If this is None, will check if a processed batch is provided.
providers : dict or str, default=None
A list of execution providers for model prediction in onnxruntime.
By default, the providers argument is None. The method would generate an ONNX module that
would perform model inference with TensorrtExecutionProvider in onnxruntime, if tensorrt
package is properly installed. Otherwise, the onnxruntime would fallback to use CUDA or CPU
execution providers instead.
Returns
-------
onnx_module : OnnxModule
The onnx-based module that can be used to replace predictor._model for model inference.
"""
return self._learner.optimize_for_inference(providers=providers)
[docs] def fit_summary(self, verbosity=0, show_plot=False):
"""
Output summary of information about models produced during `fit()`.
Parameters
----------
verbosity : int, default = 2
Verbosity levels range from 0 to 4 and control how much information is printed.
verbosity = 0 for no output printing.
TODO: Higher levels correspond to more detailed print statements
show_plot : bool, default = False
If True, shows the model summary plot in browser when verbosity > 1.
Returns
-------
Dict containing various detailed information.
We do not recommend directly printing this dict as it may be very large.
"""
return self._learner.fit_summary(verbosity=verbosity, show_plot=show_plot)
[docs] def list_supported_models(self, pretrained=True):
"""
List supported models for each problem_type to let users know
options of checkpoint name to choose during fit().
Parameters
----------
pretrained : bool, default = True
If True, only return the models with pretrained weights.
If False, return all the models as long as there is model definition.
Returns
-------
a list of model names
"""
return self._learner.list_supported_models(pretrained=pretrained)