AutoMM for Multimodal Named Entity Extraction#

Open In Colab Open In SageMaker Studio Lab

We have introduced how to train an entity extraction model with text data. Here, we move a step further by integrating data of other modalities. In many real-world applications, textual data usually comes with data of other modalities. For example, Twitter allows you to compose tweets with text, photos, videos, and GIFs. Amazon.com uses text, images, and videos to describe their products. These auxiliary modalities can be leveraged as additional context resolution of entities. Now, with AutoMM, you can easily exploit multimodal data to enhance entity extraction without worrying about the details.

import os
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

Get the Twitter Dataset#

In the following example, we will demonstrate how to build a multimodal named entity recognition model with a real-world Twitter dataset. This dataset consists of scrapped tweets from 2016 to 2017, and each tweet was composed of one sentence and one image. Let’s download the dataset.

download_dir = './ag_automm_tutorial_ner'
zip_file = 'https://automl-mm-bench.s3.amazonaws.com/ner/multimodal_ner.zip'
from autogluon.core.utils.loaders import load_zip
load_zip.unzip(zip_file, unzip_dir=download_dir)
Downloading ./ag_automm_tutorial_ner/file.zip from https://automl-mm-bench.s3.amazonaws.com/ner/multimodal_ner.zip...
100%|██████████| 423M/423M [00:07<00:00, 59.3MiB/s]

Next, we will load the CSV files.

dataset_path = download_dir + '/multimodal_ner'
train_data = pd.read_csv(f'{dataset_path}/twitter17_train.csv')
test_data = pd.read_csv(f'{dataset_path}/twitter17_test.csv')
label_col = 'entity_annotations'

We need to expand the image paths to load them in training.

image_col = 'image'
train_data[image_col] = train_data[image_col].apply(lambda ele: ele.split(';')[0]) # Use the first image for a quick tutorial
test_data[image_col] = test_data[image_col].apply(lambda ele: ele.split(';')[0])

def path_expander(path, base_folder):
	path_l = path.split(';')
	p = ';'.join([os.path.abspath(base_folder+path) for path in path_l])
	return p

train_data[image_col] = train_data[image_col].apply(lambda ele: path_expander(ele, base_folder=dataset_path))
test_data[image_col] = test_data[image_col].apply(lambda ele: path_expander(ele, base_folder=dataset_path))

train_data[image_col].iloc[0]
'/home/ci/autogluon/docs/tutorials/multimodal/multimodal_prediction/ag_automm_tutorial_ner/multimodal_ner/twitter2017_images/17_06_1818.jpg'

Each row consists of the text and image of a single tweet and the entity_annotataions which contains the named entity annotations for the text column. Let’s look at an example row and display the text and picture of the tweet.

example_row = train_data.iloc[0]

