Single GPU Billion-scale Model Training via Parameter-Efficient Finetuning#

Open In Colab Open In SageMaker Studio Lab

As pointed out by a recent paper from Stanford Institute for Human-Centered Artificial Intelligence, AI is undergoing a paradigm shift with the rise of “foundation models”, i.e., giant models that are trained on a diverse collection of datasets generally in a self-supervised way. These foundation models, which are the key of AutoMM, can be easily adapted to down-stream applications. However, as the size of these foundation models grows, finetuning these models becomes increasingly difficult. Following is a figure from the Microsoft research blog that demonstrates the trend:

Scaling of foundation models

The goal of AutoMM is to help anyone solve machine learning problems via open source foundation models, including these giant models. To finetune these large-scale models, we adopt the recently popularized parameter-efficient finetuning technique. The idea is to either finetune a small subset of the weights in the foundation model (e.g., BitFit), or adding a tiny tunable structure on top of the fixed backbone (e.g., Prompt Tuning, LoRA, Adapter, MAM Adapter, IA^3). These techniques can effectively reduce the peak memory usage and model training time, while maintaining the performance.

In this tutorial, we introduce how to apply parameter-efficient finetuning in MultiModalPredictor. We first introduce how to adopt the "ia3_bias" algorithm for parameter-efficient finetuning. Afterwards, we show how you can simply combine "ia3_bias" and gradient checkpointing to finetune the XL-variant of Google’s FLAN-T5 via a single NVIDIA T4 GPU.

Prepare Dataset#

The Cross-Lingual Amazon Product Review Sentiment dataset contains Amazon product reviews in four languages. Here, we load the English and German fold of the dataset. In the label column, 0 means negative sentiment and 1 means positive sentiment. For the purpose of demonstration, we downsampled the training data to 1000 samples. We will train the model on the English dataset and directly evaluate its performance on the German and Japanese test set.

!wget --quiet https://automl-mm-bench.s3.amazonaws.com/multilingual-datasets/amazon_review_sentiment_cross_lingual.zip -O amazon_review_sentiment_cross_lingual.zip
!unzip -q -o amazon_review_sentiment_cross_lingual.zip -d .
import os
import shutil
os.environ["TRANSFORMERS_CACHE"] = "cache"

def clear_cache():
    if os.path.exists("cache"):
        shutil.rmtree("cache")

clear_cache()
import pandas as pd
import warnings
warnings.filterwarnings("ignore")

train_en_df = pd.read_csv("amazon_review_sentiment_cross_lingual/en_train.tsv",
                          sep="\t",
                          header=None,
                          names=["label", "text"]) \
                .sample(1000, random_state=123).reset_index(drop=True)

test_en_df = pd.read_csv("amazon_review_sentiment_cross_lingual/en_test.tsv",
                          sep="\t",
                          header=None,
                          names=["label", "text"]) \
               .sample(200, random_state=123).reset_index(drop=True)
test_de_df = pd.read_csv("amazon_review_sentiment_cross_lingual/de_test.tsv",
                          sep="\t", header=None, names=["label", "text"]) \
               .sample(200, random_state=123).reset_index(drop=True)

test_jp_df = pd.read_csv('amazon_review_sentiment_cross_lingual/jp_test.tsv',
                          sep='\t', header=None, names=['label', 'text']) \
               .sample(200, random_state=123).reset_index(drop=True)
train_en_df.head(5)
label text
0 0 This is a film that literally sees little wron...
1 0 This music is pretty intelligent, but not very...
2 0 One of the best pieces of rock ever recorded, ...
3 0 Reading the posted reviews here, is like revis...
4 1 I've just finished page 341, the last page. It...
test_jp_df.head(5)
label text
0 1 原作はビクトル・ユーゴの長編小説だが、私が子供の頃読んだのは短縮版の「ああ無情」。それでもこ...
1 1 ほかの作品のレビューにみんな書いているのに、何故この作品について書いている人が一人しかいない...
2 0 一番の問題点は青島が出ていない事でしょう。 TV番組では『芸人が出ていればバラエティだから...
3 0 昔、 りんたろう監督によるアニメ「カムイの剣」があった。 「カムイの剣」…を観た人なら本作...
4 1 以前のアルバムを聴いていないのでなんとも言えないが、クラシックなメタルを聞いてきた耳には、と...

Finetuning Multilingual Model with IA3 + BitFit#

In AutoMM, to enable efficient finetuning, just specify the optimization.efficient_finetune to be "ia3_bias".

from autogluon.multimodal import MultiModalPredictor
import uuid

model_path = f"./tmp/{uuid.uuid4().hex}-multilingual_ia3"
predictor = MultiModalPredictor(label="label",
                                path=model_path)
