"""Implementation of the multimodal predictor"""
from __future__ import annotations
import json
import logging
import os
import warnings
from typing import Dict, List, Optional, Union
import pandas as pd
import transformers
from autogluon.common.utils.log_utils import set_logger_verbosity, verbosity2loglevel
from autogluon.core.metrics import Scorer
from .constants import AUTOMM_TUTORIAL_MODE, FEW_SHOT_CLASSIFICATION, NER, OBJECT_DETECTION, SEMANTIC_SEGMENTATION
from .learners import (
    BaseLearner,
    FewShotSVMLearner,
    MultiModalMatcher,
    NERLearner,
    ObjectDetectionLearner,
    SemanticSegmentationLearner,
)
from .problem_types import PROBLEM_TYPES_REG
from .utils import get_dir_ckpt_paths
pl_logger = logging.getLogger("lightning")
pl_logger.propagate = False  # https://github.com/Lightning-AI/lightning/issues/4621
logger = logging.getLogger(__name__)
[docs]
class MultiModalPredictor:
    """
    AutoMM is designed to simplify the fine-tuning of foundation models
    for downstream applications with just three lines of code.
    AutoMM seamlessly integrates with popular model zoos such as
    `HuggingFace Transformers <https://github.com/huggingface/transformers>`_,
    `TIMM <https://github.com/huggingface/pytorch-image-models>`_,
    and `MMDetection <https://github.com/open-mmlab/mmdetection>`_,
    accommodating a diverse range of data modalities,
    including image, text, tabular, and document data, whether used individually or in combination.
    It offers support for an array of tasks, encompassing classification, regression,
    object detection, named entity recognition, semantic matching, and image segmentation.
    """
