Menu

PySpark Random Forest – Building and Evaluating Random Forest Models using PySpark MLlib: A Step-By-Step Guide

How to build and evaluate Random Forest models using PySpark MLlib and cover key aspects such as hyperparameter tuning and variable selection, providing example code to help you along the way.

Written by Jagdeesh | 4 min read

Lets discuss how to build and evaluate Random Forest models using PySpark MLlib and cover key aspects such as hyperparameter tuning and variable selection, providing example code to help you along the way.

Random Forest is an ensemble machine learning algorithm that can be used for both classification and regression tasks.

PySpark is the Python library for Apache Spark, an open-source big data processing framework that can process large-scale data in parallel. PySpark MLlib is a machine learning library built on top of PySpark that provides various algorithms and tools for building scalable machine learning models.

We will cover the following topics in this post:

  1. Setting up the environment

  2. Loading and preprocessing the data

  3. Bilding a Random Forest model

  4. Hyperparameter tuning

  5. Evaluating the model

  6. Example code

1. Import required libraries and initialize SparkSession

First, let’s import the necessary libraries and create a SparkSession, the entry point to use PySpark.

python
import findspark
findspark.init()

from pyspark import SparkFiles
from pyspark.sql import SparkSession
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder

spark = SparkSession.builder.appName("RandomForestExample").getOrCreate()

2. Load the dataset

For this example, we will use the “Iris” dataset. Save the dataset as a CSV file, and then use the following code to load the data into a PySpark DataFrame.

python
url = "https://raw.githubusercontent.com/selva86/datasets/master/Iris.csv"
spark.sparkContext.addFile(url)

df = spark.read.csv(SparkFiles.get("Iris.csv"), header=True, inferSchema=True)
df.show(5)
python
+---+-------------+------------+-------------+------------+-----------+
| Id|SepalLengthCm|SepalWidthCm|PetalLengthCm|PetalWidthCm|    Species|
+---+-------------+------------+-------------+------------+-----------+
|  1|          5.1|         3.5|          1.4|         0.2|Iris-setosa|
|  2|          4.9|         3.0|          1.4|         0.2|Iris-setosa|
|  3|          4.7|         3.2|          1.3|         0.2|Iris-setosa|
|  4|          4.6|         3.1|          1.5|         0.2|Iris-setosa|
|  5|          5.0|         3.6|          1.4|         0.2|Iris-setosa|
+---+-------------+------------+-------------+------------+-----------+
only showing top 5 rows

3. Prepare the data

Before building the model, we need to assemble the input features into a single feature vector using the VectorAssembler class. Then, we will split the dataset into a training set (80%) and a testing set (20%).

python
# Preprocessing: StringIndexer for categorical labels
stringIndexer  = StringIndexer(inputCol="Species", outputCol="label")

# Define the feature and label columns & Assemble the feature vector
assembler = VectorAssembler(inputCols=["SepalLengthCm", "SepalWidthCm", "PetalLengthCm", "PetalWidthCm"], outputCol="features")

rf = RandomForestClassifier(labelCol="label", featuresCol="features")

# Split the data into training and test sets
train_data, test_data = df.randomSplit([0.7, 0.3], seed=42)

4. Create a Pipeline

We will now create a pipeline that includes the feature assembler and the Random Forest classifier.

python
pipeline = Pipeline(stages=[stringIndexer, assembler, rf])

5. Hyperparameter Tuning and Model Selection

We will use cross-validation to select the best model based on hyperparameter tuning. We will tune the numTrees and maxDepth parameters of the Random Forest model.

python
# Define the hyperparameter grid
paramGrid = ParamGridBuilder() \
    .addGrid(rf.numTrees, [10, 20, 30]) \
    .addGrid(rf.maxDepth, [5, 10, 15]) \
    .build()

# Create the cross-validator
cross_validator = CrossValidator(estimator=pipeline,
                          estimatorParamMaps=paramGrid,
                          evaluator=MulticlassClassificationEvaluator(labelCol="label", metricName="accuracy"),
                          numFolds=5, seed=42)

