Menu

PySpark Missing Data Imputation – How to handle missing values in PySpark

Handling missing data is an essential step in the data preprocessing pipeline. let’s explore various methods to impute missing values in PySpark, a popular distributed data processing framework.

We will discuss different techniques, such as mean, median, mode imputation, and using machine learning algorithms to fill in missing values. By the end of this post, you will have a clear understanding of how to implement these techniques in PySpark and choose the most suitable one for your use case.

missing data is broadly classified into 3 broad types.

Types of missing values

  • Missing completely at random (MCAR): Probability of missing is the same across all the variables.

  • Missing at random (MAR): Missing at random but it is possible to predict the missing value based on other variables.

  • Not missing at random (NMAR): Can be handled by studying the root cause of missing

Common reasons for missing data

  1. Data might not be available for the complete time period of analysis

  2. Non-occurence of events. for example, a student’s exam mark may be missing because he/she may not have taken the test.

  3. Skipped response for some questions of the survey

  4. Non-applicability of questions

  5. Missing out at random

What to do when the data is missing?

When there are missing values in data, you have four options:

  • Approach 1: Drop the row that has missing values.
  • Approach 2: Drop the entire column if most of the values in the column has missing values.
  • Approach 3: Impute the missing data, that is, fill in the missing values with appropriate values (like mean, median, mode..).
  • Approach 4: Use an ML algorithm that handles missing values on its own, internally

In this post i will be covering Approaches 1, 2 & 3

Different ways to Impute Missing Values:

1. Mean, Median, and Mode Imputation: The simplest way to fill in missing values is by using the mean, median, or mode of the available data. These techniques are suitable for cases where the missing values are randomly distributed and not too numerous. We can use PySpark’s DataFrame API along with the Imputer class from the pyspark.ml.feature module to achieve this.

2. K-Nearest Neighbors Imputation: K-Nearest Neighbors (KNN) is a machine learning algorithm that can be used for imputing missing values by finding the K nearest neighbors of the instance with the missing value and filling in the missing value with the average of these neighbors. We can use the KNNImputer class from the pyspark.ml.feature module.

3. Regression Imputation : Regression imputation is a method where we train a regression model to predict the missing values based on other features in the dataset. This method is useful when there’s a strong correlation between the missing feature and the other features. We can use the LinearRegression class from the pyspark.ml.regression module.

1. Import required libraries and initialize SparkSession

First, let’s import the necessary libraries and create a SparkSession, the entry point to use PySpark.

import findspark
findspark.init()

from pyspark import SparkFiles
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, countDistinct, count, when, isnan, isnull
from pyspark.sql.types import IntegerType, StringType, NumericType

spark = SparkSession.builder.appName("ImputeMissingValues").getOrCreate()

2. Preparing the Sample Data

To demonstrate the Variable type Identification, we’ll use a sample dataset. First, let’s load the data into a DataFrame

url = "https://raw.githubusercontent.com/selva86/datasets/master/Churn_Modelling_m.csv"
spark.sparkContext.addFile(url)

df = spark.read.csv(SparkFiles.get("Churn_Modelling_m.csv"), header=True, inferSchema=True)
df.show(2, truncate=False)
+---------+----------+--------+-----------+---------+------+---+------+--------+-------------+---------+--------------+---------------+------+
|RowNumber|CustomerId|Surname |CreditScore|Geography|Gender|Age|Tenure|Balance |NumOfProducts|HasCrCard|IsActiveMember|EstimatedSalary|Exited|
+---------+----------+--------+-----------+---------+------+---+------+--------+-------------+---------+--------------+---------------+------+
|1        |15634602  |Hargrave|619        |France   |Female|42 |2     |0.0     |1            |1        |1             |101348.88      |1     |
|2        |15647311  |Hill    |608        |Spain    |Female|41 |1     |83807.86|1            |0        |1             |112542.58      |0     |
+---------+----------+--------+-----------+---------+------+---+------+--------+-------------+---------+--------------+---------------+------+
only showing top 2 rows

3. Approach 1: Drop the row that has missing values

When to drop rows?

1. Size of the dataset: If you have a large dataset, dropping rows with missing values might not have a significant impact on your analysis. However, if your dataset is small, you might want to preserve as many rows as possible.

2. Proportion of missing values: Calculate the percentage of missing values in each row or column. As a rule of thumb, if a column has more than 50-60% of missing values, it might be a good idea to drop the column, as it may not provide much useful information.

However, this threshold may vary depending on the data and problem you’re trying to solve. For rows, you can use a similar threshold or drop rows with missing values only if it doesn’t significantly reduce your sample size.

3. Examine the nature of missing values: First, try to understand why the data is missing. Is it random, or is there a pattern? If the missing data is systematic (i.e., not missing at random), you may introduce bias by simply dropping rows or columns.

Custom function to Drop the row that has missing values**

Let’s write a custom function to percentage of the missing values in each row and create a new column indicating the percentage missing values

After identifying the percentage missing values next step is to drop those columns where percentage missing values to more than the desired threshold

from pyspark.sql.functions import when, isnan, count, col, monotonically_increasing_id, round