[docs]
    def __init__(
        self,
        label: Optional[str] = None,
        problem_type: Optional[str] = None,
        query: Optional[Union[str, List[str]]] = None,
        response: Optional[Union[str, List[str]]] = None,
        match_label: Optional[Union[int, str]] = None,
        presets: Optional[str] = None,
        eval_metric: Optional[Union[str, Scorer]] = None,
        hyperparameters: Optional[dict] = None,
        path: Optional[str] = None,
        verbosity: Optional[int] = 2,
        num_classes: Optional[int] = None,  # TODO: can we infer this from data?
        classes: Optional[list] = None,
        warn_if_exist: Optional[bool] = True,
        enable_progress_bar: Optional[bool] = None,
        pretrained: Optional[bool] = True,
        validation_metric: Optional[str] = None,
        sample_data_path: Optional[str] = None,
    ):
        """
        Parameters
        ----------
        label
            Name of one pd.DataFrame column that contains the target variable to predict.
        problem_type
            Type of problem. We support standard problems like
            - 'binary': Binary classification
            - 'multiclass': Multi-class classification
            - 'regression': Regression
            - 'classification': Classification problems include 'binary' and 'multiclass' classification.
            In addition, we support advanced problems such as
            - 'object_detection': Object detection
            - 'ner' or 'named_entity_recognition': Named entity extraction
            - 'text_similarity': Text-text semantic matching
            - 'image_similarity': Image-image semantic matching
            - 'image_text_similarity': Text-image semantic matching
            - 'feature_extraction': Extracting feature (only support inference)
            - 'zero_shot_image_classification': Zero-shot image classification (only support inference)
            - 'few_shot_classification': Few-shot classification for image or text data.
            - 'semantic_segmentation': Semantic segmentation with Segment Anything Model.
            For certain problem types, the default behavior is to load a pretrained model based on
            the presets / hyperparameters and the predictor can do zero-shot inference
            (running inference without .fit()). Those include the following
            problem types:
            - 'object_detection'
            - 'text_similarity'
            - 'image_similarity'
            - 'image_text_similarity'
            - 'feature_extraction'
            - 'zero_shot_image_classification'
        query
            Name of one pd.DataFrame column that has the query data in semantic matching tasks.
        response
            Name of one pd.DataFrame column that contains the response data in semantic matching tasks.
            If no label column is provided, the query and response pairs in
            one pd.DataFrame row are assumed to be positive pairs.
        match_label
            The label class that indicates the <query, response> pair is counted as a "match".
            This is used when the task belongs to semantic matching, and the labels are binary.
            For example, the label column can contain ["duplicate", "not duplicate"] in a duplicate detection task.
            The match_label should be "duplicate" since it means that two items match.
        presets
            Presets regarding model quality, e.g., 'best_quality', 'high_quality' (default), and 'medium_quality'.
            Each quality has its corresponding HPO presets: 'best_quality_hpo', 'high_quality_hpo', and 'medium_quality_hpo'.
        eval_metric
            Evaluation metric name. If `eval_metric = None`, it is automatically chosen based on `problem_type`.
            Defaults to 'accuracy' for multiclass classification, `roc_auc` for binary classification,
            and 'root_mean_squared_error' for regression.
        hyperparameters
            This is to override some default configurations.
            For example, changing the text and image backbones can be done by formatting:
            a string
            hyperparameters = "model.hf_text.checkpoint_name=google/electra-small-discriminator model.timm_image.checkpoint_name=swin_small_patch4_window7_224"
            or a list of strings
            hyperparameters = ["model.hf_text.checkpoint_name=google/electra-small-discriminator", "model.timm_image.checkpoint_name=swin_small_patch4_window7_224"]
            or a dictionary
            hyperparameters = {
                            "model.hf_text.checkpoint_name": "google/electra-small-discriminator",
                            "model.timm_image.checkpoint_name": "swin_small_patch4_window7_224",
                        }
        path
            Path to directory where models and related artifacts should be saved.
            If unspecified, a time-stamped folder called "AutogluonAutoMM/ag-[TIMESTAMP]"
            will be created in the working directory.
            Note: To call `fit()` twice and save all results of each fit,
            you must specify different `path` locations or don't specify `path` at all.
        verbosity
            Verbosity levels range from 0 to 4, controlling how much logging information is printed.
            Higher levels correspond to more detailed print statements.
            You can set verbosity = 0 to suppress warnings.
        num_classes
            Number of classes (used for object detection).
            If this is specified and is different from the pretrained model's output shape,
            the model's head will be changed to have <num_classes> output.
        classes
            All the classes (used for object detection).
        warn_if_exist
            Whether to raise warning if the specified path already exists (Default True).
        enable_progress_bar
            Whether to show progress bar (default True). It would be
            disabled if the environment variable os.environ["AUTOMM_DISABLE_PROGRESS_BAR"] is set.
        pretrained
            Whether to initialize the model with pretrained weights (default True).
            If False, it creates a model with random initialization.
        validation_metric
            Validation metric for selecting the best model and early-stopping during training.
            If not provided, it would be automatically chosen based on the problem type.
        sample_data_path
            The path to sample data from which we can infer num_classes or classes used for object detection.
        """
        if problem_type is not None:
            problem_type = problem_type.lower()
            assert problem_type in PROBLEM_TYPES_REG, (
                f"problem_type='{problem_type}' is not supported yet. You may pick a problem type from"
                f" {PROBLEM_TYPES_REG.list_keys()}."
            )
            problem_property = PROBLEM_TYPES_REG.get(problem_type)
            if problem_property.experimental:
                warnings.warn(
                    f"problem_type='{problem_type}' is currently experimental.",
                    UserWarning,
                )
            problem_type = problem_property.name
        else:
            problem_property = None
        if os.environ.get(AUTOMM_TUTORIAL_MODE):
            enable_progress_bar = False
            # Also disable progress bar of transformers package
            transformers.logging.disable_progress_bar()
        if verbosity is not None:
            set_logger_verbosity(verbosity)
        self._verbosity = verbosity
        if problem_property and problem_property.is_matching:
            learner_class = MultiModalMatcher
        elif problem_type == OBJECT_DETECTION:
            learner_class = ObjectDetectionLearner
        elif problem_type == NER:
            learner_class = NERLearner
        elif problem_type == FEW_SHOT_CLASSIFICATION:
            learner_class = FewShotSVMLearner
        elif problem_type == SEMANTIC_SEGMENTATION:
            learner_class = SemanticSegmentationLearner
        else:
            learner_class = BaseLearner
        self._learner = learner_class(
            label=label,
            problem_type=problem_type,
            presets=presets,
            eval_metric=eval_metric,
            hyperparameters=hyperparameters,
            path=path,
            verbosity=verbosity,
            num_classes=num_classes,
            classes=classes,
            warn_if_exist=warn_if_exist,
            enable_progress_bar=enable_progress_bar,
            pretrained=pretrained,
            sample_data_path=sample_data_path,
            validation_metric=validation_metric,
            query=query,
            response=response,
            match_label=match_label,
        ) 
    @property
    def path(self):
        """
        Path to directory where the model and related artifacts are stored.
        """
        return self._learner.path
    @property
    def label(self):
        """
        Name of one pd.DataFrame column that contains the target variable to predict.
        """
        return self._learner.label
    @property
    def query(self):
        """
        Name of one pd.DataFrame column that has the query data in semantic matching tasks.
        """
        return self._learner.query
    @property
    def response(self):
        """
        Name of one pd.DataFrame column that contains the response data in semantic matching tasks.
        """
        return self._learner.response
    @property
    def match_label(self):
        """
        The label class that indicates the <query, response> pair is counted as "match" in the semantic matching tasks.
        """
        return self._learner.match_label
    @property
    def problem_type(self):
        """
        What type of prediction problem this predictor has been trained for.
        """
        return self._learner.problem_type
    @property
    def problem_property(self):
        """
        Property of the problem, storing the problem type and its related properties.
        """
        return self._learner.problem_property
    @property
    def column_types(self):
        """
        Column types in the pd.DataFrame.
        """
        return self._learner.column_types
    @property
    def eval_metric(self):
        """
        What metric is used to evaluate predictive performance.
        """
        return self._learner.eval_metric
    @property
    def validation_metric(self):
        """
        Validation metric for selecting the best model and early-stopping during training.
        Note that the validation metric may be different from the evaluation metric.
        """
        return self._learner.validation_metric
    @property
    def verbosity(self):
        """
        Verbosity levels range from 0 to 4 and control how much information is printed.
        Higher levels correspond to more detailed print statements.
        """
        return self._verbosity
    @property
    def total_parameters(self) -> int:
        """
        The number of model parameters.
        """
        return self._learner.total_parameters
    @property
    def trainable_parameters(self) -> int:
        """
        The number of trainable model parameters, usually those with requires_grad=True.
        """
        return self._learner.trainable_parameters
    @property
    def model_size(self) -> float:
        """
        Returns the model size in Megabyte.
        """
        return self._learner.model_size
    @property
    def classes(self):
        """
        Object classes for the object detection problem type.
        """
        return self._learner.classes
    @property
    def class_labels(self):
        """
        The original name of the class labels.
        For example, the tabular data may contain classes equal to
        "entailment", "contradiction", "neutral". Internally, these will be converted to
        0, 1, 2, ...
        This function returns the original names of these raw labels.
        Returns
        -------
        List that contain the class names. It will be None if it's not a classification problem.
        """
        return self._learner.class_labels
    @property
    def positive_class(self):
        """
        Name of the class label that will be mapped to 1.
        This is only meaningful for binary classification problems.
        It is useful for computing metrics such as F1 which require a positive and negative class.
        You may refer to https://en.wikipedia.org/wiki/F-score for more details.
        In binary classification, :class:`MultiModalPredictor.predict_proba(as_multiclass=False)`
        returns the estimated probability that each row belongs to the positive class.
        Will print a warning and return None if called when `predictor.problem_type != 'binary'`.
        Returns
        -------
        The positive class name in binary classification or None if the problem is not binary classification.
        """
        return self._learner.positive_class
    # This func is required by the abstract trainer of TabularPredictor.
