Single GPU Billion-scale Model Training via Parameter-Efficient Finetuning¶
As pointed out by a recent paper from Stanford Institute for Human-Centered Artificial Intelligence, AI is undergoing a paradigm shift with the rise of “foundation models”, i.e., giant models that are trained on a diverse collection of datasets generally in a self-supervised way. These foundation models, which are the key of AutoMM, can be easily adapted to down-stream applications. However, as the size of these foundation models grows, finetuning these models becomes increasingly difficult. Following is a figure from the Microsoft research blog that demonstrates the trend:

The goal of AutoMM is to help anyone solve machine learning problems via open source foundation models, including these giant models. To finetune these large-scale models, we adopt the recently popularized parameter-efficient finetuning technique. The idea is to either finetune a small subset of the weights in the foundation model (e.g., BitFit), or adding a tiny tunable structure on top of the fixed backbone (e.g., Prompt Tuning, LoRA, Adapter, MAM Adapter, IA^3). These techniques can effectively reduce the peak memory usage and model training time, while maintaining the performance.
In this tutorial, we introduce how to apply parameter-efficient finetuning in MultiModalPredictor.
We first introduce how to adopt the "ia3_bias" algorithm for parameter-efficient finetuning. Afterwards, we show how you can simply combine "ia3_bias"
and gradient checkpointing to finetune the XL-variant of Google’s FLAN-T5 via a single NVIDIA T4 GPU.
Prepare Dataset¶
The Cross-Lingual Amazon Product Review Sentiment dataset contains Amazon product reviews in four languages.
Here, we load the English and German fold of the dataset. In the label column, 0 means negative sentiment and 1 means positive sentiment.
For the purpose of demonstration, we downsampled the training data to 1000 samples. We will train the model on the English dataset and
directly evaluate its performance on the German and Japanese test set.
!wget --quiet https://automl-mm-bench.s3.amazonaws.com/multilingual-datasets/amazon_review_sentiment_cross_lingual.zip -O amazon_review_sentiment_cross_lingual.zip
!unzip -q -o amazon_review_sentiment_cross_lingual.zip -d .
import os
import shutil
os.environ["TRANSFORMERS_CACHE"] = "cache"
def clear_cache():
if os.path.exists("cache"):
shutil.rmtree("cache")
clear_cache()
import pandas as pd
import warnings
warnings.filterwarnings("ignore")
train_en_df = pd.read_csv("amazon_review_sentiment_cross_lingual/en_train.tsv",
sep="\t",
header=None,
names=["label", "text"]) \
.sample(1000, random_state=123).reset_index(drop=True)
test_en_df = pd.read_csv("amazon_review_sentiment_cross_lingual/en_test.tsv",
sep="\t",
header=None,
names=["label", "text"]) \
.sample(200, random_state=123).reset_index(drop=True)
test_de_df = pd.read_csv("amazon_review_sentiment_cross_lingual/de_test.tsv",
sep="\t", header=None, names=["label", "text"]) \
.sample(200, random_state=123).reset_index(drop=True)
test_jp_df = pd.read_csv('amazon_review_sentiment_cross_lingual/jp_test.tsv',
sep='\t', header=None, names=['label', 'text']) \
.sample(200, random_state=123).reset_index(drop=True)
train_en_df.head(5)
| label | text | |
|---|---|---|
| 0 | 0 | This is a film that literally sees little wron... |
| 1 | 0 | This music is pretty intelligent, but not very... |
| 2 | 0 | One of the best pieces of rock ever recorded, ... |
| 3 | 0 | Reading the posted reviews here, is like revis... |
| 4 | 1 | I've just finished page 341, the last page. It... |
test_jp_df.head(5)
| label | text | |
|---|---|---|
| 0 | 1 | 原作はビクトル・ユーゴの長編小説だが、私が子供の頃読んだのは短縮版の「ああ無情」。それでもこ... |
| 1 | 1 | ほかの作品のレビューにみんな書いているのに、何故この作品について書いている人が一人しかいない... |
| 2 | 0 | 一番の問題点は青島が出ていない事でしょう。 TV番組では『芸人が出ていればバラエティだから... |
| 3 | 0 | 昔、 りんたろう監督によるアニメ「カムイの剣」があった。 「カムイの剣」…を観た人なら本作... |
| 4 | 1 | 以前のアルバムを聴いていないのでなんとも言えないが、クラシックなメタルを聞いてきた耳には、と... |
Finetuning Multilingual Model with IA3 + BitFit¶
In AutoMM, to enable efficient finetuning, just specify the optim.peft to be "ia3_bias".
from autogluon.multimodal import MultiModalPredictor
import uuid
model_path = f"./tmp/{uuid.uuid4().hex}-multilingual_ia3"
predictor = MultiModalPredictor(label="label",
path=model_path)
predictor.fit(train_en_df,
presets="multilingual",
hyperparameters={
"optim.peft": "ia3_bias",
"optim.lr_decay": 0.9,
"optim.lr": 3e-03,
"optim.end_lr": 3e-03,
"optim.max_epochs": 2,
"optim.warmup_steps": 0,
"env.batch_size": 32,
})
/home/ci/opt/venv/lib/python3.11/site-packages/mmengine/optim/optimizer/zero_optimizer.py:11: DeprecationWarning: `TorchScript` support for functional optimizers is deprecated and will be removed in a future PyTorch release. Consider using the `torch.compile` optimizer instead.
from torch.distributed.optim import \
=================== System Info ===================
AutoGluon Version: 1.2.1b20250301
Python Version: 3.11.9
Operating System: Linux
Platform Machine: x86_64
Platform Version: #1 SMP Sat Jan 25 09:56:35 UTC 2025
CPU Count: 8
Pytorch Version: 2.5.1+cu124
CUDA Version: 12.4
Memory Avail: 28.41 GB / 30.95 GB (91.8%)
Disk Space Avail: 183.64 GB / 255.99 GB (71.7%)
===================================================
AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed).
2 unique label values: [0, 1]
If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during Predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression', 'quantile'])
AutoMM starts to create your model. ✨✨✨
To track the learning progress, you can open a terminal and launch Tensorboard:
```shell
# Assume you have installed tensorboard
tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/8afa45d11eb7469f98828733d7cf8f58-multilingual_ia3
```
Seed set to 0
GPU Count: 1
GPU Count to be Used: 1
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params | Mode
---------------------------------------------------------------------------
0 | model | HFAutoModelForTextPrediction | 278 M | train
1 | validation_metric | BinaryAUROC | 0 | train
2 | loss_func | CrossEntropyLoss | 0 | train
---------------------------------------------------------------------------
122 K Trainable params
278 M Non-trainable params
278 M Total params
1,112.955 Total estimated model params size (MB)
28 Modules in train mode
213 Modules in eval mode
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[6], line 7
4 model_path = f"./tmp/{uuid.uuid4().hex}-multilingual_ia3"
5 predictor = MultiModalPredictor(label="label",
6 path=model_path)
----> 7 predictor.fit(train_en_df,
8 presets="multilingual",
9 hyperparameters={
10 "optim.peft": "ia3_bias",
11 "optim.lr_decay": 0.9,
12 "optim.lr": 3e-03,
13 "optim.end_lr": 3e-03,
14 "optim.max_epochs": 2,
15 "optim.warmup_steps": 0,
16 "env.batch_size": 32,
17 })
File ~/autogluon/multimodal/src/autogluon/multimodal/predictor.py:540, in MultiModalPredictor.fit(self, train_data, presets, tuning_data, max_num_tuning_data, id_mappings, time_limit, save_path, hyperparameters, column_types, holdout_frac, teacher_predictor, seed, standalone, hyperparameter_tune_kwargs, clean_ckpts, predictions, labels, predictors)
537 assert isinstance(predictors, list)
538 learners = [ele if isinstance(ele, str) else ele._learner for ele in predictors]
--> 540 self._learner.fit(
541 train_data=train_data,
542 presets=presets,
543 tuning_data=tuning_data,
544 max_num_tuning_data=max_num_tuning_data,
545 time_limit=time_limit,
546 save_path=save_path,
547 hyperparameters=hyperparameters,
548 column_types=column_types,
549 holdout_frac=holdout_frac,
550 teacher_learner=teacher_learner,
551 seed=seed,
552 standalone=standalone,
553 hyperparameter_tune_kwargs=hyperparameter_tune_kwargs,
554 clean_ckpts=clean_ckpts,
555 id_mappings=id_mappings,
556 predictions=predictions,
557 labels=labels,
558 learners=learners,
559 )
561 return self
File ~/autogluon/multimodal/src/autogluon/multimodal/learners/base.py:665, in BaseLearner.fit(self, train_data, presets, tuning_data, time_limit, save_path, hyperparameters, column_types, holdout_frac, teacher_learner, seed, standalone, hyperparameter_tune_kwargs, clean_ckpts, **kwargs)
658 self.fit_sanity_check()
659 self.prepare_fit_args(
660 time_limit=time_limit,
661 seed=seed,
662 standalone=standalone,
663 clean_ckpts=clean_ckpts,
664 )
--> 665 fit_returns = self.execute_fit()
666 self.on_fit_end(
667 training_start=training_start,
668 strategy=fit_returns.get("strategy", None),
(...)
671 clean_ckpts=clean_ckpts,
672 )
674 return self
File ~/autogluon/multimodal/src/autogluon/multimodal/learners/base.py:577, in BaseLearner.execute_fit(self)
575 return dict()
576 else:
--> 577 attributes = self.fit_per_run(**self._fit_args)
578 self.update_attributes(**attributes) # only update attributes for non-HPO mode
579 return attributes
File ~/autogluon/multimodal/src/autogluon/multimodal/learners/base.py:1358, in BaseLearner.fit_per_run(self, max_time, save_path, ckpt_path, resume, enable_progress_bar, seed, hyperparameters, advanced_hyperparameters, config, df_preprocessor, data_processors, model, standalone, clean_ckpts)
1339 config = self.post_update_config_per_run(
1340 config=config,
1341 num_gpus=num_gpus,
1342 precision=precision,
1343 strategy=strategy,
1344 )
1345 trainer = self.init_trainer_per_run(
1346 num_gpus=num_gpus,
1347 config=config,
(...)
1355 enable_progress_bar=enable_progress_bar,
1356 )
-> 1358 self.run_trainer(
1359 trainer=trainer,
1360 litmodule=litmodule,
1361 datamodule=datamodule,
1362 ckpt_path=ckpt_path,
1363 resume=resume,
1364 )
1365 self.on_fit_per_run_end(
1366 save_path=save_path,
1367 standalone=standalone,
(...)
1372 model=model,
1373 )
1375 best_score = (
1376 trainer.callback_metrics[f"val_{self._validation_metric_name}"].item()
1377 if f"val_{self._validation_metric_name}" in trainer.callback_metrics
1378 else self._best_score
1379 ) # https://github.com/autogluon/autogluon/issues/4428
File ~/autogluon/multimodal/src/autogluon/multimodal/learners/base.py:1211, in BaseLearner.run_trainer(self, trainer, litmodule, datamodule, ckpt_path, resume, pred_writer, is_train)
1209 warnings.filterwarnings("ignore", filter)
1210 if is_train:
-> 1211 trainer.fit(
1212 litmodule,
1213 datamodule=datamodule,
1214 ckpt_path=ckpt_path if resume else None, # this is to resume training that was broken accidentally
1215 )
1216 else:
1217 blacklist_msgs = []
File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:539, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
537 self.state.status = TrainerStatus.RUNNING
538 self.training = True
--> 539 call._call_and_handle_interrupt(
540 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
541 )
File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:47, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
45 if trainer.strategy.launcher is not None:
46 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 47 return trainer_fn(*args, **kwargs)
49 except _TunerExitException:
50 _call_teardown_hook(trainer)
File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:575, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
568 assert self.state.fn is not None
569 ckpt_path = self._checkpoint_connector._select_ckpt_path(
570 self.state.fn,
571 ckpt_path,
572 model_provided=True,
573 model_connected=self.lightning_module is not None,
574 )
--> 575 self._run(model, ckpt_path=ckpt_path)
577 assert self.state.stopped
578 self.training = False
File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:982, in Trainer._run(self, model, ckpt_path)
977 self._signal_connector.register_signal_handlers()
979 # ----------------------------
980 # RUN THE TRAINER
981 # ----------------------------
--> 982 results = self._run_stage()
984 # ----------------------------
985 # POST-Training CLEAN UP
986 # ----------------------------
987 log.debug(f"{self.__class__.__name__}: trainer tearing down")
File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:1024, in Trainer._run_stage(self)
1022 if self.training:
1023 with isolate_rng():
-> 1024 self._run_sanity_check()
1025 with torch.autograd.set_detect_anomaly(self._detect_anomaly):
1026 self.fit_loop.run()
File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:1053, in Trainer._run_sanity_check(self)
1050 call._call_callback_hooks(self, "on_sanity_check_start")
1052 # run eval step
-> 1053 val_loop.run()
1055 call._call_callback_hooks(self, "on_sanity_check_end")
1057 # reset logger connector
File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/loops/utilities.py:179, in _no_grad_context.<locals>._decorator(self, *args, **kwargs)
177 context_manager = torch.no_grad
178 with context_manager():
--> 179 return loop_run(self, *args, **kwargs)
File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/loops/evaluation_loop.py:144, in _EvaluationLoop.run(self)
142 self.batch_progress.is_last_batch = data_fetcher.done
143 # run step hooks
--> 144 self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
145 except StopIteration:
146 # this needs to wrap the `*_step` call too (not just `next`) for `dataloader_iter` support
147 break
File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/loops/evaluation_loop.py:433, in _EvaluationLoop._evaluation_step(self, batch, batch_idx, dataloader_idx, dataloader_iter)
427 hook_name = "test_step" if trainer.testing else "validation_step"
428 step_args = (
429 self._build_step_args_from_hook_kwargs(hook_kwargs, hook_name)
430 if not using_dataloader_iter
431 else (dataloader_iter,)
432 )
--> 433 output = call._call_strategy_hook(trainer, hook_name, *step_args)
435 self.batch_progress.increment_processed()
437 if using_dataloader_iter:
438 # update the hook kwargs now that the step method might have consumed the iterator
File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:323, in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
320 return None
322 with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"):
--> 323 output = fn(*args, **kwargs)
325 # restore current_fx when nested context
326 pl_module._current_fx_name = prev_fx_name
File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py:412, in Strategy.validation_step(self, *args, **kwargs)
410 if self.model != self.lightning_module:
411 return self._forward_redirection(self.model, self.lightning_module, "validation_step", *args, **kwargs)
--> 412 return self.lightning_module.validation_step(*args, **kwargs)
File ~/autogluon/multimodal/src/autogluon/multimodal/optim/lit_module.py:381, in LitModule.validation_step(self, batch, batch_idx)
365 def validation_step(self, batch, batch_idx):
366 """
367 Per validation step. This function is registered by LightningModule.
368 Refer to https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#validation-loop
(...)
379 Index of mini-batch.
380 """
--> 381 output, loss = self._shared_step(batch)
382 if self.model_postprocess_fn:
383 output = self.model_postprocess_fn(output)
File ~/autogluon/multimodal/src/autogluon/multimodal/optim/lit_module.py:305, in LitModule._shared_step(self, batch)
303 self.mixup_fn.mixup_enabled = self.training & (self.current_epoch < self.hparams.mixup_off_epoch)
304 batch, label = multimodel_mixup(batch=batch, model=self.model, mixup_fn=self.mixup_fn)
--> 305 output = run_model(self.model, batch)
306 loss = self._compute_loss(output=output, label=label)
307 return output, loss
File ~/autogluon/multimodal/src/autogluon/multimodal/models/utils.py:863, in run_model(model, batch, trt_model)
861 output_vec = pure_model(*tuple(input_vec))
862 else:
--> 863 output_vec = model(*tuple(input_vec))
865 output = pure_model.get_output_dict(*output_vec)
866 else:
File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/autogluon/multimodal/src/autogluon/multimodal/models/hf_text.py:230, in HFAutoModelForTextPrediction.forward(self, text_token_ids, text_segment_ids, text_valid_length, text_column_names, text_column_indices)
228 else:
229 if "token_type_ids" in self.tokenizer.model_input_names:
--> 230 outputs = self.model(
231 input_ids=text_token_ids,
232 token_type_ids=text_segment_ids,
233 attention_mask=text_masks,
234 )
235 else:
236 outputs = self.model(
237 input_ids=text_token_ids,
238 attention_mask=text_masks,
239 )
File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/opt/venv/lib/python3.11/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:870, in DebertaV2Model.forward(self, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds, output_attentions, output_hidden_states, return_dict)
860 token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
862 embedding_output = self.embeddings(
863 input_ids=input_ids,
864 token_type_ids=token_type_ids,
(...)
867 inputs_embeds=inputs_embeds,
868 )
--> 870 encoder_outputs = self.encoder(
871 embedding_output,
872 attention_mask,
873 output_hidden_states=True,
874 output_attentions=output_attentions,
875 return_dict=return_dict,
876 )
877 encoded_layers = encoder_outputs[1]
879 if self.z_steps > 1:
File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/opt/venv/lib/python3.11/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:674, in DebertaV2Encoder.forward(self, hidden_states, attention_mask, output_hidden_states, output_attentions, query_states, relative_pos, return_dict)
664 output_states, attn_weights = self._gradient_checkpointing_func(
665 layer_module.__call__,
666 next_kv,
(...)
671 output_attentions,
672 )
673 else:
--> 674 output_states, attn_weights = layer_module(
675 next_kv,
676 attention_mask,
677 query_states=query_states,
678 relative_pos=relative_pos,
679 rel_embeddings=rel_embeddings,
680 output_attentions=output_attentions,
681 )
683 if output_attentions:
684 all_attentions = all_attentions + (attn_weights,)
File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/opt/venv/lib/python3.11/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:442, in DebertaV2Layer.forward(self, hidden_states, attention_mask, query_states, relative_pos, rel_embeddings, output_attentions)
433 def forward(
434 self,
435 hidden_states,
(...)
440 output_attentions: bool = False,
441 ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
--> 442 attention_output, att_matrix = self.attention(
443 hidden_states,
444 attention_mask,
445 output_attentions=output_attentions,
446 query_states=query_states,
447 relative_pos=relative_pos,
448 rel_embeddings=rel_embeddings,
449 )
450 intermediate_output = self.intermediate(attention_output)
451 layer_output = self.output(intermediate_output, attention_output)
File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/opt/venv/lib/python3.11/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:375, in DebertaV2Attention.forward(self, hidden_states, attention_mask, output_attentions, query_states, relative_pos, rel_embeddings)
366 def forward(
367 self,
368 hidden_states,
(...)
373 rel_embeddings=None,
374 ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
--> 375 self_output, att_matrix = self.self(
376 hidden_states,
377 attention_mask,
378 output_attentions,
379 query_states=query_states,
380 relative_pos=relative_pos,
381 rel_embeddings=rel_embeddings,
382 )
383 if query_states is None:
384 query_states = hidden_states
File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/opt/venv/lib/python3.11/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:267, in DisentangledSelfAttention.forward(self, hidden_states, attention_mask, output_attentions, query_states, relative_pos, rel_embeddings)
262 attention_scores = attention_scores.view(
263 -1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1)
264 )
266 attention_mask = attention_mask.bool()
--> 267 attention_scores = attention_scores.masked_fill(~(attention_mask), torch.finfo(query_layer.dtype).min)
268 # bsz x height x length x dimension
269 attention_probs = nn.functional.softmax(attention_scores, dim=-1)
RuntimeError: value cannot be converted to type at::BFloat16 without overflow
The fraction of the tunable parameters is around 0.5% of all parameters. Actually, the model trained purely on English data can achieve good performance on the test sets, even on the German / Japanese test set. It obtained comparable results as full-finetuning as in AutoMM for Text - Multilingual Problems.
score_in_en = predictor.evaluate(test_en_df)
score_in_de = predictor.evaluate(test_de_df)
score_in_jp = predictor.evaluate(test_jp_df)
print('Score in the English Testset:', score_in_en)
print('Score in the German Testset:', score_in_de)
print('Score in the Japanese Testset:', score_in_jp)
Training FLAN-T5-XL on Single GPU¶
By combining gradient checkpointing and parameter-efficient finetuning, it is feasible to finetune
google/flan-t5-xl that has close to two billion parameterswith a single T4 GPU available in
AWS G4 instances.
To turn on gradient checkpointing, you just need to set "model.hf_text.gradient_checkpointing" to True.
To accelerate the training, we downsample the number of training samples to be 200.
# Just for clean the space
clear_cache()
shutil.rmtree(model_path)
from autogluon.multimodal import MultiModalPredictor
train_en_df_downsample = train_en_df.sample(200, random_state=123)
new_model_path = f"./tmp/{uuid.uuid4().hex}-multilingual_ia3_gradient_checkpoint"
predictor = MultiModalPredictor(label="label",
path=new_model_path)
predictor.fit(train_en_df_downsample,
presets="multilingual",
hyperparameters={
"model.hf_text.checkpoint_name": "google/flan-t5-xl",
"model.hf_text.gradient_checkpointing": True,
"model.hf_text.low_cpu_mem_usage": True,
"optim.peft": "ia3_bias",
"optim.lr_decay": 0.9,
"optim.lr": 3e-03,
"optim.end_lr": 3e-03,
"optim.max_epochs": 1,
"optim.warmup_steps": 0,
"env.batch_size": 1,
"env.inference_batch_size_ratio": 1
})
Global seed set to 123
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params
-------------------------------------------------------------------
0 | model | HFAutoModelForTextPrediction | 1.2 B
1 | validation_metric | AUROC | 0
2 | loss_func | CrossEntropyLoss | 0
-------------------------------------------------------------------
203 K Trainable params
1.2 B Non-trainable params
1.2 B Total params
4,894.913 Total estimated model params size (MB)
Epoch 0, global step 20: 'val_roc_auc' reached 0.88802 (best 0.88802), saving model to '/home/ubuntu/autogluon/docs/tutorials/multimodal/advanced_topics/multilingual_ia3_gradient_checkpoint/epoch=0-step=20.ckpt' as top 1
Epoch 0, global step 40: 'val_roc_auc' reached 0.94531 (best 0.94531), saving model to '/home/ubuntu/autogluon/docs/tutorials/multimodal/advanced_topics/multilingual_ia3_gradient_checkpoint/epoch=0-step=40.ckpt' as top 1
`Trainer.fit` stopped: `max_epochs=1` reached.
<autogluon.multimodal.predictor.MultiModalPredictor at 0x7fd58c4dbca0>
score_in_en = predictor.evaluate(test_en_df)
print('Score in the English Testset:', score_in_en)
Score in the English Testset: {'roc_auc': 0.931263189629183}
# Just for clean the space
clear_cache()
shutil.rmtree(new_model_path)
Other Examples¶
You may go to AutoMM Examples to explore other examples about AutoMM.
Customization¶
To learn how to customize AutoMM, please refer to Customize AutoMM.