Menu

PySpark GroupBy() – Mastering PySpark GroupBy with Advanced Examples, Unleash the Power of Complex Aggregations

In this post, we’ll take a deeper dive into PySpark’s GroupBy functionality, exploring more advanced and complex use cases. With the help of detailed examples, you’ll learn how to perform multiple aggregations, group by multiple columns, and even apply custom aggregation functions. Let’s dive in!

What is PySpark GroupBy?

As a quick reminder, PySpark GroupBy is a powerful operation that allows you to perform aggregations on your data. It groups the rows of a DataFrame based on one or more columns and then applies an aggregation function to each group. Common aggregation functions include sum, count, mean, min, and max.

Here’s a general structure of a GroupBy operation:

Syntax :

dataFrame.groupBy(“column_name”).agg(aggregation_function)

aggregation functions

count() – return the number of rows for each group

max() – returns the maximum of values for each group

min() – returns the minimum of values for each group

sum() – returns the total for values for each group

avg() – returns the average for values for each group

To illustrate the power of PySpark GroupBy, let’s work with a sample dataset. We’ll use a dataset containing sales data with the following columns: ‘OrderID’, ‘Product’, ‘Category’, ‘Quantity’, ‘Price’, and ‘Date’.

Importing necessary libraries and creating a sample DataFrame

import findspark
findspark.init()

from pyspark.sql import SparkSession
from pyspark.sql.functions import *

# Create a Spark session
spark = SparkSession.builder \
    .appName("PySpark GroupBy Example") \
    .getOrCreate()

# Sample data
data = [("1001", "Laptop", "Electronics", 1, 1000, "2023-01-01"),
        ("1002", "Mouse", "Electronics", 2, 50, "2023-01-02"),
        ("1003", "Laptop", "Electronics", 1, 1200, "2023-01-03"),
        ("1004", "Mouse", "Electronics", 3, 30, "2023-01-04"),
        ("1005", "Smartphone", "Electronics", 1, 700, "2023-01-05")]

# Create DataFrame
columns = ["OrderID", "Product", "Category", "Quantity", "Price", "Date"]
df = spark.createDataFrame(data, columns)

df.show()
+-------+----------+-----------+--------+-----+----------+
|OrderID|   Product|   Category|Quantity|Price|      Date|
+-------+----------+-----------+--------+-----+----------+
|   1001|    Laptop|Electronics|       1| 1000|2023-01-01|
|   1002|     Mouse|Electronics|       2|   50|2023-01-02|
|   1003|    Laptop|Electronics|       1| 1200|2023-01-03|
|   1004|     Mouse|Electronics|       3|   30|2023-01-04|
|   1005|Smartphone|Electronics|       1|  700|2023-01-05|
+-------+----------+-----------+--------+-----+----------+

1. GroupBy operation on single column

Let’s say we want to find the total sales amount for each product. We can achieve this using the GroupBy operation with the “Product” column and applying the “sum” aggregation function to the “Price” column.

# GroupBy and aggregate

result = df.groupBy("Product").agg(sum("Price").alias("Total_Sales"))

# Show results
result.show()
+----------+-----------+
|   Product|Total_Sales|
+----------+-----------+
|    Laptop|       2200|
|     Mouse|         80|
|Smartphone|        700|
+----------+-----------+

2 GroupBy operation on Multiple Columns

Now, let’s say we want to find the total sales amount for each product by category. We can achieve this by grouping by both “Product” and “Category” columns.

# GroupBy and aggregate

result = df.groupBy(["Product", "Category"]) \
    .agg(sum("Price").alias("Total_Sales"))

# Show results
result.show()
+----------+-----------+-----------+
|   Product|   Category|Total_Sales|
+----------+-----------+-----------+
|    Laptop|Electronics|       2200|
|     Mouse|Electronics|         80|
|Smartphone|Electronics|        700|
+----------+-----------+-----------+

3. GroupBy operation on Multiple Aggregations

Let’s say we want to find the total sales amount and the total quantity sold for each product. We can achieve this by chaining multiple aggregation functions.

# GroupBy and aggregate

result = df.groupBy("Product") \
    .agg(sum("Price").alias("Total_Sales"),
         sum("Quantity").alias("Total_Quantity"))

# Show results
result.show()
+----------+-----------+--------------+
|   Product|Total_Sales|Total_Quantity|
+----------+-----------+--------------+
|    Laptop|       2200|             2|
|     Mouse|         80|             5|
|Smartphone|        700|             1|
+----------+-----------+--------------+

4. Filter Aggregated data using where condition

you can use a combination of where() (which is equivalent to the SQL WHERE clause) and groupBy() to perform a groupBy operation with a specific condition.

# GroupBy and aggregate using where condition

result = df.groupBy("Product") \
    .agg(avg("Price").alias("Total_Sales"), 
         sum("Quantity").alias("Total_Quantity")) \
    .where(col("Total_Quantity") >= 2)

# Show results
result.show()
+-------+-----------+--------------+
|Product|Total_Sales|Total_Quantity|
+-------+-----------+--------------+
| Laptop|     1100.0|             2|
|  Mouse|       40.0|             5|
+-------+-----------+--------------+

5. Custom Aggregation Functions

In some cases, you may need to apply a custom aggregation function. For this example, let’s calculate the median price for each product category.

import pandas as pd
from pyspark.sql.types import FloatType
from pyspark.sql.functions import pandas_udf

@pandas_udf(FloatType())
def median(column: pd.Series) -> float:
    return float(column.median())

In the code snippet above, we define a custom User-Defined Aggregation Function (UDAF) using the pandas_udf decorator. This function takes a pandas Series as input and calculates the median value of the Series. The return type of the function is specified as FloatType().

Apply the custom aggregation function

Now that we have defined our custom aggregation function, we can apply it to our DataFrame to compute the median price for each product category.

# GroupBy and aggregate

result = df.groupBy("Category") \
    .agg(median("Price").alias("Median_Price"))

# Show results
result.show()
+-----------+------------+
|   Category|Median_Price|
+-----------+------------+
|Electronics|       500.0|
+-----------+------------+

In this example, since we only have one category (Electronics), the output shows the median price for that category.

Recommended

Mastering PySpark’s GroupBy functionality opens up a world of possibilities for data analysis and aggregation.

By understanding how to perform multiple aggregations, group by multiple columns, and even apply custom aggregation functions, you can efficiently analyze your data and draw valuable insights.

Keep exploring and experimenting with different GroupBy operations to unlock the full potential of PySpark!

Course Preview

Machine Learning A-Z™: Hands-On Python & R In Data Science

Free Sample Videos:

Machine Learning A-Z™: Hands-On Python & R In Data Science

Machine Learning A-Z™: Hands-On Python & R In Data Science

Machine Learning A-Z™: Hands-On Python & R In Data Science

Machine Learning A-Z™: Hands-On Python & R In Data Science

Machine Learning A-Z™: Hands-On Python & R In Data Science