AutoGluon Tabular - Feature Engineering

Open In Colab Open In SageMaker Studio Lab

Introduction

Feature engineering involves taking raw tabular data and

  1. converting it into a format ready for the machine learning model to read

  2. trying to enhance some columns (‘features’ in ML jargon) to give the ML models more information, hoping to get more accurate results.

AutoGluon does some of this for you. This document describes how that works, and how you can extend it. We describe here the default behaviour, much of which is configurable, as well as pointers to how to alter the behaviour from the default.

Column Types

AutoGluon Tabular recognises the following types of features, and has separate processing for them:

Feature Type

Example Values

boolean

A, B

numerical

1.3, 2.0, -1.6

categorical

Red, Blue, Yellow

datetime

1/31/2021, Mar-31

text

Mary had a little lamb

In addition, other AutoGluon prediction modules recognise additional feature types, these can also be enabled in AutoGluon Tabular by using the MultiModal option.

Feature Type

Example Values

image

path/image123.png

Column Type Detection

  • Boolean columns are any columns with only 2 unique values.

  • Any string columns are deemed categorical unless they are text (see below). Some models perform better if you tell them which columns are categorical and which are continuous.

  • Numeric columns are passed through without change, except to identify them as float or int. Currently, numeric columns are not tested to determine if they are likely to be categorical. You can force them to be treated as categorical with the Pandas syntax .astype("category"), see below.

  • Text columns are detected by firstly checking that most rows are unique. If they are, and there are multiple separate words detected in most rows, the row is a text column. For details see common/features/infer_types.py in the source.

  • Datetime columns are detected by trying to convert them to Pandas datetimes. Pandas detects a wide range of datetime formats. If many of the values in a column are successfully converted, they are datetimes. Currently datetimes that appear to be purely numeric (e.g. 20210530) are not correctly detected. Any NaN values are set to the column mean. For details see common/features/infer_types.py.

Problem Type Detection

If the user does not specify whether the problem is a classification problem or a regression problem, the ‘label’ column is examined to try to guess. Several things point towards a regression problem : the values are floating point non-integers, and there are a large amount of unique values. Within classification, both multiclass and binary (n=2 categories) are detected. For details see utils/utils.py.

To override the automatic inference, explicitly pass the problem_type (one of ‘binary’, ‘regression’, ‘multiclass’) to TabularPredictor(). For example:

predictor = TabularPredictor(label='class', problem_type='multiclass').fit(train_data)

Automatic Feature Engineering

Numerical Columns

Numeric columns, both integer and floating point, currently have no automated feature engineering.

Categorical Columns

Since many downstream models require categories to be encoded as integers, each categorical feature is mapped to monotonically increasing integers.

Datetime Columns

Columns recognised as datetime, are converted into several features:

  • a numerical Pandas datetime. Note this has maximum and minimum values specified at pandas.Timestamp.min and pandas.Timestamp.max respectively, which may affect extremely dates very far into the future or past.

  • several extracted columns, the default is [year, month, day, dayofweek]. This is configrable via the DatetimeFeatureGenerator

Note that missing, invalid and out-of-range features generated by the above logic will be converted to the mean value across all valid rows.

Text Columns

If the MultiModal option is enabled, then text columns are processed using a full Transformer neural network model with pretrained NLP models.

Otherwise, they are processed in two more simple ways:

  • an n-gram feature generator extracts n-grams (short strings) from the text feature, adding many additional columns, one for each n-gram feature. These columns are ‘n-hot’ encoded, containing 1 or more if the original feature contains the n-gram 1 or more times, and 0 otherwise. By default, all text columns are concatenated before applying this stage, and the n-grams are individual words, not substrings of words. You can configure this via the TextNgramFeatureGenerator class. The n-gram generation is done in generators/text_ngram.py

  • Some additional numerical features are calculated, such as word counts, character counts, proportion of uppercase characters, etc. This is configurable via the TextSpecialFeatureGenerator. This is done in generators/text_special.py

Additional Processing

  • Columns containing only 1 value are dropped before passing to models.

  • Columns containing duplicates of other columns are removed before passing to models.

Feature Engineering Example

