Predicting User Churn for a Streaming Service using Spark.

A machine learning approach to predict churn from user logs using Spark.

Photo by Zarak Khan on Unsplash

Over the last 15 years streaming services such as Netflix and Spotify changed the way we watch movies or listen to music. Since then, many new service providers have emerged and are trying to poach customers from the established providers.

As company offering a streaming service it is important to have an effective customer retention strategy, to stop users changing their provider. How can a company predict when a user will stop using their service ahead of time?

In this article I will walk you through a project on how to predict users who are about to churn from the streaming service using only the recorded user activity. This project is structured following the CRISP-DM, the Cross Industry Process for Data Mining.

Before working with the data itself, it is relevant to understand the business context and therefore, the goal of the project. Why would a company spend resources in the first place to predict user churn ahead of time?

Business Understanding

The dataset contains user interactions of a (fictional) music streaming service called Sparkify with a business case similar to Spotify. The service can be used on two levels: free tier or premium tier.

Both levels of service generate revenue for Sparkify. The free tier service is financed by advertisement in between the songs, the premium tier consists of a monthly subscription fee to enjoy the service without advertisement. At any moment a user can decide to downgrade/upgrade from free to premium or to cancel from the service completely. Goal of this project is to predict those users who will cancel the service completely.

Defining Churn

Customer churn is generally defined as when a customer unsubscribes from a service, ceases to purchase a product or stops engaging with a service [1]. In the case of the music streaming service of this project, churn will be defined as when a user cancels from the service completely by deleting his/her user account. This can happen for both, paid and free tier users.

Usually, it is more expensive for a business to acquire new customers than retaining existing customers [2]. To prevent churn candidates from leaving, special discounts or other costly measures are offered to customers. These measures typically lower the revenue per customer. Therefore, the goal is to identify users who are about to churn ahead of time with high precision and only target them with marketing campaigns.

Churn prediction is an important classification use case for streaming services such as Netflix, Spotify or Apple Music. Companies that can predict customers who are about to churn ahead of time can also implement more effective customer retention strategies. Here machine learning algorithms can help for these kinds of classification problems.

In this project I will test three machine learning algorithms for the classification problem of user churn: Logistic Regression, Random Forest Classifier and Gradient-Boosted Tree Classifier.

After understanding the business context of the data and explaining the importance of predicting churn ahead of time, the next step is to explore the data available. The aim is to provide an overview of the available data and its quality.

Data Understanding

Basis of the project is a 12 GB user log containing all information about user interactions with the online streaming service. The data is stored in an Amazon Web Services (AWS) Simple Storage Service (S3) bucket as JSON format. Datasets of such a large scale are challenging to process on a single computer and can therefore be referred to as big data.

Spark for Big Data

Apache Spark is a tool for large scale data processing and will be used to work with the dataset. It allows to efficiently spread data and computations across a network of distributed computers, called clusters. Each cluster has nodes (computers) which do the computations in parallel.

To reduce necessary computation, the exploration of the data will be done in a small subset of the full dataset. The full dataset will be processed afterwards in Amazon Web Services (AWS) with an Elastic Map Reduce (EMR) cluster of 3 m5.xlarge machines.

Map Reduce is a technique developed by Google to process data in parallel among distributed machines. It works by first dividing up a large dataset and distributing the data across a cluster. In the map step, each data subset is analyzed and converted into a key-value pair. Then, these key-value pairs are shuffled across the cluster so all keys are on the same machine. In the final reduce step, the values with the same keys are combined together.

Exploring a subset of the data

The Sparkify data is an user log formatted as a table with 18 columns. The small subset contains only 286,500 rows. Each row represents an API event like a login or playing the next song. There are numerical and categorical columns. The screenshot below shows the columns of the user log data table.

By wrangling with the data, it is possible to get a deeper understanding. For example, filtering the “userId” and “gender” columns shows that there are 225 unique user Ids, of those users are 121 of male gender and 104 female.

The values in column “itemInSession” count the interactions which happened for one user during the same session Id. Which type of user interaction/ API call happened is described by the values in the “page” column.

Possible user interactions with the service and the number of their appearances in the dataset are shown in the graph below. The most occured page event is “NextSong”, which is the main function of the streaming service. It seems like the “NextSong” page gets loaded automatically once a song ends.

The Home page is the page the user enters when starting a streaming session and the second most common called page event. There is the exact same amount of “Cancel” events as there is for “Cancellation Confirmation” events. Therefore, the “Cancel” event seems to be part of the churning process and it won’t be used for predicting churn in this project.