example_row
text_snippet           Uefa Super Cup : Real Madrid v Manchester United
image                 /home/ci/autogluon/docs/tutorials/multimodal/m...
entity_annotations    [{"entity_group": "B-MISC", "start": 0, "end":...
Name: 0, dtype: object

Below is the image of this tweet.

example_image = example_row[image_col]

from IPython.display import Image, display
pil_img = Image(filename=example_image, width =300)
display(pil_img)
../../../_images/c076ff68a445f46f5b29a045835912425d0c79074bdf6fdf9eef4c318bebbf0d.jpg

As you can see, this photo contains the logos of the Real Madrid football club, Manchester United football club, and the UEFA super cup. Clearly, the key information of the tweet sentence is coded here in a different modality.

Training#

Now let’s fit the predictor with the training data. Firstly, we need to specify the problem_type to ner. As our annotations are used for text columns, to ensure the model to locate the correct text column for entity extraction, we need to set the corresponding column type to text_ner using the column_types parameter in cases where multiple text columns are present. Here we set a tight time budget for a quick demo.

from autogluon.multimodal import MultiModalPredictor
import uuid

label_col = "entity_annotations"
model_path = f"./tmp/{uuid.uuid4().hex}-automm_multimodal_ner"
predictor = MultiModalPredictor(problem_type="ner", label=label_col, path=model_path)
predictor.fit(
	train_data=train_data,
	column_types={"text_snippet":"text_ner"},
	time_limit=300, #second
)
INFO:lightning_fabric.utilities.seed:Global seed set to 0
AutoMM starts to create your model. ✨

- AutoGluon version is 0.8.1b20230622.

- Pytorch version is 1.13.1+cu117.

- Model will be saved to "/home/ci/autogluon/docs/tutorials/multimodal/multimodal_prediction/tmp/c8a63b2455a24ad5830da7b3cf8ce9ef-automm_multimodal_ner".

- Validation metric is "ner_token_f1".

- To track the learning progress, you can open a terminal and launch Tensorboard:
    ```shell
    # Assume you have installed tensorboard
    tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/multimodal_prediction/tmp/c8a63b2455a24ad5830da7b3cf8ce9ef-automm_multimodal_ner
    ```

Enjoy your coffee, and let AutoMM do the job ☕☕☕ Learn more at https://auto.gluon.ai
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
 in <module>:7                                                                                    
                                                                                                  
    4 label_col = "entity_annotations"                                                            
    5 model_path = f"./tmp/{uuid.uuid4().hex}-automm_multimodal_ner"                              
    6 predictor = MultiModalPredictor(problem_type="ner", label=label_col, path=model_path)       
  7 predictor.fit(                                                                              
    8 train_data=train_data,                                                                  
    9 column_types={"text_snippet":"text_ner"},                                               
   10 time_limit=300, #second                                                                 
                                                                                                  
 /home/ci/autogluon/multimodal/src/autogluon/multimodal/predictor.py:864 in fit                   
                                                                                                  
    861 │   │   │   )                                                                             
    862 │   │   │   return predictor                                                              
    863 │   │                                                                                     
  864 │   │   self._fit(**_fit_args)                                                            
    865 │   │   training_end = time.time()                                                        
    866 │   │   self._total_train_time = training_end - training_start                            
    867                                                                                           
                                                                                                  
 /home/ci/autogluon/multimodal/src/autogluon/multimodal/predictor.py:1140 in _fit                 
                                                                                                  
   1137 │   │   │   self._output_shape = len(df_preprocessor.label_generator.unique_entity_group  
   1138 │   │                                                                                     
   1139 │   │   if self._model is None:                                                           
 1140 │   │   │   model = create_fusion_model(                                                  
   1141 │   │   │   │   config=config,                                                            
   1142 │   │   │   │   num_classes=self._output_shape,                                           
   1143 │   │   │   │   classes=self._classes,                                                    
                                                                                                  
 /home/ci/autogluon/multimodal/src/autogluon/multimodal/utils/model.py:442 in create_fusion_model 
                                                                                                  
   439                                                                                        
   440 for model_name in names:                                                               
   441 │   │   model_config = getattr(config.model, model_name)                                   
 442 │   │   model = create_model(                                                              
   443 │   │   │   model_name=model_name,                                                         
   444 │   │   │   model_config=model_config,                                                     
   445 │   │   │   num_classes=num_classes,                                                       
                                                                                                  
 /home/ci/autogluon/multimodal/src/autogluon/multimodal/utils/model.py:337 in create_model        
                                                                                                  
   334 │   │   │   checkpoint_name=model_config.checkpoint_name,                                  
   335 │   │   )                                                                                  
   336 elif model_name.lower().startswith(NER_TEXT):                                          
 337 │   │   model = HFAutoModelForNER(                                                         
   338 │   │   │   prefix=model_name,                                                             
   339 │   │   │   checkpoint_name=model_config.checkpoint_name,                                  
   340 │   │   │   num_classes=num_classes,                                                       
                                                                                                  
 /home/ci/autogluon/multimodal/src/autogluon/multimodal/models/ner_text.py:71 in __init__         
                                                                                                  
    68 │   │   pretrained                                                                         
    69 │   │   │   Whether using the pretrained weights. If pretrained=True, download the pretr   
    70 │   │   """                                                                                
  71 │   │   super().__init__(                                                                  
    72 │   │   │   prefix=prefix,                                                                 
    73 │   │   │   checkpoint_name=checkpoint_name,                                               
    74 │   │   │   num_classes=num_classes,                                                       
                                                                                                  
 /home/ci/autogluon/multimodal/src/autogluon/multimodal/models/huggingface_text.py:84 in __init__ 
                                                                                                  
    81 │   │   self.config, self.model = get_hf_config_and_model(                                 
    82 │   │   │   checkpoint_name=checkpoint_name, pretrained=pretrained, low_cpu_mem_usage=lo   
    83 │   │   )                                                                                  
  84 │   │   self._hf_model_input_names = AutoTokenizer.from_pretrained(checkpoint_name).mode   
    85 │   │                                                                                      
    86 │   │   if isinstance(self.model, T5PreTrainedModel):                                      
    87 │   │   │   self.is_t5 = True                                                              
                                                                                                  
 /home/ci/opt/venv/lib/python3.8/site-packages/transformers/models/auto/tokenization_auto.py:676  
 in from_pretrained                                                                               
                                                                                                  
   673 │   │   if model_type is not None:                                                         
   674 │   │   │   tokenizer_class_py, tokenizer_class_fast = TOKENIZER_MAPPING[type(config)]     
   675 │   │   │   if tokenizer_class_fast and (use_fast or tokenizer_class_py is None):          
 676 │   │   │   │   return tokenizer_class_fast.from_pretrained(pretrained_model_name_or_pat   
   677 │   │   │   else:                                                                          
   678 │   │   │   │   if tokenizer_class_py is not None:                                         
   679 │   │   │   │   │   return tokenizer_class_py.from_pretrained(pretrained_model_name_or_p   
                                                                                                  
 /home/ci/opt/venv/lib/python3.8/site-packages/transformers/tokenization_utils_base.py:1804 in    
 from_pretrained                                                                                  
                                                                                                  
   1801 │   │   │   else:                                                                         
   1802 │   │   │   │   logger.info(f"loading file {file_path} from cache at {resolved_vocab_fil  
   1803 │   │                                                                                     
 1804 │   │   return cls._from_pretrained(                                                      
   1805 │   │   │   resolved_vocab_files,                                                         
   1806 │   │   │   pretrained_model_name_or_path,                                                
   1807 │   │   │   init_configuration,                                                           
                                                                                                  
 /home/ci/opt/venv/lib/python3.8/site-packages/transformers/tokenization_utils_base.py:1959 in    
 _from_pretrained                                                                                 
                                                                                                  
   1956 │   │                                                                                     
   1957 │   │   # Instantiate tokenizer.                                                          
   1958 │   │   try:                                                                              
 1959 │   │   │   tokenizer = cls(*init_inputs, **init_kwargs)                                  
   1960 │   │   except OSError:                                                                   
   1961 │   │   │   raise OSError(                                                                
   1962 │   │   │   │   "Unable to load vocabulary from file. "                                   
                                                                                                  
 /home/ci/opt/venv/lib/python3.8/site-packages/transformers/models/deberta_v2/tokenization_debert 
 a_v2_fast.py:133 in __init__                                                                     
                                                                                                  
   130 │   │   mask_token="[MASK]",                                                               
   131 │   │   **kwargs                                                                           
   132 ) -> None:                                                                             
 133 │   │   super().__init__(                                                                  
   134 │   │   │   vocab_file,                                                                    
   135 │   │   │   tokenizer_file=tokenizer_file,                                                 
   136 │   │   │   do_lower_case=do_lower_case,                                                   
                                                                                                  
 /home/ci/opt/venv/lib/python3.8/site-packages/transformers/tokenization_utils_fast.py:114 in     
 __init__                                                                                         
                                                                                                  
   111 │   │   │   fast_tokenizer = TokenizerFast.from_file(fast_tokenizer_file)                  
   112 │   │   elif slow_tokenizer is not None:                                                   
   113 │   │   │   # We need to convert a slow tokenizer to build the backend                     
 114 │   │   │   fast_tokenizer = convert_slow_tokenizer(slow_tokenizer)                        
   115 │   │   elif self.slow_tokenizer_class is not None:                                        
   116 │   │   │   # We need to create and convert a slow tokenizer to build the backend          
   117 │   │   │   slow_tokenizer = self.slow_tokenizer_class(*args, **kwargs)                    
                                                                                                  
 /home/ci/opt/venv/lib/python3.8/site-packages/transformers/convert_slow_tokenizer.py:1162 in     
 convert_slow_tokenizer                                                                           
                                                                                                  
   1159                                                                                       
   1160 converter_class = SLOW_TO_FAST_CONVERTERS[tokenizer_class_name]                       
   1161                                                                                       
 1162 return converter_class(transformer_tokenizer).converted()                             
   1163                                                                                           
                                                                                                  
 /home/ci/opt/venv/lib/python3.8/site-packages/transformers/convert_slow_tokenizer.py:438 in      
 __init__                                                                                         
                                                                                                  
    435 │   │                                                                                     
    436 │   │   super().__init__(*args)                                                           
    437 │   │                                                                                     
  438 │   │   from .utils import sentencepiece_model_pb2 as model_pb2                           
    439 │   │                                                                                     
    440 │   │   m = model_pb2.ModelProto()                                                        
    441 │   │   with open(self.original_tokenizer.vocab_file, "rb") as f:                         
                                                                                                  
 /home/ci/opt/venv/lib/python3.8/site-packages/transformers/utils/sentencepiece_model_pb2.py:92   
 in <module>                                                                                      
                                                                                                  
     89 file=DESCRIPTOR,                                                                      
     90 create_key=_descriptor._internal_create_key,                                          
     91 values=[                                                                              
   92 │   │   _descriptor.EnumValueDescriptor(                                                  
     93 │   │   │   name="UNIGRAM",                                                               
     94 │   │   │   index=0,                                                                      
     95 │   │   │   number=1,                                                                     
                                                                                                  
 /home/ci/opt/venv/lib/python3.8/site-packages/google/protobuf/descriptor.py:796 in __new__       
                                                                                                  
    793 def __new__(cls, name, index, number,                                                 
    794 │   │   │   │   type=None,  # pylint: disable=redefined-builtin                           
    795 │   │   │   │   options=None, serialized_options=None, create_key=None):                  
  796 _message.Message._CheckCalledFromGeneratedFile()                                    
    797 # There is no way we can build a complete EnumValueDescriptor with the              
    798 # given parameters (the name of the Enum is not known, for example).                
    799 # Fortunately generated files just pass it to the EnumDescriptor()                  
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
TypeError: Descriptors cannot not be created directly.
If this call came from a _pb2.py file, your generated code is out of date and must be regenerated with protoc >= 
3.19.0.
If you cannot immediately regenerate your protos, some other possible workarounds are:
 1. Downgrade the protobuf package to 3.20.x or lower.
 2. Set PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python (but this will use pure-Python parsing and will be much 
slower).

More information: https://developers.google.com/protocol-buffers/docs/news/2022-05-06#python-updates

Under the hood, AutoMM automatically detects the data modalities, selects the related models from the multimodal model pools, and trains the selected models. If multiple backbones are available, AutoMM appends a late-fusion model on top of them.

Evaluation#

predictor.evaluate(test_data,  metrics=['overall_recall', "overall_precision", "overall_f1"])

Prediction#

You can easily obtain the predictions by calling predictor.predict().

prediction_input = test_data.drop(columns=label_col).head(1)
predictions = predictor.predict(prediction_input)
print('Tweet:', prediction_input.text_snippet[0])
print('Image path:', prediction_input.image[0])
print('Predicted entities:', predictions[0])

for entity in predictions[0]:
	print(f"Word '{prediction_input.text_snippet[0][entity['start']:entity['end']]}' belongs to group: {entity['entity_group']}")

Reloading and Continuous Training#

The trained predictor is automatically saved and you can easily reload it using the path. If you are not satisfied with the current model performance, you can continue training the loaded model with new data.

new_predictor = MultiModalPredictor.load(model_path)
new_model_path = f"./tmp/{uuid.uuid4().hex}-automm_multimodal_ner_continue_train"
new_predictor.fit(train_data, time_limit=60, save_path=new_model_path)
test_score = new_predictor.evaluate(test_data, metrics=['overall_f1'])
print(test_score)

Other Examples#

You may go to AutoMM Examples to explore other examples about AutoMM.

Customization#

To learn how to customize AutoMM, please refer to Customize AutoMM.