[docs]
    def set_verbosity(self, verbosity: int):
        """Set the verbosity level of the log.
        Parameters
        ----------
        verbosity
            The verbosity level.
            0 --> only errors
            1 --> only warnings and critical print statements
            2 --> key print statements which should be shown by default
            3 --> more-detailed printing
            4 --> everything
        """
        self._verbosity = verbosity
        set_logger_verbosity(verbosity) 
        # TODO: align verbosity2loglevel with https://huggingface.co/docs/transformers/main_classes/logging#transformers.utils.logging.get_verbosity
[docs]
    def set_num_gpus(self, num_gpus):
        """
        Set the number of GPUs in config.
        """
        self._learner.set_num_gpus(num_gpus) 
[docs]
    def get_num_gpus(self):
        """
        Get the number of GPUs from config.
        """
        self._learner.get_num_gpus() 
[docs]
    def fit(
        self,
        train_data: Union[pd.DataFrame, str],
        presets: Optional[str] = None,
        tuning_data: Optional[Union[pd.DataFrame, str]] = None,
        max_num_tuning_data: Optional[int] = None,
        id_mappings: Optional[Union[Dict[str, Dict], Dict[str, pd.Series]]] = None,
        time_limit: Optional[int] = None,
        save_path: Optional[str] = None,
        hyperparameters: Optional[Union[str, Dict, List[str]]] = None,
        column_types: Optional[dict] = None,
        holdout_frac: Optional[float] = None,
        teacher_predictor: Union[str, MultiModalPredictor] = None,
        seed: Optional[int] = 0,
        standalone: Optional[bool] = True,
        hyperparameter_tune_kwargs: Optional[dict] = None,
        clean_ckpts: Optional[bool] = True,
    ):
        """
        Fit models to predict a column of a data table (label) based on the other columns (features).
        Parameters
        ----------
        train_data
            A pd.DataFrame containing training data.
        presets
            Presets regarding model quality, e.g., best_quality, high_quality, and medium_quality.
            Each quality has its corresponding HPO presets: 'best_quality_hpo', 'high_quality_hpo', and 'medium_quality_hpo'.
        tuning_data
            A pd.DataFrame containing validation data, which should have the same columns as the train_data.
            If `tuning_data = None`, `fit()` will automatically hold out some random validation data from `train_data`.
        max_num_tuning_data
            The maximum number of tuning samples (used for object detection).
        id_mappings
             Id-to-content mappings (used for semantic matching). The contents can be text, image, etc.
             This is used when the pd.DataFrame contains the query/response identifiers instead of their contents.
        time_limit
            How long `fit()` should run for (wall clock time in seconds).
            If not specified, `fit()` will run until the model has completed training.
        save_path
            Path to directory where models and artifacts should be saved.
        hyperparameters
            This is to override some default configurations.
            For example, changing the text and image backbones can be done by formatting:
            a string
            hyperparameters = "model.hf_text.checkpoint_name=google/electra-small-discriminator model.timm_image.checkpoint_name=swin_small_patch4_window7_224"
            or a list of strings
            hyperparameters = ["model.hf_text.checkpoint_name=google/electra-small-discriminator", "model.timm_image.checkpoint_name=swin_small_patch4_window7_224"]
            or a dictionary
            hyperparameters = {
                            "model.hf_text.checkpoint_name": "google/electra-small-discriminator",
                            "model.timm_image.checkpoint_name": "swin_small_patch4_window7_224",
                        }
        column_types
            A dictionary that maps column names to their data types.
            For example: `column_types = {"item_name": "text", "image": "image_path",
            "product_description": "text", "height": "numerical"}`
            may be used for a table with columns: "item_name", "brand", "product_description", and "height".
            If None, column_types will be automatically inferred from the data.
            The current supported types are:
                - "image_path": each row in this column is one image path.
                - "text": each row in this column contains text (sentence, paragraph, etc.).
                - "numerical": each row in this column contains a number.
                - "categorical": each row in this column belongs to one of K categories.
        holdout_frac
            Fraction of train_data to holdout as tuning_data for optimizing hyperparameters or
            early stopping (ignored unless `tuning_data = None`).
            Default value (if None) is selected based on the number of rows in the training data
            and whether hyperparameter optimization is utilized.
        teacher_predictor
            The pre-trained teacher predictor or its saved path. If provided, `fit()` can distill its
            knowledge to a student predictor, i.e., the current predictor.
        seed
            The random seed to be used for training (default 0).
        standalone
            Whether to save the entire model for offline deployment.
        hyperparameter_tune_kwargs
                Hyperparameter tuning strategy and kwargs (for example, how many HPO trials to run).
                If None, then hyperparameter tuning will not be performed.
                    num_trials: int
                        How many HPO trials to run. Either `num_trials` or `time_limit` to `fit` needs to be specified.
                    scheduler: Union[str, ray.tune.schedulers.TrialScheduler]
                        If str is passed, AutoGluon will create the scheduler for you with some default parameters.
                        If ray.tune.schedulers.TrialScheduler object is passed, you are responsible for initializing the object.
                    scheduler_init_args: Optional[dict] = None
                        If provided str to `scheduler`, you can optionally provide custom init_args to the scheduler
                    searcher: Union[str, ray.tune.search.SearchAlgorithm, ray.tune.search.Searcher]
                        If str is passed, AutoGluon will create the searcher for you with some default parameters.
                        If ray.tune.schedulers.TrialScheduler object is passed, you are responsible for initializing the object.
                        You don't need to worry about `metric` and `mode` of the searcher object. AutoGluon will figure it out by itself.
                    scheduler_init_args: Optional[dict] = None
                        If provided str to `searcher`, you can optionally provide custom init_args to the searcher
                        You don't need to worry about `metric` and `mode`. AutoGluon will figure it out by itself.
        clean_ckpts
            Whether to clean the intermediate checkpoints after training.
        Returns
        -------
        An "MultiModalPredictor" object (itself).
        """
        if teacher_predictor is None:
            teacher_learner = None
        elif isinstance(teacher_predictor, str):
            teacher_learner = teacher_predictor
        else:
            teacher_learner = teacher_predictor._learner
        self._learner.fit(
            train_data=train_data,
            presets=presets,
            tuning_data=tuning_data,
            max_num_tuning_data=max_num_tuning_data,
            time_limit=time_limit,
            save_path=save_path,
            hyperparameters=hyperparameters,
            column_types=column_types,
            holdout_frac=holdout_frac,
            teacher_learner=teacher_learner,
            seed=seed,
            standalone=standalone,
            hyperparameter_tune_kwargs=hyperparameter_tune_kwargs,
            clean_ckpts=clean_ckpts,
            id_mappings=id_mappings,
        )
        return self 
