Menu

PySpark Decision Tree – How to Build and Evaluate Decision Tree Model for Classification using PySpark MLlib

How to build and evaluate a Decision Tree model for classification using PySpark’s MLlib library. Decision Trees are widely used for solving classification problems due to their simplicity, interpretability, and ease of use.

PySpark’s MLlib library provides an array of tools and algorithms that make it easier to build, train, and evaluate machine learning models on distributed data.

We will cover the following topics in this tutorial:

  1. Setting up the PySpark environment

  2. Importing the necessary libraries

  3. Loading and preparing the data

  4. Building and training the Decision Tree model

  5. Evaluating the model performance

  6. Creating a Pipeline & Hyperparameter Tuning

  7. Example code

1. Import required libraries and initialize SparkSession

First, let’s import the necessary libraries and create a SparkSession, the entry point to use PySpark.

import findspark
findspark.init()

from pyspark import SparkFiles

from pyspark.sql import SparkSession
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, VectorAssembler, OneHotEncoder
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder

spark = SparkSession.builder.appName("Decision Tree Model").getOrCreate()

2. Load the dataset

For this example, we will use the Breast Cancer Wisconsin (Diagnostic) dataset

url = "https://raw.githubusercontent.com/selva86/datasets/master/Iris.csv"
spark.sparkContext.addFile(url)

df = spark.read.csv(SparkFiles.get("Iris.csv"), header=True, inferSchema=True)
df.show(5)
+---+-------------+------------+-------------+------------+-----------+
| Id|SepalLengthCm|SepalWidthCm|PetalLengthCm|PetalWidthCm|    Species|
+---+-------------+------------+-------------+------------+-----------+
|  1|          5.1|         3.5|          1.4|         0.2|Iris-setosa|
|  2|          4.9|         3.0|          1.4|         0.2|Iris-setosa|
|  3|          4.7|         3.2|          1.3|         0.2|Iris-setosa|
|  4|          4.6|         3.1|          1.5|         0.2|Iris-setosa|
|  5|          5.0|         3.6|          1.4|         0.2|Iris-setosa|
+---+-------------+------------+-------------+------------+-----------+
only showing top 5 rows

3. Prepare the data

We need to convert the categorical labels in the ‘species’ column to numerical values using the StringIndexer

Before building the model, we need to assemble the input features into a single feature vector using the VectorAssembler class. Then, we will split the dataset into a training set (80%) and a testing set (20%).

# Convert the categorical labels in the 'Species' column to numerical values
label_indexer = StringIndexer(inputCol="Species", outputCol="label")
data = label_indexer.fit(df).transform(df)

# Assemble the feature columns into a single vector column
assembler = VectorAssembler(inputCols=["SepalLengthCm", "SepalWidthCm", "PetalLengthCm", "PetalWidthCm"], outputCol="features")
data = assembler.transform(data)

# Split data into training and testing sets
train_data, test_data = data.randomSplit([0.8, 0.2], seed=42)

4. Building the DecisionTreeClassifier model

Create an instance of the DecisionTreeClassifier and set the necessary parameters

dt_classifier = DecisionTreeClassifier(labelCol="label", featuresCol="features")

model = dt_classifier.fit(train_data)

5. Evaluating the model on test data

Predict the labels for the test data

Evaluate the model performance using the MulticlassClassificationEvaluator

predictions = model.transform(test_data)

evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)

print(f"Test Accuracy: {accuracy:.2f}")
Test Accuracy: 0.92

6. Feature Importance

Analyze the feature importance of the decision tree model to understand the key factors contributing to the classification task.

feature_importance = model.featureImportances.toArray()

# Show feature importance
for i, column in enumerate(assembler.getInputCols()):
    print(f"Feature '{column}': {feature_importance[i]:.2f}")
Feature 'SepalLengthCm': 0.00
Feature 'SepalWidthCm': 0.02
Feature 'PetalLengthCm': 0.53
Feature 'PetalWidthCm': 0.45

7. Visualize the Decision Tree

Visualize the decision tree model to gain insights into the decisions made by the model.

print(model.toDebugString)
DecisionTreeClassificationModel: uid=DecisionTreeClassifier_5e5d7ac37be8, depth=4, numNodes=13, numClasses=3, numFeatures=4
  If (feature 2 <= 2.45)
   Predict: 0.0
  Else (feature 2 > 2.45)
   If (feature 3 <= 1.65)
    If (feature 2 <= 4.95)
     Predict: 1.0
    Else (feature 2 > 4.95)
     If (feature 3 <= 1.55)
      Predict: 2.0
     Else (feature 3 > 1.55)
      Predict: 1.0
   Else (feature 3 > 1.65)
    If (feature 2 <= 4.85)
     If (feature 1 <= 3.05)
      Predict: 2.0
     Else (feature 1 > 3.05)
      Predict: 1.0
    Else (feature 2 > 4.85)
     Predict: 2.0

8. Save and load the model (optional)

If you want to reuse the model in the future, you can save it to disk and load it back when needed.

# Save the model
model.save("Dtree_model")

# Load the model
from pyspark.ml.classification import DecisionTreeClassificationModel
loaded_model = DecisionTreeClassificationModel.load("Dtree_model")

8. Example code

Here is the complete example code:

import findspark
findspark.init()

from pyspark import SparkFiles

from pyspark.sql import SparkSession
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, VectorAssembler, OneHotEncoder
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder

# Create a Spark session
spark = SparkSession.builder.appName("Decision Tree Model").getOrCreate()

# Load the Iris dataset
url = "https://raw.githubusercontent.com/selva86/datasets/master/Iris.csv"
spark.sparkContext.addFile(url)

