train_test_split: Tutorial on how to use this function

A Machine Learning model is capable of autonomously learning from a dataset with the goal of predicting behaviors on another dataset. To achieve this, it finds underlying relationships between independent explanatory variables and a target variable in the initial dataset. Then, it uses these patterns to predict or classify new data.

How would you define the train_test_split function?

To verify the effectiveness of a Machine Learning model, the initial dataset is divided into two sets: a training set and a test set. The training set is used to fit, or train, the model on a portion of the data. The test set is used to evaluate the performance of this model on the remaining portion of the data. The train_test_split function of the ScikitLearn (sklearn) library in Python allows you to split the dataset into these two sets.

Before anything else, you need to import the train_test_split function from the model_selection package of sklearn using the following code:

from sklearn.model_selection import train_test_split

Once imported, the function takes several arguments:

1) Arrays extracted from the dataset to be split

In supervised learning, these arrays are the input array X, consisting of explanatory variables in columns, and the output array y, consisting of the target variable (i.e., the labels).

In unsupervised learning, the only array passed as an argument is the input array X, consisting of explanatory variables in columns.

Note: Pay attention to the dimensions! X must be a two-dimensional array. y must be a one-dimensional array equal to the number of rows in X. To achieve this, do not hesitate to use the .reshape function.

2) The size of the test set (test_size) and the size of the training set (train_size).

The size of each set can either be a decimal number between 0 and 1 representing a proportion of the dataset or an integer representing the number of examples in the dataset.

Note: It is sufficient to define only one of these arguments, the second one being complementary.

3) The random state (random_state)

A random state is a number that controls how the pseudo-random generator splits the data.

Note: Choosing an integer as the random state allows the data to be split in the same way every time the function is called. This makes the code reproducible.

4) The shuffle (shuffle)

The shuffle is a boolean that selects whether the data should be shuffled or not before being split. In the case where they are not shuffled, the data is split according to the order in which they were originally.

Note: The default value is True.

5) Stratify (stratify)

The stratify parameter selects whether the data should be split in order to keep the same proportions of observations in each class in the training and test sets as in the initial dataset.


  • This parameter is particularly useful for unbalanced data with very imbalanced proportions between the different classes.
  • The default value is None.

The train_test_split function returns a number of outputs equal to twice its number of inputs, in the form of an array. Thus, in supervised learning, it returns four outputs: X_train, X_test, y_train, and y_test. In unsupervised learning, it returns two outputs: X_train and X_test.

How to evaluate the performance of a model using the train_test_split function?

Once the train_test_split function is defined, it returns a training set and a test set. This splitting of the data allows evaluating a Machine Learning model from two different angles.

The model is trained on the training set returned by the function. Then its predictive capabilities are evaluated on the test set returned by the function. Several metrics can be used for this evaluation. In the case of linear regression, the coefficient of determination, RMSE, and MAE are preferred. In the case of classification, accuracy, precision, recall, and F1-score are preferred. These scores on the test set allow for determining whether the model is performing well and to what extent it needs to be improved before making predictions on a new dataset.

The training and test sets returned by the train_test_split function also play an essential role in detecting overfitting or underfitting. Overfitting describes a situation where the constructed model is too complex (with too many explanatory variables, for example), such that it perfectly learns the training data but fails to generalize to other data. On the other hand, underfitting describes a situation where the model is too simple or poorly chosen (choosing linear regression on data that does not meet its assumptions, for example), such that it learns poorly. These two problems can be corrected by different techniques, but they must first be identified, which is possible thanks to the train_test_split function. Indeed, we can compare the model's performance on the training set and the test set created by the function. If the performance is good on the training set but poor on the test set, we are probably facing overfitting. If the performance is as poor on the training set as it is on the test set, we are probably facing underfitting. Therefore, the two sets returned by the function are essential in detecting these recurrent problems in Machine Learning.

How to solve a complete Machine Learning problem using the train_test_split function?

Now that we have understood the use and features of the train_test_split function, let's put it into practice through a real Machine Learning problem.

Step 1: Understanding the problem

We choose to solve a supervised learning problem where the expected labels are known. Specifically, we focus on binary classification. The goal is to predict whether an individual has or does not have breast cancer based on their physical characteristics.

Step 2: Data retrieval

We use the "breast_cancer" dataset included in the Sklearn library.

import numpy as np
from sklearn.datasets import load_breast_cancer

In the following lines of code, we retrieve the explanatory variables (features) and the target variable:

print("features :", df.feature_names)
print ("target :", df.target_names)

We find that the target variable to predict takes two values ("malignant" and "benign") and that the problem is indeed a binary classification.

Step 3: Creating X and y

We create the two-dimensional input array X and the one-dimensional output array y. For this dataset, the binary encoding of the target variable is done by sklearn and can be directly retrieved.

X, y = load_breast_cancer(return_X_y=True)
print("X :", X)
print("y :", y)
print("Dimensions de X :", X.shape)
print("Dimensions de y :", y.shape)

We verify that the dimensions of X and y are corresponding: y has the same number of rows as X.

Step 4: Creating the train and test sets

We split the data into a train set and a test set. 

Since we provide two arrays X and y to the train_test_split function, it returns four elements. We choose a test set consisting of 10% of the data. We choose an integer number as a random state to ensure code reproducibility. We do not use the last parameters of the function, which are not necessary for such a simple problem." 

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.1, random_state=123)

Step 5: Classification model

To solve the classification task, we build a k-nearest neighbors model. We train the model on the train set using the .fit() method. Then we test the model's performance on the test set using the .predict() method. This gives us the predicted classes for the observations in the test set." 

from sklearn.neighbors import KNeighborsClassifier
clf = KNeighborsClassifier(), y_train)
prediction = clf.predict(X_test)

Step 6: Model evaluation

We choose accuracy as the metric. Accuracy represents the number of correct predictions out of the total number of predictions. We calculate it on the train set and the test set using the .score() method, which compares the true classes of the dataset to the classes predicted by the classifier clf." 

print(clf.score(X_train, y_train))
print(clf.score(X_test, y_test))

We achieve an accuracy of 0.95 on the training set and 0.93 on the test set. Therefore, the model has good classification performance.

Furthermore, the accuracy on the test set is only slightly lower than that on the training set. This means that the model generalizes well to new data. We are therefore not facing an overfitting problem.

Thus, the train_test_split function is easy to use and very effective for solving a complete machine-learning problem.

Are there any limitations to the train_test_split function?

Despite its effectiveness, the train_test_split function has a main limitation related to its random_state parameter. Indeed, when the value given to random_state is an integer, the data is split using a pseudo-random generator initialized with this integer, called a random seed. The split performed is reproducible by keeping the same seed. However, it has been shown that the choice of the seed has an influence on the performance of the associated machine learning model: different seeds can create different sets and variable scores.

One solution to this problem is to use the train_test_split function several times with different values for the random_state. We can then calculate the average of the obtained scores.

Next Post