[docs]
    def evaluate(
        self,
        data: Union[pd.DataFrame, dict, list, str],
        query_data: Optional[list] = None,
        response_data: Optional[list] = None,
        id_mappings: Optional[Union[Dict[str, Dict], Dict[str, pd.Series]]] = None,
        metrics: Optional[Union[str, List[str]]] = None,
        chunk_size: Optional[int] = 1024,
        similarity_type: Optional[str] = "cosine",
        cutoffs: Optional[List[int]] = [1, 5, 10],
        label: Optional[str] = None,
        return_pred: Optional[bool] = False,
        realtime: Optional[bool] = False,
        eval_tool: Optional[str] = None,
    ):
        """
        Evaluate the model on a given dataset.
        Parameters
        ----------
        data
            A pd.DataFrame, containing the same columns as the training data.
            Or a str, that is a path of the annotation file for detection.
        query_data
            Query data used for ranking.
        response_data
            Response data used for ranking.
        id_mappings
             Id-to-content mappings. The contents can be text, image, etc.
             This is used when data/query_data/response_data contain the query/response identifiers instead of their contents.
        metrics
            A list of metric names to report.
            If None, we only return the score for the stored `_eval_metric_name`.
        chunk_size
            Scan the response data by chunk_size each time. Increasing the value increases the speed, but requires more memory.
        similarity_type
            Use what function (cosine/dot_prod) to score the similarity (default: cosine).
        cutoffs
            A list of cutoff values to evaluate ranking.
        label
            The label column name in data. Some tasks, e.g., image<-->text matching, have no label column in training data,
            but the label column may be still required in evaluation.
        return_pred
            Whether to return the prediction result of each row.
        realtime
            Whether to do realtime inference, which is efficient for small data (default False).
            If provided None, we would infer it on based on the data modalities
            and sample number.
        eval_tool
            The eval_tool for object detection. Could be "pycocotools" or "torchmetrics".
        Returns
        -------
        A dictionary with the metric names and their corresponding scores.
        Optionally return a pd.DataFrame of prediction results.
        """
        return self._learner.evaluate(
            data=data,
            metrics=metrics,
            return_pred=return_pred,
            realtime=realtime,
            eval_tool=eval_tool,
            query_data=query_data,
            response_data=response_data,
            id_mappings=id_mappings,
            chunk_size=chunk_size,
            similarity_type=similarity_type,
            cutoffs=cutoffs,
            label=label,
        ) 
