Menu

KL Divergence – What is it and mathematical details explained

At its core, KL (Kullback-Leibler) Divergence is a statistical measure that quantifies the dissimilarity between two probability distributions.

Think of it like a mathematical ruler that tells us the “distance” or difference between two probability distributions.

Remember, in data science, we’re often working with probabilities – the chances of events happening.

So, if we have two models giving us different probability distributions for the same event, KL Divergence helps us figure out how different these two models are.

In this post we will look at:

  1. Simple Illustration
  2. Why is KL Divergence Important?
  3. How to Calculate KL Divergence?
  4. How to compute KL Divergence from scratch?
  5. Common Pitfalls and how to address them

Simple Illustration

Imagine you have two bags of marbles. One bag has 8 red marbles and 2 blue marbles, while the other has 5 of each.

Now, if I asked you, “How different are these bags?” you’d likely think about the proportion of marbles in each bag. KL Divergence does the same but in a mathematical way.

Why is KL Divergence Important?

In the machine learning realm, we often want our models to make predictions as close as possible to the real-world outcomes. Sometimes, we have a target probability distribution (like the real outcomes) and a predicted distribution from our model. KL Divergence helps us understand how much our model’s predictions deviate from the target.

It can also be used post production during model monitoring to see how different the predictors and targets are different from the base distributions.

How to Calculate KL Divergence?

The formula for KL Divergence is:

KL(P∣∣Q) = ∑P(i)×log(Q(i)/P(i))

Where:

  • P is the true probability distribution.
  • Q is the probability distribution of the model.

However, P and Q need not necessarily be True vs model probability distributions. It can be any two probability distributions in a different context.

This might look a bit complex, but let’s simplify it:

Simple Explanation:

Remember our two bags of marbles?

Let’s assume Bag 1 is our true distribution and Bag 2 is our model’s predicted distribution.

For every marble color in Bag 1, we check how many of that color are in Bag 2. The larger the difference, the more the total divergence score goes up.

Real-world Python Example:

Suppose we’re working on a weather prediction model. We have real data (actual outcomes) for whether it rained (1) or not (0) on certain days, and our model’s predictions for the same days.

import numpy as np
from scipy.stats import entropy

# True data: 0.9 probability of no rain and 0.1 probability of rain
p = np.array([0.9, 0.1])

# Model's prediction: 0.8 probability of no rain and 0.2 probability of rain
q = np.array([0.8, 0.2])

# Calculate KL Divergence
kl_divergence = entropy(p, q)
print(kl_divergence)
0.036690014034750584

If you look at the formula for KL Divergence, KL(P||Q) != KL(Q||P).

This implies, KL Divergence is not a symmetric measure.

# KL Divergence of Q||P
entropy(q, p)
0.04440300758688234

How to compute KL Divergence from scratch?

Let’s create a Python function to do the calculation.

def kl_divergence(p, q):
    """Compute KL Divergence between two distributions."""
    return np.sum(p * np.log(p / q))

# Using our previously defined p and q
print(kl_divergence(p q))

Common Pitfalls and how to address them

  1. Non-Symmetry: KL Divergence is not symmetric. This means KL(P||Q) != KL(Q||P). This means it cannot be used as a distance measure.

    So alternately, Jensen-Shannon entropy can be used, which is nothing but:

    JS(P||Q) = (0.5 * KL(P||M)) + (0.5 * KL(Q||M))

    Where, M = (P + Q) / 2

  2. Undefined Values: If q has a zero where p doesn’t, the division becomes undefined. It’s important to ensure no zero values in q for events that have non-zero probabilities in p.

Course Preview

Machine Learning A-Z™: Hands-On Python & R In Data Science

Free Sample Videos:

Machine Learning A-Z™: Hands-On Python & R In Data Science

Machine Learning A-Z™: Hands-On Python & R In Data Science

Machine Learning A-Z™: Hands-On Python & R In Data Science

Machine Learning A-Z™: Hands-On Python & R In Data Science

Machine Learning A-Z™: Hands-On Python & R In Data Science