df = spark.read.csv(SparkFiles.get("Iris.csv"), header=True, inferSchema=True)

# Preprocessing: StringIndexer for categorical labels
label_indexer = StringIndexer(inputCol="Species", outputCol="label")
data = label_indexer.fit(df).transform(df)

# Preprocessing: VectorAssembler for feature columns
assembler = VectorAssembler(inputCols=["SepalLengthCm", "SepalWidthCm", "PetalLengthCm", "PetalWidthCm"], outputCol="features")
data = assembler.transform(data)

# Split data into training and testing sets
train_data, test_data = data.randomSplit([0.8, 0.2], seed=42)

# Create a Decision Tree Classifier instance
dt_classifier = DecisionTreeClassifier(labelCol="label", featuresCol="features")

# Train the model
model = dt_classifier.fit(train_data)

# Make predictions on the test data
predictions = model.transform(test_data)

# Evaluate the model performance
evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)

print(f"Test Accuracy: {accuracy:.2f}")
Test Accuracy: 0.92

9. Improve the model (optional)

If the model’s performance does not meet your expectations, you can try the following strategies to improve it:

Now that we have a working example of a Decision Tree model for classification using PySpark MLlib, let’s discuss some further improvements and potential applications of this approach.

Hyperparameter Tuning: The Decision Tree model used in this example relies on default hyperparameters. To improve the model’s performance, you can use techniques like Grid Search or Random Search to find the optimal hyperparameter values. PySpark MLlib provides the ParamGridBuilder and CrossValidator classes to perform cross-validated hyperparameter tuning.

Feature Selection: To further improve the model, you can experiment with different feature selection techniques, such as Recursive Feature Elimination, to identify the most important features and reduce the complexity of the model. This can result in better generalization and improved performance on unseen data.

Ensemble Methods: Combining multiple decision trees into an ensemble model, like Random Forest or Gradient Boosted Trees, can improve the overall model performance. PySpark MLlib provides implementations of these ensemble methods, which can be easily incorporated into your workflow.

Handling Imbalanced Data: In some real-world applications, you may encounter imbalanced datasets, where some classes are under-represented. To address this issue, you can apply techniques such as resampling, assigning class weights, or using cost-sensitive learning.

Real-world Applications: Decision Trees and their ensemble counterparts can be applied to a wide range of classification tasks, such as fraud detection, customer churn prediction, medical diagnosis, and natural language processing tasks like sentiment analysis.

10. Creating a Pipeline & Hyperparameter Tuning

Assemble all the steps (indexing, encoding, assembling, and model building) into a pipeline

Perform hyperparameter tuning using CrossValidator and ParamGridBuilder. This step helps to find the optimal combination of parameters for the decision tree model.

import findspark
findspark.init()

from pyspark import SparkFiles

from pyspark.sql import SparkSession
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, VectorAssembler, OneHotEncoder
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder

# Create a Spark session
spark = SparkSession.builder.appName("Decision Tree Model").getOrCreate()

# Load the Iris dataset
url = "https://raw.githubusercontent.com/selva86/datasets/master/Iris.csv"
spark.sparkContext.addFile(url)

df = spark.read.csv(SparkFiles.get("Iris.csv"), header=True, inferSchema=True)

# Preprocessing: StringIndexer for categorical labels
stringIndexer  = StringIndexer(inputCol="Species", outputCol="label")

# Preprocessing: VectorAssembler for feature columns
assembler = VectorAssembler(inputCols=["SepalLengthCm", "SepalWidthCm", "PetalLengthCm", "PetalWidthCm"], outputCol="features")

# Split data into training and testing sets
train_data, test_data = df.randomSplit([0.8, 0.2], seed=42)

# Create a Decision Tree Classifier instance
dt = DecisionTreeClassifier(labelCol='label', featuresCol='features')

# Assemble all the steps (indexing, assembling, and model building) into a pipeline.
pipeline = Pipeline(stages=[stringIndexer, assembler, dt])

paramGrid = ParamGridBuilder() \
    .addGrid(dt.maxDepth, [3, 5, 7]) \
    .addGrid(dt.minInstancesPerNode, [1, 3, 5]) \
    .build()

Cross Validation:

Fit the CrossValidator to the training data. This will train multiple models with different hyperparameter combinations and select the best one.

crossval = CrossValidator(estimator=pipeline, estimatorParamMaps=paramGrid,
                      evaluator=MulticlassClassificationEvaluator(
                      labelCol='label', predictionCol='prediction', metricName='accuracy'),
                      numFolds=5)

cvModel = crossval.fit(train_data)

best_model = cvModel.bestModel

predictions = best_model.transform(test_data)

# Evaluate the model performance
evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)

print(f"Test Accuracy: {accuracy:.2f}")
Test Accuracy: 0.96

After corss validation accuracy improved from 0.92 to 0.96

Clean up

Don’t forget to stop the Spark session once you’re done

spark.stop()

Conclusion

PySpark MLlib library offers a scalable and efficient solution for building and evaluating Decision Tree models for classification.

By following the steps outlined in this tutorial and exploring the additional improvements and applications mentioned above, you can leverage the power of PySpark and Decision Trees to solve complex classification problems on large, distributed datasets.

Course Preview

Machine Learning A-Z™: Hands-On Python & R In Data Science

Free Sample Videos:

Machine Learning A-Z™: Hands-On Python & R In Data Science

Machine Learning A-Z™: Hands-On Python & R In Data Science

Machine Learning A-Z™: Hands-On Python & R In Data Science

Machine Learning A-Z™: Hands-On Python & R In Data Science

Machine Learning A-Z™: Hands-On Python & R In Data Science