[docs]
    def predict(
        self,
        data: Union[pd.DataFrame, dict, list, str],
        candidate_data: Optional[Union[pd.DataFrame, dict, list]] = None,
        id_mappings: Optional[Union[Dict[str, Dict], Dict[str, pd.Series]]] = None,
        as_pandas: Optional[bool] = None,
        realtime: Optional[bool] = False,
        save_results: Optional[bool] = None,
        **kwargs,
    ):
        """
        Predict the label column values for new data.
        Parameters
        ----------
        data
            The data to make predictions for. Should contain same column names as training data and
            follow same format (except for the `label` column).
        candidate_data
            The candidate data from which to search the query data's matches.
        id_mappings
            Id-to-content mappings. The contents can be text, image, etc.
            This is used when data contain the query/response identifiers instead of their contents.
        as_pandas
            Whether to return the output as a pandas DataFrame(Series) (True) or numpy array (False).
        realtime
            Whether to do realtime inference, which is efficient for small data (default False).
            If provided None, we would infer it on based on the data modalities
            and sample number.
        save_results
            Whether to save the prediction results (only works for detection now)
        **kwargs
            Additional keyword arguments to pass to the underlying learner's predict method.
            For example, `as_coco` for object detection tasks.
        Returns
        -------
        Array of predictions, one corresponding to each row in given dataset.
        Format depends on the specific learner and provided arguments.
        """
        return self._learner.predict(
            data=data,
            candidate_data=candidate_data,
            as_pandas=as_pandas,
            realtime=realtime,
            save_results=save_results,
            id_mappings=id_mappings,
            **kwargs,
        ) 
