Can you tell when a user is about to leave your service?

I came across this question recently and decided to take on the challenge. I wondered if it was possible for companies to notice when they believe the user is most likely to churn and potentially offer them a discount so that the user stays with the service longer.

Udacity provided a dataset for a fictional company called Sparkify. Every interaction a user has with the service is stored in the data file. The data file contains a lot of information about each user such as gender, browser, song played, length of the song, the time the song was played, liking a song, etc.

This type of idea could provide a significant increase in revenues that are currently going unrealized. What business wouldn’t want more revenue?

So I got to work on developing something that can be scaled with big data. I used Apache Spark to clean the data, as well as create the machine learning model. This is very important because it allows for companies with HUGE datasets to implement this solution. It also means that my solution will be able to scale with the dataset size.

The plan of action for this project was to clean the dataset and then apply a Decision Tree to model the data. The metrics used to analyze the efficiency of the model will be accuracy as well as the F1 score. Accuracy can be calculated as:

Accuracy = (True Positives + True Negatives) / All Samples

Problem Statement

The job of this project is to find if a specific user is likely to churn based on the information that is already collected. Then based on which users are likely to churn Sparkify will offer those users a slightly lower rate to prevent them from churning.

Another big thing that needs to be kept in the front of our mind when developing this solution is that the model must be scalable. It must be able to be used for big data sets.

Loading and Exploring

First I needed to import everything that would be necessary to clean, explore, and build my model.

I also needed to create my SparkSession so I can load in my dataset.

After loading in the data from the json file by using the method, I checked the schema of the data.

Artist: The artist of the song that is currently being played

Auth: Tells you if the user is logged in to their account

FirstName: The First Name of the user

Gender: The gender of the user

ItemInSession: Item number in session

LastName: The Last Name of the user

Length: The length of the song in seconds

Level: The type of subscription the user has (free or paid)

Location: The location of the user

Method: GET or PUT requests

Page: The page where the user is currently

Registration: User registration number

SessionID: The current session ID

Song: The title of the song currently being listened to

Status: Web Status

TS: Timestamp of the action that has occurred

UserAgent: User agent

UserID: The unique ID for each user

One example of what the data looks like in the dataset is:

There were only a handful of columns that I deemed to be useful such as: gender, level (is the user paying for the service or is it free), and page.

All of the possible different pages are:

Two keys things to notice here is that Cancellation Confirmation and Downgrade are both pages that get stored. These would be very helpful in building the model later.

Let’s get an idea for how many customers were flagged for cancelling their service:

This data seems quite lopsided. This means that the F1 metric will probably be better metric to assess our model on.

Now lets see what happens when we compare the amount of users who churn compared across gender and the level of the subscription.

Each of the categories on the plot correspond to the index of the dataframe above. This bar chart reveals a lot of interesting information to us. It appears that males and females seem to have different churn amounts.


In this step I will clean the data and get it in a useable form for the machine learning model to utilize.

I first removed all of the entries that contained no userID because we cannot trace the actions performed back to a proper user. This data may weaken our model if we keep it in the dataset, so it is best we remove it. We also do not have enough information to fill it in, otherwise that would be the better approach.

I also created a column called time which contains datetime objects. This column will convert the previous ts column to a datetime object. This is done because it is much easier to read the date and time of the occurrence of the action in this format compared to the ts column.

I converted the level column into two dummy columns called paid and level. This will allow me to utilize this information in my model.

I did something similar to the gender column as well.

Finally, I created a downgrade column which will catch if a user has downgraded their service.

A lot of the columns of which are strings in the various columns. I will only convert the columns that will eventually be used in building the model. There is no need to convert all of the columns if they all will not be used.



When looking at the data and analysis we have done thus far. The columns that will be included in the model will be: Downgrade, Gender, and Level.

I created vectors for each of the features and then normalized the vectors so the scale does not significantly impact the model.


Now, comes time to split the data into a training and testing set. This will allow me to train and form my model and then test how well it preforms.

I chose to go with a 90/10 split for training and testing. This is good because it optimizes the model creation and still leaves me a decent bit of data to test with.

Selecting the Model

This part of the project is less concrete and more up to intuition and design. I decided to go with a Decision Tree for my model compared to others. I decided to go with a Decision Tree because I felt that by providing enough binary data though gender, level, and downgrade, it would allow the data to be split up in a way that would separate those user that have churned from those that haven’t.

The ability for the Decision Tree to split data up and categorize efficiently and effectively made me feel that this would be the best model to best solve our churning problem.