The value in column “length” represents a song’s duration and therefore is null for all page events other than “NextSong”. Using time passed between a “NextSong” event and the following event, it would be possible to calculate the time a user spent listening to a song.

In the dataset each user accesses the streaming service always from the same location. This could mean that users access the service only from home. This hypothesis is supported by the fact that in the column “user agent” there are no entries for mobile devices.

It seems as the user log of this dataset contains a period of two months as the column “ts” containing timestamps has the minimum value October 1st and maximum value December 3rd.

After getting an understanding of the data, the next step of the CRISP-DM is to prepare the dataset for the model to train.

Data Preparation

The first step of data preparation is cleaning the data from invalid or missing data. In this project, invalid data includes records without user ids or session ids. After that, an exploratory data analysis will be conducted to find possible features for the customer churn prediction.

Data Cleaning

There are user Id values with empty strings. These empty user Ids appear for instance when a user enters the streaming service without logging in. All records (8346) containing an empty user Id will be dropped, resulting into 278,154 rows in the cleaned dataset. These were the users with authentification status “Logged Out” (8249) and “Guest” (97).

Goal of the following exploratory data analysis is to observe differences in the behaviour of customers who stayed, versus customers who churned. One way is to explore aggregates on these two groups of users, observing how much of a specific action they experienced per a certain time unit or number of songs played.

Labeling data

First a column “churned” will be created to use as a label for differentiation between customers who churned and those who stayed with the service. This column will later be used as label for training the supervised machine learning model. The “Cancellation Confirmation” events serve to define the exact moment of churn, which appear for both paid and free users.

The table below shows an example of the last six user’s interactions of the user called Adriel before churning. After listening to five songs, Adriel downgrades a song, then enters the cancel page and deletes his account.

|firstName| page | artist|
| Adriel|NextSong | Tonic|
| Adriel|NextSong | Arch Enemy|
| Adriel|NextSong |Les Ogres De Barback|
| Adriel|NextSong |The Notorious B.I.G.|
| Adriel|NextSong | Nickelback|
| Adriel|Downgrade | None|
| Adriel|Cancel | None|
| Adriel|Cancellation Confirmation| None|

Check for imbalance in Data

A new column “cancellation_event” is created to mark the exact event of cancellation confirmation. With the new column it is possible to check, if the dataset is balanced regarding the number of users who eventually churn and those who stay. This is interesting, because of the users amount who churn is substantially lower than those who will not churn. Then, there will be less examples for the model to train how these users behave.

There is an imbalance in the dataset regarding the amount of users who churned versus those who stayed. Of a total 225 users, only 52 users eventually churned, which equals a churn rate of 23.11% .

How does this imbalance scale in amount of users on the amount of interactions? Only 16.13% of user interactions are by churned customers. The amount of data available regarding interactions to analyse the difference in behaviour for users who stayed versus users who churned is clearly imbalanced.

Imbalance in the training data can lead to naive behaviour in the prediction of the supervised machine learning model. With 76.89% of users not churning a prediction accuracy of 76.89% can be achieved by simply always predicting “not churned” [3].

There are different ways to handle imbalanced data before feeding it into machine learning algorithms. One way would be to manipulate the input data by either undersampling data of loyal users, oversampling the data of churned users or generating synthetic data. In this project the way to handle the imbalance in data will be by creating additional features and choosing appropriate performance metrics.

Feature Creation

The next step is Feature creation for the machine learning model. Features are created from the available data with the goal to allow the model to distinguish between users who churn and those who not. An useful feature exposes differences in the behaviour of loyal users and those who probably churn.

One example for a possible feature is the time since registration. The mean duration from registration to last interaction with the streaming service is 57.8 days for users who eventually churn and 87.1 days for users who stay with the service. It seems intuitive that loyal users, in average, stay longer with the service.

| true| 57.80769230769231|
| false| 87.14450867052022|

The violin plot below shows that this difference in mean is also visible in the distribution of the values for users who churned versus users who did not churn.

Therefore the time passed since a user registered to the streaming service is a feature that could be useful to predict users who are prone to churn.

Feature Selection

After exploring possible features for the prediction model, the next step is to select the features which should be used for the machine learning model to decide if a user will churn or not.

The resulting features consist of 18 numerical features and two binary features. The binary features are gender and service level of the user. Among the numerical ones are:

  • percentage_active_day: the percentage of days a user actually accessed the service during his registration period
  • streaming_per_active_day: the accumulated length of songs the user listened to during a day
  • songs_per_homevisit: the amount of songs a user listened to between visiting the home page
  • days_registrated: the amount of days passed since the user registration
  • event_per_songs_played: For each possible event, like visiting the home page or giving a Thumbs Down, there is a feature representing the amount of occurances of this event for a user relative to the amount of songs played for the same user