[docs]
    def predict_proba(
        self,
        data: Union[pd.DataFrame, dict, list],
        candidate_data: Optional[Union[pd.DataFrame, dict, list]] = None,
        id_mappings: Optional[Union[Dict[str, Dict], Dict[str, pd.Series]]] = None,
        as_pandas: Optional[bool] = None,
        as_multiclass: Optional[bool] = True,
        realtime: Optional[bool] = False,
    ):
        """
        Predict class probabilities rather than class labels.
        Note that this is only for the classification tasks.
        Calling it for a regression task will throw an exception.
        Parameters
        ----------
        data
            The data to make predictions for. Should contain same column names as training data and
              follow same format (except for the `label` column).
        candidate_data
            The candidate data from which to search the query data's matches.
        id_mappings
             Id-to-content mappings. The contents can be text, image, etc.
             This is used when data contain the query/response identifiers instead of their contents.
        as_pandas
            Whether to return the output as a pandas DataFrame(Series) (True) or numpy array (False).
        as_multiclass
            Whether to return the probability of all labels or
            just return the probability of the positive class for binary classification problems.
        realtime
            Whether to do realtime inference, which is efficient for small data (default False).
            If provided None, we would infer it on based on the data modalities
            and sample number.
        Returns
        -------
        Array of predicted class-probabilities, corresponding to each row in the given data.
        When as_multiclass is True, the output will always have shape (#samples, #classes).
        Otherwise, the output will have shape (#samples,)
        """
        return self._learner.predict_proba(
            data=data,
            candidate_data=candidate_data,
            as_pandas=as_pandas,
            as_multiclass=as_multiclass,
            realtime=realtime,
            id_mappings=id_mappings,
        ) 