def find_row_missing_percentage(df):
    """
    This function takes a PySpark DataFrame as input and returns a dataframe after adding 
    columns that have missing values and the percentage of missing values in each row.

    :param df: PySpark DataFrame
    :return: PySpark DataFrame with percentage missing in each row
             List containing percentage of missing rows
    """
    # Count the number of columns
    total_columns = len(df.columns)

    # Find rows with missing values and count missing values in each row
    missing_values_df = df.select("*",
        sum(when(isnan(col(column_name)) | col(column_name).isNull(), 1).otherwise(0)
            for column_name in df.columns).alias("missing_count"))

    # Calculate the percentage of missing values for each row
    missing_percent_df = missing_values_df.withColumn("percent_missing", col("missing_count") / total_columns * 100)

    missing_percent_df = missing_percent_df.withColumn("percent_missing", round(missing_percent_df["percent_missing"], 0))

    # Calculate percentage of missing rows in dataframe 
    No_of_missRows = missing_percent_df.select('missing_count').where(missing_percent_df.missing_count>0).count()
    Percent_miss_rows = No_of_missRows/df.count()

    return missing_percent_df, Percent_miss_rows

Now that the imputation function is ready lets execute and see the results

# Execute/calling the function
df_misscount, Percent_miss_rows = find_row_missing_percentage(df)

# Print Percentage Missing Rows in the DataFrame
print("% Missing Rows:", "{:.2%}".format(Percent_miss_rows))
% Missing Rows: 0.92%

Find the maximum value of percentage missing values in each row use this value to determine the threshold to drop the missing rows

# Print Row value having with max % missing
your_max_value = df_misscount.agg({"percent_missing": "max"}).collect()[0][0]
print("Row value with max % missing:", your_max_value)
Row value with max % missing: 14.0
df_misscount.show(2)
+---------+----------+--------+-----------+---------+------+---+------+--------+-------------+---------+--------------+---------------+------+-------------+---------------+
|RowNumber|CustomerId| Surname|CreditScore|Geography|Gender|Age|Tenure| Balance|NumOfProducts|HasCrCard|IsActiveMember|EstimatedSalary|Exited|missing_count|percent_missing|
+---------+----------+--------+-----------+---------+------+---+------+--------+-------------+---------+--------------+---------------+------+-------------+---------------+
|        1|  15634602|Hargrave|        619|   France|Female| 42|     2|     0.0|            1|        1|             1|      101348.88|     1|            0|            0.0|
|        2|  15647311|    Hill|        608|    Spain|Female| 41|     1|83807.86|            1|        0|             1|      112542.58|     0|            0|            0.0|
+---------+----------+--------+-----------+---------+------+---+------+--------+-------------+---------+--------------+---------------+------+-------------+---------------+
only showing top 2 rows

Use the output from find_row_missing_percentage function and select only those rows where percentage missing values each row is less than the missing threshold

#Drop Columns with percent_missing > 60%
missing_threshold = 60
if Percent_miss_rows >= missing_threshold:
    filtered_df = df_misscount.filter(df_misscount.percent_missing < missing_threshold)

4. Approach 2: Drop the entire column if most of the values in the column has missing values

def find_column_missing_percentage(df):
    """
    This function takes a PySpark DataFrame as input and returns a dictionary with
    columns that have missing values and the percentage of missing values in each column.

    :param df: PySpark DataFrame
    :return: DataFrame with column names and percentage of missing values in each column
    """
    # Calculate the total number of rows in the DataFrame
    total_rows = df.count()

    # Use DataFrame's 'agg' method to apply count and when functions
    # This will count the number of null or NaN values for each column
    missing_count_df = df.agg(*[count(when(isnull(c) | isnan(c), c)).alias(c) for c in df.columns])

    # Collect the result into a dictionary
    missing_count = missing_count_df.collect()[0].asDict()

    # Calculate the percentage of missing values for each column and store in a dictionary
    missing_percentage = {column: (value / total_rows) * 100 for column, value in missing_count.items()}

    # Filter out columns with no missing values
    missing_percentage = {column: value for column, value in missing_percentage.items() if value > 0}

    missing_percentage = spark.createDataFrame([(k, v) for k, v in missing_percentage.items()], ["Variables","Percent_Missing"])
    missing_percent = missing_percentage.withColumn("Percent_Missing", round(missing_percentage["Percent_Missing"], 2))
    missing_percent = missing_percent.sort("Percent_Missing", ascending=False)

    return missing_percent

Let’s execute the find_column_missing_percentage and verify the missing values in each column.

In the returned output “Percent_Missing” column is sorted in descending order

# Example usage:
# Assuming you have a PySpark DataFrame named 'data'
missing_values_col = find_column_missing_percentage(df)

missing_values_col.show()
+---------------+---------------+
|      Variables|Percent_Missing|
+---------------+---------------+
|            Age|            0.4|
|        Balance|           0.37|
|         Gender|           0.14|
|    CreditScore|           0.01|
|EstimatedSalary|           0.01|
+---------------+---------------+

Find the names of the columns having missing values more than ‘missing_threshold’

# Find the names of the columns having missing values more than 'missing_threshold'

