Text Prediction - Multimodal Table with Text¶
In many applications, text data may be mixed with numeric/categorical
data. AutoGluon’s TextPredictor can train a single neural network
that jointly operates on multiple feature types, including text,
categorical, and numerical columns. The general idea is to embed the
text, categorical and numeric fields separately and fuse these features
across modalities. This tutorial demonstrates such an application.
import numpy as np
import pandas as pd
import os
import warnings
warnings.filterwarnings('ignore')
np.random.seed(123)
!python3 -m pip install openpyxl
Collecting openpyxl
Using cached openpyxl-3.0.9-py2.py3-none-any.whl (242 kB)
Collecting et-xmlfile
Using cached et_xmlfile-1.1.0-py3-none-any.whl (4.7 kB)
Installing collected packages: et-xmlfile, openpyxl
Successfully installed et-xmlfile-1.1.0 openpyxl-3.0.9
Book Price Prediction Data¶
For demonstration, we use the book price prediction dataset from the MachineHack Salary Prediction Hackathon. Our goal is to predict a book’s price given various features like its author, the abstract, the book’s rating, etc.
!mkdir -p price_of_books
!wget https://automl-mm-bench.s3.amazonaws.com/machine_hack_competitions/predict_the_price_of_books/Data.zip -O price_of_books/Data.zip
!cd price_of_books && unzip -o Data.zip
!ls price_of_books/Participants_Data
--2022-01-21 06:09:02-- https://automl-mm-bench.s3.amazonaws.com/machine_hack_competitions/predict_the_price_of_books/Data.zip
Resolving automl-mm-bench.s3.amazonaws.com (automl-mm-bench.s3.amazonaws.com)... 52.216.129.115
Connecting to automl-mm-bench.s3.amazonaws.com (automl-mm-bench.s3.amazonaws.com)|52.216.129.115|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3521673 (3.4M) [application/zip]
Saving to: ‘price_of_books/Data.zip’
price_of_books/Data 100%[===================>] 3.36M 5.67MB/s in 0.6s
2022-01-21 06:09:03 (5.67 MB/s) - ‘price_of_books/Data.zip’ saved [3521673/3521673]
Archive: Data.zip
inflating: Participants_Data/Data_Test.xlsx
inflating: Participants_Data/Data_Train.xlsx
inflating: Participants_Data/Sample_Submission.xlsx
Data_Test.xlsx Data_Train.xlsx Sample_Submission.xlsx
train_df = pd.read_excel(os.path.join('price_of_books', 'Participants_Data', 'Data_Train.xlsx'), engine='openpyxl')
train_df.head()
| Title | Author | Edition | Reviews | Ratings | Synopsis | Genre | BookCategory | Price | |
|---|---|---|---|---|---|---|---|---|---|
| 0 | The Prisoner's Gold (The Hunters 3) | Chris Kuzneski | Paperback,– 10 Mar 2016 | 4.0 out of 5 stars | 8 customer reviews | THE HUNTERS return in their third brilliant no... | Action & Adventure (Books) | Action & Adventure | 220.00 |
| 1 | Guru Dutt: A Tragedy in Three Acts | Arun Khopkar | Paperback,– 7 Nov 2012 | 3.9 out of 5 stars | 14 customer reviews | A layered portrait of a troubled genius for wh... | Cinema & Broadcast (Books) | Biographies, Diaries & True Accounts | 202.93 |
| 2 | Leviathan (Penguin Classics) | Thomas Hobbes | Paperback,– 25 Feb 1982 | 4.8 out of 5 stars | 6 customer reviews | "During the time men live without a common Pow... | International Relations | Humour | 299.00 |
| 3 | A Pocket Full of Rye (Miss Marple) | Agatha Christie | Paperback,– 5 Oct 2017 | 4.1 out of 5 stars | 13 customer reviews | A handful of grain is found in the pocket of a... | Contemporary Fiction (Books) | Crime, Thriller & Mystery | 180.00 |
| 4 | LIFE 70 Years of Extraordinary Photography | Editors of Life | Hardcover,– 10 Oct 2006 | 5.0 out of 5 stars | 1 customer review | For seven decades, "Life" has been thrilling t... | Photography Textbooks | Arts, Film & Photography | 965.62 |
We do some basic preprocessing to convert Reviews and Ratings in
the data table to numeric values, and we transform prices to a
log-scale.
def preprocess(df):
df = df.copy(deep=True)
df.loc[:, 'Reviews'] = pd.to_numeric(df['Reviews'].apply(lambda ele: ele[:-len(' out of 5 stars')]))
df.loc[:, 'Ratings'] = pd.to_numeric(df['Ratings'].apply(lambda ele: ele.replace(',', '')[:-len(' customer reviews')]))
df.loc[:, 'Price'] = np.log(df['Price'] + 1)
return df
train_subsample_size = 1500 # subsample for faster demo, you can try setting to larger values
test_subsample_size = 5
train_df = preprocess(train_df)
train_data = train_df.iloc[100:].sample(train_subsample_size, random_state=123)
test_data = train_df.iloc[:100].sample(test_subsample_size, random_state=245)
train_data.head()
| Title | Author | Edition | Reviews | Ratings | Synopsis | Genre | BookCategory | Price | |
|---|---|---|---|---|---|---|---|---|---|
| 949 | Furious Hours | Casey Cep | Paperback,– 1 Jun 2019 | 4.0 | NaN | ‘It’s been a long time since I picked up a boo... | True Accounts (Books) | Biographies, Diaries & True Accounts | 5.743003 |
| 5504 | REST API Design Rulebook | Mark Masse | Paperback,– 7 Nov 2011 | 5.0 | NaN | In todays market, where rival web services com... | Computing, Internet & Digital Media (Books) | Computing, Internet & Digital Media | 5.786897 |
| 5856 | The Atlantropa Articles: A Novel | Cody Franklin | Paperback,– Import, 1 Nov 2018 | 4.5 | 2.0 | #1 Amazon Best Seller! Dystopian Alternate His... | Action & Adventure (Books) | Romance | 6.893656 |
| 4137 | Hickory Dickory Dock (Poirot) | Agatha Christie | Paperback,– 5 Oct 2017 | 4.3 | 21.0 | There’s more than petty theft going on in a Lo... | Action & Adventure (Books) | Crime, Thriller & Mystery | 5.192957 |
| 3205 | The Stanley Kubrick Archives (Bibliotheca Univ... | Alison Castle | Hardcover,– 21 Aug 2016 | 4.6 | 3.0 | In 1968, when Stanley Kubrick was asked to com... | Cinema & Broadcast (Books) | Humour | 6.889591 |
Training¶
We can simply create a TextPredictor and call predictor.fit() to
train a model that operates on across all types of features. Internally,
the neural network will be automatically generated based on the inferred
data type of each feature column. To save time, we subsample the data
and only train for three minutes.
from autogluon.text import TextPredictor
time_limit = 3 * 60 # set to larger value in your applications
predictor = TextPredictor(label='Price', path='ag_text_book_price_prediction')
predictor.fit(train_data, time_limit=time_limit)
Problem Type="regression"
Column Types:
- "Title": text
- "Author": text
- "Edition": text
- "Reviews": numerical
- "Ratings": numerical
- "Synopsis": text
- "Genre": text
- "BookCategory": categorical
- "Price": numerical
The GluonNLP V0 backend is used. We will use 8 cpus and 1 gpus to train each trial.
All Logs will be saved to /var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/ag_text_book_price_prediction/task0/training.log
2022-01-21 06:09:09,035 - autogluon.text.text_prediction.mx.models - INFO - Fitting and transforming the train data...
Fitting and transforming the train data...
2022-01-21 06:09:10,722 - autogluon.text.text_prediction.mx.models - INFO - Done! Preprocessor saved to /var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/ag_text_book_price_prediction/task0/preprocessor.pkl
Done! Preprocessor saved to /var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/ag_text_book_price_prediction/task0/preprocessor.pkl
2022-01-21 06:09:10,733 - autogluon.text.text_prediction.mx.models - INFO - Process dev set...
Process dev set...
2022-01-21 06:09:10,932 - autogluon.text.text_prediction.mx.models - INFO - Done!
Done!
2022-01-21 06:09:10,942 - autogluon.text.text_prediction.mx.models - INFO - Max length for chunking text: 480, Stochastic chunk: Train-False/Test-False, Test #repeat: 1.
Max length for chunking text: 480, Stochastic chunk: Train-False/Test-False, Test #repeat: 1.
2022-01-21 06:09:14,389 - autogluon.text.text_prediction.mx.models - INFO - #Total Params/Fixed Params=109338913/0
#Total Params/Fixed Params=109338913/0
2022-01-21 06:09:14,407 - autogluon.text.text_prediction.mx.models - Level 15 - Using gradient accumulation. Global batch size = 128
Using gradient accumulation. Global batch size = 128
2022-01-21 06:09:14,506 - autogluon.text.text_prediction.mx.models - INFO - Local training results will be saved to /var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/ag_text_book_price_prediction/task0/results_local.jsonl.
Local training results will be saved to /var/lib/jenkins/workspace/workspace/autogluon-tutorial-text-v3/docs/_build/eval/tutorials/text_prediction/ag_text_book_price_prediction/task0/results_local.jsonl.
2022-01-21 06:09:25,355 - autogluon.text.text_prediction.mx.models - Level 15 - [Iter 1/100, Epoch 0] train loss=1.43e+00, gnorm=1.97e+01, lr=1.00e-05, #samples processed=128, #sample per second=11.80. ETA=17.90min
[Iter 1/100, Epoch 0] train loss=1.43e+00, gnorm=1.97e+01, lr=1.00e-05, #samples processed=128, #sample per second=11.80. ETA=17.90min
2022-01-21 06:09:34,026 - autogluon.text.text_prediction.mx.models - Level 15 - [Iter 2/100, Epoch 0] train loss=1.35e+00, gnorm=9.64e+00, lr=2.00e-05, #samples processed=128, #sample per second=14.76. ETA=15.94min
[Iter 2/100, Epoch 0] train loss=1.35e+00, gnorm=9.64e+00, lr=2.00e-05, #samples processed=128, #sample per second=14.76. ETA=15.94min
2022-01-21 06:09:44,021 - autogluon.text.text_prediction.mx.models - Level 25 - [Iter 2/100, Epoch 0] Validation r2=-9.0861e-01, root_mean_squared_error=1.1113e+00, mean_absolute_error=8.9773e-01, Time computing validation-score=8.862s, Total time spent=0.49min. Found improved model=True, Improved top-3 models=True
[Iter 2/100, Epoch 0] Validation r2=-9.0861e-01, root_mean_squared_error=1.1113e+00, mean_absolute_error=8.9773e-01, Time computing validation-score=8.862s, Total time spent=0.49min. Found improved model=True, Improved top-3 models=True
2022-01-21 06:09:54,066 - autogluon.text.text_prediction.mx.models - Level 15 - [Iter 3/100, Epoch 0] train loss=1.70e+00, gnorm=4.78e+01, lr=3.00e-05, #samples processed=128, #sample per second=6.39. ETA=21.32min
[Iter 3/100, Epoch 0] train loss=1.70e+00, gnorm=4.78e+01, lr=3.00e-05, #samples processed=128, #sample per second=6.39. ETA=21.32min
2022-01-21 06:10:04,399 - autogluon.text.text_prediction.mx.models - Level 15 - [Iter 4/100, Epoch 0] train loss=2.11e+00, gnorm=4.63e+01, lr=4.00e-05, #samples processed=128, #sample per second=12.39. ETA=19.96min
[Iter 4/100, Epoch 0] train loss=2.11e+00, gnorm=4.63e+01, lr=4.00e-05, #samples processed=128, #sample per second=12.39. ETA=19.96min
2022-01-21 06:10:14,722 - autogluon.text.text_prediction.mx.models - Level 25 - [Iter 4/100, Epoch 0] Validation r2=9.2734e-02, root_mean_squared_error=7.6619e-01, mean_absolute_error=6.1794e-01, Time computing validation-score=9.033s, Total time spent=1.00min. Found improved model=True, Improved top-3 models=True
[Iter 4/100, Epoch 0] Validation r2=9.2734e-02, root_mean_squared_error=7.6619e-01, mean_absolute_error=6.1794e-01, Time computing validation-score=9.033s, Total time spent=1.00min. Found improved model=True, Improved top-3 models=True
2022-01-21 06:10:24,667 - autogluon.text.text_prediction.mx.models - Level 15 - [Iter 5/100, Epoch 0] train loss=1.22e+00, gnorm=2.71e+01, lr=5.00e-05, #samples processed=128, #sample per second=6.32. ETA=22.22min
[Iter 5/100, Epoch 0] train loss=1.22e+00, gnorm=2.71e+01, lr=5.00e-05, #samples processed=128, #sample per second=6.32. ETA=22.22min
2022-01-21 06:10:34,422 - autogluon.text.text_prediction.mx.models - Level 15 - [Iter 6/100, Epoch 0] train loss=1.07e+00, gnorm=2.43e+01, lr=6.00e-05, #samples processed=128, #sample per second=13.12. ETA=20.87min
[Iter 6/100, Epoch 0] train loss=1.07e+00, gnorm=2.43e+01, lr=6.00e-05, #samples processed=128, #sample per second=13.12. ETA=20.87min
2022-01-21 06:10:44,102 - autogluon.text.text_prediction.mx.models - Level 25 - [Iter 6/100, Epoch 0] Validation r2=-3.7551e-01, root_mean_squared_error=9.4342e-01, mean_absolute_error=7.4564e-01, Time computing validation-score=9.203s, Total time spent=1.49min. Found improved model=False, Improved top-3 models=True
[Iter 6/100, Epoch 0] Validation r2=-3.7551e-01, root_mean_squared_error=9.4342e-01, mean_absolute_error=7.4564e-01, Time computing validation-score=9.203s, Total time spent=1.49min. Found improved model=False, Improved top-3 models=True
2022-01-21 06:10:53,420 - autogluon.text.text_prediction.mx.models - Level 15 - [Iter 7/100, Epoch 0] train loss=1.45e+00, gnorm=3.12e+01, lr=7.00e-05, #samples processed=128, #sample per second=6.74. ETA=21.90min
[Iter 7/100, Epoch 0] train loss=1.45e+00, gnorm=3.12e+01, lr=7.00e-05, #samples processed=128, #sample per second=6.74. ETA=21.90min
2022-01-21 06:11:03,452 - autogluon.text.text_prediction.mx.models - Level 15 - [Iter 8/100, Epoch 0] train loss=1.14e+00, gnorm=1.45e+01, lr=8.00e-05, #samples processed=128, #sample per second=12.76. ETA=20.88min
[Iter 8/100, Epoch 0] train loss=1.14e+00, gnorm=1.45e+01, lr=8.00e-05, #samples processed=128, #sample per second=12.76. ETA=20.88min
2022-01-21 06:11:13,169 - autogluon.text.text_prediction.mx.models - Level 25 - [Iter 8/100, Epoch 0] Validation r2=-3.9416e-02, root_mean_squared_error=8.2010e-01, mean_absolute_error=6.6583e-01, Time computing validation-score=8.916s, Total time spent=1.98min. Found improved model=False, Improved top-3 models=True
[Iter 8/100, Epoch 0] Validation r2=-3.9416e-02, root_mean_squared_error=8.2010e-01, mean_absolute_error=6.6583e-01, Time computing validation-score=8.916s, Total time spent=1.98min. Found improved model=False, Improved top-3 models=True
2022-01-21 06:11:22,514 - autogluon.text.text_prediction.mx.models - Level 15 - [Iter 9/100, Epoch 0] train loss=1.17e+00, gnorm=2.16e+01, lr=9.00e-05, #samples processed=128, #sample per second=6.72. ETA=21.57min
[Iter 9/100, Epoch 0] train loss=1.17e+00, gnorm=2.16e+01, lr=9.00e-05, #samples processed=128, #sample per second=6.72. ETA=21.57min
2022-01-21 06:11:32,219 - autogluon.text.text_prediction.mx.models - Level 15 - [Iter 10/100, Epoch 0] train loss=1.12e+00, gnorm=1.60e+01, lr=1.00e-04, #samples processed=128, #sample per second=13.19. ETA=20.66min
[Iter 10/100, Epoch 0] train loss=1.12e+00, gnorm=1.60e+01, lr=1.00e-04, #samples processed=128, #sample per second=13.19. ETA=20.66min
2022-01-21 06:11:41,796 - autogluon.text.text_prediction.mx.models - Level 25 - [Iter 10/100, Epoch 0] Validation r2=9.5444e-03, root_mean_squared_error=8.0055e-01, mean_absolute_error=6.0912e-01, Time computing validation-score=8.774s, Total time spent=2.45min. Found improved model=False, Improved top-3 models=True
[Iter 10/100, Epoch 0] Validation r2=9.5444e-03, root_mean_squared_error=8.0055e-01, mean_absolute_error=6.0912e-01, Time computing validation-score=8.774s, Total time spent=2.45min. Found improved model=False, Improved top-3 models=True
2022-01-21 06:11:51,391 - autogluon.text.text_prediction.mx.models - Level 15 - [Iter 11/100, Epoch 1] train loss=1.09e+00, gnorm=1.07e+01, lr=9.89e-05, #samples processed=128, #sample per second=6.68. ETA=21.16min
[Iter 11/100, Epoch 1] train loss=1.09e+00, gnorm=1.07e+01, lr=9.89e-05, #samples processed=128, #sample per second=6.68. ETA=21.16min
2022-01-21 06:12:01,142 - autogluon.text.text_prediction.mx.models - Level 15 - [Iter 12/100, Epoch 1] train loss=9.75e-01, gnorm=9.26e+00, lr=9.78e-05, #samples processed=128, #sample per second=13.13. ETA=20.37min
[Iter 12/100, Epoch 1] train loss=9.75e-01, gnorm=9.26e+00, lr=9.78e-05, #samples processed=128, #sample per second=13.13. ETA=20.37min
2022-01-21 06:12:11,466 - autogluon.text.text_prediction.mx.models - Level 25 - [Iter 12/100, Epoch 1] Validation r2=2.5167e-01, root_mean_squared_error=6.9585e-01, mean_absolute_error=5.4406e-01, Time computing validation-score=8.734s, Total time spent=2.95min. Found improved model=True, Improved top-3 models=True
[Iter 12/100, Epoch 1] Validation r2=2.5167e-01, root_mean_squared_error=6.9585e-01, mean_absolute_error=5.4406e-01, Time computing validation-score=8.734s, Total time spent=2.95min. Found improved model=True, Improved top-3 models=True
Training completed. Auto-saving to "ag_text_book_price_prediction/". For loading the model, you can use predictor = TextPredictor.load("ag_text_book_price_prediction/")
<autogluon.text.text_prediction.predictor.predictor.TextPredictor at 0x7f1725bb55e0>
Prediction¶
We can easily obtain predictions and extract data embeddings using the TextPredictor.
predictions = predictor.predict(test_data)
print('Predictions:')
print('------------')
print(np.exp(predictions) - 1)
print()
print('True Value:')
print('------------')
print(np.exp(test_data['Price']) - 1)
Predictions:
------------
1 388.794800
31 463.092896
19 656.493774
45 641.799744
82 701.047729
Name: Price, dtype: float32
True Value:
------------
1 202.93
31 799.00
19 352.00
45 395.10
82 409.00
Name: Price, dtype: float64
performance = predictor.evaluate(test_data)
print(performance)
0.5703942775726318
embeddings = predictor.extract_embedding(test_data)
print(embeddings)
[[-0.3209607 -0.03355258 -0.00624458 ... -0.31315193 -0.40560308
-0.99085736]
[-0.26278123 -0.05829103 0.06610287 ... -0.38716957 -0.34932163
-1.0310341 ]
[-0.2889479 0.10081426 0.00895867 ... -0.07957182 -0.35302812
-0.37785846]
[-0.22750379 -0.04325046 0.0204105 ... 0.00270245 -0.16621079
0.35989 ]
[-0.2810258 0.16438672 0.01511844 ... -0.17571244 -0.0506384
0.35036042]]
What’s happening inside?¶
Internally, we use different networks to encode the text columns, categorical columns, and numerical columns. The features generated by individual networks are aggregated by a late-fusion aggregator. The aggregator can output both the logits or score predictions. The architecture can be illustrated as follows:
Fig. 1 Multimodal Network with Late Fusion¶
Here, we use the pretrained NLP backbone to extract the text features and then use two other towers to extract the feature from categorical column and the numerical column.
In addition, to deal with multiple text fields, we separate these fields
with the [SEP] token and alternate 0s and 1s as the segment IDs,
which is shown as follows:
Fig. 2 Preprocessing¶
How does this compare with TabularPredictor?¶
Note that TabularPredictor can also handle data tables with text,
numeric, and categorical columns, but it uses an ensemble of many types
of models and may featurize text. TextPredictor instead directly
fits individual Transformer neural network models directly to the raw
text (which are also capable of handling additional numeric/categorical
columns). We generally recommend TabularPredictor if your table contains
mainly numeric/categorical columns and TextPredictor if your table
contains mainly text columns, but you may easily try both and we
encourage this. In fact,
TabularPredictor.fit(..., hyperparameters='multimodal') will train a
TextPredictor along with many tabular models and ensemble them together.
Refer to the tutorial “Multimodal Data Tables: Combining BERT/Transformers and Classical Tabular Models”
for more details.
Other Examples¶
You may go to https://github.com/awslabs/autogluon/tree/master/examples/text_prediction to explore other TextPredictor examples, including scripts to train a TextPredictor on the complete book price prediction dataset.