# Train the model with the best hyperparameters
cv_model = cross_validator.fit(train_data)

6. Analyze feature importance

To understand the importance of each variable in the model, we can examine the featureImportances attribute of the best Random Forest model obtained from cross-validation.

python
best_rf_model = cv_model.bestModel.stages[-1]
importances = best_rf_model.featureImportances
feature_list = ["SepalLengthCm", "SepalWidthCm", "PetalLengthCm", "PetalWidthCm"]

print("Feature Importances:")
for feature, importance in zip(feature_list, importances):
    print(f"{feature}: {importance:.4f}")
python
Feature Importances:
SepalLengthCm: 0.0887
SepalWidthCm: 0.0590
PetalLengthCm: 0.3873
PetalWidthCm: 0.4650

This will display the importance of each feature in the best Random Forest model. Features with higher importance values contribute more to the model’s decision-making process.

7. Evaluating the model

To evaluate the performance of our Lasso Regression model, we’ll use the RegressionEvaluator class from PySpark MLlib.

python
# Make predictions on the test data
predictions = cv_model.transform(test_data)

evaluator = MulticlassClassificationEvaluator(labelCol="label", metricName="accuracy")

# Evaluate the model
accuracy = evaluator.evaluate(predictions)
print("Test set accuracy = {:.2f}".format(accuracy))
python
Test set accuracy = 0.93

8. Save and load the model (optional)

If you want to reuse the model in the future, you can save it to disk and load it back when needed.

python
# Save the model
best_rf_model.save("rf_model")

# Load the model
from pyspark.ml.classification import RandomForestClassificationModel
loaded_model = RandomForestClassificationModel.load("rf_model")

Conclusion

we have demonstrated how to build and evaluate a Random Forest model using PySpark MLlib. We covered important aspects such as hyperparameter tuning, variable selection, and model evaluation.

With this knowledge, you can now apply the Random Forest algorithm to your own datasets using PySpark and gain valuable insights from your data.

Free Course
Master Core Python — Your First Step into AI/ML

Build a strong Python foundation with hands-on exercises designed for aspiring Data Scientists and AI/ML Engineers.

Start Free Course
Trusted by 50,000+ learners
Jagdeesh
Written by
Related Course
Master PySpark — Hands-On
Join 5,000+ students at edu.machinelearningplus.com
Explore Course
Free Callback - Limited Slots
Not Sure Which Course to Start With?
Talk to our AI Counsellors and Practitioners. We'll help you clear all your questions for your background and goals, bridging the gap between your current skills and a career in AI.
10-digit mobile number
📞
Thank You!
We'll Call You Soon!
Our learning advisor will reach out within 24 hours.
(Check your inbox too — we've sent a confirmation)
⚡ Before you go

Python.
SQL. NumPy.
All free.

Get the exact 10-course programming foundation that Data Science professionals use.

🐍
Core Python — from first line to expert level
📈
NumPy & Pandas — the #1 libraries every DS job needs
🗃️
SQL Levels I–III — basics to Window Functions
📄
Real industry data — Jupyter notebooks included
R A M S K
57,000+ students
★★★★★ Rated 4.9/5
⚡ Before you go
Python. SQL.
All Free.
R A M S K
57,000+ students  ★★★★★ 4.9/5
Get Free Access Now
10 courses. Real projects. Zero cost. No credit card.
New learners enrolling right now
🔒 100% free ☕ No spam, ever ✓ Instant access
🚀
You're in!
Check your inbox for your access link.
(Check Promotions or Spam if you don't see it)
Or start your first course right now:
Start Free Course →
Scroll to Top
Scroll to Top
Course Preview

Machine Learning A-Z™: Hands-On Python & R In Data Science

Free Sample Videos:

Machine Learning A-Z™: Hands-On Python & R In Data Science

Machine Learning A-Z™: Hands-On Python & R In Data Science

Machine Learning A-Z™: Hands-On Python & R In Data Science

Machine Learning A-Z™: Hands-On Python & R In Data Science

Machine Learning A-Z™: Hands-On Python & R In Data Science