Check for Multicollinearity in Features

If the model will be based on algorithms like Logistic Regression or Linear Regression, the features have to be checked for Multicollinearity. When features have a high correlation and one feature can be predicted from other features there might be Multicollinearity. This can have misleading results in the prediction of the label.

Decision trees and boosted trees algorithms are immune to multicollinearity. When they decide to split, the tree will choose only one of the perfectly correlated features. 3

The following graph shows the pairwise Pearson correlation among the features created. In this case the Pearson’s correlation coefficient is the covariance of two features divided by the product of their standard deviations. There is no pair of features whith a Pearson’s coefficient higher than +0.62 or lower than -0.63.

Furthermore, the correlation matrix above shows at the bottom the correlation between the label “churned” and the different created features. The highest positive Pearson correlation coefficient with the label has the feature describing the amount of Thumbs Down given by a user per songs played. The lowest negative Pearson correlation coefficient with the label has the feature describing the days since registering to the service. This already indicates that amount of Thumbs Down given by a user and the time since registration are useful features to predict users who eventually churn.

Data Modeling

In the following section three binary machine learning classifiers will be compared to find the model best for predicting user churn with the available data. Those are Logistic Regression, Random Forest Classifier and Gradient-Boosted Tree Classifier.

Preprocessing Data for Modeling

To be able to train the Machine Learning Models provided by the Spark library, the data has to be preprocessed.

The categorical data will be transformed into numerical values using a String Indexer. The numerical data will be normalised using the Standard Scaler so that they are all on approximately similar scale. The Standard Scaler divides each dimension by its standard deviation, once it has been zero-centered. Furthermore all features have to be assembled into one single vector.

Before training the models with the prepared data, the data will be split into training and test data to be able to measure model performance via cross validation. The models will be trained with 80% of the available data and the performance validated with the remaining unseen 20% of the data.

Training Machine Learning Models

In a first step the data will be trained and predicted using default parameters of Spark. After that, the predictions of each model will be evaluated to choose an algorithm which will be further investigated. For this algorithm the goal will be to maximise the prediction performance by tuning its parameters. The following three machine learning classifiers will be compared: Logistic Regression, Random Forest Classifier and Gradient-Boosted Tree Classifier.

Logistic Regression is a statistical model used to solve classification problems. Like any other supervised learning problem, Logistic Regression tries to learn a function that can predict the label given the feature values. In a classification model the values are discrete values. Using maximum likelyhood Logistic Regression fits a sigmoid function to the available data.

Random Forest Classifier is an ensemble method that fits a number of decision tree classifiers on various sub-samples of the dataset and uses averaging to improve predictive accuracy and control overfitting. First data is bootstrapped, which means randomly selecting samples with allowance to pick multiple times the same data point. Then, a number of decision trees is created with a random subset of the variables. The prediction is made by aggregating the prediction of all trees.

Gradient-Boosted Tree Classifier is an ensemble method boosting decision trees. Boosting is a method of converting weak learners into strong learners. Gradient-Boosted Algorithms are similar to AdaBoost Algorithms, they differ in the way they determine the error of the previous iteration.

The Gradient-Boosted Tree Classifier uses stumps as classifiers. A stump is a tree with just one node and two leaves. It begins by fitting a classifier on the original dataset. After evaluation of the first classifier, it creates a second stump on the same dataset, but weights of incorrectly classified instances are adjusted such that subsequent classifiers focus on more difficult cases. The idea is to improve upon the predictions of the first stump. This process is repeated a few specified number of iterations. Subsequent stumps help to classify observations that are not well classified by the previous trees. The prediction of the final ensemble model is therefore the weighted sum of the predictions made by the previous tree models.

Gradient boosting identifies the shortcomings by using gradients in the loss function. The loss function is a measure indicating how good are model’s coefficients are at fitting the underlying data.

Performance Metrics

As earlier mentioned, the imbalance in the data regarding churned users makes it possible for a model to achieve a prediction accuracy of 76.89% by simply always predicting “not churned”. Therefore, accuracy is not a suitable evaluation metric.

Thus, in this project the following two metrics will be used to measure the performance of the models: AUC and F1-Score.

AUC (Area under ROC) is a binary classification metric which works well for imbalanced data-sets. A ROC curve (Receiver Operating Characteristic curve) is a graph showing the performance of a classification model at various threshold settings. The threshold defines the boundary on which output probabilities are interpreted as positive or negative class predictions.

ROC is a probability curve and AUC represents the degree or measure of separability. It tells how much the model is capable of distinguishing between classes. This works fine for the development environment, but for deployment of the model it is necessary to define an exact threshold to interpret the prediction probabilities.