[docs]
    def save(self, path: str, standalone: Optional[bool] = True):
        """
        Save this predictor to file in directory specified by `path`.
        Parameters
        ----------
        path
            The directory to save this predictor.
        standalone
            Whether to save the downloaded model for offline deployment.
            When standalone = True, save the transformers.CLIPModel and transformers.AutoModel to os.path.join(path,model_name),
            and reset the associate model.model_name.checkpoint_name start with `local://` in config.yaml.
            When standalone = False, the saved artifact may require an online environment to process in load().
        """
        self._learner.save(path=path, standalone=standalone) 
[docs]
    @classmethod
    def load(
        cls,
        path: str,
        resume: Optional[bool] = False,
        verbosity: Optional[int] = 3,
    ):
        """
        Load a predictor object from a directory specified by `path`. The to-be-loaded predictor
        can be completely or partially trained by .fit(). If a previous training has completed,
        it will load the checkpoint `model.ckpt`. Otherwise, if a previous training accidentally
        collapses in the middle, it can load the `last.ckpt` checkpoint by setting `resume=True`.
        It also supports loading one specific checkpoint given its path.
        .. warning::
            :meth:`autogluon.multimodal.MultiModalPredictor.load` uses `pickle` module implicitly, which is known to
            be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during
            unpickling. Never load data that could have come from an untrusted source, or that could have been tampered
            with. **Only load data you trust.**
        Parameters
        ----------
        path
            The directory to load the predictor object.
        resume
            Whether to resume training from `last.ckpt`. This is useful when a training was accidentally
            broken during the middle, and we want to resume the training from the last saved checkpoint.
        verbosity
            Verbosity levels range from 0 to 4 and control how much information is printed.
            Higher levels correspond to more detailed print statements.
            You can set verbosity = 0 to suppress warnings.
        Returns
        -------
        The loaded predictor object.
        """
        dir_path, ckpt_path = get_dir_ckpt_paths(path=path)
        assert os.path.isdir(dir_path), f"'{dir_path}' must be an existing directory."
        predictor = cls(label="dummy_label")
        with open(os.path.join(dir_path, "assets.json"), "r") as fp:
            assets = json.load(fp)
        if "class_name" in assets and assets["class_name"] == "MultiModalMatcher":
            learner_class = MultiModalMatcher
        elif assets["problem_type"] == OBJECT_DETECTION:
            learner_class = ObjectDetectionLearner
        elif assets["problem_type"] == NER:
            learner_class = NERLearner
        elif assets["problem_type"] == FEW_SHOT_CLASSIFICATION:
            learner_class = FewShotSVMLearner
        elif assets["problem_type"] == SEMANTIC_SEGMENTATION:
            learner_class = SemanticSegmentationLearner
        else:
            learner_class = BaseLearner
        predictor._learner = learner_class.load(path=path, resume=resume, verbosity=verbosity)
        return predictor 