By default a feature generator called AutoMLPipelineFeatureGenerator is used. Let’s see this in action. We’ll create a dataframe containing a floating point column, an integer column, a datetime column, a categorical column. We’ll first take a look at the raw data we created.

from autogluon.tabular import TabularDataset, TabularPredictor
import pandas as pd
import numpy as np
import random
from sklearn.datasets import make_regression
from datetime import datetime

x, y = make_regression(n_samples = 100,n_features = 5,n_targets = 1, random_state = 1)
dfx = pd.DataFrame(x, columns=['A','B','C','D','E'])
dfy = pd.DataFrame(y, columns=['label'])

# Create an integer column, a datetime column, a categorical column and a string column to demonstrate how they are processed.
dfx['B'] = (dfx['B']).astype(int)
dfx['C'] = datetime(2000,1,1) + pd.to_timedelta(dfx['C'].astype(int), unit='D')
dfx['D'] = pd.cut(dfx['D'] * 10, [-np.inf,-5,0,5,np.inf],labels=['v','w','x','y'])
dfx['E'] = pd.Series(list(' '.join(random.choice(["abc", "d", "ef", "ghi", "jkl"]) for i in range(4)) for j in range(100)))
dataset=TabularDataset(dfx)
print(dfx)
           A  B          C  D               E
0  -0.545774  0 2000-01-01  y   ghi d ghi abc
1  -0.468674  0 2000-01-02  x   abc abc abc d
2   1.767960  0 1999-12-31  v      d ef ef ef
3  -0.118771  1 2000-01-01  y     jkl d d abc
4   0.630196  0 1999-12-31  w  ef jkl ghi jkl
..       ... ..        ... ..             ...
95 -1.182318 -1 2000-01-01  v       d ef ef d
96  0.562761  0 2000-01-01  v  abc ghi ef jkl
97 -0.797270  0 2000-01-01  w    abc jkl ef d
98  0.502741  0 1999-12-31  y   ef abc ef abc
99  2.056356  0 1999-12-30  w   jkl ghi d jkl

[100 rows x 5 columns]

Now let’s call the default feature generator AutoMLPipeLineFeatureGenerator with no parameters and see what it does.

from autogluon.features.generators import AutoMLPipelineFeatureGenerator
auto_ml_pipeline_feature_generator = AutoMLPipelineFeatureGenerator()
auto_ml_pipeline_feature_generator.fit_transform(X=dfx)
A B D E C C.year C.month C.day C.dayofweek E.char_count E.symbol_ratio. __nlp__.abc __nlp__.ef __nlp__.ghi __nlp__.jkl __nlp__._total_
0 -0.545774 0 3 NaN 946684800000000000 2000 1 1 5 5 2 1 0 2 0 2
1 -0.468674 0 2 NaN 946771200000000000 2000 1 2 6 5 2 3 0 0 0 1
2 1.767960 0 0 NaN 946598400000000000 1999 12 31 4 2 5 0 3 0 0 1
3 -0.118771 1 3 NaN 946684800000000000 2000 1 1 5 3 4 1 0 0 1 2
4 0.630196 0 1 NaN 946598400000000000 1999 12 31 4 6 1 0 1 1 2 3
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
95 -1.182318 -1 0 NaN 946684800000000000 2000 1 1 5 1 6 0 2 0 0 1
96 0.562761 0 0 NaN 946684800000000000 2000 1 1 5 6 1 1 1 1 1 4
97 -0.797270 0 1 NaN 946684800000000000 2000 1 1 5 4 3 1 1 0 1 3
98 0.502741 0 3 NaN 946598400000000000 1999 12 31 4 5 2 2 2 0 0 2
99 2.056356 0 1 NaN 946512000000000000 1999 12 30 3 5 2 0 0 1 2 2

100 rows × 16 columns

We can see that:

  • The floating point and integer columns ‘A’ and ‘B’ are unchanged.

  • The datetime column ‘C’ has been converted to a raw value (in nanoseconds), as well as parsed into additional columns for the year, month, day and dayofweek.

  • The string categorical column ‘D’ has been mapped 1:1 to integers - a lot of models only accept numerical input.

  • The freeform text column has been mapped into some summary features (‘char_count’ etc) as well as a N-hot matrix saying whether each text contained each word.

