On Dataset Shift in Machine Learning
Best Practices Presented by Peter Prettenhofer.
During the North Star AI conference, powered by Proekspert, Peter Prettenhofer, VP of Engineering at DataRobot, presented one of the most fundamental questions in machine learning: What happens if the data applied to a model looks quite different from the data that the model was initially trained on?
This phenomenon is often referred to as a ‘concept shift,’ where changes in the relationship between input features and outcomes can lead to a decrease in model accuracy, impacting both model performance and target distribution.
The obvious question that naturally follows, noted Prettenhofer, is: Does the model, in this case, need to be retrained?
Introduction
In the ever-evolving field of machine learning, one of the most significant challenges is ensuring that models remain accurate and reliable over time. This is where the concept of dataset shift comes into play. Dataset shift occurs when the data distribution changes between the training phase and the deployment phase, leading to potential degradation in model performance. Understanding and addressing dataset shift is crucial for building robust machine learning systems that can adapt to real-world changes.
In Supervised Machine Learning
In supervised machine learning, the goal is to train a model on a labeled dataset to make accurate predictions on unseen data. However, in real-world scenarios, the data distribution can change over time, leading to a phenomenon known as data shift. Data shift occurs when the training data and the data on which the model is deployed have different distributions. This can significantly impact the performance and reliability of machine learning models. For instance, a model trained on historical sales data may struggle to make accurate predictions if consumer behavior shifts due to economic changes or new market trends.
In supervised machine learning
In supervised machine learning — which was the focus of Prettenhofer’s talk — you are given a set of samples, X (feature vector) and Y (a class label) pairs, as well as a loss function that can be used to assess how well your model is making predictions. The goal, explained Prettenhofer, is to find a function (h) that maps from X to Y, and has a minimal error on new, unseen data.
How do you find the function h? Empirical Risk Minimization tells you to pick the one that minimizes the error on the training data. Like most machine learning techniques, it assumes that training and test data need to be drawn from the same distribution, specifically an independent identically distributed (IID) distribution. The identical aspect means that the distribution that generates data over time stays fixed. The independence aspect boils down to sampling a sample that has no effect on the next one. This means that you need to submit to the idea that all data is drawn from some probabilistic process.
Uses of machine learning
In dataset shift, said Prettenhofer, you fiddle with this fundamental assumption of machine learning. You go from a classical view that training data and test data are drawn from a fixed but unknown distribution, to a view where training and test data are drawn from source distribution and target distribution. The question is whether or not learning is possible under such conditions.
Arguably, if training and test distributions differ, then learning is not possible; if the future has nothing in common with the past, then you cannot learn. Prettenhofer concluded that supervised learning techniques are thus negatively affected by dataset shifts. However, there are simple methods that you can use to detect this kind of change.
What is Data Shift?
Data shift refers to the difference in the distribution of the training data and the data encountered during deployment. It is the mismatch between the data used to train the model and the data it will encounter in the real world. Data shift can occur due to various reasons, such as changes in user behavior, changes in the environment, or changes in the data collection process. Understanding data shift is crucial for building reliable and robust machine learning systems. When data shift occurs, the assumptions made during the training phase no longer hold, leading to potential inaccuracies in model predictions.
Characterizing and Identifying Dataset Shift
Prettenhofer presented a taxonomy of methods that can be used to identify whether or not dataset shift is a problem you may have. In general, there are two methodologies for detecting dataset shift issues, supervised and unsupervised. Prettenhofer noted that he will focus on unsupervised methods because they are more common and practical.
1) Statistical Distance
The Statistical Distance method is useful for detecting if your model predictions change over time. This is done by creating and using histograms. By making histograms, you are not only able to detect whether your model predictions change over time, but also check if your most important features change over time. Simply put, you form histograms of your training data, keep track of them over time, and compare them to see any changes. Changes in the joint distribution of inputs and outputs can be detected using statistical distance methods. This method, noted Prettenhofer, is used most commonly by financial institutions on credit-scoring models. There are several metrics that can be used to monitor the change in model predictions over time. These include the Population Stability Index (PSI), Kolmogorov-Smirnov statistic, Kullback-Lebler divergence, and histogram intersection.
2) Novelty Detection
A method that is more amenable to fairly complex domains such as computer vision, is Novelty Detection. The idea is to create a model for modeling source distribution. Given a new data point, you try to test what is the likelihood that this data point is drawn from the source distribution. Monitoring raw input data in machine learning workflows presents challenges due to the various sources, formats, and structures, as well as restricted access for ML engineers once processed by a data platform team. For this method, you can use various techniques such as a one-class support vector machine, available in most common libraries. If you are in a regime of homogenous but very complex interactions, then this is a method you should look into, because in that case, the histogram method won’t be effective.
3) Discriminative Distance
The Discriminative Distance method is less common, nonetheless, it can be effective. The intuition is that you want to train a classifier to detect whether or not an example is from the source or target domain. You can use the training error as proxy of the distance between those two distributions. The higher the error, the closer they are. Discriminative distance is widely applicable and high dimensional. Though it takes time and can be very complicated, this method is a useful technique if you are doing domain adaptation.
Another important concept to consider is the ‘prior probability shift,’ where the distribution of the target variable changes while the distribution of input features remains constant. This shift can significantly impact model performance, especially in scenarios where the class variable distribution is affected without changes to the input distributions.
Understanding Covariate Shift
Covariate shift is a type of data shift that occurs when the input distribution changes, but the conditional probability of the output given the input remains the same. In other words, the input features change, but the relationship between the input and output variables remains the same. Covariate shift can be caused by changes in the state of latent variables, which can be temporal, spatial, or less obvious. For example, in a credit risk assessment model, the input features such as income and credit score may change over time, but the relationship between these features and the output variable (credit risk) remains the same. Addressing covariate shift involves techniques like importance reweighting to ensure the model remains accurate despite changes in input distribution.
Understanding Concept Drift
Concept drift is a type of data shift that occurs when the relationship between the input and output variables changes over time. In other words, the input features and output variable may remain the same, but the underlying relationship between them changes. Concept drift can be caused by changes in the underlying data distribution, such as changes in population demographics or changes in user behavior. For example, in a recommender system, the user preferences may change over time, leading to a change in the relationship between the input features (user demographics) and the output variable (recommended products). Detecting and adapting to concept drift is essential for maintaining the relevance and accuracy of machine learning models in dynamic environments.
Correcting Dataset Shift
How do you correct dataset shift? The short answer, said Prettenhofer, is that you don’t — you retrain your model. If possible, you should always retrain. Of course, in some situations, it may not be possible, for example, if there are latency problems with retraining. In such cases, there are two easy techniques for correcting dataset shift.
1) Importance Reweighting
The main idea with Importance Reweighting is that you want to upweight training instances that are very similar to your test instances. Essentially, you try to change your training data set such that it looks like it was drawn from the test data set. The only thing required for this method is unlabeled examples for the test domain.
2) Change Representations
You can also change the representation of your data to make the two distributions appear more similar. Given that you can still make good predictions on the training data, this means that this mapping will help you during knowledge transfer. An example of this is feature selection. A feature that differs a lot during training and test, but does not give you a lot of predictive power, should always be dropped. This is one of the most important lessons in understanding dataset shifts, said Prattenhofer.