[docs]
    def dump_model(self, save_path: Optional[str] = None):
        """
        Save model weights and config to a local directory.
        Model weights are saved in the file `pytorch_model.bin` (for `timm_image` or `hf_text`)
        or '<ckpt_name>.pth' (for `mmdet_image`).
        Configs are saved in the file `config.json` (for `timm_image` or `hf_text`)
        or  '<ckpt_name>.py' (for `mmdet_image`).
        Parameters
        ----------
        save_path : str
           Path to directory where models and configs should be saved.
        """
        return self._learner.dump_model(save_path=save_path) 
[docs]
    def export_onnx(
        self,
        data: Union[dict, pd.DataFrame],
        path: Optional[str] = None,
        batch_size: Optional[int] = None,
        verbose: Optional[bool] = False,
        opset_version: Optional[int] = 16,
        truncate_long_and_double: Optional[bool] = False,
    ):
        """
        Export this predictor's model to an ONNX file.
        When `path` argument is not provided, the method would not save the model into disk.
        Instead, it would export the onnx model into BytesIO and return its binary as bytes.
        Parameters
        ----------
        data
            Raw data used to trace and export the model.
            If this is None, will check if a processed batch is provided.
        path : str, default=None
            The export path of onnx model. If path is not provided, the method would export model to memory.
        batch_size
            The batch_size of export model's input.
            Normally the batch_size is a dynamic axis, so we could use a small value for faster export.
        verbose
            verbose flag in torch.onnx.export.
        opset_version
            opset_version flag in torch.onnx.export.
        truncate_long_and_double: bool, default False
            Truncate weights provided in int64 or double (float64) to int32 and float32
        Returns
        -------
        onnx_path : str or bytes
            A string that indicates location of the exported onnx model, if `path` argument is provided.
            Otherwise, would return the onnx model as bytes.
        """
        # Make sure _model is initialized
        self._learner.on_predict_start()
        return self._learner.export_onnx(
            data=data,
            path=path,
            batch_size=batch_size,
            verbose=verbose,
            opset_version=opset_version,
            truncate_long_and_double=truncate_long_and_double,
        ) 
[docs]
    def optimize_for_inference(
        self,
        providers: Optional[Union[dict, List[str]]] = None,
    ):
        """
        Optimize the predictor's model for inference.
        Under the hood, the implementation would convert the PyTorch module into an ONNX module, so that
        we can leverage efficient execution providers in onnxruntime for faster inference.
        Parameters
        ----------
        providers : dict or str, default=None
            A list of execution providers for model prediction in onnxruntime.
            By default, the providers argument is None. The method would generate an ONNX module that
            would perform model inference with TensorrtExecutionProvider in onnxruntime, if tensorrt
            package is properly installed. Otherwise, the onnxruntime would fallback to use CUDA or CPU
            execution providers instead.
        Returns
        -------
        onnx_module : OnnxModule
            The onnx-based module that can be used to replace predictor._model for model inference.
        """
        return self._learner.optimize_for_inference(providers=providers) 
[docs]
    def fit_summary(self, verbosity=0, show_plot=False):
        """
        Output the training summary information from `fit()`.
        Parameters
        ----------
        verbosity : int, default = 2
            Verbosity levels range from 0 to 4 and control how much information is printed.
            verbosity = 0 for no output printing.
            TODO: Higher levels correspond to more detailed print statements
        show_plot : bool, default = False
            If True, shows the model summary plot in browser when verbosity > 1.
        Returns
        -------
        Dict containing various detailed information.
        We do not recommend directly printing this dict as it may be very large.
        """
        return self._learner.fit_summary(verbosity=verbosity, show_plot=show_plot) 
[docs]
    def list_supported_models(self, pretrained=True):
        """
        List supported models for each problem type.
        Parameters
        ----------
        pretrained : bool, default = True
            If True, only return the models with pretrained weights.
            If False, return all the models as long as there is model definition.
        Returns
        -------
        a list of model names
        """
        return self._learner.list_supported_models(pretrained=pretrained)