To get more details, we should call the pipeline as part of TabularPredictor.fit(). We need to combine the dfx and dfy DataFrames since fit() expects a single dataframe.

df = pd.concat([dfx, dfy], axis=1)
predictor = TabularPredictor(label='label')
predictor.fit(df, hyperparameters={'GBM' : {}}, feature_generator=auto_ml_pipeline_feature_generator)
No path specified. Models will be saved in: "AutogluonModels/ag-20250618_162521"
Verbosity: 2 (Standard Logging)
=================== System Info ===================
AutoGluon Version:  1.3.2b20250618
Python Version:     3.12.10
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count:          8
Memory Avail:       28.79 GB / 30.95 GB (93.0%)
Disk Space Avail:   206.63 GB / 255.99 GB (80.7%)
===================================================
No presets specified! To achieve strong results with AutoGluon, it is recommended to use the available presets. Defaulting to `'medium'`...
	Recommended Presets (For more details refer to https://auto.gluon.ai/stable/tutorials/tabular/tabular-essentials.html#presets):
	presets='experimental' : New in v1.2: Pre-trained foundation model + parallel fits. The absolute best accuracy without consideration for inference speed. Does not support GPU.
	presets='best'         : Maximize accuracy. Recommended for most users. Use in competitions and benchmarks.
	presets='high'         : Strong accuracy with fast inference speed.
	presets='good'         : Good accuracy with very fast inference speed.
	presets='medium'       : Fast training time, ideal for initial prototyping.
/home/ci/autogluon/common/src/autogluon/common/utils/utils.py:97: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  import pkg_resources
Beginning AutoGluon training ...
AutoGluon will save models to "/home/ci/autogluon/docs/tutorials/tabular/AutogluonModels/ag-20250618_162521"
Train Data Rows:    100
Train Data Columns: 5
Label Column:       label
AutoGluon infers your prediction problem is: 'regression' (because dtype of label-column == float and many unique label-values observed).
	Label info (max, min, mean, stddev): (186.98105511749836, -267.99365510467214, 9.38193, 71.29287)
	If 'regression' is not the correct problem_type, please manually specify the problem_type parameter during Predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression', 'quantile'])
Problem Type:       regression
Preprocessing data ...
Using Feature Generators to preprocess the data ...
AutoMLPipelineFeatureGenerator is already fit, so the training data will be processed via .transform() instead of .fit_transform().
	Types of features in original data (raw dtype, special dtypes):
		('category', [])     : 1 | ['D']
		('datetime', [])     : 1 | ['C']
		('float', [])        : 1 | ['A']
		('int', [])          : 1 | ['B']
		('object', ['text']) : 1 | ['E']
	Types of features in processed data (raw dtype, special dtypes):
		('category', [])                    : 1 | ['D']
		('category', ['text_as_category'])  : 1 | ['E']
		('float', [])                       : 1 | ['A']
		('int', [])                         : 1 | ['B']
		('int', ['binned', 'text_special']) : 2 | ['E.char_count', 'E.symbol_ratio. ']
		('int', ['datetime_as_int'])        : 5 | ['C', 'C.year', 'C.month', 'C.day', 'C.dayofweek']
		('int', ['text_ngram'])             : 5 | ['__nlp__.abc', '__nlp__.ef', '__nlp__.ghi', '__nlp__.jkl', '__nlp__._total_']
Data preprocessing and feature engineering runtime = 0.02s ...
AutoGluon will gauge predictive performance using evaluation metric: 'root_mean_squared_error'
	This metric's sign has been flipped to adhere to being higher_is_better. The metric score can be multiplied by -1 to get the metric value.
	To change this, specify the eval_metric parameter of Predictor()
Automatically generating train/validation split with holdout_frac=0.2, Train Rows: 80, Val Rows: 20
User-specified model hyperparameters to be fit:
{
	'GBM': [{}],
}
Fitting 1 L1 models, fit_strategy="sequential" ...
Fitting model: LightGBM ...
	-57.1631	 = Validation score   (-root_mean_squared_error)
	0.34s	 = Training   runtime
	0.0s	 = Validation runtime
Fitting model: WeightedEnsemble_L2 ...
	Ensemble Weights: {'LightGBM': 1.0}
	-57.1631	 = Validation score   (-root_mean_squared_error)
	0.0s	 = Training   runtime
	0.0s	 = Validation runtime
AutoGluon training complete, total runtime = 0.4s ... Best model: WeightedEnsemble_L2 | Estimated inference throughput: 5317.3 rows/s (20 batch size)
TabularPredictor saved. To load, use: predictor = TabularPredictor.load("/home/ci/autogluon/docs/tutorials/tabular/AutogluonModels/ag-20250618_162521")
<autogluon.tabular.predictor.predictor.TabularPredictor at 0x7f2404460710>

Reading the output, note that:

  • the string-categorical column ‘D’, despite being mapped to integers, is still recognised as categorical.

  • the integer column ‘B’ has not been identified as categorical, even though it only has a few unique values:

print(len(set(dfx['B'])))
5

To mark it as categorical, we can explicitly mark it as categorical in the original dataframe:

dfx["B"] = dfx["B"].astype("category")
auto_ml_pipeline_feature_generator = AutoMLPipelineFeatureGenerator()
auto_ml_pipeline_feature_generator.fit_transform(X=dfx)
Fitting AutoMLPipelineFeatureGenerator...
	Available Memory:                    29461.69 MB
	Train Data (Original)  Memory Usage: 0.01 MB (0.0% of available memory)
	Inferring data type of each feature based on column values. Set feature_metadata_in to manually specify special dtypes of the features.
	Stage 1 Generators:
		Fitting AsTypeFeatureGenerator...
	Stage 2 Generators:
		Fitting FillNaFeatureGenerator...
	Stage 3 Generators:
		Fitting IdentityFeatureGenerator...
		Fitting CategoryFeatureGenerator...
			Fitting CategoryMemoryMinimizeFeatureGenerator...
		Fitting DatetimeFeatureGenerator...
		Fitting TextSpecialFeatureGenerator...
			Fitting BinnedFeatureGenerator...
			Fitting DropDuplicatesFeatureGenerator...
		Fitting TextNgramFeatureGenerator...
			Fitting CountVectorizer for text features: ['E']
			CountVectorizer fit with vocabulary size = 4
	Stage 4 Generators:
		Fitting DropUniqueFeatureGenerator...
	Stage 5 Generators:
		Fitting DropDuplicatesFeatureGenerator...
	Types of features in original data (raw dtype, special dtypes):
		('category', [])     : 2 | ['B', 'D']
		('datetime', [])     : 1 | ['C']
		('float', [])        : 1 | ['A']
		('object', ['text']) : 1 | ['E']
	Types of features in processed data (raw dtype, special dtypes):
		('category', [])                    : 2 | ['B', 'D']
		('category', ['text_as_category'])  : 1 | ['E']
		('float', [])                       : 1 | ['A']
		('int', ['binned', 'text_special']) : 2 | ['E.char_count', 'E.symbol_ratio. ']
		('int', ['datetime_as_int'])        : 5 | ['C', 'C.year', 'C.month', 'C.day', 'C.dayofweek']
		('int', ['text_ngram'])             : 5 | ['__nlp__.abc', '__nlp__.ef', '__nlp__.ghi', '__nlp__.jkl', '__nlp__._total_']
	0.1s = Fit runtime
	5 features in original data used to generate 16 features in processed data.
	Train Data (Processed) Memory Usage: 0.01 MB (0.0% of available memory)
A B D E C C.year C.month C.day C.dayofweek E.char_count E.symbol_ratio. __nlp__.abc __nlp__.ef __nlp__.ghi __nlp__.jkl __nlp__._total_
0 -0.545774 1 3 NaN 946684800000000000 2000 1 1 5 5 2 1 0 2 0 2
1 -0.468674 1 2 NaN 946771200000000000 2000 1 2 6 5 2 3 0 0 0 1
2 1.767960 1 0 NaN 946598400000000000 1999 12 31 4 2 5 0 3 0 0 1
3 -0.118771 2 3 NaN 946684800000000000 2000 1 1 5 3 4 1 0 0 1 2
4 0.630196 1 1 NaN 946598400000000000 1999 12 31 4 6 1 0 1 1 2 3
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
95 -1.182318 0 0 NaN 946684800000000000 2000 1 1 5 1 6 0 2 0 0 1
96 0.562761 1 0 NaN 946684800000000000 2000 1 1 5 6 1 1 1 1 1 4
97 -0.797270 1 1 NaN 946684800000000000 2000 1 1 5 4 3 1 1 0 1 3
98 0.502741 1 3 NaN 946598400000000000 1999 12 31 4 5 2 2 2 0 0 2
99 2.056356 1 1 NaN 946512000000000000 1999 12 30 3 5 2 0 0 1 2 2

100 rows × 16 columns

Missing Value Handling

To illustrate missing value handling, let’s set the first row to all NaNs:

dfx.iloc[0] = np.nan
dfx.head()
A B C D E
0 NaN NaN NaT NaN NaN
1 -0.468674 0 2000-01-02 x abc abc abc d
2 1.767960 0 1999-12-31 v d ef ef ef
3 -0.118771 1 2000-01-01 y jkl d d abc
4 0.630196 0 1999-12-31 w ef jkl ghi jkl

Now if we reprocess:

auto_ml_pipeline_feature_generator = AutoMLPipelineFeatureGenerator()
auto_ml_pipeline_feature_generator.fit_transform(X=dfx)
Fitting AutoMLPipelineFeatureGenerator...
	Available Memory:                    29462.08 MB
	Train Data (Original)  Memory Usage: 0.01 MB (0.0% of available memory)
	Inferring data type of each feature based on column values. Set feature_metadata_in to manually specify special dtypes of the features.
	Stage 1 Generators:
		Fitting AsTypeFeatureGenerator...
	Stage 2 Generators:
		Fitting FillNaFeatureGenerator...
	Stage 3 Generators:
		Fitting IdentityFeatureGenerator...
		Fitting CategoryFeatureGenerator...
			Fitting CategoryMemoryMinimizeFeatureGenerator...
		Fitting DatetimeFeatureGenerator...
		Fitting TextSpecialFeatureGenerator...
			Fitting BinnedFeatureGenerator...
			Fitting DropDuplicatesFeatureGenerator...
		Fitting TextNgramFeatureGenerator...
			Fitting CountVectorizer for text features: ['E']
			CountVectorizer fit with vocabulary size = 4
	Stage 4 Generators:
		Fitting DropUniqueFeatureGenerator...
	Stage 5 Generators:
		Fitting DropDuplicatesFeatureGenerator...
	Types of features in original data (raw dtype, special dtypes):
		('category', [])     : 2 | ['B', 'D']
		('datetime', [])     : 1 | ['C']
		('float', [])        : 1 | ['A']
		('object', ['text']) : 1 | ['E']
	Types of features in processed data (raw dtype, special dtypes):
		('category', [])                    : 2 | ['B', 'D']
		('category', ['text_as_category'])  : 1 | ['E']
		('float', [])                       : 1 | ['A']
		('int', ['binned', 'text_special']) : 3 | ['E.char_count', 'E.word_count', 'E.symbol_ratio. ']
		('int', ['datetime_as_int'])        : 5 | ['C', 'C.year', 'C.month', 'C.day', 'C.dayofweek']
		('int', ['text_ngram'])             : 5 | ['__nlp__.abc', '__nlp__.ef', '__nlp__.ghi', '__nlp__.jkl', '__nlp__._total_']
	4.4s = Fit runtime
	5 features in original data used to generate 17 features in processed data.
	Train Data (Processed) Memory Usage: 0.01 MB (0.0% of available memory)
A B D E C C.year C.month C.day C.dayofweek E.char_count E.word_count E.symbol_ratio. __nlp__.abc __nlp__.ef __nlp__.ghi __nlp__.jkl __nlp__._total_
0 NaN NaN NaN NaN 946687418181818240 2000 1 1 5 0 0 0 0 0 0 0 0
1 -0.468674 1 2 NaN 946771200000000000 2000 1 2 6 6 1 3 3 0 0 0 1
2 1.767960 1 0 NaN 946598400000000000 1999 12 31 4 3 1 6 0 3 0 0 1
3 -0.118771 2 3 NaN 946684800000000000 2000 1 1 5 4 1 5 1 0 0 1 2
4 0.630196 1 1 NaN 946598400000000000 1999 12 31 4 7 1 2 0 1 1 2 3
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
95 -1.182318 0 0 NaN 946684800000000000 2000 1 1 5 2 1 7 0 2 0 0 1
96 0.562761 1 0 NaN 946684800000000000 2000 1 1 5 7 1 2 1 1 1 1 4
97 -0.797270 1 1 NaN 946684800000000000 2000 1 1 5 5 1 4 1 1 0 1 3
98 0.502741 1 3 NaN 946598400000000000 1999 12 31 4 6 1 3 2 2 0 0 2
99 2.056356 1 1 NaN 946512000000000000 1999 12 30 3 6 1 3 0 0 1 2 2

100 rows × 17 columns

We see that the floating point, integer, categorical and text fields ‘A’, ‘B’, ‘D’, and ‘E’ have retained the NaNs, but the datetime column ‘C’ has been set to the mean of the non-NaN values.

Customization of Feature Engineering

To customize your feature generation pipeline, it is recommended to call PipelineFeatureGenerator, passing in non-default parameters to other feature generators as required. For example, if we think downstream models would benefit from removing rare categorical values and replacing with NaN, we can supply the parameter maximum_num_cat to CategoryFeatureGenerator, as below:

from autogluon.features.generators import PipelineFeatureGenerator, CategoryFeatureGenerator, IdentityFeatureGenerator
from autogluon.common.features.types import R_INT, R_FLOAT
mypipeline = PipelineFeatureGenerator(
    generators = [[        
        CategoryFeatureGenerator(maximum_num_cat=10),  # Overridden from default.
        IdentityFeatureGenerator(infer_features_in_args=dict(valid_raw_types=[R_INT, R_FLOAT])),
    ]]
)

If we then dump out the transformed data, we can see that all columns have been converted to numeric, because that’s what most models require, and the rare categorical values have been replaced with NaN:

mypipeline.fit_transform(X=dfx)
Fitting PipelineFeatureGenerator...
	Available Memory:                    29435.65 MB
	Train Data (Original)  Memory Usage: 0.01 MB (0.0% of available memory)
	Inferring data type of each feature based on column values. Set feature_metadata_in to manually specify special dtypes of the features.
	Stage 1 Generators:
		Fitting AsTypeFeatureGenerator...
	Stage 2 Generators:
		Fitting FillNaFeatureGenerator...
	Stage 3 Generators:
		Fitting CategoryFeatureGenerator...
			Fitting CategoryMemoryMinimizeFeatureGenerator...
		Fitting IdentityFeatureGenerator...
	Stage 4 Generators:
		Fitting DropUniqueFeatureGenerator...
	Stage 5 Generators:
		Fitting DropDuplicatesFeatureGenerator...
	Unused Original Features (Count: 1): ['C']
		These features were not used to generate any of the output features. Add a feature generator compatible with these features to utilize them.
		Features can also be unused if they carry very little information, such as being categorical but having almost entirely unique values or being duplicates of other features.
		These features do not need to be present at inference time.
		('datetime', []) : 1 | ['C']
	Types of features in original data (raw dtype, special dtypes):
		('category', [])     : 2 | ['B', 'D']
		('float', [])        : 1 | ['A']
		('object', ['text']) : 1 | ['E']
	Types of features in processed data (raw dtype, special dtypes):
		('category', [])                   : 2 | ['B', 'D']
		('category', ['text_as_category']) : 1 | ['E']
		('float', [])                      : 1 | ['A']
	0.0s = Fit runtime
	4 features in original data used to generate 4 features in processed data.
	Train Data (Processed) Memory Usage: 0.00 MB (0.0% of available memory)
B D E A
0 NaN NaN NaN NaN
1 1 2 NaN -0.468674
2 1 0 NaN 1.767960
3 2 3 NaN -0.118771
4 1 1 NaN 0.630196
... ... ... ... ...
95 0 0 NaN -1.182318
96 1 0 NaN 0.562761
97 1 1 NaN -0.797270
98 1 3 NaN 0.502741
99 1 1 NaN 2.056356

100 rows × 4 columns

For more on custom feature engineering, see the detailed notebook examples/tabular/example_custom_feature_generator.py.