When coding the Decision Tree, there were many complications such as trying to properly set up the features so the model can properly be fit to the training data. Another issue I ran into was setting up the cross validation and tuning the parameters in a time efficient manner.


A parameter grid was used to find the best parameters and optimize the model. I chose to optimize one parameter because the more parameters you add the longer it takes to optimize.

I decided to tune the maxDepth parameter of the Decision Tree. The two options I passed along were 2 and 10.

I then fed it all into the cross validator with 3 folds to get the most optimal model.

Then I fit the model.

This cross validation refines my model by generating two models. It first generates a model with the maxDepth set to 2 and then calculates the F1 score. It then does it again but this time using a maxDepth of 10. It compares the F1 score of the two models and selects the model that has the higher F1 score.

The cross validation automatically uses the better model for further predictions and analysis. There is no output when running cross validation to compare the two models. It handles it all internally, and would thus be not necessary to run the models separately.


The robustness of the model comes in when you check the training speeds. This will show that the model will be able to scale efficiently and effectively.

This time shows that this model will be able to handle larger datasets with ease. The time posted also includes creating the model twice because of the optimization with cross validation, so the time is slightly inflated when compared to running the properly tuned and efficient model.


I transformed and saved the predictions for my testing data. I then tested the accuracy and F1 score for my model.

Since there is a fairly small amount of users that have churned, F1 score is a better metric for assessing how good of a model was produced. Since it has an F1 score of 1.0 that means that my model was very accurate at solving our initial problem of being able to identify users that are likely to churn.

This means that Sparkify can take necessary actions such as reducing the price of their service through discounts to prevent users from churning. This will benefit the company in the long run by providing income that my have gone unrealized.


Improvements to the model were made by running cross validation. Cross validation handled optimizing the model. After optimizing the model, we got these parameters that provide the best and most effective model at predicting churn.

This shows that the maxDepth was more optimal at the setting of 2 rather than 10. This means that a maxDepth of 2 provided a higher F1 score and thus a better and more improved model when compared to the maxDepth of 10.


From the results obtained, this proves that my model effectively predicts user churn. This means that the tuning of my model worked out really well. The accuracy and F1 score show that my model was able to generalize really well.


In conclusion, I have gone through step-by-step on how to load, clean, explore, visualize, create features, model, and check for model efficiency. I started with the problem of accurately predicting if a user is likely to churn. I needed to find variables that would accurately predict this.

I first removed any data that might mess up the models and raise any sort of inconsistencies. Then after cleaning the data I explored the data to learn as much about the various variables as I could.

After analyzing the variables, I decided to go with gender, level, and if a user has downgraded. I felt that these variables would provide sufficient information when deciding churn.

Then came the fun part of feature engineering and modeling. I needed to create dummy variables for each of the variables I wanted to use. This is because all of those columns were strings. In order to model, all of the data needed to be in integer format.

After creating my new features, I modeled my data using a Decision Tree. I then used cross validation to further tune and refine my model so that it preforms efficiently and effectively.

Finally, I analyzed the model using my testing set. I analyzed the accuracy and F1 score metrics produced by the testing set and came to the conclusion that my model predicts churn extremely well.


In this journey, we utilized Apache Spark and the PySpark Machine Learning library to navigate this massive dataset. These tools allowed for us to have the flexibility to scale with the size of the data and still produce effective and efficient models. The goal of this project was to predict churn, which we did an amazing job at. The steps taken to achieve this goal were loading, cleaning, EDA, feature extraction, modeling, evaluating, and tuning. We also had to find variables that would help us predict our desired outcome.

Lessons Learned

I learned a lot about building models for scalability. Previously, I often designed models with sci-kit learn. This project opened my eyes to the downsides of that library and at the same time opened my eyes to Spark. I feel that I learned many valuable lessons such as how to effectively utilize Spark and how Spark actually operates.

I feel that knowing Spark is very helpful when it comes to working for a company because as we all know, these massive companies have even bigger datasets. Spark is a valuable tool to help navigate the giant data lake.


There can be improvements made to this model. For example, one important variable that could be utilized would be the amount of time the user has had each level of service. This might reveal that users that have had their service for a long time might be less likely to churn compared to those who have just been introduced to the service. My prediction is that this variable will help the model become better than it currently is.

This was an interesting and fun challenge to predict churn for a company. It was also fun using Spark because it allows for scalable solutions that can be implemented by major companies. Churning is something that all companies deal with and this model might get them one step closer to minimizing this problem.

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store