autogluon.multimodal.MultiModalPredictor¶
- class autogluon.multimodal.MultiModalPredictor(label: str | None = None, problem_type: str | None = None, query: str | List[str] | None = None, response: str | List[str] | None = None, match_label: int | str | None = None, presets: str | None = None, eval_metric: str | Scorer | None = None, hyperparameters: dict | None = None, path: str | None = None, verbosity: int | None = 2, num_classes: int | None = None, classes: list | None = None, warn_if_exist: bool | None = True, enable_progress_bar: bool | None = None, pretrained: bool | None = True, validation_metric: str | None = None, sample_data_path: str | None = None)[source]¶
- AutoMM is designed to simplify the fine-tuning of foundation models for downstream applications with just three lines of code. AutoMM seamlessly integrates with popular model zoos such as HuggingFace Transformers, TIMM, and MMDetection, accommodating a diverse range of data modalities, including image, text, tabular, and document data, whether used individually or in combination. It offers support for an array of tasks, encompassing classification, regression, object detection, named entity recognition, semantic matching, and image segmentation. - __init__(label: str | None = None, problem_type: str | None = None, query: str | List[str] | None = None, response: str | List[str] | None = None, match_label: int | str | None = None, presets: str | None = None, eval_metric: str | Scorer | None = None, hyperparameters: dict | None = None, path: str | None = None, verbosity: int | None = 2, num_classes: int | None = None, classes: list | None = None, warn_if_exist: bool | None = True, enable_progress_bar: bool | None = None, pretrained: bool | None = True, validation_metric: str | None = None, sample_data_path: str | None = None)[source]¶
- Parameters:
- label – Name of one pd.DataFrame column that contains the target variable to predict. 
- problem_type – - Type of problem. We support standard problems like - ’binary’: Binary classification 
- ’multiclass’: Multi-class classification 
- ’regression’: Regression 
- ’classification’: Classification problems include ‘binary’ and ‘multiclass’ classification. 
 - In addition, we support advanced problems such as - ’object_detection’: Object detection 
- ’ner’ or ‘named_entity_recognition’: Named entity extraction 
- ’text_similarity’: Text-text semantic matching 
- ’image_similarity’: Image-image semantic matching 
- ’image_text_similarity’: Text-image semantic matching 
- ’feature_extraction’: Extracting feature (only support inference) 
- ’zero_shot_image_classification’: Zero-shot image classification (only support inference) 
- ’few_shot_classification’: Few-shot classification for image or text data. 
- ’semantic_segmentation’: Semantic segmentation with Segment Anything Model. 
 - For certain problem types, the default behavior is to load a pretrained model based on the presets / hyperparameters and the predictor can do zero-shot inference (running inference without .fit()). Those include the following problem types: - ’object_detection’ 
- ’text_similarity’ 
- ’image_similarity’ 
- ’image_text_similarity’ 
- ’feature_extraction’ 
- ’zero_shot_image_classification’ 
 
- query – Name of one pd.DataFrame column that has the query data in semantic matching tasks. 
- response – Name of one pd.DataFrame column that contains the response data in semantic matching tasks. If no label column is provided, the query and response pairs in one pd.DataFrame row are assumed to be positive pairs. 
- match_label – The label class that indicates the <query, response> pair is counted as a “match”. This is used when the task belongs to semantic matching, and the labels are binary. For example, the label column can contain [“duplicate”, “not duplicate”] in a duplicate detection task. The match_label should be “duplicate” since it means that two items match. 
- presets – Presets regarding model quality, e.g., ‘best_quality’, ‘high_quality’ (default), and ‘medium_quality’. Each quality has its corresponding HPO presets: ‘best_quality_hpo’, ‘high_quality_hpo’, and ‘medium_quality_hpo’. 
- eval_metric – Evaluation metric name. If eval_metric = None, it is automatically chosen based on problem_type. Defaults to ‘accuracy’ for multiclass classification, roc_auc for binary classification, and ‘root_mean_squared_error’ for regression. 
- hyperparameters – - This is to override some default configurations. For example, changing the text and image backbones can be done by formatting: - a string hyperparameters = “model.hf_text.checkpoint_name=google/electra-small-discriminator model.timm_image.checkpoint_name=swin_small_patch4_window7_224” - or a list of strings hyperparameters = [“model.hf_text.checkpoint_name=google/electra-small-discriminator”, “model.timm_image.checkpoint_name=swin_small_patch4_window7_224”] - or a dictionary hyperparameters = { - ”model.hf_text.checkpoint_name”: “google/electra-small-discriminator”, “model.timm_image.checkpoint_name”: “swin_small_patch4_window7_224”, - } 
- path – Path to directory where models and related artifacts should be saved. If unspecified, a time-stamped folder called “AutogluonAutoMM/ag-[TIMESTAMP]” will be created in the working directory. Note: To call fit() twice and save all results of each fit, you must specify different path locations or don’t specify path at all. 
- verbosity – Verbosity levels range from 0 to 4, controlling how much logging information is printed. Higher levels correspond to more detailed print statements. You can set verbosity = 0 to suppress warnings. 
- num_classes – Number of classes (used for object detection). If this is specified and is different from the pretrained model’s output shape, the model’s head will be changed to have <num_classes> output. 
- classes – All the classes (used for object detection). 
- warn_if_exist – Whether to raise warning if the specified path already exists (Default True). 
- enable_progress_bar – Whether to show progress bar (default True). It would be disabled if the environment variable os.environ[“AUTOMM_DISABLE_PROGRESS_BAR”] is set. 
- pretrained – Whether to initialize the model with pretrained weights (default True). If False, it creates a model with random initialization. 
- validation_metric – Validation metric for selecting the best model and early-stopping during training. If not provided, it would be automatically chosen based on the problem type. 
- sample_data_path – The path to sample data from which we can infer num_classes or classes used for object detection. 
 
 
 - Methods - Save model weights and config to a local directory. - Evaluate the model on a given dataset. - Export this predictor's model to an ONNX file. - Extract features for each sample, i.e., one row in the provided pd.DataFrame data. - Fit models to predict a column of a data table (label) based on the other columns (features). - Output the training summary information from fit(). - Get the number of GPUs from config. - List supported models for each problem type. - Load a predictor object from a directory specified by path. - Optimize the predictor's model for inference. - Predict the label column values for new data. - Predict class probabilities rather than class labels. - Save this predictor to file in directory specified by path. - Set the number of GPUs in config. - Set the verbosity level of the log. - Attributes - class_labels- The original name of the class labels. - classes- Object classes for the object detection problem type. - column_types- Column types in the pd.DataFrame. - eval_metric- What metric is used to evaluate predictive performance. - label- Name of one pd.DataFrame column that contains the target variable to predict. - match_label- The label class that indicates the <query, response> pair is counted as "match" in the semantic matching tasks. - model_size- Returns the model size in Megabyte. - path- Path to directory where the model and related artifacts are stored. - positive_class- Name of the class label that will be mapped to 1. - problem_property- Property of the problem, storing the problem type and its related properties. - problem_type- What type of prediction problem this predictor has been trained for. - query- Name of one pd.DataFrame column that has the query data in semantic matching tasks. - response- Name of one pd.DataFrame column that contains the response data in semantic matching tasks. - total_parameters- The number of model parameters. - trainable_parameters- The number of trainable model parameters, usually those with requires_grad=True. - validation_metric- Validation metric for selecting the best model and early-stopping during training. - verbosity- Verbosity levels range from 0 to 4 and control how much information is printed.