missing_threshold = 0.3
Miss_Vars_df = missing_values_col.filter(missing_values_col.Percent_Missing > missing_threshold).select("Variables")
Miss_Vars_df.show()
+---------+
|Variables|
+---------+
|      Age|
|  Balance|
+---------+

Drop Columns having missing vales more than threshold

Miss_Vars_list = Miss_Vars_df.rdd.map(lambda x: x.Variables).collect()

# Drop missing values columns
df2 = df.drop(*Miss_Vars_list)
df2.show(5)
+---------+----------+--------+-----------+---------+------+------+-------------+---------+--------------+---------------+------+
|RowNumber|CustomerId| Surname|CreditScore|Geography|Gender|Tenure|NumOfProducts|HasCrCard|IsActiveMember|EstimatedSalary|Exited|
+---------+----------+--------+-----------+---------+------+------+-------------+---------+--------------+---------------+------+
|        1|  15634602|Hargrave|        619|   France|Female|     2|            1|        1|             1|      101348.88|     1|
|        2|  15647311|    Hill|        608|    Spain|Female|     1|            1|        0|             1|      112542.58|     0|
|        3|  15619304|    Onio|        502|   France|  null|     8|            3|        1|             0|      113931.57|     1|
|        4|  15701354|    Boni|        699|   France|  null|     1|            2|        0|             0|       93826.63|     0|
|        5|  15737888|Mitchell|        850|    Spain|Female|     2|            1|        1|             1|        79084.1|     0|
+---------+----------+--------+-----------+---------+------+------+-------------+---------+--------------+---------------+------+
only showing top 5 rows

5. Approach 3: Impute the missing data, that is, fill in the missing values with appropriate values

Mean, Median, and Mode Imputation, The simplest way to fill in missing values is by using the mean, median, or mode of the available data

We can use PySpark’s DataFrame API along with the Imputer class from the pyspark.ml.feature to fill the missing using Mean, Median or Mode. Currently Imputer support only continuous variables,so before using Imputer class let’s find out the continuous variables in the DataFrame

from pyspark.sql.types import IntegerType, FloatType, DoubleType

numeric_column_names = [column.name for column in df.schema.fields
                        if isinstance(column.dataType, (IntegerType, FloatType, DoubleType))]

Create an instance of the Imputer class by specifying the input and output columns, and the strategy for handling missing values.

from pyspark.ml.feature import Imputer

# Initialize the Imputer
imputer = Imputer(
    inputCols= numeric_column_names, #specifying the input column names
    outputCols=numeric_column_names, #specifying the output column names
    strategy="mean"                  # or "median" if you want to use the median value
)

Fit the Imputer instance on the dataset to compute the imputation statistics (mean, median, or most frequent value) for each specified column

Use the fitted Imputer model to transform the dataset and fill in the missing values.

# Fit the Imputer
model = imputer.fit(df)

Transform the dataset
imputed_df = model.transform(df)

imputed_df.show(5)
+---------+----------+--------+-----------+---------+------+---+------+-----------------+-------------+---------+--------------+---------------+------+
|RowNumber|CustomerId| Surname|CreditScore|Geography|Gender|Age|Tenure|          Balance|NumOfProducts|HasCrCard|IsActiveMember|EstimatedSalary|Exited|
+---------+----------+--------+-----------+---------+------+---+------+-----------------+-------------+---------+--------------+---------------+------+
|        1|  15634602|Hargrave|        619|   France|Female| 42|     2|              0.0|            1|        1|             1|      101348.88|     1|
|        2|  15647311|    Hill|        608|    Spain|Female| 41|     1|         83807.86|            1|        0|             1|      112542.58|     0|
|        3|  15619304|    Onio|        502|   France|  null| 38|     8|         159660.8|            3|        1|             0|      113931.57|     1|
|        4|  15701354|    Boni|        699|   France|  null| 39|     1|              0.0|            2|        0|             0|       93826.63|     0|
|        5|  15737888|Mitchell|        850|    Spain|Female| 43|     2|76432.45615176117|            1|        1|             1|        79084.1|     0|
+---------+----------+--------+-----------+---------+------+---+------+-----------------+-------------+---------+--------------+---------------+------+
only showing top 5 rows

Conclusion

In this blog post, we explored various methods to impute missing values in PySpark, including mean, median, mode imputation, K-Nearest Neighbors, regression imputation, and iterative imputation. Each of these techniques has its own strengths and weaknesses and should be chosen based on the specific characteristics of the data and the problem you are trying to solve.

Remember that imputing missing values is only one of the many steps in the data preprocessing pipeline, and it should be complemented with other data cleaning, transformation, and feature engineering techniques to build robust and accurate machine learning models.

Course Preview

Machine Learning A-Z™: Hands-On Python & R In Data Science

Free Sample Videos:

Machine Learning A-Z™: Hands-On Python & R In Data Science

Machine Learning A-Z™: Hands-On Python & R In Data Science

Machine Learning A-Z™: Hands-On Python & R In Data Science

Machine Learning A-Z™: Hands-On Python & R In Data Science

Machine Learning A-Z™: Hands-On Python & R In Data Science