To understand the F1-Score, it is important to first understand Precision and Recall. Precision is in this case is the ratio of the users who were correctly predicted as churned to all predicted useres predicted as churned (including those useres who were wrongly predicted as churned). Recall is the fraction of users who were predicted correclty as churned among all users who churned.

F1-Score is the harmonic mean of the precision and recall. For the F1-Score to be high, both precision and recall need to be high. F1-Score is an evaluation metric fit for imbalanced input data.

For Sparkify high precision and high recall are important when predicting churn. High recall is important to correctly identify all the users who are about to churn and high precision is important to not give too many discounts to users who are not actually probable to churn.

Evaluation of the Modelling Results

In the first step all three machine learning models are trained with the preprocessed data and their default parameters. It becomes clear that in the case of this project the Gradient Boost Tree Classifier performs best with the chosen evaluation metrics. The table below shows that the GBT Classifier achieved 67% F1-Score with less than three minutes training time on the local machine.

|Metric |Logistic Regression |Random Forest | Gradient Boost|
|Training time | 95s | 145s | 156s |
|F1-score | 65% | 57% | 67% |
|Area under ROC| 48% | 69%. | 70% |

Model Selection and Parameter Tuning

After selecting the GBT Classifier for its best performance with the default paremeters (maxDepth=5, maxIter=20, maxBins=32 and minInfoGain=0.0), the next step is tuning the parameters.

The following parameters were tuned:

  • maxDepth: Maximum depth of the tree where depth 0 means 1 leaf node, depth 1 means 1 internal node + 2 leaf nodes .
  • maxIter: max number of iterations.
  • maxBins: max number of bins for discretizing continuous features.
  • minInfoGain: Minimum information gain for a split to be considered at a tree node.

Using gridsearch on the GBT Classifier and 3-fold cross validation, it was possible to increase the F1-Score up to 72% F1-Score with the following parameter values:

max depth:5, max iteration:15, max bins:32, minInfoGain:0.1

Changing the parameter “minInfoGain” from 0.0 to 0.1 increased the performance of the model as overfitting was reduced. Reducing the maximum amount of iterations from 20 to 15 had the same effect.

|Metric | Gradient Boost|
|F1-score | 72% |
|Area under ROC| 58% |

Running grid search with 3-fold cross validation was a challenge for the local machine. With three parameters and each two parameter values it took about one hour (156s * 3 folds * 2 * 2 * 2 = 3744s) to train the model.

With the GBT Classifier and the Random Forest Classifier it is possible to see the feature importance. Which feature contributed how much in predicting the probability of a user leaving the service. The graph below shows that the time since registration is the most important features among those available to the model. Furthermore the amount of Thumbs Up and Thumbs Down per songs played seemed to have helped the model predict correctly. The feature importance graph and the correlation matrix earlier show similar tendencies.


A F1-Score of 72% for this relatively small subset of data is acceptable. Hyperparameter tuning and cross-validation did not increase the F1-Score substantially, probably because of a small number of sample size. Using the complete dataset, and therefore more training data, the F1-Score will most probably improve. Another option would be to try out more supervised classification approaches like XGBoost or LightGBM or a custom ensemble method.

Furthermore, it became clear which features are most relevant in the identification of users who will churn. Sparkify could use the Gradient-Boosted Tree Classifier. One possibility would be to maximise Precision for the price of a worse Recall. This would mean that there are few false-positives in the prediction, while not detecting all users who eventually churn. The churn rate would be lowered while only running customer retention strategies on customers who are about to churn with a high probability.


In this article I analysed user interactions of a streaming service using Spark. After understanding the business importance of churn, I explored a small subset of data to get a comprehension of the data and its quality. Then, I cleaned the data from missing values and created features which allow a machine learning model to find differences in the behaviour of users who stay with the service versus those who churn. With the prepared data I trained different classifiers and evaluated their performance with a suitable metric for imbalanced data. The result of the project is a Gradient-Boosted Tree Classifier which can identify users who will churn with high precision and high recall.

The next step is to apply the model to the full dataset on AWS. When working with the full dataset on AWS other factors besides accuracy of the model become important too. Among those are the configuration of the run time memory usage and in general the resource optimisation. As the model gets deployed to run periodically, the frequency on which the model is run has to be defined depending on costs, data latency and business requirements. After deployment, the model results have to be monitored continously as well as the operational costs to make sure that the model brings value to the business.

Find the complete project data on Github.

Data Analyst and Machine Learning Enthusiast working at BMW Group. A mechanical engineer transitioned to software engineer.

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store