.. _sec_textprediction_heterogeneous: Text Prediction - Heterogeneous Data Types ========================================== In your applications, your text data may be mixed with other common data types like numerical data and categorical data (which are commonly found in tabular data). The ``TextPrediction`` task in AutoGluon can train a single neural network that jointly operates on multiple feature types, including text, categorical, and numerical columns. Here we'll again use the `Semantic Textual Similarity `__ dataset to illustrate this functionality. .. code:: python import numpy as np import warnings warnings.filterwarnings('ignore') np.random.seed(123) Load Data and Train Model ------------------------- .. code:: python from autogluon.core.utils.loaders import load_pd train_data = load_pd.load('https://autogluon-text.s3-accelerate.amazonaws.com/glue/sts/train.parquet') dev_data = load_pd.load('https://autogluon-text.s3-accelerate.amazonaws.com/glue/sts/dev.parquet') train_data.head(10) .. raw:: html
sentence1 sentence2 genre score
0 A plane is taking off. An air plane is taking off. main-captions 5.00
1 A man is playing a large flute. A man is playing a flute. main-captions 3.80
2 A man is spreading shreded cheese on a pizza. A man is spreading shredded cheese on an uncoo... main-captions 3.80
3 Three men are playing chess. Two men are playing chess. main-captions 2.60
4 A man is playing the cello. A man seated is playing the cello. main-captions 4.25
5 Some men are fighting. Two men are fighting. main-captions 4.25
6 A man is smoking. A man is skating. main-captions 0.50
7 The man is playing the piano. The man is playing the guitar. main-captions 1.60
8 A man is playing on a guitar and singing. A woman is playing an acoustic guitar and sing... main-captions 2.20
9 A person is throwing a cat on to the ceiling. A person throws a cat on the ceiling. main-captions 5.00
Note the STS dataset contains two text fields: ``sentence1`` and ``sentence2``, one categorical field: ``genre``, and one numerical field ``score``. Let's try to predict the **score** based on the other features: ``sentence1``, ``sentence2``, ``genre``. .. code:: python import autogluon.core as ag from autogluon.text import TextPrediction as task predictor_score = task.fit(train_data, label='score', time_limits=60, ngpus_per_trial=1, seed=123, output_directory='./ag_sts_mixed_score') .. parsed-literal:: :class: output 2021-02-23 19:31:54,207 - autogluon.text.text_prediction.text_prediction - INFO - All Logs will be saved to ./ag_sts_mixed_score/ag_text_prediction.log INFO:autogluon.text.text_prediction.text_prediction:All Logs will be saved to ./ag_sts_mixed_score/ag_text_prediction.log 2021-02-23 19:31:54,244 - autogluon.text.text_prediction.text_prediction - INFO - Train Dataset: INFO:autogluon.text.text_prediction.text_prediction:Train Dataset: 2021-02-23 19:31:54,244 - autogluon.text.text_prediction.text_prediction - INFO - Columns: - Text( name="sentence1" #total/missing=4599/0 length, min/avg/max=16/57.62056968906284/367 ) - Text( name="sentence2" #total/missing=4599/0 length, min/avg/max=15/57.47532072189606/311 ) - Categorical( name="genre" #total/missing=4599/0 num_class (total/non_special)=4/3 categories=['main-captions', 'main-forums', 'main-news'] freq=[1608, 366, 2625] ) - Numerical( name="score" #total/missing=4599/0 shape=() ) INFO:autogluon.text.text_prediction.text_prediction:Columns: - Text( name="sentence1" #total/missing=4599/0 length, min/avg/max=16/57.62056968906284/367 ) - Text( name="sentence2" #total/missing=4599/0 length, min/avg/max=15/57.47532072189606/311 ) - Categorical( name="genre" #total/missing=4599/0 num_class (total/non_special)=4/3 categories=['main-captions', 'main-forums', 'main-news'] freq=[1608, 366, 2625] ) - Numerical( name="score" #total/missing=4599/0 shape=() ) 2021-02-23 19:31:54,246 - autogluon.text.text_prediction.text_prediction - INFO - Tuning Dataset: INFO:autogluon.text.text_prediction.text_prediction:Tuning Dataset: 2021-02-23 19:31:54,247 - autogluon.text.text_prediction.text_prediction - INFO - Columns: - Text( name="sentence1" #total/missing=1150/0 length, min/avg/max=16/58.06/315 ) - Text( name="sentence2" #total/missing=1150/0 length, min/avg/max=15/57.76173913043478/256 ) - Categorical( name="genre" #total/missing=1150/0 num_class (total/non_special)=4/3 categories=['main-captions', 'main-forums', 'main-news'] freq=[392, 84, 674] ) - Numerical( name="score" #total/missing=1150/0 shape=() ) INFO:autogluon.text.text_prediction.text_prediction:Columns: - Text( name="sentence1" #total/missing=1150/0 length, min/avg/max=16/58.06/315 ) - Text( name="sentence2" #total/missing=1150/0 length, min/avg/max=15/57.76173913043478/256 ) - Categorical( name="genre" #total/missing=1150/0 num_class (total/non_special)=4/3 categories=['main-captions', 'main-forums', 'main-news'] freq=[392, 84, 674] ) - Numerical( name="score" #total/missing=1150/0 shape=() ) WARNING:autogluon.core.utils.multiprocessing_utils:WARNING: changing multiprocessing start method to forkserver 2021-02-23 19:31:54,254 - autogluon.text.text_prediction.text_prediction - INFO - All Logs will be saved to ./ag_sts_mixed_score/main.log INFO:autogluon.text.text_prediction.text_prediction:All Logs will be saved to ./ag_sts_mixed_score/main.log .. parsed-literal:: :class: output 0%| | 0/3 [00:00