predictor.fit(train_en_df,
              presets="multilingual",
              hyperparameters={
                  "optimization.efficient_finetune": "ia3_bias",
                  "optimization.lr_decay": 0.9,
                  "optimization.learning_rate": 3e-03,
                  "optimization.end_lr": 3e-03,
                  "optimization.max_epochs": 2,
                  "optimization.warmup_steps": 0,
                  "env.batch_size": 32,
              })
AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed).
	2 unique label values:  [0, 1]
	If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression'])
Global seed set to 0
AutoMM starts to create your model. ✨

- AutoGluon version is 0.8.1b20230622.

- Pytorch version is 1.13.1+cu117.

- Model will be saved to "/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/9cdb6292a5704d8e8a0bd70f028fdc3d-multilingual_ia3".

- Validation metric is "roc_auc".

- To track the learning progress, you can open a terminal and launch Tensorboard:
    ```shell
    # Assume you have installed tensorboard
    tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/9cdb6292a5704d8e8a0bd70f028fdc3d-multilingual_ia3
    ```

Enjoy your coffee, and let AutoMM do the job ☕☕☕ Learn more at https://auto.gluon.ai
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
 in <module>:7                                                                                    
                                                                                                  
    4 model_path = f"./tmp/{uuid.uuid4().hex}-multilingual_ia3"                                   
    5 predictor = MultiModalPredictor(label="label",                                              
    6 │   │   │   │   │   │   │   │   path=model_path)                                            
  7 predictor.fit(train_en_df,                                                                  
    8 │   │   │     presets="multilingual",                                                       
    9 │   │   │     hyperparameters={                                                             
   10 │   │   │   │     "optimization.efficient_finetune": "ia3_bias",                            
                                                                                                  
 /home/ci/autogluon/multimodal/src/autogluon/multimodal/predictor.py:864 in fit                   
                                                                                                  
    861 │   │   │   )                                                                             
    862 │   │   │   return predictor                                                              
    863 │   │                                                                                     
  864 │   │   self._fit(**_fit_args)                                                            
    865 │   │   training_end = time.time()                                                        
    866 │   │   self._total_train_time = training_end - training_start                            
    867                                                                                           
                                                                                                  
 /home/ci/autogluon/multimodal/src/autogluon/multimodal/predictor.py:1140 in _fit                 
                                                                                                  
   1137 │   │   │   self._output_shape = len(df_preprocessor.label_generator.unique_entity_group  
   1138 │   │                                                                                     
   1139 │   │   if self._model is None:                                                           
 1140 │   │   │   model = create_fusion_model(                                                  
   1141 │   │   │   │   config=config,                                                            
   1142 │   │   │   │   num_classes=self._output_shape,                                           
   1143 │   │   │   │   classes=self._classes,                                                    
                                                                                                  
 /home/ci/autogluon/multimodal/src/autogluon/multimodal/utils/model.py:442 in create_fusion_model 
                                                                                                  
   439                                                                                        
   440 for model_name in names:                                                               
   441 │   │   model_config = getattr(config.model, model_name)                                   
 442 │   │   model = create_model(                                                              
   443 │   │   │   model_name=model_name,                                                         
   444 │   │   │   model_config=model_config,                                                     
   445 │   │   │   num_classes=num_classes,                                                       
                                                                                                  
 /home/ci/autogluon/multimodal/src/autogluon/multimodal/utils/model.py:209 in create_model        
                                                                                                  
   206 │   │   │   pretrained=pretrained,                                                         
   207 │   │   )                                                                                  
   208 elif model_name.lower().startswith(HF_TEXT):                                           
 209 │   │   model = HFAutoModelForTextPrediction(                                              
   210 │   │   │   prefix=model_name,                                                             
   211 │   │   │   checkpoint_name=model_config.checkpoint_name,                                  
   212 │   │   │   num_classes=num_classes,                                                       
                                                                                                  
 /home/ci/autogluon/multimodal/src/autogluon/multimodal/models/huggingface_text.py:84 in __init__ 
                                                                                                  
    81 │   │   self.config, self.model = get_hf_config_and_model(                                 
    82 │   │   │   checkpoint_name=checkpoint_name, pretrained=pretrained, low_cpu_mem_usage=lo   
    83 │   │   )                                                                                  
  84 │   │   self._hf_model_input_names = AutoTokenizer.from_pretrained(checkpoint_name).mode   
    85 │   │                                                                                      
    86 │   │   if isinstance(self.model, T5PreTrainedModel):                                      
    87 │   │   │   self.is_t5 = True                                                              
                                                                                                  
 /home/ci/opt/venv/lib/python3.8/site-packages/transformers/models/auto/tokenization_auto.py:676  
 in from_pretrained                                                                               
                                                                                                  
   673 │   │   if model_type is not None:                                                         
   674 │   │   │   tokenizer_class_py, tokenizer_class_fast = TOKENIZER_MAPPING[type(config)]     
   675 │   │   │   if tokenizer_class_fast and (use_fast or tokenizer_class_py is None):          
 676 │   │   │   │   return tokenizer_class_fast.from_pretrained(pretrained_model_name_or_pat   
   677 │   │   │   else:                                                                          
   678 │   │   │   │   if tokenizer_class_py is not None:                                         
   679 │   │   │   │   │   return tokenizer_class_py.from_pretrained(pretrained_model_name_or_p   
                                                                                                  
 /home/ci/opt/venv/lib/python3.8/site-packages/transformers/tokenization_utils_base.py:1804 in    
 from_pretrained                                                                                  
                                                                                                  
   1801 │   │   │   else:                                                                         
   1802 │   │   │   │   logger.info(f"loading file {file_path} from cache at {resolved_vocab_fil  
   1803 │   │                                                                                     
 1804 │   │   return cls._from_pretrained(                                                      
   1805 │   │   │   resolved_vocab_files,                                                         
   1806 │   │   │   pretrained_model_name_or_path,                                                
   1807 │   │   │   init_configuration,                                                           
                                                                                                  
 /home/ci/opt/venv/lib/python3.8/site-packages/transformers/tokenization_utils_base.py:1959 in    
 _from_pretrained                                                                                 
                                                                                                  
   1956 │   │                                                                                     
   1957 │   │   # Instantiate tokenizer.                                                          
   1958 │   │   try:                                                                              
 1959 │   │   │   tokenizer = cls(*init_inputs, **init_kwargs)                                  
   1960 │   │   except OSError:                                                                   
   1961 │   │   │   raise OSError(                                                                
   1962 │   │   │   │   "Unable to load vocabulary from file. "                                   
                                                                                                  
 /home/ci/opt/venv/lib/python3.8/site-packages/transformers/models/deberta_v2/tokenization_debert 
 a_v2_fast.py:133 in __init__                                                                     
                                                                                                  
   130 │   │   mask_token="[MASK]",                                                               
   131 │   │   **kwargs                                                                           
   132 ) -> None:                                                                             
 133 │   │   super().__init__(                                                                  
   134 │   │   │   vocab_file,                                                                    
   135 │   │   │   tokenizer_file=tokenizer_file,                                                 
   136 │   │   │   do_lower_case=do_lower_case,                                                   
                                                                                                  
 /home/ci/opt/venv/lib/python3.8/site-packages/transformers/tokenization_utils_fast.py:114 in     
 __init__                                                                                         
                                                                                                  
   111 │   │   │   fast_tokenizer = TokenizerFast.from_file(fast_tokenizer_file)                  
   112 │   │   elif slow_tokenizer is not None:                                                   
   113 │   │   │   # We need to convert a slow tokenizer to build the backend                     
 114 │   │   │   fast_tokenizer = convert_slow_tokenizer(slow_tokenizer)                        
   115 │   │   elif self.slow_tokenizer_class is not None:                                        
   116 │   │   │   # We need to create and convert a slow tokenizer to build the backend          
   117 │   │   │   slow_tokenizer = self.slow_tokenizer_class(*args, **kwargs)                    
                                                                                                  
 /home/ci/opt/venv/lib/python3.8/site-packages/transformers/convert_slow_tokenizer.py:1162 in     
 convert_slow_tokenizer                                                                           
                                                                                                  
   1159                                                                                       
   1160 converter_class = SLOW_TO_FAST_CONVERTERS[tokenizer_class_name]                       
   1161                                                                                       
 1162 return converter_class(transformer_tokenizer).converted()                             
   1163                                                                                           
                                                                                                  
 /home/ci/opt/venv/lib/python3.8/site-packages/transformers/convert_slow_tokenizer.py:438 in      
 __init__                                                                                         
                                                                                                  
    435 │   │                                                                                     
    436 │   │   super().__init__(*args)                                                           
    437 │   │                                                                                     
  438 │   │   from .utils import sentencepiece_model_pb2 as model_pb2                           
    439 │   │                                                                                     
    440 │   │   m = model_pb2.ModelProto()                                                        
    441 │   │   with open(self.original_tokenizer.vocab_file, "rb") as f:                         
                                                                                                  
 /home/ci/opt/venv/lib/python3.8/site-packages/transformers/utils/sentencepiece_model_pb2.py:92   
 in <module>                                                                                      
                                                                                                  
     89 file=DESCRIPTOR,                                                                      
     90 create_key=_descriptor._internal_create_key,                                          
     91 values=[                                                                              
   92 │   │   _descriptor.EnumValueDescriptor(                                                  
     93 │   │   │   name="UNIGRAM",                                                               
     94 │   │   │   index=0,                                                                      
     95 │   │   │   number=1,                                                                     
                                                                                                  
 /home/ci/opt/venv/lib/python3.8/site-packages/google/protobuf/descriptor.py:796 in __new__       
                                                                                                  
    793 def __new__(cls, name, index, number,                                                 
    794 │   │   │   │   type=None,  # pylint: disable=redefined-builtin                           
    795 │   │   │   │   options=None, serialized_options=None, create_key=None):                  
  796 _message.Message._CheckCalledFromGeneratedFile()                                    
    797 # There is no way we can build a complete EnumValueDescriptor with the              
    798 # given parameters (the name of the Enum is not known, for example).                
    799 # Fortunately generated files just pass it to the EnumDescriptor()                  
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
TypeError: Descriptors cannot not be created directly.
If this call came from a _pb2.py file, your generated code is out of date and must be regenerated with protoc >= 
3.19.0.
If you cannot immediately regenerate your protos, some other possible workarounds are:
 1. Downgrade the protobuf package to 3.20.x or lower.
 2. Set PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python (but this will use pure-Python parsing and will be much 
slower).

More information: https://developers.google.com/protocol-buffers/docs/news/2022-05-06#python-updates

The fraction of the tunable parameters is around 0.5% of all parameters. Actually, the model trained purely on English data can achieve good performance on the test sets, even on the German / Japanese test set. It obtained comparable results as full-finetuning as in AutoMM for Text - Multilingual Problems.

score_in_en = predictor.evaluate(test_en_df)
score_in_de = predictor.evaluate(test_de_df)
score_in_jp = predictor.evaluate(test_jp_df)
print('Score in the English Testset:', score_in_en)
print('Score in the German Testset:', score_in_de)
print('Score in the Japanese Testset:', score_in_jp)

Training FLAN-T5-XL on Single GPU#

By combining gradient checkpointing and parameter-efficient finetuning, it is feasible to finetune google/flan-t5-xl that has close to two billion parameterswith a single T4 GPU available in AWS G4 instances. To turn on gradient checkpointing, you just need to set "model.hf_text.gradient_checkpointing" to True. To accelerate the training, we downsample the number of training samples to be 200.

# Just for clean the space
clear_cache()
shutil.rmtree(model_path)
from autogluon.multimodal import MultiModalPredictor

train_en_df_downsample = train_en_df.sample(200, random_state=123)

new_model_path = f"./tmp/{uuid.uuid4().hex}-multilingual_ia3_gradient_checkpoint"
predictor = MultiModalPredictor(label="label",
                                path=new_model_path)
predictor.fit(train_en_df_downsample,
              presets="multilingual",
              hyperparameters={
                  "model.hf_text.checkpoint_name": "google/flan-t5-xl",
                  "model.hf_text.gradient_checkpointing": True,
                  "model.hf_text.low_cpu_mem_usage": True,
                  "optimization.efficient_finetune": "ia3_bias",
                  "optimization.lr_decay": 0.9,
                  "optimization.learning_rate": 3e-03,
                  "optimization.end_lr": 3e-03,
                  "optimization.max_epochs": 1,
                  "optimization.warmup_steps": 0,
                  "env.batch_size": 1,
                  "env.eval_batch_size_ratio": 1
              })

Global seed set to 123
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type                         | Params
-------------------------------------------------------------------
0 | model             | HFAutoModelForTextPrediction | 1.2 B 
1 | validation_metric | AUROC                        | 0     
2 | loss_func         | CrossEntropyLoss             | 0     
-------------------------------------------------------------------
203 K     Trainable params
1.2 B     Non-trainable params
1.2 B     Total params
4,894.913 Total estimated model params size (MB)
Epoch 0, global step 20: 'val_roc_auc' reached 0.88802 (best 0.88802), saving model to '/home/ubuntu/autogluon/docs/tutorials/multimodal/advanced_topics/multilingual_ia3_gradient_checkpoint/epoch=0-step=20.ckpt' as top 1
Epoch 0, global step 40: 'val_roc_auc' reached 0.94531 (best 0.94531), saving model to '/home/ubuntu/autogluon/docs/tutorials/multimodal/advanced_topics/multilingual_ia3_gradient_checkpoint/epoch=0-step=40.ckpt' as top 1
`Trainer.fit` stopped: `max_epochs=1` reached.





<autogluon.multimodal.predictor.MultiModalPredictor at 0x7fd58c4dbca0>
score_in_en = predictor.evaluate(test_en_df)
print('Score in the English Testset:', score_in_en)
Score in the English Testset: {'roc_auc': 0.931263189629183}
# Just for clean the space
clear_cache()
shutil.rmtree(new_model_path)

Other Examples#

You may go to AutoMM Examples to explore other examples about AutoMM.

Customization#

To learn how to customize AutoMM, please refer to Customize AutoMM.