101 Python datatable Exercises (pydatatable)

Python datatable is the newest package for data manipulation and analysis in Python. It carries the spirit of R’s data.table with similar syntax. It is super fast, much faster than pandas and has the ability to work with out-of-memory data. Looking at the performance it is on path to become a must-use package for data manipulation in python.

101 Python datatable Exercises (pydatatable). Photo by Jet Kim.

1. How to import datatable package and check the version?

Difficulty Level: L1

Show Solution
import datatable as dt
dt.__version__
'0.8.0'

You need to import datatable as dt for the rest of the codes in this exercise to work.

2. How to create a datatable Frame from a list, numpy array, pandas dataframe?

Difficulty Level: L1

Question: Create a datatable Frame from a list, numpy array and pandas dataframe.

Input:

import pandas as pd
import numpy as np

my_list = list('abcedfghijklmnopqrstuvwxyz')
my_arr = np.arange(26)
my_df = pd.DataFrame(dict(col1=my_list, col2=my_arr))

Desired Output:

Show Solution
import pandas as pd
import numpy as np
import datatable as dt

# Inputs
my_list = list('abcedfghijklmnopqrstuvwxyz')
my_arr  = np.arange(26)
my_df   = pd.DataFrame(dict(col1=my_list, col2=my_arr))


# Solution
dt_df1  = dt.Frame(my_list)
dt_df2  = dt.Frame(my_arr)
dt_df3  = dt.Frame(my_df)
dt_df4  = dt.Frame(A=my_arr, B= my_list)

3. How to import csv file as a pydatatable Frame?

Difficulty Level: L1

Question: Read files as datatable Frame.

Show Solution

Input: BostonHousing dataset

# Solution
import datatable as dt
df = dt.fread('https://raw.githubusercontent.com/selva86/datasets/master/BostonHousing.csv')
df.head(5)

pydatatable 1

4. How to read first 5 rows of pydatatable Frame ?

Difficulty Level: L1

Question: Read first 5 rows of datatable Frame.

Input URL for CSV file: https://raw.githubusercontent.com/selva86/datasets/master/BostonHousing.csv

Show Solution
# Solution
import datatable as dt
df = dt.fread('https://raw.githubusercontent.com/selva86/datasets/master/BostonHousing.csv', max_nrows= 5)
df

pydatatable 1

5. How to add new column in pydatatable Frame from a list?

Difficulty Level: L1

Question: Read first 5 rows of datatable Frame and add a new column of length 5.

Input URL for CSV file: https://raw.githubusercontent.com/selva86/datasets/master/BostonHousing.csv

Show Solution
# Input
import datatable as dt
df = dt.fread('https://raw.githubusercontent.com/selva86/datasets/master/BostonHousing.csv', max_nrows= 5)

# Solution
df[:,"new_column"] = dt.Frame([1,2,3,4,5])
df

pydatatable

6. How to do addition of existing columns to get a new column in pydatatable Frame?

Difficulty Level: L1

Question: Add age and rad columns to get a new column in datatable Frame.

Show Solution

Input: BostonHousing dataset

# Input
import datatable as dt
df = dt.fread('https://raw.githubusercontent.com/selva86/datasets/master/BostonHousing.csv')

# Solution
df[:,"new_column"] = df[:, dt.f.age + dt.f.rad]

7. How to get the int value of a float column in a pydatatable Frame?

Difficulty Level: L1

Question: Get the int value of a float column dis in datatable Frame.

Input: BostonHousing dataset

Show Solution
# Input
import datatable as dt
df = dt.fread('https://raw.githubusercontent.com/selva86/datasets/master/BostonHousing.csv')

# Solution
df[:, "new_column"] = df[:, dt.int32(dt.f.dis)]
df.head(5)
Show Solution

8. How to create a new column based on a condition in a datatable Frame?

Difficulty Level: L2

Question: Create a new column having value as ‘Old’ if age greater than 60 else ‘New’ in a `datatable` Frame.

Input: BostonHousing dataset

Show Solution
import datatable as dt
df = dt.fread('https://raw.githubusercontent.com/selva86/datasets/master/BostonHousing.csv')
df[:, "new_column"] = dt.Frame(np.where(df[:, dt.f.age > 60], 'Old', 'New'))
df.head(5)

9. How to left join two datatable Frames?

Difficulty Level: L1

Question: join two Frames.

Input:

import datatable as dt
df1 = dt.Frame(A=[1,2,3,4],B=["a", "b", "c", "d"])
df2 = dt.Frame(A=[1,2,3,4,5],C=["a2", "b2", "c2", "d2", "e2"])

Primary Key : A

Show Solution
import datatable as dt
df1 = dt.Frame(A=[1,2,3,4],B=["a", "b", "c", "d"])
df2 = dt.Frame(A=[1,2,3,4,5],C=["a2", "b2", "c2", "d2", "e2"])
df2.key = "A"
output = df1[:, :, dt.join(df2)]
output

10. How to rename a column in a pydatatable Frame?

Difficulty Level: L1

Question: Rename column zn to zn_new in a datatable Frame.

Input: BostonHousing dataset

Show Solution
import datatable as dt
df = dt.fread('https://raw.githubusercontent.com/selva86/datasets/master/BostonHousing.csv')
df.names = {'zn': 'zn_new'}
df.head(5)

11. How to import every 50th row from a csv file to create a datatable Frame?

Difficiulty Level: L2

Question: Import every 50th row of [BostonHousing dataset] (BostonHousing.csv) as a dataframe.

Input: BostonHousing dataset

Show Solution
# Solution: Use csv reader. Unfortunately there isn't an option to do it directly using fread()
import datatable as dt
import csv          
with open('local/path/to/BostonHousing.csv', 'r') as f:
    reader = csv.reader(f)
    for i, row in enumerate(reader):
        row = [[x] for x in row]
        # 1st row
        if i == 0:  
            df = dt.Frame(row)
            header = [x[0] for x in df[0,:].to_list()]
            df.names =  header
            del df[0,:]  
        # Every 50th row
        elif i%50 ==0:
            df_temp = dt.Frame(row)
            df_temp.names = header
            df.rbind(df_temp)

df.head(5)

12. How to change column values when importing csv to a Python datatable Frame?

Difficulty Level: L2

Question: Import the boston housing dataset, but while importing change the 'medv' (median house value) column so that values < 25 becomes ‘Low’ and > 25 becomes ‘High’.

Input: BostonHousing dataset

Show Solution
# Solution: Use csv reader
import datatable as dt
import csv          
with open('https://raw.githubusercontent.com/selva86/datasets/master/BostonHousing.csv', 'r') as f:
    reader = csv.reader(f)
    for i, row in enumerate(reader):
        row = [[x] for x in row]
        if i == 0:
            df = dt.Frame(row)
            header = [x[0] for x in df[0,:].to_list()]
            df.names =  header
            del df[0,:]  
        else:
            row[13] = ['High'] if float(row[13][0]) > 25 else ['Low']
            df_temp = dt.Frame(row)
            df_temp.names = header
            df.rbind(df_temp)

df.head(5)

13. How to change value at particular row and column in a Python datatable Frame?

Difficulty Level: L1

Question: Change value at row number 2 and column number 1 as 5 in a datatable Frame.

Input: BostonHousing dataset

Show Solution
# Solution: It follows row, column indexing. No need to use "loc", ".loc"
import datatable as dt
df = dt.fread('https://raw.githubusercontent.com/selva86/datasets/master/BostonHousing.csv')
df[2,1] = 5
df.head(5)

14. How to delete specific cell, row, column, row per condition in a datatable Frame?

Difficulty Level: 2

Questions:

  1. Delete the cell at position 2,1.

  2. Delete the 3rd row.

  3. Delete the chas column.

  4. Delete rows where column zn is having 0 value.

Input: BostonHousing dataset

Show Solution
# Solution: It follows row,colume indexing. No need to use "loc", ".loc"
import datatable as dt
df = dt.fread('https://raw.githubusercontent.com/selva86/datasets/master/BostonHousing.csv')

# Delete the cell at position `2,1`.
del df[2,1]

# Delete the `3rd` row.
del df[3,:]

# Delete the `chas` column.
del df[:,"chas"]

# Delete rows where column `zn` is having 0 value.
del df[dt.f.zn == 0,:]

df.head(5)

15. How to convert datatable Frame to pandas, numpy, dictionary, list, tuples, csv files?

Difficulty Level: L1

Question: Convert datatable Frame to pandas, numpy, dictionary, list, tuples, csv files.

Input: BostonHousing dataset

Show Solution
# Solution
import datatable as dt
df = dt.fread('https://raw.githubusercontent.com/selva86/datasets/master/BostonHousing.csv')

# to pandas df
pd_df = df.to_pandas()

# to numpy arrays
np_arrays = df.to_numpy()

# to dictionary
dic = df.to_dict()

# to list
list_ = df[:,"indus"].to_list()

# to tuple
tuples_ = df[:,"indus"].to_tuples()

# to csv 
df.to_csv("BostonHousing.csv")

16. How to get data types of all the columns in the datatable Frame?

Difficulty Level: L1

Question: Get data types of all the columns in the datatable Frame.

Input: BostonHousing dataset

Desired Output:

crim : stype.float64
zn : stype.float64
indus : stype.float64
chas : stype.bool8
nox : stype.float64
rm : stype.float64
age : stype.float64
dis : stype.float64
rad : stype.int32
tax : stype.int32
ptratio : stype.float64
b : stype.float64
lstat : stype.float64
medv : stype.float64
Show Solution
# Solution
import datatable as dt
df = dt.fread('https://raw.githubusercontent.com/selva86/datasets/master/BostonHousing.csv')
for i in range(len(df.names)):
    print(df.names[i], ":", df.stypes[i])
crim : stype.float64
zn : stype.float64
indus : stype.float64
chas : stype.bool8
nox : stype.float64
rm : stype.float64
age : stype.float64
dis : stype.float64
rad : stype.int32
tax : stype.int32
ptratio : stype.float64
b : stype.float64
lstat : stype.float64
medv : stype.float64

17. How to get summary stats of each column in datatable Frame?

Difficulty Level: L1

Questions:

For each column:

  1. Get the sum of the column values.

  2. Get the max of the column values.

  3. Get the min of the column values.

  4. Get the mean of the column values.

  5. Get the standard deviation of the column values.

  6. Get the mode of the column values.

  7. Get the modal value of the column values.

  8. Get the number of unique values in column.

Input: BostonHousing dataset

Show Solution
# Solution
import datatable as dt
df = dt.fread('https://raw.githubusercontent.com/selva86/datasets/master/BostonHousing.csv')
df.sum()
df.max()
df.min()
df.mean()
df.sd()
df.mode()
df.nmodal()
df.nunique()

18. How to get the column stats of particular column of the datatable Frame?

Difficulty Level: L1

Question: Get the max value of zn column of the datatable Frame

Input: BostonHousing dataset

Desired Output: 100

Show Solution
# Input
import datatable as dt
df = dt.fread('https://raw.githubusercontent.com/selva86/datasets/master/BostonHousing.csv')
df[:,dt.max(dt.f.zn)]

19. How to apply group by functions in datatable Frame?

Difficulty Level: L1

Question: Find the mean price for every manufacturer using Cars93 dataset.

Input:
Cars93

Desired Output:

     Manufacturer         C0
0            None  28.550000
1           Acura  15.900000
2            Audi  33.400000
3             BMW  30.000000
4           Buick  21.625000
5        Cadillac  37.400000
..
..

30     Volkswagen  18.025000
31          Volvo  22.700000
Show Solution
# Solution
import datatable as dt
df = dt.fread('https://raw.githubusercontent.com/selva86/datasets/master/Cars93_miss.csv')
df[:, dt.mean(dt.f.Price), dt.by("Manufacturer")].head(5)

20. How to arrange datatabe Frame in ascending order by column value?

Difficulty Level: L1

Question: Arrange datatable Frame in ascending order by Price.

Input:
Cars93

Desired Output:

Manufacturer    Model     Type  Min.Price  Price  Max.Price  MPG.city  \ 
0       Saturn       SL    Small        9.2    NaN       12.9       NaN   
1       Toyota    Camry  Midsize       15.2    NaN       21.2      22.0   
2         Ford  Festiva    Small        6.9    7.4        7.9      31.0   
3      Hyundai    Excel    Small        6.8    8.0        9.2      29.0   
4        Mazda      323    Small        7.4    8.3        9.1      29.0   


   Width  Turn.circle Rear.seat.room  Luggage.room  Weight   Origin  \
0   68.0         40.0           26.5           NaN  2495.0      USA   
1   70.0         38.0           28.5          15.0  3030.0  non-USA   
2   63.0         33.0           26.0          12.0  1845.0      USA   
3   63.0         35.0           26.0          11.0  2345.0  non-USA   
4   66.0         34.0           27.0          16.0  2325.0  non-USA   

            Make  
0      Saturn SL  
1   Toyota Camry  
2   Ford Festiva  
3  Hyundai Excel  
4      Mazda 323  
Show Solution
import datatable as dt
df = dt.fread('https://raw.githubusercontent.com/selva86/datasets/master/Cars93_miss.csv')

# Solution1
df.sort("Price")

# Solution2
df[:,:, dt.sort(dt.f.Price)].head(5)

21. How to arrange datatabe Frame in descending order by column value?

Difficulty Level: L1

Question: Arrange datatable Frame in descending order by Price.

Input:
Cars93

Desired Output:

   Manufacturer     Model     Type  Min.Price  Price  Max.Price  MPG.city  \
0  Mercedes-Benz      300E  Midsize       43.8   61.9       80.0      19.0   
1       Infiniti       Q45  Midsize       45.4   47.9        NaN      17.0   
2       Cadillac   Seville  Midsize       37.5   40.1       42.7      16.0   
3      Chevrolet  Corvette   Sporty       34.6   38.0       41.5      17.0   
4           Audi       100  Midsize        NaN   37.7       44.6      19.0   

   MPG.highway             AirBags DriveTrain  ... Passengers  Length  \
0         25.0  Driver & Passenger       Rear  ...        5.0     NaN   
1         22.0                None       Rear  ...        5.0   200.0   
2         25.0  Driver & Passenger      Front  ...        5.0   204.0   
3         25.0         Driver only       Rear  ...        2.0   179.0   
4         26.0  Driver & Passenger       None  ...        6.0   193.0   

   Wheelbase  Width  Turn.circle Rear.seat.room  Luggage.room  Weight  \
0      110.0   69.0         37.0            NaN          15.0  3525.0   
1      113.0   72.0         42.0           29.0          15.0  4000.0   
2      111.0   74.0         44.0           31.0           NaN  3935.0   
3       96.0   74.0         43.0            NaN           NaN  3380.0   
4      106.0    NaN         37.0           31.0          17.0  3405.0   

    Origin                Make  
0  non-USA  Mercedes-Benz 300E  
1  non-USA        Infiniti Q45  
2      USA    Cadillac Seville  
3     None  Chevrolet Corvette  
4  non-USA            Audi 100  
Show Solution
import datatable as dt
df = dt.fread('https://raw.githubusercontent.com/selva86/datasets/master/Cars93_miss.csv')

# Solution
df[::-1,:, dt.sort(dt.f.Price)].head()

22. How to repeat(append) the same data in datatable Frame?

Difficulty Level: L1

Question: Repeat(append) the same data 5 times in datatable Frame.

Input:
Cars93

Show Solution
# Input
import datatable as dt
df = dt.fread('https://raw.githubusercontent.com/selva86/datasets/master/Cars93_miss.csv')

# Solution
dt.repeat(df, 5)

23. How to replace string with another string in entire datatable Frame?

Difficulty Level: L1

Question: Replace Audi with My Dream Car in entire datatable Frame.

Input:
Cars93

Show Solution
# Input
import datatable as dt
df = dt.fread('https://raw.githubusercontent.com/selva86/datasets/master/Cars93_miss.csv')

# Solution
df.replace("Audi", "My Dream Car")
df.head(5)

24. How to extract the details of a particular cell with given criterion??

Difficulty Level: L1

Question: Extract which manufacturer, model and type has the highest Price.

Input:
Cars93

Desired Output:

 Manufacturer  Model     Type
 Mercedes-Benz  300E  Midsize
Show Solution
# Input
import datatable as dt
df = dt.fread('https://raw.githubusercontent.com/selva86/datasets/master/Cars93_miss.csv')

# Solution

# Get the highest price
print("Highest Price : ", df[:,dt.f.Price].max()[0,0])

# Get Manufacturer with highest price
df[dt.f.Price ==  df[:,dt.f.Price].max()[0,0], ['Manufacturer', 'Model', 'Type']]
Highest Price :  61.9

25. How to rename a specific columns in a dataframe?

Difficulty Level: L2

Question: Rename the column Model as Car Model.

Input:
Cars93

Show Solution
# Input
import datatable as dt
df = dt.fread('https://raw.githubusercontent.com/selva86/datasets/master/Cars93_miss.csv')

# Solution
old_col_name = "Model"
new_col_name = "Car Model"
df.names = [new_col_name if x == old_col_name else x for x in df.names]
df.head(5)

26. How to count NA values in every column of a datatable Frame?

Difficulty Level: L1

Question: Count NA values in every column of a datatable Frame.

Input:
Cars93

Desired Output:

Manufacturer  Model  Type  Min.Price  Price  Max.Price  MPG.city  \
0             4      1     3          7      2          5         9   

   MPG.highway  AirBags  DriveTrain  ...  Passengers  Length  Wheelbase  \
0            2        6           7  ...           2       4          1   

   Width  Turn.circle  Rear.seat.room  Luggage.room  Weight  Origin  Make  
0      6            5               4            19       7       5     3
Show Solution
# Input
import datatable as dt
df = dt.fread('https://raw.githubusercontent.com/selva86/datasets/master/Cars93_miss.csv')

# Solution
df.countna()

27. How to get a specific column from a datatable Frame as a datatable Frame instead of a series?

Difficulty Level: L1

Question :Get the column (Model) in datatable Frame as a datatable Frame (rather than as a Series).

Input:
Cars93

Show Solution
# Input
import datatable as dt
df = dt.fread('https://raw.githubusercontent.com/selva86/datasets/master/Cars93_miss.csv')

# Solution
df[:,"Model"].head(5)
Model
▪▪▪▪
0Integra
1Legend
290
3100
4535i

28. How to reverse the order of columns of a datatable Frame?

Difficulty Level: L1

Question : Reverse the order of columns in Cars93 datatable Frame.

Input:
Cars93

Show Solution
# Input
import datatable as dt
df = dt.fread('https://raw.githubusercontent.com/selva86/datasets/master/Cars93_miss.csv')

# Solution 1
df.head()
df[:,::-1].head(5)

29. How to format or suppress scientific notations in Python datatable Frame?

Difficulty Level: L2

Question: Suppress scientific notations like ‘e-03’ in df and print upto 6 numbers after decimal.

Input

import datatable as dt
df = dt.Frame(random=np.random.random(4)**10)
df
         random
0  3.518290e-04
1  5.104371e-02
2  5.895886e-06
3  1.274671e-09

Desired Output

         random   random2
0  3.518290e-04  0.000352
1  5.104371e-02  0.051044
2  5.895886e-06  0.000006
3  1.274671e-09  0.000000
Show Solution
# Solution
import datatable as dt
df = dt.Frame(random=np.random.random(4)**10)
df[:,"random2"] = dt.Frame(['%.6f' % x for x in df[:,"random"].to_list()[0]])
df

30. How to filter every nth row in a pydatatable?

Difficulty Level: L1

Question: From df, filter the 'Manufacturer', 'Model' and 'Type' for every 20th row starting from 1st (row 0).

Input:
Cars93

Show Solution
# Input
import datatable as dt
df = dt.fread('https://raw.githubusercontent.com/selva86/datasets/master/Cars93_miss.csv')

# Solution
df[::20, ['Manufacturer', 'Model', 'Type']]

31. How to reverse the rows of a python datatable Frame?

Difficulty Level: L2

Question: Reverse all the rows.

Input:
Cars93

Show Solution
# Input
import datatable as dt
df = dt.fread('https://raw.githubusercontent.com/selva86/datasets/master/Cars93_miss.csv')

# Solution
df[::-1,:]

32. How to find out which column contains the highest number of row-wise maximum values?

Difficulty Level: L2

Question: What is the column name with the highest number of row-wise maximum’s.

Input:
BostonHousing dataset

Desired Output:
tax

Show Solution
# Input
import datatable as dt
df = dt.fread("https://raw.githubusercontent.com/selva86/datasets/master/BostonHousing.csv")

# Solution
for i in range(len(df.names)):
    if df.sum()[0:1,:].to_list()[i] == max(df.sum()[0:1,:].to_list()):
        print(df.names[i])
tax

33. How to normalize all columns in a dataframe?

Difficulty Level: L2

Questions:

  1. Normalize all columns of df by subtracting the column mean and divide by standard deviation.
  2. Range all columns of df such that the minimum value in each column is 0 and max is 1.

Don’t use external packages like sklearn.

Input:
BostonHousing dataset

Desired Output:

       crim    zn     indus  chas       nox        rm       age       dis  \
0  0.000000  0.18  0.067815   0.0  0.314815  0.577505  0.641607  0.269203   
1  0.000236  0.00  0.242302   0.0  0.172840  0.547998  0.782698  0.348962   
2  0.000236  0.00  0.242302   0.0  0.172840  0.694386  0.599382  0.348962   
3  0.000293  0.00  0.063050   0.0  0.150206  0.658555  0.441813  0.448545   
4  0.000705  0.00  0.063050   0.0  0.150206  0.687105  0.528321  0.448545   

        rad       tax   ptratio         b     lstat      medv  
0  0.000000  0.208015  0.287234  1.000000  0.089680  0.422222  
1  0.043478  0.104962  0.553191  1.000000  0.204470  0.368889  
2  0.043478  0.104962  0.553191  0.989737  0.063466  0.660000  
3  0.086957  0.066794  0.648936  0.994276  0.033389  0.631111  
4  0.086957  0.066794  0.648936  1.000000  0.099338  0.693333
Show Solution
# Input
import datatable as dt
df = dt.fread("BostonHousing.csv")

# Solution
for i in df.names:
    df[:,i] = df[:,(dt.f[i] - df[:,dt.min(dt.f[i])][0,0])/(df[:,dt.max(dt.f[i])][0,0] - df[:,dt.min(dt.f[i])][0,0])]
df.head(5)

34. How to compute grouped mean on datatable Frame and keep the grouped column as another column?

Difficulty Level: L1

Question: In df, Compute the mean price of every fruit, while keeping the fruit as another column instead of an index.

Input

df = dt.Frame(fruit = ['apple', 'banana', 'orange'] * 3,
             rating =  np.random.rand(9),
             price  =  np.random.randint(0, 15, 9))

Desired Output:

    fruit        C0
0   apple  7.666667
1  banana  5.000000
2  orange  8.333333
Show Solution
# Input
import datatable as dt
df = dt.Frame(fruit = ['apple', 'banana', 'orange'] * 3,
             rating =  np.random.rand(9),
             price  =  np.random.randint(0, 15, 9))
df[:, dt.mean(dt.f.price), dt.by("fruit")]

35. How to join two datatable Frames by 2 columns?

Difficulty Level: L2

Question: Join dataframes df1 and df2 by ‘A’ and ‘B’.

Input

df1 = dt.Frame(A=[1, 2, 3, 4],
               B=["a", "b", "c", "d"],
               D=[1, 2, 3, 4])

df2 = dt.Frame(A=[1, 2, 4, 5],
               B=["a", "b", "d", "e"],
               C=["a2", "b2", "d2", "e2"])

Desired Output:

   A  B  D   C
0  1  a  1  a2
1  2  b  2  b2
2  3  c  3  
3  4  d  4  d2
Show Solution
# Input
import datatable as dt
df1 = dt.Frame(A=[1, 2, 3, 4], B=["a", "b", "c", "d"], D=[1, 2, 3, 4])
df2 = dt.Frame(A=[1, 2, 4, 5], B=["a", "b", "d", "e"], C=["a2", "b2", "d2", "e2"])

# Solution
df2.key = ["A","B"]
output = df1[:, :, dt.join(df2)]
output

36. How to create leads (column shifted up by 1 row) of a column in a datatable Frame?

Difficulty Level: L2

Question: Create new column in df, which is a lead1 (shift column A up by 1 row).

Input:

df = dt.Frame(A=[1,2,3,4],B=["a", "b", "c", "d"],d=[1,2,3,4])

Desired Output:

   A  B  d  A.1
0  1  a  1    2
1  2  b  2    3
2  3  c  3    4
3  4  d  4  NaN
Show Solution
# Input
import datatable as dt
df = dt.Frame(A=[1,2,3,4],B=["a", "b", "c", "d"],d=[1,2,3,4])

# Solution
dt.cbind(df1,df[1:,"A"],force= True)

Machine Learning Exercise

36. How to use FTRL Model to calculate the probability of a person having diabetes?

Difficulty Level: L3

Question 1: Use Follow the Regularized Leader (Ftrl) Model to calculate the probability of a person having diabetes.

Question 2: Find the feature importance of the features used in model.

Input:

Training Data : pima_indian_diabetes_training_data.csv

Testing Data : pima_indian_diabetes_testing_data.csv

Show Solution
import datatable as dt
from datatable.models import Ftrl

# Import data
train_df = dt.fread('pima_indian_diabetes_training_data.csv')
test_df = dt.fread('pima_indian_diabetes_testing_data.csv')

# Create Ftrl model
ftrl_model = Ftrl()

#  add parameter values while creating model
ftrl_model = Ftrl(alpha = 0.1, lambda1 = 0.5, lambda2 = 0.6)

# change paramter of existing model
ftrl_model.alpha = 0.1
ftrl_model.lambda1 = 0.5
ftrl_model.lambda2 = 0.6

# Prepare training and test dataset
train_df[:,"diabetes"] = dt.Frame(np.where(train_df[:, dt.f["diabetes"] == "pos"], 1,0))
test_df[:,"diabetes"] = dt.Frame(np.where(test_df[:, dt.f["diabetes"] == "pos"], 1,0))

x_train = train_df[:, ["pregnant", "glucose", "pressure", "mass", "pedigree", "age"]]
y_train = train_df[:, ["diabetes"]]

x_test = test_df[:, ["pregnant", "glucose", "pressure", "mass", "pedigree", "age"]]
y_test = test_df[:, ["diabetes"]]

# training the model
ftrl_model.fit(x_train,y_train)

# predictions of the model
targets = ftrl_model.predict(x_test)
print(targets.head(5))

# feature importance
fi = ftrl_model.feature_importances
fi

To be continued..

Author: Ajay Kumar

Vector Autoregression (VAR) – Comprehensive Guide with Examples in Python

Vector Autoregression (VAR) is a forecasting algorithm that can be used when two or more time series influence each other. That is, the relationship between the time series involved is bi-directional. In this post, we will see the concepts, intuition behind VAR models and see a comprehensive and correct method to train and forecast VAR models in python using statsmodels.

Vector Autoregression (VAR) – Comprehensive Guide with Examples in Python. Photo by Kyran Low.

Content

[columnize]
  1. Introduction
  2. Intuition behind VAR Model Formula
  3. Building a VAR model in Python
  4. Import the datasets
  5. Visualize the Time Series
  6. Testing Causation using Granger’s Causality Test
  7. Cointegration Test
  8. Split the Series into Training and Testing Data
  9. Check for Stationarity and Make the Time Series Stationary
  10. How to Select the Order (P) of VAR model
  11. Train the VAR Model of Selected Order(p)
  12. Check for Serial Correlation of Residuals (Errors) using Durbin Watson Statistic
  13. How to Forecast VAR model using statsmodels
  14. Train the VAR Model of Selected Order(p)
  15. Invert the transformation to get the real forecast
  16. Plot of Forecast vs Actuals
  17. Evaluate the Forecasts
  18. Conclusion
  19. [/columnize]

1. Introduction

First, what is Vector Autoregression (VAR) and when to use it?

Vector Autoregression (VAR) is a multivariate forecasting algorithm that is used when two or more time series influence each other.

That means, the basic requirements in order to use VAR are:

  1. You need atleast two time series (variables)
  2. The time series should influence each other.

Alright. So why is it called ‘Autoregressive’?

It is considered as an Autoregressive model because, each variable (Time Series) is modeled as a function of the past values, that is the predictors are nothing but the lags (time delayed value) of the series.

Ok, so how is VAR different from other Autoregressive models like AR, ARMA or ARIMA?

The primary difference is those models are uni-directional, where, the predictors influence the Y and not vice-versa. Whereas, Vector Auto Regression (VAR) is bi-directional. That is, the variables influence each other.

We will go more in detail in the next section.

In this article you will gain a clear understanding of:

  • Intuition behind VAR Model formula
  • How to check the bi-directional relationship using Granger Causality
  • Procedure to building a VAR model in Python
  • How to determine the right order of VAR model
  • Interpreting the results of VAR model
  • How to generate forecasts to original scale of time series

2. Intuition behind VAR Model Formula

If you remember in Autoregression models, the time series is modeled as a linear combination of it’s own lags. That is, the past values of the series are used to forecast the current and future.

A typical AR(p) model equation looks something like this:

AR(p) Model - Equation

where α is the intercept, a constant and β1, β2 till βp are the coefficients of the lags of Y till order p.

Order ‘p’ means, up to p-lags of Y is used and they are the predictors in the equation. The ε_{t} is the error, which is considered as white noise.

Alright. So, how does a VAR model’s formula look like?

In the VAR model, each variable is modeled as a linear combination of past values of itself and the past values of other variables in the system. Since you have multiple time series that influence each other, it is modeled as a system of equations with one equation per variable (time series).

That is, if you have 5 time series that influence each other, we will have a system of 5 equations.

Well, how is the equation exactly framed?

Let’s suppose, you have two variables (Time series) Y1 and Y2, and you need to forecast the values of these variables at time (t).

To calculate Y1(t), VAR will use the past values of both Y1 as well as Y2. Likewise, to compute Y2(t), the past values of both Y1 and Y2 be used.

For example, the system of equations for a VAR(1) model with two time series (variables `Y1` and `Y2`) is as follows:

Formula for VAR(1) model with two Y's

Where, Y{1,t-1} and Y{2,t-1} are the first lag of time series Y1 and Y2 respectively.

The above equation is referred to as a VAR(1) model, because, each equation is of order 1, that is, it contains up to one lag of each of the predictors (Y1 and Y2).

Since the Y terms in the equations are interrelated, the Y’s are considered as endogenous variables, rather than as exogenous predictors.

Likewise, the second order VAR(2) model for two variables would include up to two lags for each variable (Y1 and Y2).

VAR(2) model with Two Y's

Can you imagine what a second order VAR(2) model with three variables (Y1, Y2 and Y3) would look like?

VAR(2) model with three Y's

As you increase the number of time series (variables) in the model the system of equations become larger.

3. Building a VAR model in Python

The procedure to build a VAR model involves the following steps:

  1. Analyze the time series characteristics
  2. Test for causation amongst the time series
  3. Test for stationarity
  4. Transform the series to make it stationary, if needed
  5. Find optimal order (p)
  6. Prepare training and test datasets
  7. Train the model
  8. Roll back the transformations, if any.
  9. Evaluate the model using test set
  10. Forecast to future
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

# Import Statsmodels
from statsmodels.tsa.api import VAR
from statsmodels.tsa.stattools import adfuller
from statsmodels.tools.eval_measures import rmse, aic

4. Import the datasets

For this article let’s use the time series used in Yash P Mehra’s 1994 article: “Wage Growth and the Inflation Process: An Empirical Approach”.

This dataset has the following 8 quarterly time series:

1. rgnp  : Real GNP.
2. pgnp  : Potential real GNP.
3. ulc   : Unit labor cost.
4. gdfco : Fixed weight deflator for personal consumption expenditure excluding food and energy.
5. gdf   : Fixed weight GNP deflator.
6. gdfim : Fixed weight import deflator.
7. gdfcf : Fixed weight deflator for food in personal consumption expenditure.
8. gdfce : Fixed weight deflator for energy in personal consumption expenditure.

Let’s import the data.

filepath = 'https://raw.githubusercontent.com/selva86/datasets/master/Raotbl6.csv'
df = pd.read_csv(filepath, parse_dates=['date'], index_col='date')
print(df.shape)  # (123, 8)
df.tail()

Input Time Series Data for VAR model

5. Visualize the Time Series

# Plot
fig, axes = plt.subplots(nrows=4, ncols=2, dpi=120, figsize=(10,6))
for i, ax in enumerate(axes.flatten()):
    data = df[df.columns[i]]
    ax.plot(data, color='red', linewidth=1)
    # Decorations
    ax.set_title(df.columns[i])
    ax.xaxis.set_ticks_position('none')
    ax.yaxis.set_ticks_position('none')
    ax.spines["top"].set_alpha(0)
    ax.tick_params(labelsize=6)

plt.tight_layout();
Multi Dimensional Time Series for VAR model
Actual Multi Dimensional Time Series for VAR model

Each of the series have a fairly similar trend patterns over the years except for gdfce and gdfim, where a different pattern is noticed starting in 1980.

Alright, next step in the analysis is to check for causality amongst these series. The Granger’s Causality test and the Cointegration test can help us with that.

6. Testing Causation using Granger’s Causality Test

The basis behind Vector AutoRegression is that each of the time series in the system influences each other. That is, you can predict the series with past values of itself along with other series in the system.

Using Granger’s Causality Test, it’s possible to test this relationship before even building the model.

So what does Granger’s Causality really test?

Granger’s causality tests the null hypothesis that the coefficients of past values in the regression equation is zero.

In simpler terms, the past values of time series (X) do not cause the other series (Y). So, if the p-value obtained from the test is lesser than the significance level of 0.05, then, you can safely reject the null hypothesis.

The below code implements the Granger’s Causality test for all possible combinations of the time series in a given dataframe and stores the p-values of each combination in the output matrix.

from statsmodels.tsa.stattools import grangercausalitytests
maxlag=12
test = 'ssr_chi2test'
def grangers_causation_matrix(data, variables, test='ssr_chi2test', verbose=False):    
    """Check Granger Causality of all possible combinations of the Time series.
    The rows are the response variable, columns are predictors. The values in the table 
    are the P-Values. P-Values lesser than the significance level (0.05), implies 
    the Null Hypothesis that the coefficients of the corresponding past values is 
    zero, that is, the X does not cause Y can be rejected.

    data      : pandas dataframe containing the time series variables
    variables : list containing names of the time series variables.
    """
    df = pd.DataFrame(np.zeros((len(variables), len(variables))), columns=variables, index=variables)
    for c in df.columns:
        for r in df.index:
            test_result = grangercausalitytests(data[[r, c]], maxlag=maxlag, verbose=False)
            p_values = [round(test_result[i+1][0][test][1],4) for i in range(maxlag)]
            if verbose: print(f'Y = {r}, X = {c}, P Values = {p_values}')
            min_p_value = np.min(p_values)
            df.loc[r, c] = min_p_value
    df.columns = [var + '_x' for var in variables]
    df.index = [var + '_y' for var in variables]
    return df

grangers_causation_matrix(df, variables = df.columns)        

Grangers Causality Test - Results p-value Matrix

So how to read the above output?

The row are the Response (Y) and the columns are the predictor series (X).

For example, if you take the value 0.0003 in (row 1, column 2), it refers to the p-value of pgnp_x causing rgnp_y. Whereas, the 0.000 in (row 2, column 1) refers to the p-value of rgnp_y causing pgnp_x.

So, how to interpret the p-values?

If a given p-value is < significance level (0.05), then, the corresponding X series (column) causes the Y (row).

For example, P-Value of 0.0003 at (row 1, column 2) represents the p-value of the Grangers Causality test for pgnp_x causing rgnp_y, which is less that the significance level of 0.05.

So, you can reject the null hypothesis and conclude pgnp_x causes rgnp_y.

Looking at the P-Values in the above table, you can pretty much observe that all the variables (time series) in the system are interchangeably causing each other.

This makes this system of multi time series a good candidate for using VAR models to forecast.

Next, let’s do the Cointegration test.

7. Cointegration Test

Cointegration test helps to establish the presence of a statistically significant connection between two or more time series.

But, what does Cointegration mean?

To understand that, you first need to know what is ‘order of integration’ (d).

Order of integration(d) is nothing but the number of differencing required to make a non-stationary time series stationary.

Now, when you have two or more time series, and there exists a linear combination of them that has an order of integration (d) less than that of the individual series, then the collection of series is said to be cointegrated.

Ok?

When two or more time series are cointegrated, it means they have a long run, statistically significant relationship.

This is the basic premise on which Vector Autoregression(VAR) models is based on. So, it’s fairly common to implement the cointegration test before starting to build VAR models.

Alright, So how to do this test?

Soren Johanssen in his paper (1991) devised a procedure to implement the cointegration test.

It is fairly straightforward to implement in python’s statsmodels, as you can see below.

from statsmodels.tsa.vector_ar.vecm import coint_johansen

def cointegration_test(df, alpha=0.05): 
    """Perform Johanson's Cointegration Test and Report Summary"""
    out = coint_johansen(df,-1,5)
    d = {'0.90':0, '0.95':1, '0.99':2}
    traces = out.lr1
    cvts = out.cvt[:, d[str(1-alpha)]]
    def adjust(val, length= 6): return str(val).ljust(length)

    # Summary
    print('Name   ::  Test Stat > C(95%)    =>   Signif  \n', '--'*20)
    for col, trace, cvt in zip(df.columns, traces, cvts):
        print(adjust(col), ':: ', adjust(round(trace,2), 9), ">", adjust(cvt, 8), ' =>  ' , trace > cvt)

cointegration_test(df)

Results:

Name   ::  Test Stat > C(95%)    =>   Signif  
 ----------------------------------------
rgnp   ::  248.0     > 143.6691  =>   True
pgnp   ::  183.12    > 111.7797  =>   True
ulc    ::  130.01    > 83.9383   =>   True
gdfco  ::  85.28     > 60.0627   =>   True
gdf    ::  55.05     > 40.1749   =>   True
gdfim  ::  31.59     > 24.2761   =>   True
gdfcf  ::  14.06     > 12.3212   =>   True
gdfce  ::  0.45      > 4.1296    =>   False

8. Split the Series into Training and Testing Data

Splitting the dataset into training and test data.

The VAR model will be fitted on df_train and then used to forecast the next 4 observations. These forecasts will be compared against the actuals present in test data.

To do the comparisons, we will use multiple forecast accuracy metrics, as seen later in this article.

nobs = 4
df_train, df_test = df[0:-nobs], df[-nobs:]

# Check size
print(df_train.shape)  # (119, 8)
print(df_test.shape)  # (4, 8)

9. Check for Stationarity and Make the Time Series Stationary

Since the VAR model requires the time series you want to forecast to be stationary, it is customary to check all the time series in the system for stationarity.

Just to refresh, a stationary time series is one whose characteristics like mean and variance does not change over time.

So, how to test for stationarity?

There is a suite of tests called unit-root tests. The popular ones are:

  1. Augmented Dickey-Fuller Test (ADF Test)
  2. KPSS test
  3. Philip-Perron test

Let’s use the ADF test for our purpose.

By the way, if a series is found to be non-stationary, you make it stationary by differencing the series once and repeat the test again until it becomes stationary.

Since, differencing reduces the length of the series by 1 and since all the time series has to be of the same length, you need to difference all the series in the system if you choose to difference at all.

Got it?

Let’s implement the ADF Test.

First, we implement a nice function (adfuller_test()) that writes out the results of the ADF test for any given time series and implement this function on each series one-by-one.

def adfuller_test(series, signif=0.05, name='', verbose=False):
    """Perform ADFuller to test for Stationarity of given series and print report"""
    r = adfuller(series, autolag='AIC')
    output = {'test_statistic':round(r[0], 4), 'pvalue':round(r[1], 4), 'n_lags':round(r[2], 4), 'n_obs':r[3]}
    p_value = output['pvalue'] 
    def adjust(val, length= 6): return str(val).ljust(length)

    # Print Summary
    print(f'    Augmented Dickey-Fuller Test on "{name}"', "\n   ", '-'*47)
    print(f' Null Hypothesis: Data has unit root. Non-Stationary.')
    print(f' Significance Level    = {signif}')
    print(f' Test Statistic        = {output["test_statistic"]}')
    print(f' No. Lags Chosen       = {output["n_lags"]}')

    for key,val in r[4].items():
        print(f' Critical value {adjust(key)} = {round(val, 3)}')

    if p_value <= signif:
        print(f" => P-Value = {p_value}. Rejecting Null Hypothesis.")
        print(f" => Series is Stationary.")
    else:
        print(f" => P-Value = {p_value}. Weak evidence to reject the Null Hypothesis.")
        print(f" => Series is Non-Stationary.")    

Call the adfuller_test() on each series.

# ADF Test on each column
for name, column in df_train.iteritems():
    adfuller_test(column, name=column.name)
    print('\n')

Results:

    Augmented Dickey-Fuller Test on "rgnp" 
    -----------------------------------------------
 Null Hypothesis: Data has unit root. Non-Stationary.
 Significance Level    = 0.05
 Test Statistic        = 0.5428
 No. Lags Chosen       = 2
 Critical value 1%     = -3.488
 Critical value 5%     = -2.887
 Critical value 10%    = -2.58
 => P-Value = 0.9861. Weak evidence to reject the Null Hypothesis.
 => Series is Non-Stationary.


    Augmented Dickey-Fuller Test on "pgnp" 
    -----------------------------------------------
 Null Hypothesis: Data has unit root. Non-Stationary.
 Significance Level    = 0.05
 Test Statistic        = 1.1556
 No. Lags Chosen       = 1
 Critical value 1%     = -3.488
 Critical value 5%     = -2.887
 Critical value 10%    = -2.58
 => P-Value = 0.9957. Weak evidence to reject the Null Hypothesis.
 => Series is Non-Stationary.


    Augmented Dickey-Fuller Test on "ulc" 
    -----------------------------------------------
 Null Hypothesis: Data has unit root. Non-Stationary.
 Significance Level    = 0.05
 Test Statistic        = 1.2474
 No. Lags Chosen       = 2
 Critical value 1%     = -3.488
 Critical value 5%     = -2.887
 Critical value 10%    = -2.58
 => P-Value = 0.9963. Weak evidence to reject the Null Hypothesis.
 => Series is Non-Stationary.


    Augmented Dickey-Fuller Test on "gdfco" 
    -----------------------------------------------
 Null Hypothesis: Data has unit root. Non-Stationary.
 Significance Level    = 0.05
 Test Statistic        = 1.1954
 No. Lags Chosen       = 3
 Critical value 1%     = -3.489
 Critical value 5%     = -2.887
 Critical value 10%    = -2.58
 => P-Value = 0.996. Weak evidence to reject the Null Hypothesis.
 => Series is Non-Stationary.


    Augmented Dickey-Fuller Test on "gdf" 
    -----------------------------------------------
 Null Hypothesis: Data has unit root. Non-Stationary.
 Significance Level    = 0.05
 Test Statistic        = 1.676
 No. Lags Chosen       = 7
 Critical value 1%     = -3.491
 Critical value 5%     = -2.888
 Critical value 10%    = -2.581
 => P-Value = 0.9981. Weak evidence to reject the Null Hypothesis.
 => Series is Non-Stationary.


    Augmented Dickey-Fuller Test on "gdfim" 
    -----------------------------------------------
 Null Hypothesis: Data has unit root. Non-Stationary.
 Significance Level    = 0.05
 Test Statistic        = -0.0799
 No. Lags Chosen       = 1
 Critical value 1%     = -3.488
 Critical value 5%     = -2.887
 Critical value 10%    = -2.58
 => P-Value = 0.9514. Weak evidence to reject the Null Hypothesis.
 => Series is Non-Stationary.


    Augmented Dickey-Fuller Test on "gdfcf" 
    -----------------------------------------------
 Null Hypothesis: Data has unit root. Non-Stationary.
 Significance Level    = 0.05
 Test Statistic        = 1.4395
 No. Lags Chosen       = 8
 Critical value 1%     = -3.491
 Critical value 5%     = -2.888
 Critical value 10%    = -2.581
 => P-Value = 0.9973. Weak evidence to reject the Null Hypothesis.
 => Series is Non-Stationary.


    Augmented Dickey-Fuller Test on "gdfce" 
    -----------------------------------------------
 Null Hypothesis: Data has unit root. Non-Stationary.
 Significance Level    = 0.05
 Test Statistic        = -0.3402
 No. Lags Chosen       = 8
 Critical value 1%     = -3.491
 Critical value 5%     = -2.888
 Critical value 10%    = -2.581
 => P-Value = 0.9196. Weak evidence to reject the Null Hypothesis.
 => Series is Non-Stationary.

The ADF test confirms none of the time series is stationary. Let’s difference all of them once and check again.

# 1st difference
df_differenced = df_train.diff().dropna()

Re-run ADF test on each differenced series.

# ADF Test on each column of 1st Differences Dataframe
for name, column in df_differenced.iteritems():
    adfuller_test(column, name=column.name)
    print('\n')
    Augmented Dickey-Fuller Test on "rgnp" 
    -----------------------------------------------
 Null Hypothesis: Data has unit root. Non-Stationary.
 Significance Level    = 0.05
 Test Statistic        = -5.3448
 No. Lags Chosen       = 1
 Critical value 1%     = -3.488
 Critical value 5%     = -2.887
 Critical value 10%    = -2.58
 => P-Value = 0.0. Rejecting Null Hypothesis.
 => Series is Stationary.


    Augmented Dickey-Fuller Test on "pgnp" 
    -----------------------------------------------
 Null Hypothesis: Data has unit root. Non-Stationary.
 Significance Level    = 0.05
 Test Statistic        = -1.8282
 No. Lags Chosen       = 0
 Critical value 1%     = -3.488
 Critical value 5%     = -2.887
 Critical value 10%    = -2.58
 => P-Value = 0.3666. Weak evidence to reject the Null Hypothesis.
 => Series is Non-Stationary.


    Augmented Dickey-Fuller Test on "ulc" 
    -----------------------------------------------
 Null Hypothesis: Data has unit root. Non-Stationary.
 Significance Level    = 0.05
 Test Statistic        = -3.4658
 No. Lags Chosen       = 1
 Critical value 1%     = -3.488
 Critical value 5%     = -2.887
 Critical value 10%    = -2.58
 => P-Value = 0.0089. Rejecting Null Hypothesis.
 => Series is Stationary.


    Augmented Dickey-Fuller Test on "gdfco" 
    -----------------------------------------------
 Null Hypothesis: Data has unit root. Non-Stationary.
 Significance Level    = 0.05
 Test Statistic        = -1.4385
 No. Lags Chosen       = 2
 Critical value 1%     = -3.489
 Critical value 5%     = -2.887
 Critical value 10%    = -2.58
 => P-Value = 0.5637. Weak evidence to reject the Null Hypothesis.
 => Series is Non-Stationary.


    Augmented Dickey-Fuller Test on "gdf" 
    -----------------------------------------------
 Null Hypothesis: Data has unit root. Non-Stationary.
 Significance Level    = 0.05
 Test Statistic        = -1.1289
 No. Lags Chosen       = 2
 Critical value 1%     = -3.489
 Critical value 5%     = -2.887
 Critical value 10%    = -2.58
 => P-Value = 0.7034. Weak evidence to reject the Null Hypothesis.
 => Series is Non-Stationary.


    Augmented Dickey-Fuller Test on "gdfim" 
    -----------------------------------------------
 Null Hypothesis: Data has unit root. Non-Stationary.
 Significance Level    = 0.05
 Test Statistic        = -4.1256
 No. Lags Chosen       = 0
 Critical value 1%     = -3.488
 Critical value 5%     = -2.887
 Critical value 10%    = -2.58
 => P-Value = 0.0009. Rejecting Null Hypothesis.
 => Series is Stationary.


    Augmented Dickey-Fuller Test on "gdfcf" 
    -----------------------------------------------
 Null Hypothesis: Data has unit root. Non-Stationary.
 Significance Level    = 0.05
 Test Statistic        = -2.0545
 No. Lags Chosen       = 7
 Critical value 1%     = -3.491
 Critical value 5%     = -2.888
 Critical value 10%    = -2.581
 => P-Value = 0.2632. Weak evidence to reject the Null Hypothesis.
 => Series is Non-Stationary.


    Augmented Dickey-Fuller Test on "gdfce" 
    -----------------------------------------------
 Null Hypothesis: Data has unit root. Non-Stationary.
 Significance Level    = 0.05
 Test Statistic        = -3.1543
 No. Lags Chosen       = 7
 Critical value 1%     = -3.491
 Critical value 5%     = -2.888
 Critical value 10%    = -2.581
 => P-Value = 0.0228. Rejecting Null Hypothesis.
 => Series is Stationary.

After the first difference, Real Wages (Manufacturing) is still not stationary. It’s critical value is between 5% and 10% significance level.

All of the series in the VAR model should have the same number of observations.

So, we are left with one of two choices.

That is, either proceed with 1st differenced series or difference all the series one more time.

# Second Differencing
df_differenced = df_differenced.diff().dropna()

Re-run ADF test again on each second differenced series.

# ADF Test on each column of 2nd Differences Dataframe
for name, column in df_differenced.iteritems():
    adfuller_test(column, name=column.name)
    print('\n')

Results:

    Augmented Dickey-Fuller Test on "rgnp" 
    -----------------------------------------------
 Null Hypothesis: Data has unit root. Non-Stationary.
 Significance Level    = 0.05
 Test Statistic        = -9.0123
 No. Lags Chosen       = 2
 Critical value 1%     = -3.489
 Critical value 5%     = -2.887
 Critical value 10%    = -2.58
 => P-Value = 0.0. Rejecting Null Hypothesis.
 => Series is Stationary.


    Augmented Dickey-Fuller Test on "pgnp" 
    -----------------------------------------------
 Null Hypothesis: Data has unit root. Non-Stationary.
 Significance Level    = 0.05
 Test Statistic        = -10.9813
 No. Lags Chosen       = 0
 Critical value 1%     = -3.488
 Critical value 5%     = -2.887
 Critical value 10%    = -2.58
 => P-Value = 0.0. Rejecting Null Hypothesis.
 => Series is Stationary.


    Augmented Dickey-Fuller Test on "ulc" 
    -----------------------------------------------
 Null Hypothesis: Data has unit root. Non-Stationary.
 Significance Level    = 0.05
 Test Statistic        = -8.769
 No. Lags Chosen       = 2
 Critical value 1%     = -3.489
 Critical value 5%     = -2.887
 Critical value 10%    = -2.58
 => P-Value = 0.0. Rejecting Null Hypothesis.
 => Series is Stationary.


    Augmented Dickey-Fuller Test on "gdfco" 
    -----------------------------------------------
 Null Hypothesis: Data has unit root. Non-Stationary.
 Significance Level    = 0.05
 Test Statistic        = -7.9102
 No. Lags Chosen       = 3
 Critical value 1%     = -3.49
 Critical value 5%     = -2.887
 Critical value 10%    = -2.581
 => P-Value = 0.0. Rejecting Null Hypothesis.
 => Series is Stationary.


    Augmented Dickey-Fuller Test on "gdf" 
    -----------------------------------------------
 Null Hypothesis: Data has unit root. Non-Stationary.
 Significance Level    = 0.05
 Test Statistic        = -10.0351
 No. Lags Chosen       = 1
 Critical value 1%     = -3.489
 Critical value 5%     = -2.887
 Critical value 10%    = -2.58
 => P-Value = 0.0. Rejecting Null Hypothesis.
 => Series is Stationary.


    Augmented Dickey-Fuller Test on "gdfim" 
    -----------------------------------------------
 Null Hypothesis: Data has unit root. Non-Stationary.
 Significance Level    = 0.05
 Test Statistic        = -9.4059
 No. Lags Chosen       = 1
 Critical value 1%     = -3.489
 Critical value 5%     = -2.887
 Critical value 10%    = -2.58
 => P-Value = 0.0. Rejecting Null Hypothesis.
 => Series is Stationary.


    Augmented Dickey-Fuller Test on "gdfcf" 
    -----------------------------------------------
 Null Hypothesis: Data has unit root. Non-Stationary.
 Significance Level    = 0.05
 Test Statistic        = -6.922
 No. Lags Chosen       = 5
 Critical value 1%     = -3.491
 Critical value 5%     = -2.888
 Critical value 10%    = -2.581
 => P-Value = 0.0. Rejecting Null Hypothesis.
 => Series is Stationary.


    Augmented Dickey-Fuller Test on "gdfce" 
    -----------------------------------------------
 Null Hypothesis: Data has unit root. Non-Stationary.
 Significance Level    = 0.05
 Test Statistic        = -5.1732
 No. Lags Chosen       = 8
 Critical value 1%     = -3.492
 Critical value 5%     = -2.889
 Critical value 10%    = -2.581
 => P-Value = 0.0. Rejecting Null Hypothesis.
 => Series is Stationary.

All the series are now stationary.

Let’s prepare the training and test datasets.

10. How to Select the Order (P) of VAR model

To select the right order of the VAR model, we iteratively fit increasing orders of VAR model and pick the order that gives a model with least AIC.

Though the usual practice is to look at the AIC, you can also check other best fit comparison estimates of BIC, FPE and HQIC.

model = VAR(df_differenced)
for i in [1,2,3,4,5,6,7,8,9]:
    result = model.fit(i)
    print('Lag Order =', i)
    print('AIC : ', result.aic)
    print('BIC : ', result.bic)
    print('FPE : ', result.fpe)
    print('HQIC: ', result.hqic, '\n')

Results:

Lag Order = 1
AIC :  -1.3679402315450664
BIC :  0.3411847146588838
FPE :  0.2552682517347198
HQIC:  -0.6741331335699554 

Lag Order = 2
AIC :  -1.621237394447824
BIC :  1.6249432095295848
FPE :  0.2011349437137139
HQIC:  -0.3036288826795923 

Lag Order = 3
AIC :  -1.7658008387012791
BIC :  3.0345473163767833
FPE :  0.18125103746164364
HQIC:  0.18239143783963296 

Lag Order = 4
AIC :  -2.000735164470318
BIC :  4.3712151376540875
FPE :  0.15556966521481097
HQIC:  0.5849359332771069 

Lag Order = 5
AIC :  -1.9619535608363954
BIC :  5.9993645622420955
FPE :  0.18692794389114886
HQIC:  1.268206331178333 

Lag Order = 6
AIC :  -2.3303386524829053
BIC :  7.2384526890885805
FPE :  0.16380374017443664
HQIC:  1.5514371669548073 

Lag Order = 7
AIC :  -2.592331352347129
BIC :  8.602387254937796
FPE :  0.1823868583715414
HQIC:  1.9483069621146551 

Lag Order = 8
AIC :  -3.317261976458205
BIC :  9.52219581032303
FPE :  0.15573163248209088
HQIC:  1.8896071386220985 

Lag Order = 9
AIC :  -4.804763125958631
BIC :  9.698613139231597
FPE :  0.08421466682671915
HQIC:  1.0758291640834052 

In the above output, the AIC drops to lowest at lag 4, then increases at lag 5 and then continuously drops further.

Let’s go with the lag 4 model.

An alternate method to choose the order(p) of the VAR models is to use the model.select_order(maxlags) method.

The selected order(p) is the order that gives the lowest ‘AIC’, ‘BIC’, ‘FPE’ and ‘HQIC’ scores.

x = model.select_order(maxlags=12)
x.summary()

How to select the order of the VAR model

According to FPE and HQIC, the optimal lag is observed at a lag order of 3.

I, however, don’t have an explanation for why the observed AIC and BIC values differ when using result.aic versus as seen using model.select_order().

Since the explicitly computed AIC is the lowest at lag 4, I choose the selected order as 4.

11. Train the VAR Model of Selected Order(p)

model_fitted = model.fit(4)
model_fitted.summary()

Results:

  Summary of Regression Results   
==================================
Model:                         VAR
Method:                        OLS
Date:           Sat, 18, May, 2019
Time:                     11:35:15
--------------------------------------------------------------------
No. of Equations:         8.00000    BIC:                    4.37122
Nobs:                     113.000    HQIC:                  0.584936
Log likelihood:          -905.679    FPE:                   0.155570
AIC:                     -2.00074    Det(Omega_mle):       0.0200322
--------------------------------------------------------------------
Results for equation rgnp
===========================================================================
              coefficient       std. error           t-stat            prob
---------------------------------------------------------------------------
const            2.430021         2.677505            0.908           0.364
L1.rgnp         -0.750066         0.159023           -4.717           0.000
L1.pgnp         -0.095621         4.938865           -0.019           0.985
L1.ulc          -6.213996         4.637452           -1.340           0.180
L1.gdfco        -7.414768        10.184884           -0.728           0.467
L1.gdf         -24.864063        20.071245           -1.239           0.215
L1.gdfim         1.082913         4.309034            0.251           0.802
L1.gdfcf        16.327252         5.892522            2.771           0.006
L1.gdfce         0.910522         2.476361            0.368           0.713
L2.rgnp         -0.568178         0.163971           -3.465           0.001
L2.pgnp         -1.156201         4.931931           -0.234           0.815
L2.ulc         -11.157111         5.381825           -2.073           0.038
L2.gdfco         3.012518        12.928317            0.233           0.816
L2.gdf         -18.143523        24.090598           -0.753           0.451
L2.gdfim        -4.438115         4.410654           -1.006           0.314
L2.gdfcf        13.468228         7.279772            1.850           0.064
L2.gdfce         5.130419         2.805310            1.829           0.067
L3.rgnp         -0.514985         0.152724           -3.372           0.001
L3.pgnp        -11.483607         5.392037           -2.130           0.033
L3.ulc         -14.195308         5.188718           -2.736           0.006
L3.gdfco       -10.154967        13.105508           -0.775           0.438
L3.gdf         -15.438858        21.610822           -0.714           0.475
L3.gdfim        -6.405290         4.292790           -1.492           0.136
L3.gdfcf         9.217402         7.081652            1.302           0.193
L3.gdfce         5.279941         2.833925            1.863           0.062
L4.rgnp         -0.166878         0.138786           -1.202           0.229
L4.pgnp          5.329900         5.795837            0.920           0.358
L4.ulc          -4.834548         5.259608           -0.919           0.358
L4.gdfco        10.841602        10.526530            1.030           0.303
L4.gdf         -17.651510        18.746673           -0.942           0.346
L4.gdfim        -1.971233         4.029415           -0.489           0.625
L4.gdfcf         0.617824         5.842684            0.106           0.916
L4.gdfce        -2.977187         2.594251           -1.148           0.251
===========================================================================

Results for equation pgnp
===========================================================================
              coefficient       std. error           t-stat            prob
---------------------------------------------------------------------------
const            0.094556         0.063491            1.489           0.136
L1.rgnp         -0.004231         0.003771           -1.122           0.262
L1.pgnp          0.082204         0.117114            0.702           0.483
L1.ulc          -0.097769         0.109966           -0.889           0.374

(... TRUNCATED because of long output....)
(... TRUNCATED because of long output....)
(... TRUNCATED because of long output....)

Correlation matrix of residuals
             rgnp      pgnp       ulc     gdfco       gdf     gdfim     gdfcf     gdfce
rgnp     1.000000  0.248342 -0.668492 -0.160133 -0.047777  0.084925  0.009962  0.205557
pgnp     0.248342  1.000000 -0.148392 -0.167766 -0.134896  0.007830 -0.169435  0.032134
ulc     -0.668492 -0.148392  1.000000  0.268127  0.327761  0.171497  0.135410 -0.026037
gdfco   -0.160133 -0.167766  0.268127  1.000000  0.303563  0.232997 -0.035042  0.184834
gdf     -0.047777 -0.134896  0.327761  0.303563  1.000000  0.196670  0.446012  0.309277
gdfim    0.084925  0.007830  0.171497  0.232997  0.196670  1.000000 -0.089852  0.707809
gdfcf    0.009962 -0.169435  0.135410 -0.035042  0.446012 -0.089852  1.000000 -0.197099
gdfce    0.205557  0.032134 -0.026037  0.184834  0.309277  0.707809 -0.197099  1.000000

12. Check for Serial Correlation of Residuals (Errors) using Durbin Watson Statistic

Serial correlation of residuals is used to check if there is any leftover pattern in the residuals (errors).

What does this mean to us?

If there is any correlation left in the residuals, then, there is some pattern in the time series that is still left to be explained by the model. In that case, the typical course of action is to either increase the order of the model or induce more predictors into the system or look for a different algorithm to model the time series.

So, checking for serial correlation is to ensure that the model is sufficiently able to explain the variances and patterns in the time series.

Alright, coming back to topic.

A common way of checking for serial correlation of errors can be measured using the Durbin Watson’s Statistic.

Durbin Watson Statistic - Formula

The value of this statistic can vary between 0 and 4. The closer it is to the value 2, then there is no significant serial correlation. The closer to 0, there is a positive serial correlation, and the closer it is to 4 implies negative serial correlation.

from statsmodels.stats.stattools import durbin_watson
out = durbin_watson(model_fitted.resid)

for col, val in zip(df.columns, out):
    print(adjust(col), ':', round(val, 2))

Results:

rgnp   : 2.09
pgnp   : 2.02
ulc    : 2.17
gdfco  : 2.05
gdf    : 2.25
gdfim  : 1.99
gdfcf  : 2.2
gdfce  : 2.17

The serial correlation seems quite alright. Let’s proceed with the forecast.

13. How to Forecast VAR model using statsmodels

In order to forecast, the VAR model expects up to the lag order number of observations from the past data.

This is because, the terms in the VAR model are essentially the lags of the various time series in the dataset, so you need to provide it as many of the previous values as indicated by the lag order used by the model.

# Get the lag order
lag_order = model_fitted.k_ar
print(lag_order)  #> 4

# Input data for forecasting
forecast_input = df_differenced.values[-lag_order:]
forecast_input
4

array([[ 13.5,   0.1,   1.4,   0.1,   0.1,  -0.1,   0.4,  -2. ],
       [-23.6,   0.2,  -2. ,  -0.5,  -0.1,  -0.2,  -0.3,  -1.2],
       [ -3.3,   0.1,   3.1,   0.5,   0.3,   0.4,   0.9,   2.2],
       [ -3.9,   0.2,  -2.1,  -0.4,   0.2,  -1.5,   0.9,  -0.3]])

Let’s forecast.

# Forecast
fc = model_fitted.forecast(y=forecast_input, steps=nobs)
df_forecast = pd.DataFrame(fc, index=df.index[-nobs:], columns=df.columns + '_2d')
df_forecast

Raw forecasts from the VAR model

The forecasts are generated but it is on the scale of the training data used by the model. So, to bring it back up to its original scale, you need to de-difference it as many times you had differenced the original input data.

In this case it is two times.

14. Invert the transformation to get the real forecast

def invert_transformation(df_train, df_forecast, second_diff=False):
    """Revert back the differencing to get the forecast to original scale."""
    df_fc = df_forecast.copy()
    columns = df_train.columns
    for col in columns:        
        # Roll back 2nd Diff
        if second_diff:
            df_fc[str(col)+'_1d'] = (df_train[col].iloc[-1]-df_train[col].iloc[-2]) + df_fc[str(col)+'_2d'].cumsum()
        # Roll back 1st Diff
        df_fc[str(col)+'_forecast'] = df_train[col].iloc[-1] + df_fc[str(col)+'_1d'].cumsum()
    return df_fc
df_results = invert_transformation(train, df_forecast, second_diff=True)        
df_results.loc[:, ['rgnp_forecast', 'pgnp_forecast', 'ulc_forecast', 'gdfco_forecast',
                   'gdf_forecast', 'gdfim_forecast', 'gdfcf_forecast', 'gdfce_forecast']]

VAR Forecasts

The forecasts are back to the original scale. Let’s plot the forecasts against the actuals from test data.

15. Plot of Forecast vs Actuals

fig, axes = plt.subplots(nrows=int(len(df.columns)/2), ncols=2, dpi=150, figsize=(10,10))
for i, (col,ax) in enumerate(zip(df.columns, axes.flatten())):
    df_results[col+'_forecast'].plot(legend=True, ax=ax).autoscale(axis='x',tight=True)
    df_test[col][-nobs:].plot(legend=True, ax=ax);
    ax.set_title(col + ": Forecast vs Actuals")
    ax.xaxis.set_ticks_position('none')
    ax.yaxis.set_ticks_position('none')
    ax.spines["top"].set_alpha(0)
    ax.tick_params(labelsize=6)

plt.tight_layout();
Forecast vs Actuals comparison of VAR model
Forecast vs Actuals comparison of VAR model

16. Evaluate the Forecasts

To evaluate the forecasts, let’s compute a comprehensive set of metrics, namely, the MAPE, ME, MAE, MPE, RMSE, corr and minmax.

from statsmodels.tsa.stattools import acf
def forecast_accuracy(forecast, actual):
    mape = np.mean(np.abs(forecast - actual)/np.abs(actual))  # MAPE
    me = np.mean(forecast - actual)             # ME
    mae = np.mean(np.abs(forecast - actual))    # MAE
    mpe = np.mean((forecast - actual)/actual)   # MPE
    rmse = np.mean((forecast - actual)**2)**.5  # RMSE
    corr = np.corrcoef(forecast, actual)[0,1]   # corr
    mins = np.amin(np.hstack([forecast[:,None], 
                              actual[:,None]]), axis=1)
    maxs = np.amax(np.hstack([forecast[:,None], 
                              actual[:,None]]), axis=1)
    minmax = 1 - np.mean(mins/maxs)             # minmax
    return({'mape':mape, 'me':me, 'mae': mae, 
            'mpe': mpe, 'rmse':rmse, 'corr':corr, 'minmax':minmax})

print('Forecast Accuracy of: rgnp')
accuracy_prod = forecast_accuracy(df_results['rgnp_forecast'].values, df_test['rgnp'])
for k, v in accuracy_prod.items():
    print(adjust(k), ': ', round(v,4))

print('\nForecast Accuracy of: pgnp')
accuracy_prod = forecast_accuracy(df_results['pgnp_forecast'].values, df_test['pgnp'])
for k, v in accuracy_prod.items():
    print(adjust(k), ': ', round(v,4))

print('\nForecast Accuracy of: ulc')
accuracy_prod = forecast_accuracy(df_results['ulc_forecast'].values, df_test['ulc'])
for k, v in accuracy_prod.items():
    print(adjust(k), ': ', round(v,4))

print('\nForecast Accuracy of: gdfco')
accuracy_prod = forecast_accuracy(df_results['gdfco_forecast'].values, df_test['gdfco'])
for k, v in accuracy_prod.items():
    print(adjust(k), ': ', round(v,4))

print('\nForecast Accuracy of: gdf')
accuracy_prod = forecast_accuracy(df_results['gdf_forecast'].values, df_test['gdf'])
for k, v in accuracy_prod.items():
    print(adjust(k), ': ', round(v,4))

print('\nForecast Accuracy of: gdfim')
accuracy_prod = forecast_accuracy(df_results['gdfim_forecast'].values, df_test['gdfim'])
for k, v in accuracy_prod.items():
    print(adjust(k), ': ', round(v,4))

print('\nForecast Accuracy of: gdfcf')
accuracy_prod = forecast_accuracy(df_results['gdfcf_forecast'].values, df_test['gdfcf'])
for k, v in accuracy_prod.items():
    print(adjust(k), ': ', round(v,4))

print('\nForecast Accuracy of: gdfce')
accuracy_prod = forecast_accuracy(df_results['gdfce_forecast'].values, df_test['gdfce'])
for k, v in accuracy_prod.items():
    print(adjust(k), ': ', round(v,4))
Forecast Accuracy of: rgnp
mape   :  0.0192
me     :  79.1031
mae    :  79.1031
mpe    :  0.0192
rmse   :  82.0245
corr   :  0.9849
minmax :  0.0188

Forecast Accuracy of: pgnp
mape   :  0.0005
me     :  2.0432
mae    :  2.0432
mpe    :  0.0005
rmse   :  2.146
corr   :  1.0
minmax :  0.0005

Forecast Accuracy of: ulc
mape   :  0.0081
me     :  -1.4947
mae    :  1.4947
mpe    :  -0.0081
rmse   :  1.6856
corr   :  0.963
minmax :  0.0081

Forecast Accuracy of: gdfco
mape   :  0.0033
me     :  0.0007
mae    :  0.4384
mpe    :  0.0
rmse   :  0.5169
corr   :  0.9407
minmax :  0.0032

Forecast Accuracy of: gdf
mape   :  0.0023
me     :  0.2554
mae    :  0.29
mpe    :  0.002
rmse   :  0.3392
corr   :  0.9905
minmax :  0.0022

Forecast Accuracy of: gdfim
mape   :  0.0097
me     :  -0.4166
mae    :  1.06
mpe    :  -0.0038
rmse   :  1.0826
corr   :  0.807
minmax :  0.0096

Forecast Accuracy of: gdfcf
mape   :  0.0036
me     :  -0.0271
mae    :  0.4604
mpe    :  -0.0002
rmse   :  0.5286
corr   :  0.9713
minmax :  0.0036

Forecast Accuracy of: gdfce
mape   :  0.0177
me     :  0.2577
mae    :  1.72
mpe    :  0.0031
rmse   :  2.034
corr   :  0.764
minmax :  0.0175

17. Conclusion

In this article we covered VAR from scratch beginning from the intuition behind it, interpreting the formula, causality tests, finding the optimal order of the VAR model, preparing the data for forecasting, build the model, checking for serial autocorrelation, inverting the transform to get the actual forecasts, plotting the results and computing the accuracy metrics.

Hope you enjoyed reading this as much as I did writing it. I will see you in the next one.

Mahalonobis Distance – Understanding the math with examples (python)

Mahalanobis distance is an effective multivariate distance metric that measures the distance between a point and a distribution. It is an extremely useful metric having, excellent applications in multivariate anomaly detection, classification on highly imbalanced datasets and one-class classification. This post explains the intuition and the math with practical examples on three machine learning use cases.

Mahalanobis Distance – Understanding the Math and Applications. Photo by Greg Nunes.

Content

[columnize]
  1. Introduction
  2. What’s wrong with using Euclidean Distance for Multivariate data?
  3. What is Mahalanobis Distance?
  4. The math and intuition behind Mahalanobis Distance
  5. How to compute Mahalanobis Distance in Python
  6. Usecase 1: Multivariate outlier detection using Mahalanobis distance
  7. Usecase 2: Mahalanobis Distance for Classification Problems
  8. Usecase 3: One-Class Classification
  9. Conclusion
  10. [/columnize]

1. Introduction

Mahalanobis distance is an effective multivariate distance metric that measures the distance between a point (vector) and a distribution.

It has excellent applications in multivariate anomaly detection, classification on highly imbalanced datasets and one-class classification and more untapped use cases.

Considering its extremely useful applications, this metric is seldom discussed or used in stats or ML workflows. This post explains the why and the when to use Mahalanobis distance and then explains the intuition and the math with useful applications.

2. What’s wrong with using Euclidean Distance for Multivariate data?

Let’s start with the basics. Euclidean distance is the commonly used straight line distance between two points.

If the two points are in a two-dimensional plane (meaning, you have two numeric columns (p) and (q)) in your dataset), then the Euclidean distance between the two points (p1, q1) and (p2, q2) is:

Simple Euclidean Distance Formula

This formula may be extended to as many dimensions you want:

Multivariate Euclidean Distance Formula

Well, Euclidean distance will work fine as long as the dimensions are equally weighted and are independent of each other.

What do I mean by that?

Let’s consider the following tables:

Scales Comparison Table

The two tables above show the ‘area’ and ‘price’ of the same objects. Only the units of the variables change.

Since both tables represent the same entities, the distance between any two rows, point A and point B should be the same. But Euclidean distance gives a different value even though the distances are technically the same in physical space.

This can technically be overcome by scaling the variables, by computing the z-score (ex: (x – mean) / std) or make it vary within a particular range like between 0 and 1.

But there is another major drawback.

That is, if the dimensions (columns in your dataset) are correlated to one another, which is typically the case in real-world datasets, the Euclidean distance between a point and the center of the points (distribution) can give little or misleading information about how close a point really is to the cluster.

Mahalanobis Distance Usecase

The above image (on the right) is a simple scatterplot of two variables that are positively correlated with each other. That is, as the value of one variable (x-axis) increases, so does the value of the other variable (y-axis).

The two points above are equally distant (Euclidean) from the center. But only one of them (blue) is actually more close to the cluster, even though, technically the Euclidean distance between the two points are equal.

This is because, Euclidean distance is a distance between two points only. It does not consider how the rest of the points in the dataset vary. So, it cannot be used to really judge how close a point actually is to a distribution of points.

What we need here is a more robust distance metric that is an accurate representation of how distant a point is from a distribution.

3. What is Mahalanobis Distance?

Mahalonobis distance is the distance between a point and a distribution. And not between two distinct points. It is effectively a multivariate equivalent of the Euclidean distance.

It was introduced by Prof. P. C. Mahalanobis in 1936 and has been used in various statistical applications ever since. However, it’s not so well known or used in the machine learning practice. Well, let’s get into it.

So computationally, how is Mahalanobis distance different from Euclidean distance?

  1. It transforms the columns into uncorrelated variables
  2. Scale the columns to make their variance equal to 1
  3. Finally, it calculates the Euclidean distance.

The above three steps are meant to address the problems with Euclidean distance we just talked about. But how?

Let’s look at the formula and try to understand its components.

4. The math and intuition behind Mahalanobis Distance

The formula to compute Mahalanobis distance is as follows:

Mahalanobis Distance Formula

where, 
 - D^2        is the square of the Mahalanobis distance. 
 - x          is the vector of the observation (row in a dataset), 
 - m          is the vector of mean values of independent variables (mean of each column), 
 - C^(-1)     is the inverse covariance matrix of independent variables. 

So, how to understand the above formula?

Let’s take the (x – m)^T . C^(-1) term.

(x – m) is essentially the distance of the vector from the mean. We then divide this by the covariance matrix (or multiply by the inverse of the covariance matrix).

If you think about it, this is essentially a multivariate equivalent of the regular standardization (z = (x – mu)/sigma). That is, z = (x vector) – (mean vector) / (covariance matrix).

So, What is the effect of dividing by the covariance?

If the variables in your dataset are strongly correlated, then, the covariance will be high. Dividing by a large covariance will effectively reduce the distance.

Likewise, if the X’s are not correlated, then the covariance is not high and the distance is not reduced much.

So effectively, it addresses both the problems of scale as well as the correlation of the variables that we talked about in the introduction.

5. How to compute Mahalanobis Distance in Python

import pandas as pd
import scipy as sp
import numpy as np

filepath = 'https://raw.githubusercontent.com/selva86/datasets/master/diamonds.csv'
df = pd.read_csv(filepath).iloc[:, [0,4,6]]
df.head()

Diamonds Dataset

Let’s write the function to calculate Mahalanobis Distance.

def mahalanobis(x=None, data=None, cov=None):
    """Compute the Mahalanobis Distance between each row of x and the data  
    x    : vector or matrix of data with, say, p columns.
    data : ndarray of the distribution from which Mahalanobis distance of each observation of x is to be computed.
    cov  : covariance matrix (p x p) of the distribution. If None, will be computed from data.
    """
    x_minus_mu = x - np.mean(data)
    if not cov:
        cov = np.cov(data.values.T)
    inv_covmat = sp.linalg.inv(cov)
    left_term = np.dot(x_minus_mu, inv_covmat)
    mahal = np.dot(left_term, x_minus_mu.T)
    return mahal.diagonal()

df_x = df[['carat', 'depth', 'price']].head(500)
df_x['mahala'] = mahalanobis(x=df_x, data=df[['carat', 'depth', 'price']])
df_x.head()

Mahalanobis Distance Computed

6. Usecase 1: Multivariate outlier detection using Mahalanobis distance

Assuming that the test statistic follows chi-square distributed with ‘n’ degree of freedom, the critical value at a 0.01 significance level and 2 degrees of freedom is computed as:

# Critical values for two degrees of freedom
from scipy.stats import chi2
chi2.ppf((1-0.01), df=2)
#> 9.21

That mean an observation can be considered as extreme if its Mahalanobis distance exceeds 9.21.

If you prefer P values instead to determine if an observation is extreme or not, the P values can be computed as follows:

# Compute the P-Values
df_x['p_value'] = 1 - chi2.cdf(df_x['mahala'], 2)

# Extreme values with a significance level of 0.01
df_x.loc[df_x.p_value < 0.01].head(10)

Mahalanobis Distance P-Values

If you compare the above observations against rest of the dataset, they are clearly extreme.

7. Usecase 2: Mahalanobis Distance for Classification Problems

Mahalanobis distance can be used for classification problems. A naive implementation of a Mahalanobis classifier is coded below. The intuition is that, an observation is assigned the class that it is closest to based on the Mahalanobis distance.

Let’s see an example implementation on the BreastCancer dataset, where the objective is to determine if a tumour is benign or malignant.

df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/BreastCancer.csv', 
                 usecols=['Cl.thickness', 'Cell.size', 'Marg.adhesion', 
                          'Epith.c.size', 'Bare.nuclei', 'Bl.cromatin', 'Normal.nucleoli', 
                          'Mitoses', 'Class'])

df.dropna(inplace=True)  # drop missing values.
df.head()
Breast Cancer Dataset
Breast Cancer Dataset

Let’s split the dataset in 70:30 ratio as Train and Test. And the training dataset is split into homogeneous groups of ‘pos'(1) and ‘neg'(0) classes. To predict the class of the test dataset, we measure the Mahalanobis distances between a given observation (row) and both the positive (xtrain_pos) and negative datasets(xtrain_neg).

Then that observation is assigned the class based on the group it is closest to.

from sklearn.model_selection import train_test_split
xtrain, xtest, ytrain, ytest = train_test_split(df.drop('Class', axis=1), df['Class'], test_size=.3, random_state=100)

# Split the training data as pos and neg
xtrain_pos = xtrain.loc[ytrain == 1, :]
xtrain_neg = xtrain.loc[ytrain == 0, :]

Let’s build the MahalanobiBinaryClassifier. To do that, you need to define the predict_proba() and the predict() methods. This classifier does not require a separate fit() (training) method.

class MahalanobisBinaryClassifier():
    def __init__(self, xtrain, ytrain):
        self.xtrain_pos = xtrain.loc[ytrain == 1, :]
        self.xtrain_neg = xtrain.loc[ytrain == 0, :]

    def predict_proba(self, xtest):
        pos_neg_dists = [(p,n) for p, n in zip(mahalanobis(xtest, self.xtrain_pos), mahalanobis(xtest, self.xtrain_neg))]
        return np.array([(1-n/(p+n), 1-p/(p+n)) for p,n in pos_neg_dists])

    def predict(self, xtest):
        return np.array([np.argmax(row) for row in self.predict_proba(xtest)])


clf = MahalanobisBinaryClassifier(xtrain, ytrain)        
pred_probs = clf.predict_proba(xtest)
pred_class = clf.predict(xtest)

# Pred and Truth
pred_actuals = pd.DataFrame([(pred, act) for pred, act in zip(pred_class, ytest)], columns=['pred', 'true'])
print(pred_actuals[:5])        

Output:

   pred  true
0     0     0
1     1     1
2     0     0
3     0     0
4     0     0

Let’s see how the classifier performed on the test dataset.

from sklearn.metrics import classification_report, accuracy_score, roc_auc_score, confusion_matrix
truth = pred_actuals.loc[:, 'true']
pred = pred_actuals.loc[:, 'pred']
scores = np.array(pred_probs)[:, 1]
print('AUROC: ', roc_auc_score(truth, scores))
print('\nConfusion Matrix: \n', confusion_matrix(truth, pred))
print('\nAccuracy Score: ', accuracy_score(truth, pred))
print('\nClassification Report: \n', classification_report(truth, pred))

Output:

AUROC:  0.9909743589743589

Confusion Matrix: 
 [[113  17]
 [  0  75]]

Accuracy Score:  0.9170731707317074

Classification Report: 
              precision    recall  f1-score   support

          0       1.00      0.87      0.93       130
          1       0.82      1.00      0.90        75

avg / total       0.93      0.92      0.92       205

Mahalanobis distance alone is able to contribute to this much accuracy (92%).

8. Usecase 3: One-Class Classification

One Class classification is a type of algorithm where the training dataset contains observations belonging to only one class.

With only that information known, the objective is to figure out if a given observation in a new (or test) dataset belongs to that class.

You might wonder when would such a situation occur. Well, it’s a quite common problem in Data Science.

For example consider the following situation: You have a large dataset containing millions of records that are NOT yet categorized as 1’s and 0’s. But you also have with you a small sample dataset containing only positive (1’s) records. By learning the information in this sample dataset, you want to classify all the records in the large dataset as 1’s and 0’s.

Based on the information from the sample dataset, it is possible to tell if any given sample is a 1 or 0 by viewing only the 1’s (and having no knowledge of the 0’s at all).

This can be done using Mahalanobis Distance.

Let’s try this on the BreastCancer dataset, only this time we will consider only the malignant observations (class column=1) in the training data.

df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/BreastCancer.csv', 
                 usecols=['Cl.thickness', 'Cell.size', 'Marg.adhesion', 
                          'Epith.c.size', 'Bare.nuclei', 'Bl.cromatin', 'Normal.nucleoli', 
                          'Mitoses', 'Class'])

df.dropna(inplace=True)  # drop missing values.

Splitting 50% of the dataset into training and test. Only the 1’s are retained in the training data.

from sklearn.model_selection import train_test_split
xtrain, xtest, ytrain, ytest = train_test_split(df.drop('Class', axis=1), df['Class'], test_size=.5, random_state=100)

# Split the training data as pos and neg
xtrain_pos = xtrain.loc[ytrain == 1, :]

Let’s build the MahalanobisOneClassClassifier and get the mahalanobis distance of each datapoint in x from the training set (xtrain_pos).

class MahalanobisOneclassClassifier():
    def __init__(self, xtrain, significance_level=0.01):
        self.xtrain = xtrain
        self.critical_value = chi2.ppf((1-significance_level), df=xtrain.shape[1]-1)
        print('Critical value is: ', self.critical_value)

    def predict_proba(self, xtest):
        mahalanobis_dist = mahalanobis(xtest, self.xtrain)
        self.pvalues = 1 - chi2.cdf(mahalanobis_dist, 2)
        return mahalanobis_dist

    def predict(self, xtest):
        return np.array([int(i) for i in self.predict_proba(xtest) > self.critical_value])

clf = MahalanobisOneclassClassifier(xtrain_pos, significance_level=0.05)
mahalanobis_dist = clf.predict_proba(xtest)

# Pred and Truth
mdist_actuals = pd.DataFrame([(m, act) for m, act in zip(mahalanobis_dist, ytest)], columns=['mahal', 'true_class'])
print(mdist_actuals[:5])            
Critical value is:  14.067140449340169
       mahal  true_class
0  13.104716           0
1  14.408570           1
2  14.932236           0
3  14.588622           0
4  15.471064           0

We have the Mahalanobis distance and the actual class of each observation.

I would expect those observations with low Mahalanobis distance to be 1’s.

So, I sort the mdist_actuals by Mahalanobis distance and quantile cut the rows into 10 equal sized groups. The observations in the top quantiles should have more 1’s compared to the ones in the bottom. Let’s see.

# quantile cut in 10 pieces
mdist_actuals['quantile'] = pd.qcut(mdist_actuals['mahal'], 
                                    q=[0, .10, .20, .3, .4, .5, .6, .7, .8, .9, 1], 
                                    labels=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10])

# sort by mahalanobis distance
mdist_actuals.sort_values('mahal', inplace=True)
perc_truths = mdist_actuals.groupby('quantile').agg({'mahal': np.mean, 'true_class': np.sum}).rename(columns={'mahal':'avg_mahaldist', 'true_class':'sum_of_trueclass'})
print(perc_truths)
          avg_mahaldist  sum_of_trueclass
quantile                                 
1              3.765496                33
2              6.511026                32
3              9.272944                30
4             12.209504                20
5             14.455050                 4
6             15.684493                 4
7             17.368633                 3
8             18.840714                 1
9             21.533159                 2
10            23.524055                 1

If you notice above, nearly 90% of the 1’s (malignant cases) fall within the first 40%ile of the Mahalanobis distance. Incidentally, all of these are lower than the critical value pf 14.05. So, let’s the critical value as the cutoff and mark those observations with Mahalanobis distance less than the cutoff as positive.

from sklearn.metrics import classification_report, accuracy_score, roc_auc_score, confusion_matrix

# Positive if mahalanobis 
pred_actuals = pd.DataFrame([(int(p), y) for y, p in zip(ytest, clf.predict_proba(xtest) < clf.critical_value)], columns=['pred', 'true'])

# Accuracy Metrics
truth = pred_actuals.loc[:, 'true']
pred = pred_actuals.loc[:, 'pred']
print('\nConfusion Matrix: \n', confusion_matrix(truth, pred))
print('\nAccuracy Score: ', accuracy_score(truth, pred))
print('\nClassification Report: \n', classification_report(truth, pred))
Confusion Matrix: 
 [[183  29]
 [ 15 115]]

Accuracy Score:  0.8713450292397661

Classification Report: 
              precision    recall  f1-score   support

          0       0.92      0.86      0.89       212
          1       0.80      0.88      0.84       130

avg / total       0.88      0.87      0.87       342

So, without the knowledge of the benign class, we are able to accurately predict the class of 87% of the observations.

Conclusion

In this post, we covered nearly everything about Mahalanobis distance: the intuition behind the formula, the actual calculation in python and how it can be used for multivariate anomaly detection, binary classification, and one-class classification. It is known to perform really well when you have a highly imbalanced dataset.

Hope it was useful? Please leave your comments below and I will see you in the next one.

datetime in Python – Simplified Guide with Clear Examples

datetime is the standard module for working with dates in python. It provides 4 main objects for date and time operations: datetime, date, time and timedelta. In this post you will learn how to do all sorts of operations with these objects and solve date-time related practice problems (easy to hard) in Python.

datetime in Python – Simplified Guide with Clear Examples. Photo by Sergio.

Content

[columnize]
  1. Introduction to datetime
  2. How to get the current date and the time in Python
  3. How to create the datetime object
  4. How to parse a string to datetime in python?
  5. How to format the datetime object into any date format?
  6. Useful datetime functions
  7. When and how to use the datetime.time() class?
  8. When and how to use the datetime.timedelta() class?
  9. Timezones
  10. 14 Practice Exercises with Solutions
[/columnize]

1. Introduction to datetime

The datetime is the main module for working with dates in python. Whenever you need to work with dates in python, datetime module provides the necessary tools.

datetime is part of python’s standard library, which means, you don’t need to install it separately.

You can simply import as is.

import datetime

If you have to learn only one thing about handling dates in datetime module, it is the datetime.datetime() class.

Inside datetime module, the most important and the commonly used object is the datetime class. Notice, I am talking about the datetime class inside the datetime module.

Since both the module and the class have the same name, pay attention to what object you are using.

Alright, besides the datetime.datetime class, there is also the:

  • date class
  • time class
  • timedelta class

Each these classes have its own purpose.

We’ll cover all of these in this post and about a more advanced parser (not in `datetime`) to help parse any date.

2. How to get the current date and the time in Python

The datetime.datetime.now() method gives the current datetime.

datetime.datetime.now()
#> datetime.datetime(2019, 2, 15, 18, 54, 58, 291224)

The output is a nice datetime.datetime object with the current date and time in local time zone. The output is in the following order: ‘year’, ‘month’, ‘date’, ‘hour’, ‘minute’, ‘seconds’, ‘microseconds’.

To get the date alone, use the datetime.date.today() instead.

datetime.date.today()
#> datetime.date(2019, 2, 15)

It returns a datetime.date object and not datetime.datetime. Why? That’s because, today() is a method of the datetime.date class and does not contain time information.

Good.

But the above notation hard to read. Printing it out will show in a nice YYYY-mm-dd format.

print(datetime.date.today())
#> 2019-02-15

We will see how to format datetime to many more formats shortly.

3. How to create the datetime object

We saw how to create the datetime object for current time. But how to create one for any given date and time? Say, for the following time: 2001-01-31::10:51:00

You can pass it in the same order to datetime.datetime(). (I will show an easier method in next section)

datetime.datetime(2001, 1, 31, 10, 51, 0)
#> datetime.datetime(2001, 1, 31, 10, 51)

You can also create a datetime from a unixtimestamp. A unixtimestamp is nothing but the number of seconds since the epoch date: ‘Jan 01, 1970’

mydatetime = datetime.datetime.fromtimestamp(528756281)
mydatetime
#> datetime.datetime(1986, 10, 4, 2, 14, 41)

You can convert the datetime back to a unixtimestamp as follows:

mydatetime.timestamp()
#> 528756281.0

4. How to parse a string to datetime in python?

The above method requires you to manually key in the year, month etc to create a datetime object. But, it not convenient when working with datasets or spreadsheet columns containing date strings.

We need way to automatically parse a given date string, in whatever format, to a datetime object.

Why is this needed?

Because, datasets containing dates are often imported as strings. Secondly, the date can be in any arbitrary date string format, like, ‘2010 Jan 31’ or ‘January 31, 2010′ or even ’31-01-2010’.

So, How to convert a date string to a datetime?

The parser module from dateutil let’s you parse pretty much any date string to a datetime object.

from dateutil.parser import parse
parse('January 31, 2010')
#> datetime.datetime(2010, 1, 31, 0, 0)

5. Example 1 – Parsing a date string to datetime

Parse the following date string to a datetime object: ’31, March 31, 2010, 10:51pm’

Solution:

from dateutil.parser import parse
parse('31, March 31, 2010, 10:51pm')

You can convert any datetime object to nearly any representation of date format using its strftime() method.

6. How to format the datetime object into any date format?

You can convert any datetime object to nearly any representation of date format using its strftime() method. You need to pass the right symbol representaion of the date format as an argument.

dt = datetime.datetime(2001, 1, 31, 10, 51, 0)

print(dt.strftime('%Y-%m-%d::%H-%M'))
#> 2001-01-31::10-51

7. Example 2 – Formatting a datetime object

Parse the following datetime object to the following representation: ’31 January, 2001, Wednesday’

# Input
dt = datetime.datetime(2001, 1, 31)
Show Solution

Solution:

dt.strftime('%d %B, %Y, %A')

8. Useful datetime functions

The datetime object contains a number of useful date-time related methods.

# create a datatime obj
dt = datetime.datetime(2019, 2, 15)

# 1. Get the current day of the month
dt.day #> 31

# 2. Get the current day of the week
dt.isoweekday() #> 5 --> Friday

# 3. Get the current month of the year 
dt.month  #> 2 --> February

# 4. Get the Year
dt.year  #> 2019

9. When and how to use the datetime.time() class?

The datetime.time() is used to represnt the time component alone, without the date. The defualt output format is: hours, minutes, seconds and microseconds.

# hours, minutues, seconds, microseconds
tm = datetime.time(10,40,10,102301)
tm
#> datetime.time(10, 40, 10, 102301)

10. When and how to use the datetime.timedelta() class?

‘TimeDeltas’ represent a period of time that a particular time instance. You can think of them simply as the difference between two dates or times.

It is normally used to add or remove a certain duration of time from datetime objects.

To create a datetime.timedelta class you need to pass a specified duration to the class constructor. The arguments can be in weeks,days (default), hours, minutes, seconds, microseconds.

td = datetime.timedelta(days=30)
td

Now I have a `timedelta` object that represents a duration of 30 days. Let’s compute the date will be 30 days from now.

print(datetime.date.today() + td)
#> 2019-03-17

Likewise, you can subtract timedeltas as well.

Another convenience with timedeltas is you can create arbitrary combination of time durations represented with days, weeks, hours etc. It will simplify that combination

td = datetime.timedelta(weeks=1, days=30, hours=2, minutes=40)
td 
#> datetime.timedelta(days=37, seconds=9600)

If you subtract two datetime objects you will get a timedelta object that represent the duration.

dt1 = datetime.datetime(2002, 1, 31, 10, 10, 0)
dt2 = datetime.datetime(2001, 1, 31, 10, 10, 0)
dt1 - dt2
#> datetime.timedelta(days=365)

Likewise, you can subtract two time deltas to get another timedelta object.

td1 = datetime.timedelta(days=30)  # 30 days
td2 = datetime.timedelta(weeks=1)  # 1 week
td1 - td2
#> datetime.timedelta(days=23)

11. How to work with timezones?

For time zones, python recommends pytz module which is not a standard built-in library. You need to install it separately (enter `pip install pytz` in terminal or command prompt)

So how to set time zone to a particular datetime?

Simply pass the respective pytz timezone object to tzinfo parameter when you create the datetime. Then, that datetime will become timezone aware. Let’s create a datetime object that belongs to UTC timezone.

import pytz
datetime.datetime(2001, 1, 31, 10, 10, 0, tzinfo=pytz.UTC)

UTC was a direct attribute of the pytz module. So, how to set to a different timezone?

Lookup pytz.all_timezones for your timezone of choice. Then use the pytz.timezone() to create the respective timezone object that will be passed to the tzinfo argument.

# See available time zones
pytz.all_timezones[:5]
#> ['Africa/Abidjan',
#>  'Africa/Accra',
#>  'Africa/Addis_Ababa',
#>  'Africa/Algiers',
#>  'Africa/Asmara']
# Set to particular timezone
dt_in = datetime.datetime(2001, 1, 31, 3, 30, 0, 0, tzinfo=pytz.timezone('Asia/Tokyo'))
dt_in
#> datetime.datetime(2001, 1, 31, 3, 30, tzinfo=<DstTzInfo 'Asia/Tokyo' LMT+9:19:00 STD>)

You can know that by converting to respective target timezone.

tgt_timezone = pytz.timezone('Africa/Addis_Ababa')
dt_in.astimezone(tgt_timezone)

12. Practice Examples

Rules for the challenges:

  1. No looking at the calendar
  2. Solve the problems with python code even if it is possible to compute it mentally

Exercise 1: How to parse date strings to datetime format?

Parse the following date strings to datetime format (easy)

# Input
s1 = "2010 Jan 1"
s2 = '31-1-2000' 
s3 = 'October10, 1996, 10:40pm'

# Deisred Output
#> 2010-01-01 00:00:00
#> 2000-01-31 00:00:00
#> 2019-10-10 22:40:00
Show Solution
# Input
s1 = "2010 Jan 1"
s2 = '31-1-2000' 
s3 = 'October10,1996, 10:40pm'

# Solution
from dateutil.parser import parse
print(parse(s1))
print(parse(s2))
print(parse(s3))
2010-01-01 00:00:00
2000-01-31 00:00:00
2019-10-10 22:40:00

Exercise 2: How many days has it been since you were born?

How many days has it been since you were born? (easy)

# Input
bday = 'Oct 2, 1869'  # use bday
Show Solution
# Input
bday = 'Oct 2, 1869'

import datetime
from dateutil.parser import parse

# Solution
td = datetime.datetime.now() - parse(bday)
td.days
54558

Exercise 3: How to count the number of saturdays between two dates?

Count the number of saturdays between two dates (medium)

# Input
import datetime
d1 = datetime.date(1869, 1, 2)
d2 = datetime.date(1869, 10, 2)

# Desired Output
#> 40
Show Solution
# Input
import datetime
d1 = datetime.date(1869, 1, 2)
d2 = datetime.date(1869, 10, 2)

# Solution
delta = d2 - d1  # timedelta

# Get all dates 
dates_btw_d1d2 = [(d1 + datetime.timedelta(i)) for i in range(delta.days + 1)]

n_saturdays = 0
for d in dates_btw_d1d2:
    n_saturdays += int(d.isoweekday() == 6)

print(n_saturdays)    
40

Exercise 4: How many days is it until your next birthday this year?

How many days is it until your next birthday this year? (easy)

# Input
bday = 'Oct 2, 1869'  # use b'day
Show Solution
# Input
bday = 'Oct 2, 1869'  # Enter birthday here

import datetime
from dateutil.parser import parse

# Solution
bdate = parse(bday)
current_bdate = datetime.date(year=datetime.date.today().year, month=bdate.month, day=bdate.day) 
td = current_bdate - datetime.date.today()
td.days
228

Exercise 5: How to count the number of days between successive days in an irregular sequence?

Count the number of days between successive days in the following list. (medium)

# Input
['Oct, 2, 1869', 'Oct, 10, 1869', 'Oct, 15, 1869', 'Oct, 20, 1869', 'Oct, 23, 1869']

# Desired Output
#> [8, 5, 5, 3]
Show Solution
# Input
datestrings = ['Oct, 2, 1869', 'Oct, 10, 1869', 'Oct, 15, 1869', 'Oct, 20, 1869', 'Oct, 23, 1869']

# Solution
import datetime
from dateutil.parser import parse
import numpy as np

dates = [parse(d) for d in datestrings]

print([d.days for d in np.diff(dates)])
[8, 5, 5, 3]

Exercise 6: How to convert number of days to seconds?

Convert the number of days till your next birthday to seconds (easy)

# Input
import datetime
bdate = datetime.date(1869, 10, 2)
td = datetime.date.today() - bdate
Show Solution
# Input
import datetime
bdate = datetime.date(1869, 10, 2)
td = datetime.date.today() - bdate

# Solution
td.total_seconds()
4713811200.0

Exercise 7: How to convert a given date to a datetime set at the beginning of the day?

Convert a given date to a datetime set at the beginning of the day (easy)

# Input
import datetime
date = datetime.date(1869, 10, 2)

# Desired Output
#> 1869-10-02 00:00:00
Show Solution
from datetime import date, datetime
d = date(1869, 10, 2)
print(datetime.combine(d, datetime.min.time()))
#> 1869-10-02 00:00:00
1869-10-02 00:00:00

Exercise 8: How to get the last day of the month for any given date in python?

Get the last day of the month for the below given date in python (easy)

# Input
import datetime
dt = datetime.date(1952, 2, 12)

# Desired Output
#> 29
Show Solution
# Input
import datetime
dt = datetime.date(1952, 2, 12)

# Solution
import calendar
calendar.monthrange(dt.year,dt.month)[1]
29

Exercise 9: How many Sundays does the month of February 1948 have?

Count the Sundays does the month of February 1948 have? (medium)

Show Solution
import datetime
from calendar import monthrange

d1 = datetime.date(1948, 2, 1)
n_days = monthrange(1948, 2)

# Get all dates 
dates_btw_d1d2 = [(d1 + datetime.timedelta(i)) for i in range(n_days[1])]

n_sundays = 0
for d in dates_btw_d1d2:
    n_sundays += int(d.isoweekday() == 6)

print(n_sundays)    #> 4
4

Exercise 10: How to format a given date to “mmm-dd, YYYY” fortmat?

Format a given date to “mmm-dd, YYYY” fortmat? (easy)

# input
import datetime
d1 = datetime.date('2010-09-28')

# Desired output
#> 'Sep-28, 2010'
Show Solution
# Input
import datetime
d1 = datetime.date(2010, 9, 28)

# Solution
d1.strftime('%b-%d, %Y')
'Sep-28, 2010'

Exercise 11: How to convert datetime to Year-Qtr format?

Convert the below datetime to Year-Qtr format? (easy)

# input
import datetime
d1 = datetime.datetime(2010, 9, 28, 10, 40, 59)

# Desired output
#> '2010-Q3'
Show Solution
# input
import datetime
d1 = datetime.datetime(2010, 9, 28, 10, 40, 59)

# Solution
f'{d1.year}-Q{d1.month//4 + 1}'
'2010-Q3'

Exercise 12: How to convert unix timestamp to a readable date?

Convert the below unix timestamp to a readable date (medium)

# Input
unixtimestamp = 528756281

# Desired Output
#> 04-October-1986
Show Solution
# Input
unixtimestamp = 528756281

# Solution
import datetime
dt = datetime.datetime.fromtimestamp(528756281)
dt.strftime('%d-%B-%Y')
'04-October-1986'

Exercise 13: How to get the time in a different timezone?

If it is ‘2001-01-31::3:30:0’ in ‘Asia/Tokyo’. What time is it in ‘Asia/Kolkata’? (medium)

import datetime
dt_in = datetime.datetime(2001, 1, 31, 3, 30, 0, 0, tzinfo=pytz.timezone('Asia/Tokyo'))

# Desired Solution
#> datetime.datetime(2001, 1, 30, 23, 41, tzinfo=<DstTzInfo 'Asia/Kolkata' IST+5:30:00 STD>)
Show Solution
import datetime
dt_in = datetime.datetime(2001, 1, 31, 3, 30, 0, 0, tzinfo=pytz.timezone('Asia/Tokyo'))

# Solution
india_tz = pytz.timezone('Asia/Kolkata')
dt_in.astimezone(india_tz)
datetime.datetime(2001, 1, 30, 23, 41, tzinfo=<DstTzInfo 'Asia/Kolkata' IST+5:30:00 STD>)

Exercise 14: How to fill up missing dates in a given irregular sequence of dates?

Fill up missing dates in a given irregular sequence of dates? (hard)

# Input
['Oct 2, 1869', 'Oct 5, 1869', 'Oct 7, 1869', 'Oct 9, 1869']

# Desired Output
#> ['Oct 02, 1869', 'Oct 03, 1869', 'Oct 04, 1869', 'Oct 05, 1869', 
#> 'Oct 06, 1869', 'Oct 07, 1869', 'Oct 08, 1869', 'Oct 09, 1869']
Show Solution
# Input
datestrings = ['Oct 2, 1869', 'Oct 5, 1869', 'Oct 7, 1869', 'Oct 9, 1869']

# Solution
import datetime
from dateutil.parser import parse
import numpy as np

dates = [parse(d) for d in datestrings]

d1 = np.min(dates)
d2 = np.max(dates)

delta = d2 - d1  # timedelta

# Get all dates 
dates_btw_d1d2 = [(d1 + datetime.timedelta(i)).strftime('%b %d, %Y') for i in range(delta.days + 1)]
print(dates_btw_d1d2)
['Oct 02, 1869', 'Oct 03, 1869', 'Oct 04, 1869', 'Oct 05, 1869', 'Oct 06, 1869', 'Oct 07, 1869', 'Oct 08, 1869', 'Oct 09, 1869']

10. Conclusion

How many were you able to solve? Congratulations if you were able to solve 7 or more.

We covered nearly everything you will need to work with dates in python. Let me know if I have missed anything. Or if you have better answers or have more questions, please write in the comments area below. See you in the next one!

Python Logging – Simplest Guide with Full Code and Examples

The logging module lets you track events when your code runs so that when the code crashes you can check the logs and identify what caused it. Log messages have a built-in hierarchy – starting from debugging, informational, warnings, error and critical messages. You can include traceback information as well. It is designed for small to large python projects with multiple modules and is highly recommended for any modular python programming. This post is a simple and clear explanation of how to use the logging module.

Logging in Python – Simplified Guide with Full Code and Examples. Photo by Andrea Reiman.

Content

[columnize]
  1. Why logging?
  2. A Basic logging Example
  3. The 5 levels of logging
  4. How to log to a file instead of the console
  5. How to change the logging format
  6. Why working with the root logger for all modules isn’t the best idea
  7. How to create a new logger?
  8. What is and How to setup a File Handler and Formatter?
  9. How to include traceback information in logged messages
  10. Exercises
  11. Conclusion
  12. [/columnize]

1. Why logging?

When you run a python script, you want to know what part of the script is getting executed and inspect what values the variables hold.

Usually, you may just ‘print()‘ out meaningful messages so you can see them in the console. And this probably all you need when you are developing small programs.

The problem is, when you use this approach on larger projects with multiple modules you want a more flexible approach.

Why?

Because, the code could go through different stages as in development, debugging, review, testing or in production.

The type of messages you want to print out during development can be very different from want you to see once it goes into production. Depending on the purpose, you want the code to print out different types of messages.

This can get cumbersome with if else and print statements. Besides, you want a certain hierarchy when it comes to printing messages.

What I mean by that is, during a certain ‘testing’ run, you want to see only warnings and error messages. Whereas during ‘debugging’, you not only want to see the warnings and error messages but also the debugging-related messages. Imagine doing this with ‘if else‘ statements on a multi-module project.

If you want to print out which module and at what time the codes were run, your code could easily get messier.

There is good news. All these issues are nicely addressed by the logging module.

Using logging, you can:

  1. Control message level to log only required ones
  2. Control where to show or save the logs
  3. Control how to format the logs with built-in message templates
  4. Know which module the messages is coming from

You might say ‘I see that logging can be useful but it seems too technical and seems to be a bit difficult to grasp‘. Well, yes, logging requires a bit of learning curve but that’s what this post is here for: make logging easy to learn.

Without further delay, let’s get right into it.

2. A Basic logging Example

Python provides an in-built logging module which is part of the python standard library. So you don’t need to install anything.

To use logging, all you need to do is setup the basic configuration using logging.basicConfig(). Actually, this is also optional. We will see about that soon.

Then, instead of print(), you call logging.{level}(message) to show the message in console.

import logging
logging.basicConfig(level=logging.INFO)

def hypotenuse(a, b):
    """Compute the hypotenuse"""
    return (a**2 + b**2)**0.5

logging.info("Hypotenuse of {a}, {b} is {c}".format(a=3, b=4, c=hypotenuse(a,b)))
#> INFO:root:Hypotenuse of 3, 4 is 5.0

The printed log message has the following default format: {LEVEL}:{LOGGER}:{MESSAGE}.

In the above case, the level is info, because, I called logging.info().

The logger is called root, because that is the default logger and I did not create a new one, yet.

But what is a logger anyway?

A logger is like an entity you can create and configure to log different type and format of messages.

You can configure a logger that prints to the console and another logger that sends the logs to a file, has a different logging level and is specific to a given module. More explanations and examples coming up on this.

Finally, the message is the string I passed to logging.info().

Now, what would have happened had you not setup logging.basicConfig(level=logging.INFO)?

Answer: The log would not have been printed.

Why?

To know that let’s understand the levels of logging.

3. The 5 levels of logging

logging has 5 different hierarchical levels of logs that a given logger may be configured to.

Let’s see what the python docs has to say about each level:

  1. DEBUG: Detailed information, for diagnosing problems. Value=10.
  2. INFO: Confirm things are working as expected. Value=20.
  3. WARNING: Something unexpected happened, or indicative of some problem. But the software is still working as expected. Value=30.
  4. ERROR: More serious problem, the software is not able to perform some function. Value=40
  5. CRITICAL: A serious error, the program itself may be unable to continue running. Value=50

Now, coming back to the previous question of what would have happened had you not setup logging.basicConfig(level=logging.INFO) in the previous example.

The answer is: the log would not have been printed because, the default logger is the ‘root’ and its default basicConfig level is ‘WARNING’. That means, only messages from logging.warning() and higher levels will get logged.

So, the message of logging.info() would not be printed. And that is why the basic config was set as INFO initially (in logging.basicConfig(level=logging.INFO)).

Had I set the level as logging.ERROR instead, only message from logging.error and logging.critical will be logged. Clear?

import logging
logging.basicConfig(level=logging.WARNING)

def hypotenuse(a, b):
    """Compute the hypotenuse"""
    return (a**2 + b**2)**0.5

kwargs = {'a':3, 'b':4, 'c':hypotenuse(3, 4)}

logging.debug("a = {a}, b = {b}".format(**kwargs))
logging.info("Hypotenuse of {a}, {b} is {c}".format(**kwargs))
logging.warning("a={a} and b={b} are equal".format(**kwargs))
logging.error("a={a} and b={b} cannot be negative".format(**kwargs))
logging.critical("Hypotenuse of {a}, {b} is {c}".format(**kwargs))

#> WARNING:root:a=3 and b=3 are equal
#> ERROR:root:a=-1 and b=4 cannot be negative
#> CRITICAL:root:Hypotenuse of a, b is 5.0

4. How to log to a file instead of the console

To send the log messages to a file from the root logger, you need to set the file argument in logging.basicConfig()

import logging
logging.basicConfig(level=logging.INFO, file='sample.log')

Now all subsequent log messages will go straight to the file ‘sample.log’ in your current working directory. If you want to send it to a file in a different directory, give the full file path.

5. How to change the logging format

The logging module provides shorthands to add various details to the logged messages. The below image from Python docs shows that list.

Logging Formats
Logging Formats

Let’s change the log message format to show the TIME, LEVEL and the MESSAGE. To do that just add the format to logging.basiconfig()‘s format argument.

import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s :: %(levelname)s :: %(message)s')
logging.info("Just like that!")
#> 2019-02-17 11:40:38,254 :: INFO :: Just like that!

6. Why working with the root logger for all modules isn’t the best idea

Because they all will share the same ‘root’ logger.

But why is that bad?

Let’s look at the below code:

# 1. code inside myprojectmodule.py
import logging
logging.basicConfig(file='module.log')

#-----------------------------

# 2. code inside main.py (imports the code from myprojectmodule.py)
import logging
import myprojectmodule  # This runs the code in myprojectmodule.py

logging.basicConfig(file='main.log')  # No effect, because!

Imagine you have one or more modules in your project. And these modules use the basic root module. Then, when importing the module (‘myprojectmodule.py‘), all of that module’s code will run and logger gets configured.

Once configured, the root logger in the main file (that imported the ‘myprojectmodule‘ module) will no longer be able to change the root logger settings. Because, the logging.basicConfig() once set cannot be changed.

That means, if you want to log the messages from myprojectmodule to one file and the logs from the main module in another file, root logger can’t that.

To do that you need to create a new logger.

7. How to create a new logger?

You can create a new logger using the ‘logger.getLogger(name)‘ method. If a logger with the same name exists, then that logger will be used.

While you can give pretty much any name to the logger, the convention is to use the __name__ variable like this:

logger = logging.getLogger(__name__)
logger.info('my logging message')

But, why use __name__ as the name of the logger, instead of hardcoding a name?

Because the __name__ variable will hold the name of the module (python file) that called the code. So, when used inside a module, it will create a logger bearing the value provided by the module’s __name__ attribute.

By doing this, if you end up changing module name (file name) in future, you don’t have to modify the internal code.

Now, once you’ve created a new logger, you should remember to log all your messages using the new logger.info() instead of the root’s logging.info() method.

Another aspect to note is, all the loggers have a built-in hierarchy to it.

What do I mean by that?

For example, if you have configured the root logger to log messages to a particular file. You also have a custom logger for which you have not configured the file handler to send messages to console or another log file.

In this case, the custom logger will fallback and write to the file set by the root logger itself. Until and unless you configure the log file of your custom logger.

So what is a file handler and how to set one up?

8. What is and How to set up a File Handler and Formatter?

The FileHandler() and Formatter() classes are used to setup the output file and the format of messages for loggers other than the root logger.

Do you remember how we setup the filename and the format of the message in the root logger (inside logging.basicConfig()) earlier?

We just specified the filename and format parameters in logging.basicConfig() and all subsequent logs went to that file.

However, when you create a separate logger, you need to set them up individually using the logging.FileHandler() and logging.Formatter() objects.

A FileHandler is used to make your custom logger to log in to a different file. Likewise, a Formatter is used to change the format of your logged messages.

import logging

# Gets or creates a logger
logger = logging.getLogger(__name__)  

# set log level
logger.setLevel(logging.WARNING)

# define file handler and set formatter
file_handler = logging.FileHandler('logfile.log')
formatter    = logging.Formatter('%(asctime)s : %(levelname)s : %(name)s : %(message)s')
file_handler.setFormatter(formatter)

# add file handler to logger
logger.addHandler(file_handler)

# Logs
logger.debug('A debug message')
logger.info('An info message')
logger.warning('Something is not right.')
logger.error('A Major error has happened.')
logger.critical('Fatal error. Cannot continue')

Notice how we set the formatter on the ‘file_handler‘ and not the ‘logger‘ directly.

Assuming the above code is run from the main program, if you look inside the working directory, a file named logfile.log will be created if it doesn’t exist and would contain the below messages.

#> 2019-02-17 12:40:14,797 : WARNING : __main__ : Something is not right.
#> 2019-02-17 12:40:14,798 : ERROR : __main__ : A Major error has happened.
#> 2019-02-17 12:40:14,798 : CRITICAL : __main__ : Fatal error. Cannot continue

Note again, the Formatter is set on the FileHandler object and not directly on the logger. Something you may want to get used to.

9. How to include traceback information in logged messages

Besides ‘debug‘, ‘info‘, ‘warning‘, ‘error‘, and ‘critical‘ messages, you can log exceptions that will include any associated traceback information.

With logger.exception, you can log traceback information should the code encounter any error. logger.exception will log the message provided in its arguments as well as the error message traceback info.

Below is a nice example.

import logging

# Create or get the logger
logger = logging.getLogger(__name__)  

# set log level
logger.setLevel(logging.INFO)

def divide(x, y):
    try:
        out = x / y
    except ZeroDivisionError:
        logger.exception("Division by zero problem")
    else:
        return out

# Logs
logger.error("Divide {x} / {y} = {c}".format(x=10, y=0, c=divide(10,0)))

#> ERROR:__main__:Division by zero problem
#> Traceback (most recent call last):
#>   File "<ipython-input-16-a010a44fdc0a>", line 12, in divide
#>     out = x / y
#> ZeroDivisionError: division by zero
#> ERROR:__main__:None

10. Exercises

  1. Create a new project directory and a new python file named ‘example.py‘. Import the logging module and configure the root logger to the level of ‘debug’ messages. Log an ‘info’ message with the text: “This is root logger’s logging message!”.

  2. Configure the root logger to format the message “This is root logger’s logging message!” as the following:

#> 2019-03-03 17:18:32,703 :: INFO :: Module <stdin> :: Line No 1 :: This is root logger's logging message!
Show Solution
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s :: %(levelname)s :: Module %(module)s :: Line No %(lineno)s :: %(message)s')
logging.info("This is root logger's logging mesage!")
  1. Create another python file in the same directory called ‘mymod.py‘ and create a new logger bearing the module’s name. Configure it to the level of ‘error’ messages and make it send the log outputs to a file called “mymod_{current_date}.log”.

  2. From ‘mymod’ logger created above, log the following ‘critical’ message to the said log file: “This is a critical message!. Don’t ignore it”.

11. Conclusion

Many congratulations if you were able to solve the exercises!

That was quite useful and straightforward wasn’t it?

logging is a great tool but is not popular is Data science workflows as it should be. I hope the logging concepts are clear and the next time you work on a python based project, my kind request for you is to remember to give the logging module a shot.

Happy logging!

Matplotlib Histogram – How to Visualize Distributions in Python

Matplotlib histogram is used to visualize the frequency distribution of numeric array by splitting it to small equal-sized bins. In this article, we explore practical techniques that are extremely useful in your initial data analysis and plotting.

Content

[columnize]
  1. What is a histogram?
  2. How to plot a basic histogram in python?
  3. Histogram grouped by categories in same plot
  4. Histogram grouped by categories in separate subplots
  5. Seaborn Histogram and Density Curve on the same plot
  6. Histogram and Density Curve in Facets
  7. Difference between a Histogram and a Bar Chart
  8. Practice Exercise
  9. Conclusion
  10. [/columnize]

1. What is a Histogram?

A histogram is a plot of the frequency distribution of numeric array by splitting it to small equal-sized bins.

If you want to mathemetically split a given array to bins and frequencies, use the numpy histogram() method and pretty print it like below.

import numpy as np
x = np.random.randint(low=0, high=100, size=100)

# Compute frequency and bins
frequency, bins = np.histogram(x, bins=10, range=[0, 100])

# Pretty Print
for b, f in zip(bins[1:], frequency):
    print(round(b, 1), ' '.join(np.repeat('*', f)))

The output of above code looks like this:

10.0 * * * * * * * * *
20.0 * * * * * * * * * * * * *
30.0 * * * * * * * * *
40.0 * * * * * * * * * * * * * * *
50.0 * * * * * * * * *
60.0 * * * * * * * * *
70.0 * * * * * * * * * * * * * * * *
80.0 * * * * *
90.0 * * * * * * * * *
100.0 * * * * * *

The above representation, however, won’t be practical on large arrays, in which case, you can use matplotlib histogram.

2. How to plot a basic histogram in python?

The pyplot.hist() in matplotlib lets you draw the histogram. It required the array as the required input and you can specify the number of bins needed.

import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams.update({'figure.figsize':(7,5), 'figure.dpi':100})

# Plot Histogram on x
x = np.random.normal(size = 1000)
plt.hist(x, bins=50)
plt.gca().set(title='Frequency Histogram', ylabel='Frequency');
Matplotlib Histogram
Histogram

3. Histogram grouped by categories in same plot

You can plot multiple histograms in the same plot. This can be useful if you want to compare the distribution of a continuous variable grouped by different categories.

Let’s use the diamonds dataset from R’s ggplot2 package.

import pandas as pd
df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/diamonds.csv')
df.head()
Diamonds Table
Diamonds Table

Let’s compare the distribution of diamond depth for 3 different values of diamond cut in the same plot.

x1 = df.loc[df.cut=='Ideal', 'depth']
x2 = df.loc[df.cut=='Fair', 'depth']
x3 = df.loc[df.cut=='Good', 'depth']

kwargs = dict(alpha=0.5, bins=100)

plt.hist(x1, **kwargs, color='g', label='Ideal')
plt.hist(x2, **kwargs, color='b', label='Fair')
plt.hist(x3, **kwargs, color='r', label='Good')
plt.gca().set(title='Frequency Histogram of Diamond Depths', ylabel='Frequency')
plt.xlim(50,75)
plt.legend();
Matplotlib Multi Histogram
Multi Histogram

Well, the distributions for the 3 differenct cuts are distinctively different. But since, the number of datapoints are more for Ideal cut, the it is more dominant.

So, how to rectify the dominant class and still maintain the separateness of the distributions?

You can normalize it by setting density=True and stacked=True. By doing this the total area under each distribution becomes 1.

# Normalize
kwargs = dict(alpha=0.5, bins=100, density=True, stacked=True)

# Plot
plt.hist(x1, **kwargs, color='g', label='Ideal')
plt.hist(x2, **kwargs, color='b', label='Fair')
plt.hist(x3, **kwargs, color='r', label='Good')
plt.gca().set(title='Probability Histogram of Diamond Depths', ylabel='Probability')
plt.xlim(50,75)
plt.legend();
Multi Histogram 2
Multi Histogram 2

4. Histogram grouped by categories in separate subplots

The histograms can be created as facets using the plt.subplots()

Below I draw one histogram of diamond depth for each category of diamond cut. It’s convenient to do it in a for-loop.

# Import Data
df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/diamonds.csv')

# Plot
fig, axes = plt.subplots(1, 5, figsize=(10,2.5), dpi=100, sharex=True, sharey=True)
colors = ['tab:red', 'tab:blue', 'tab:green', 'tab:pink', 'tab:olive']

for i, (ax, cut) in enumerate(zip(axes.flatten(), df.cut.unique())):
    x = df.loc[df.cut==cut, 'depth']
    ax.hist(x, alpha=0.5, bins=100, density=True, stacked=True, label=str(cut), color=colors[i])
    ax.set_title(cut)

plt.suptitle('Probability Histogram of Diamond Depths', y=1.05, size=16)
ax.set_xlim(50, 70); ax.set_ylim(0, 1);
plt.tight_layout();
Histograms Facets
Histograms Facets

5. Seaborn Histogram and Density Curve on the same plot

If you wish to have both the histogram and densities in the same plot, the seaborn package (imported as sns) allows you to do that via the distplot(). Since seaborn is built on top of matplotlib, you can use the sns and plt one after the other.

import seaborn as sns
sns.set_style("white")

# Import data
df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/diamonds.csv')
x1 = df.loc[df.cut=='Ideal', 'depth']
x2 = df.loc[df.cut=='Fair', 'depth']
x3 = df.loc[df.cut=='Good', 'depth']

# Plot
kwargs = dict(hist_kws={'alpha':.6}, kde_kws={'linewidth':2})

plt.figure(figsize=(10,7), dpi= 80)
sns.distplot(x1, color="dodgerblue", label="Compact", **kwargs)
sns.distplot(x2, color="orange", label="SUV", **kwargs)
sns.distplot(x3, color="deeppink", label="minivan", **kwargs)
plt.xlim(50,75)
plt.legend();
Histograms Density
Histograms Density

6. Histogram and Density Curve in Facets

The below example shows how to draw the histogram and densities (distplot) in facets.

# Import data
df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/diamonds.csv')
x1 = df.loc[df.cut=='Ideal', ['depth']]
x2 = df.loc[df.cut=='Fair', ['depth']]
x3 = df.loc[df.cut=='Good', ['depth']]

# plot
fig, axes = plt.subplots(1, 3, figsize=(10, 3), sharey=True, dpi=100)
sns.distplot(x1 , color="dodgerblue", ax=axes[0], axlabel='Ideal')
sns.distplot(x2 , color="deeppink", ax=axes[1], axlabel='Fair')
sns.distplot(x3 , color="gold", ax=axes[2], axlabel='Good')
plt.xlim(50,75);
Histogram Density Facets
Histogram Density Facets

7. Difference between a Histogram and a Bar Chart

A histogram is drawn on large arrays. It computes the frequency distribution on an array and makes a histogram out of it.

On the other hand, a bar chart is used when you have both X and Y given and there are limited number of data points that can be shown as bars.

# Groupby: cutwise median
price = df[['cut', 'price']].groupby('cut').median().round(2)
price
Diamonds Cut
Diamonds_Cut
fig, axes = plt.subplots(figsize=(7,5), dpi=100)
plt.bar(price.index, height=price.price)
plt.title('Barplot of Median Diamond Price');
Barplots
Barplots

8. Practice Exercise

Create the following density on the sepal_length of iris dataset on your Jupyter Notebook.

import seaborn as sns
df = sns.load_dataset('iris')
Show Solution
# Solution
import seaborn as sns
df = sns.load_dataset('iris')

plt.subplots(figsize=(7,6), dpi=100)
sns.distplot( df.loc[df.species=='setosa', "sepal_length"] , color="dodgerblue", label="Setosa")
sns.distplot( df.loc[df.species=='virginica', "sepal_length"] , color="orange", label="virginica")
sns.distplot( df.loc[df.species=='versicolor', "sepal_length"] , color="deeppink", label="versicolor")

plt.title('Iris Histogram')
plt.legend();
Iris Histograms
Iris Histograms

9. What next

Congratulations if you were able to reproduce the plot.

You might be interested in the matplotlib tutorial, top 50 matplotlib plots, and other plotting tutorials.

ARIMA Model – Complete Guide to Time Series Forecasting in Python

Using ARIMA model, you can forecast a time series using the series past values. In this post, we build an optimal ARIMA model from scratch and extend it to Seasonal ARIMA (SARIMA) and SARIMAX models. You will also see how to build autoarima models in python

ARIMA Model – Time Series Forecasting. Photo by Cerquiera

Contents

[columnize]
  1. Introduction to Time Series Forecasting
  2. Introduction to ARIMA Models
  3. What does the p, d and q in ARIMA model mean?
  4. What are AR and MA models
  5. How to find the order of differencing (d) in ARIMA model
  6. How to find the order of the AR term (p)
  7. How to find the order of the MA term (q)
  8. How to handle if a time series is slightly under or over differenced
  9. How to build the ARIMA Model
  10. How to do find the optimal ARIMA model manually using Out-of-Time Cross validation
  11. Accuracy Metrics for Time Series Forecast
  12. How to do Auto Arima Forecast in Python
  13. How to interpret the residual plots in ARIMA model
  14. How to automatically build SARIMA model in python
  15. How to build SARIMAX Model with exogenous variable
  16. Practice Exercises
  17. Conclusion
[/columnize]

1. Introduction to Time Series Forecasting

A time series is a sequence where a metric is recorded over regular time intervals.

Depending on the frequency, a time series can be of yearly (ex: annual budget), quarterly (ex: expenses), monthly (ex: air traffic), weekly (ex: sales qty), daily (ex: weather), hourly (ex: stocks price), minutes (ex: inbound calls in a call canter) and even seconds wise (ex: web traffic).

We have already seen the steps involved in a previous post on Time Series Analysis. If you haven’t read it, I highly encourage you to do so.

Forecasting is the next step where you want to predict the future values the series is going to take.

But why forecast?

Because, forecasting a time series (like demand and sales) is often of tremendous commercial value.

In most manufacturing companies, it drives the fundamental business planning, procurement and production activities. Any errors in the forecasts will ripple down throughout the supply chain or any business context for that matter. So it’s important to get the forecasts accurate in order to save on costs and is critical to success.

Not just in manufacturing, the techniques and concepts behind time series forecasting are applicable in any business.

Now forecasting a time series can be broadly divided into two types.

If you use only the previous values of the time series to predict its future values, it is called Univariate Time Series Forecasting.

And if you use predictors other than the series (a.k.a exogenous variables) to forecast it is called Multi Variate Time Series Forecasting.

This post focuses on a particular type of forecasting method called ARIMA modeling.

ARIMA, short for ‘AutoRegressive Integrated Moving Average’, is a forecasting algorithm based on the idea that the information in the past values of the time series can alone be used to predict the future values.

2. Introduction to ARIMA Models

So what exactly is an ARIMA model?

ARIMA, short for ‘Auto Regressive Integrated Moving Average’ is actually a class of models that ‘explains’ a given time series based on its own past values, that is, its own lags and the lagged forecast errors, so that equation can be used to forecast future values.

Any ‘non-seasonal’ time series that exhibits patterns and is not a random white noise can be modeled with ARIMA models.

An ARIMA model is characterized by 3 terms: p, d, q

where,

p is the order of the AR term

q is the order of the MA term

d is the number of differencing required to make the time series stationary

If a time series, has seasonal patterns, then you need to add seasonal terms and it becomes SARIMA, short for ‘Seasonal ARIMA’. More on that once we finish ARIMA.

So, what does the ‘order of AR term’ even mean? Before we go there, let’s first look at the ‘d’ term.

3. What does the p, d and q in ARIMA model mean

The first step to build an ARIMA model is to make the time series stationary.

Why?

Because, term ‘Auto Regressive’ in ARIMA means it is a linear regression model that uses its own lags as predictors. Linear regression models, as you know, work best when the predictors are not correlated and are independent of each other.

So how to make a series stationary?

The most common approach is to difference it. That is, subtract the previous value from the current value. Sometimes, depending on the complexity of the series, more than one differencing may be needed.

The value of d, therefore, is the minimum number of differencing needed to make the series stationary. And if the time series is already stationary, then d = 0.

Next, what are the ‘p’ and ‘q’ terms?

‘p’ is the order of the ‘Auto Regressive’ (AR) term. It refers to the number of lags of Y to be used as predictors. And ‘q’ is the order of the ‘Moving Average’ (MA) term. It refers to the number of lagged forecast errors that should go into the ARIMA Model.

4. What are AR and MA models

So what are AR and MA models? what is the actual mathematical formula for the AR and MA models?

A pure Auto Regressive (AR only) model is one where Yt depends only on its own lags. That is, Yt is a function of the ‘lags of Yt’.

where, $Y{t-1}$ is the lag1 of the series, $\beta1$ is the coefficient of lag1 that the model estimates and $\alpha$ is the intercept term, also estimated by the model.

Likewise a pure Moving Average (MA only) model is one where Yt depends only on the lagged forecast errors.

where the error terms are the errors of the autoregressive models of the respective lags. The errors Et and E(t-1) are the errors from the following equations :

That was AR and MA models respectively.

So what does the equation of an ARIMA model look like?

An ARIMA model is one where the time series was differenced at least once to make it stationary and you combine the AR and the MA terms. So the equation becomes:

ARIMA model in words:

Predicted Yt = Constant + Linear combination Lags of Y (upto p lags) + Linear Combination of Lagged forecast errors (upto q lags)

The objective, therefore, is to identify the values of p, d and q. But how?

Let’s start with finding the ‘d’.

5. How to find the order of differencing (d) in ARIMA model

The purpose of differencing it to make the time series stationary.

But you need to be careful to not over-difference the series. Because, an over differenced series may still be stationary, which in turn will affect the model parameters.

So how to determine the right order of differencing?

The right order of differencing is the minimum differencing required to get a near-stationary series which roams around a defined mean and the ACF plot reaches to zero fairly quick.

If the autocorrelations are positive for many number of lags (10 or more), then the series needs further differencing. On the other hand, if the lag 1 autocorrelation itself is too negative, then the series is probably over-differenced.

In the event, you can’t really decide between two orders of differencing, then go with the order that gives the least standard deviation in the differenced series.

Let’s see how to do it with an example.

First, I am going to check if the series is stationary using the Augmented Dickey Fuller test (adfuller()), from the statsmodels package.

Why?

Because, you need differencing only if the series is non-stationary. Else, no differencing is needed, that is, d=0.

The null hypothesis of the ADF test is that the time series is non-stationary. So, if the p-value of the test is less than the significance level (0.05) then you reject the null hypothesis and infer that the time series is indeed stationary.

So, in our case, if P Value > 0.05 we go ahead with finding the order of differencing.

from statsmodels.tsa.stattools import adfuller
from numpy import log
result = adfuller(df.value.dropna())
print('ADF Statistic: %f' % result[0])
print('p-value: %f' % result[1])
ADF Statistic: -2.464240
p-value: 0.124419

Since P-value is greater than the significance level, let’s difference the series and see how the autocorrelation plot looks like.

import numpy as np, pandas as pd
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf
import matplotlib.pyplot as plt
plt.rcParams.update({'figure.figsize':(9,7), 'figure.dpi':120})

# Import data
df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/wwwusage.csv', names=['value'], header=0)

# Original Series
fig, axes = plt.subplots(3, 2, sharex=True)
axes[0, 0].plot(df.value); axes[0, 0].set_title('Original Series')
plot_acf(df.value, ax=axes[0, 1])

# 1st Differencing
axes[1, 0].plot(df.value.diff()); axes[1, 0].set_title('1st Order Differencing')
plot_acf(df.value.diff().dropna(), ax=axes[1, 1])

# 2nd Differencing
axes[2, 0].plot(df.value.diff().diff()); axes[2, 0].set_title('2nd Order Differencing')
plot_acf(df.value.diff().diff().dropna(), ax=axes[2, 1])

plt.show()
Order of Differencing
Order of Differencing

For the above series, the time series reaches stationarity with two orders of differencing. But on looking at the autocorrelation plot for the 2nd differencing the lag goes into the far negative zone fairly quick, which indicates, the series might have been over differenced.

So, I am going to tentatively fix the order of differencing as 1 even though the series is not perfectly stationary (weak stationarity).

from pmdarima.arima.utils import ndiffs
df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/wwwusage.csv', names=['value'], header=0)
y = df.value

## Adf Test
ndiffs(y, test='adf')  # 2

# KPSS test
ndiffs(y, test='kpss')  # 0

# PP test:
ndiffs(y, test='pp')  # 2
2 0 2

6. How to find the order of the AR term (p)

The next step is to identify if the model needs any AR terms. You can find out the required number of AR terms by inspecting the Partial Autocorrelation (PACF) plot.

But what is PACF?

Partial autocorrelation can be imagined as the correlation between the series and its lag, after excluding the contributions from the intermediate lags. So, PACF sort of conveys the pure correlation between a lag and the series. That way, you will know if that lag is needed in the AR term or not.

So what is the formula for PACF mathematically?

Partial autocorrelation of lag (k) of a series is the coefficient of that lag in the autoregression equation of Y.

$$Yt = \alpha0 + \alpha1 Y{t-1} + \alpha2 Y{t-2} + \alpha3 Y{t-3}$$

That is, suppose, if Y_t is the current series and Y_t-1 is the lag 1 of Y, then the partial autocorrelation of lag 3 (Y_t-3) is the coefficient $\alpha_3$ of Y_t-3 in the above equation.

Good. Now, how to find the number of AR terms?

Any autocorrelation in a stationarized series can be rectified by adding enough AR terms. So, we initially take the order of AR term to be equal to as many lags that crosses the significance limit in the PACF plot.

# PACF plot of 1st differenced series
plt.rcParams.update({'figure.figsize':(9,3), 'figure.dpi':120})

fig, axes = plt.subplots(1, 2, sharex=True)
axes[0].plot(df.value.diff()); axes[0].set_title('1st Differencing')
axes[1].set(ylim=(0,5))
plot_pacf(df.value.diff().dropna(), ax=axes[1])

plt.show()
/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/statsmodels/regression/linear_model.py:1283: RuntimeWarning: invalid value encountered in sqrt
  return rho, np.sqrt(sigmasq)
Order of AR Term
Order of AR Term

You can observe that the PACF lag 1 is quite significant since is well above the significance line. Lag 2 turns out to be significant as well, slightly managing to cross the significance limit (blue region). But I am going to be conservative and tentatively fix the p as 1.

7. How to find the order of the MA term (q)

Just like how we looked at the PACF plot for the number of AR terms, you can look at the ACF plot for the number of MA terms. An MA term is technically, the error of the lagged forecast.

The ACF tells how many MA terms are required to remove any autocorrelation in the stationarized series.

Let’s see the autocorrelation plot of the differenced series.

import pandas as pd
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf
import matplotlib.pyplot as plt
plt.rcParams.update({'figure.figsize':(9,3), 'figure.dpi':120})

# Import data
df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/austa.csv')

fig, axes = plt.subplots(1, 2, sharex=True)
axes[0].plot(df.value.diff()); axes[0].set_title('1st Differencing')
axes[1].set(ylim=(0,1.2))
plot_acf(df.value.diff().dropna(), ax=axes[1])

plt.show()
Order of MA Term
Order of MA Term

Couple of lags are well above the significance line. So, let’s tentatively fix q as 2. When in doubt, go with the simpler model that sufficiently explains the Y.

8. How to handle if a time series is slightly under or over differenced

It may so happen that your series is slightly under differenced, that differencing it one more time makes it slightly over-differenced.

How to handle this case?

If your series is slightly under differenced, adding one or more additional AR terms usually makes it up. Likewise, if it is slightly over-differenced, try adding an additional MA term.

9. How to build the ARIMA Model

Now that you’ve determined the values of p, d and q, you have everything needed to fit the ARIMA model. Let’s use the ARIMA() implementation in statsmodels package.

from statsmodels.tsa.arima_model import ARIMA

# 1,1,2 ARIMA Model
model = ARIMA(df.value, order=(1,1,2))
model_fit = model.fit(disp=0)
print(model_fit.summary())
                             ARIMA Model Results                              
==============================================================================
Dep. Variable:                D.value   No. Observations:                   99
Model:                 ARIMA(1, 1, 2)   Log Likelihood                -253.790
Method:                       css-mle   S.D. of innovations              3.119
Date:                Wed, 06 Feb 2019   AIC                            517.579
Time:                        23:32:56   BIC                            530.555
Sample:                             1   HQIC                           522.829

=================================================================================
                    coef    std err          z      P>|z|      [0.025      0.975]
---------------------------------------------------------------------------------
const             1.1202      1.290      0.868      0.387      -1.409       3.649
ar.L1.D.value     0.6351      0.257      2.469      0.015       0.131       1.139
ma.L1.D.value     0.5287      0.355      1.489      0.140      -0.167       1.224
ma.L2.D.value    -0.0010      0.321     -0.003      0.998      -0.631       0.629
                                    Roots                                    
=============================================================================
                  Real          Imaginary           Modulus         Frequency
-----------------------------------------------------------------------------
AR.1            1.5746           +0.0000j            1.5746            0.0000
MA.1           -1.8850           +0.0000j            1.8850            0.5000
MA.2          545.3515           +0.0000j          545.3515            0.0000
-----------------------------------------------------------------------------

The model summary reveals a lot of information. The table in the middle is the coefficients table where the values under ‘coef’ are the weights of the respective terms.

Notice here the coefficient of the MA2 term is close to zero and the P-Value in ‘P>|z|’ column is highly insignificant. It should ideally be less than 0.05 for the respective X to be significant.

So, let’s rebuild the model without the MA2 term.

# 1,1,1 ARIMA Model
model = ARIMA(df.value, order=(1,1,1))
model_fit = model.fit(disp=0)
print(model_fit.summary())
                             ARIMA Model Results                              
==============================================================================
Dep. Variable:                D.value   No. Observations:                   99
Model:                 ARIMA(1, 1, 1)   Log Likelihood                -253.790
Method:                       css-mle   S.D. of innovations              3.119
Date:                Sat, 09 Feb 2019   AIC                            515.579
Time:                        12:16:06   BIC                            525.960
Sample:                             1   HQIC                           519.779

=================================================================================
                    coef    std err          z      P>|z|      [0.025      0.975]
---------------------------------------------------------------------------------
const             1.1205      1.286      0.871      0.386      -1.400       3.641
ar.L1.D.value     0.6344      0.087      7.317      0.000       0.464       0.804
ma.L1.D.value     0.5297      0.089      5.932      0.000       0.355       0.705
                                    Roots                                    
=============================================================================
                  Real          Imaginary           Modulus         Frequency
-----------------------------------------------------------------------------
AR.1            1.5764           +0.0000j            1.5764            0.0000
MA.1           -1.8879           +0.0000j            1.8879            0.5000
-----------------------------------------------------------------------------

The model AIC has reduced, which is good. The P Values of the AR1 and MA1 terms have improved and are highly significant (<< 0.05).

Let’s plot the residuals to ensure there are no patterns (that is, look for constant mean and variance).

# Plot residual errors
residuals = pd.DataFrame(model_fit.resid)
fig, ax = plt.subplots(1,2)
residuals.plot(title="Residuals", ax=ax[0])
residuals.plot(kind='kde', title='Density', ax=ax[1])
plt.show()
Residuals Density
Residuals Density

The residual errors seem fine with near zero mean and uniform variance. Let’s plot the actuals against the fitted values using plot_predict().

# Actual vs Fitted
model_fit.plot_predict(dynamic=False)
plt.show()
Actual vs Fitted
Actual vs Fitted

When you set dynamic=False the in-sample lagged values are used for prediction.

That is, the model gets trained up until the previous value to make the next prediction. This can make the fitted forecast and actuals look artificially good.

So, we seem to have a decent ARIMA model. But is that the best?

Can’t say that at this point because we haven’t actually forecasted into the future and compared the forecast with the actual performance.

So, the real validation you need now is the Out-of-Time cross-validation.

10. How to do find the optimal ARIMA model manually using Out-of-Time Cross validation

In Out-of-Time cross-validation, you take few steps back in time and forecast into the future to as many steps you took back. Then you compare the forecast against the actuals.

To do out-of-time cross-validation, you need to create the training and testing dataset by splitting the time series into 2 contiguous parts in approximately 75:25 ratio or a reasonable proportion based on time frequency of series.

Why am I not sampling the training data randomly you ask?

That’s because the order sequence of the time series should be intact in order to use it for forecasting.

from statsmodels.tsa.stattools import acf

# Create Training and Test
train = df.value[:85]
test = df.value[85:]

You can now build the ARIMA model on training dataset, forecast and plot it.

# Build Model
# model = ARIMA(train, order=(3,2,1))  
model = ARIMA(train, order=(1, 1, 1))  
fitted = model.fit(disp=-1)  

# Forecast
fc, se, conf = fitted.forecast(15, alpha=0.05)  # 95% conf

# Make as pandas series
fc_series = pd.Series(fc, index=test.index)
lower_series = pd.Series(conf[:, 0], index=test.index)
upper_series = pd.Series(conf[:, 1], index=test.index)

# Plot
plt.figure(figsize=(12,5), dpi=100)
plt.plot(train, label='training')
plt.plot(test, label='actual')
plt.plot(fc_series, label='forecast')
plt.fill_between(lower_series.index, lower_series, upper_series, 
                 color='k', alpha=.15)
plt.title('Forecast vs Actuals')
plt.legend(loc='upper left', fontsize=8)
plt.show()
Forecast vs Actuals
Forecast vs Actuals

From the chart, the ARIMA(1,1,1) model seems to give a directionally correct forecast. And the actual observed values lie within the 95% confidence band. That seems fine.

But each of the predicted forecasts is consistently below the actuals. That means, by adding a small constant to our forecast, the accuracy will certainly improve. So, there is definitely scope for improvement.

So, what I am going to do is to increase the order of differencing to two, that is set d=2 and iteratively increase p to up to 5 and then q up to 5 to see which model gives least AIC and also look for a chart that gives closer actuals and forecasts.

While doing this, I keep an eye on the P values of the AR and MA terms in the model summary. They should be as close to zero, ideally, less than 0.05.

# Build Model
model = ARIMA(train, order=(3, 2, 1))  
fitted = model.fit(disp=-1)  
print(fitted.summary())

# Forecast
fc, se, conf = fitted.forecast(15, alpha=0.05)  # 95% conf

# Make as pandas series
fc_series = pd.Series(fc, index=test.index)
lower_series = pd.Series(conf[:, 0], index=test.index)
upper_series = pd.Series(conf[:, 1], index=test.index)

# Plot
plt.figure(figsize=(12,5), dpi=100)
plt.plot(train, label='training')
plt.plot(test, label='actual')
plt.plot(fc_series, label='forecast')
plt.fill_between(lower_series.index, lower_series, upper_series, 
                 color='k', alpha=.15)
plt.title('Forecast vs Actuals')
plt.legend(loc='upper left', fontsize=8)
plt.show()
                             ARIMA Model Results                              
==============================================================================
Dep. Variable:               D2.value   No. Observations:                   83
Model:                 ARIMA(3, 2, 1)   Log Likelihood                -214.248
Method:                       css-mle   S.D. of innovations              3.153
Date:                Sat, 09 Feb 2019   AIC                            440.497
Time:                        12:49:01   BIC                            455.010
Sample:                             2   HQIC                           446.327

==================================================================================
                     coef    std err          z      P>|z|      [0.025      0.975]
----------------------------------------------------------------------------------
const              0.0483      0.084      0.577      0.565      -0.116       0.212
ar.L1.D2.value     1.1386      0.109     10.399      0.000       0.924       1.353
ar.L2.D2.value    -0.5923      0.155     -3.827      0.000      -0.896      -0.289
ar.L3.D2.value     0.3079      0.111      2.778      0.007       0.091       0.525
ma.L1.D2.value    -1.0000      0.035    -28.799      0.000      -1.068      -0.932
                                    Roots                                    
=============================================================================
                  Real          Imaginary           Modulus         Frequency
-----------------------------------------------------------------------------
AR.1            1.1557           -0.0000j            1.1557           -0.0000
AR.2            0.3839           -1.6318j            1.6763           -0.2132
AR.3            0.3839           +1.6318j            1.6763            0.2132
MA.1            1.0000           +0.0000j            1.0000            0.0000
-----------------------------------------------------------------------------
Revised Forecast vs Actuals
Revised Forecast vs Actuals

The AIC has reduced to 440 from 515. Good. The P-values of the X terms are less the < 0.05, which is great.

So overall it’s much better.

Ideally, you should go back multiple points in time, like, go back 1, 2, 3 and 4 quarters and see how your forecasts are performing at various points in the year.

Here’s a great practice exercise: Try to go back 27, 30, 33, 36 data points and see how the forcasts performs. The forecast performance can be judged using various accuracy metrics discussed next.

11. Accuracy Metrics for Time Series Forecast

The commonly used accuracy metrics to judge forecasts are:

  1. Mean Absolute Percentage Error (MAPE)
  2. Mean Error (ME)
  3. Mean Absolute Error (MAE)
  4. Mean Percentage Error (MPE)
  5. Root Mean Squared Error (RMSE)
  6. Lag 1 Autocorrelation of Error (ACF1)
  7. Correlation between the Actual and the Forecast (corr)
  8. Min-Max Error (minmax)

Typically, if you are comparing forecasts of two different series, the MAPE, Correlation and Min-Max Error can be used.

Why not use the other metrics?

Because only the above three are percentage errors that vary between 0 and 1. That way, you can judge how good is the forecast irrespective of the scale of the series.

The other error metrics are quantities. That implies, an RMSE of 100 for a series whose mean is in 1000’s is better than an RMSE of 5 for series in 10’s. So, you can’t really use them to compare the forecasts of two different scaled time series.

# Accuracy metrics
def forecast_accuracy(forecast, actual):
    mape = np.mean(np.abs(forecast - actual)/np.abs(actual))  # MAPE
    me = np.mean(forecast - actual)             # ME
    mae = np.mean(np.abs(forecast - actual))    # MAE
    mpe = np.mean((forecast - actual)/actual)   # MPE
    rmse = np.mean((forecast - actual)**2)**.5  # RMSE
    corr = np.corrcoef(forecast, actual)[0,1]   # corr
    mins = np.amin(np.hstack([forecast[:,None], 
                              actual[:,None]]), axis=1)
    maxs = np.amax(np.hstack([forecast[:,None], 
                              actual[:,None]]), axis=1)
    minmax = 1 - np.mean(mins/maxs)             # minmax
    acf1 = acf(fc-test)[1]                      # ACF1
    return({'mape':mape, 'me':me, 'mae': mae, 
            'mpe': mpe, 'rmse':rmse, 'acf1':acf1, 
            'corr':corr, 'minmax':minmax})

forecast_accuracy(fc, test.values)

#> {'mape': 0.02250131357314834,
#>  'me': 3.230783108990054,
#>  'mae': 4.548322194530069,
#>  'mpe': 0.016421001932706705,
#>  'rmse': 6.373238534601827,
#>  'acf1': 0.5105506325288692,
#>  'corr': 0.9674576513924394,
#>  'minmax': 0.02163154777672227}

Around 2.2% MAPE implies the model is about 97.8% accurate in predicting the next 15 observations.

Now you know how to build an ARIMA model manually.

But in industrial situations, you will be given a lot of time series to be forecasted and the forecasting exercise be repeated regularly.

So we need a way to automate the best model selection process.

12. How to do Auto Arima Forecast in Python

Like R’s popular auto.arima() function, the pmdarima package provides auto_arima() with similar functionality.

auto_arima() uses a stepwise approach to search multiple combinations of p,d,q parameters and chooses the best model that has the least AIC.

from statsmodels.tsa.arima_model import ARIMA
import pmdarima as pm

df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/wwwusage.csv', names=['value'], header=0)

model = pm.auto_arima(df.value, start_p=1, start_q=1,
                      test='adf',       # use adftest to find optimal 'd'
                      max_p=3, max_q=3, # maximum p and q
                      m=1,              # frequency of series
                      d=None,           # let model determine 'd'
                      seasonal=False,   # No Seasonality
                      start_P=0, 
                      D=0, 
                      trace=True,
                      error_action='ignore',  
                      suppress_warnings=True, 
                      stepwise=True)

print(model.summary())

#> Fit ARIMA: order=(1, 2, 1); AIC=525.586, BIC=535.926, Fit time=0.060 seconds
#> Fit ARIMA: order=(0, 2, 0); AIC=533.474, BIC=538.644, Fit time=0.005 seconds
#> Fit ARIMA: order=(1, 2, 0); AIC=532.437, BIC=540.192, Fit time=0.035 seconds
#> Fit ARIMA: order=(0, 2, 1); AIC=525.893, BIC=533.648, Fit time=0.040 seconds
#> Fit ARIMA: order=(2, 2, 1); AIC=515.248, BIC=528.173, Fit time=0.105 seconds
#> Fit ARIMA: order=(2, 2, 0); AIC=513.459, BIC=523.798, Fit time=0.063 seconds
#> Fit ARIMA: order=(3, 2, 1); AIC=512.552, BIC=528.062, Fit time=0.272 seconds
#> Fit ARIMA: order=(3, 2, 0); AIC=515.284, BIC=528.209, Fit time=0.042 seconds
#> Fit ARIMA: order=(3, 2, 2); AIC=514.514, BIC=532.609, Fit time=0.234 seconds
#> Total fit time: 0.865 seconds
#>                              ARIMA Model Results                              
#> ==============================================================================
#> Dep. Variable:                   D2.y   No. Observations:                   98
#> Model:                 ARIMA(3, 2, 1)   Log Likelihood                -250.276
#> Method:                       css-mle   S.D. of innovations              3.069
#> Date:                Sat, 09 Feb 2019   AIC                            512.552
#> Time:                        12:57:22   BIC                            528.062
#> Sample:                             2   HQIC                           518.825
#> 
#> ==============================================================================
#>                  coef    std err          z      P>|z|      [0.025      0.975]
#> ------------------------------------------------------------------------------
#> const          0.0234      0.058      0.404      0.687      -0.090       0.137
#> ar.L1.D2.y     1.1586      0.097     11.965      0.000       0.969       1.348
#> ar.L2.D2.y    -0.6640      0.136     -4.890      0.000      -0.930      -0.398
#> ar.L3.D2.y     0.3453      0.096      3.588      0.001       0.157       0.534
#> ma.L1.D2.y    -1.0000      0.028    -36.302      0.000      -1.054      -0.946
#>                                     Roots                                    
#> =============================================================================
#>                   Real          Imaginary           Modulus         Frequency
#> -----------------------------------------------------------------------------
#> AR.1            1.1703           -0.0000j            1.1703           -0.0000
#> AR.2            0.3763           -1.5274j            1.5731           -0.2116
#> AR.3            0.3763           +1.5274j            1.5731            0.2116
#> MA.1            1.0000           +0.0000j            1.0000            0.0000
#> -----------------------------------------------------------------------------

13. How to interpret the residual plots in ARIMA model

Let’s review the residual plots using stepwise_fit.

model.plot_diagnostics(figsize=(7,5))
plt.show()
Residuals Chart
Residuals Chart

So how to interpret the plot diagnostics?

Top left: The residual errors seem to fluctuate around a mean of zero and have a uniform variance.

Top Right: The density plot suggest normal distribution with mean zero.

Bottom left: All the dots should fall perfectly in line with the red line. Any significant deviations would imply the distribution is skewed.

Bottom Right: The Correlogram, aka, ACF plot shows the residual errors are not autocorrelated. Any autocorrelation would imply that there is some pattern in the residual errors which are not explained in the model. So you will need to look for more X’s (predictors) to the model.

Overall, it seems to be a good fit. Let’s forecast.

# Forecast
n_periods = 24
fc, confint = model.predict(n_periods=n_periods, return_conf_int=True)
index_of_fc = np.arange(len(df.value), len(df.value)+n_periods)

# make series for plotting purpose
fc_series = pd.Series(fc, index=index_of_fc)
lower_series = pd.Series(confint[:, 0], index=index_of_fc)
upper_series = pd.Series(confint[:, 1], index=index_of_fc)

# Plot
plt.plot(df.value)
plt.plot(fc_series, color='darkgreen')
plt.fill_between(lower_series.index, 
                 lower_series, 
                 upper_series, 
                 color='k', alpha=.15)

plt.title("Final Forecast of WWW Usage")
plt.show()
Final Forecast of WWW Usage
Final Forecast of WWW Usage

14. How to automatically build SARIMA model in python

The problem with plain ARIMA model is it does not support seasonality.

If your time series has defined seasonality, then, go for SARIMA which uses seasonal differencing.

Seasonal differencing is similar to regular differencing, but, instead of subtracting consecutive terms, you subtract the value from previous season.

So, the model will be represented as SARIMA(p,d,q)x(P,D,Q), where, P, D and Q are SAR, order of seasonal differencing and SMA terms respectively and 'x' is the frequency of the time series.

If your model has well defined seasonal patterns, then enforce D=1 for a given frequency ‘x’.

Here’s some practical advice on building SARIMA model:

As a general rule, set the model parameters such that D never exceeds one. And the total differencing ‘d + D’ never exceeds 2. Try to keep only either SAR or SMA terms if your model has seasonal components.

Let’s build an SARIMA model on 'a10' – the drug sales dataset.

# Import
data = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/a10.csv', parse_dates=['date'], index_col='date')

# Plot
fig, axes = plt.subplots(2, 1, figsize=(10,5), dpi=100, sharex=True)

# Usual Differencing
axes[0].plot(data[:], label='Original Series')
axes[0].plot(data[:].diff(1), label='Usual Differencing')
axes[0].set_title('Usual Differencing')
axes[0].legend(loc='upper left', fontsize=10)


# Seasinal Dei
axes[1].plot(data[:], label='Original Series')
axes[1].plot(data[:].diff(12), label='Seasonal Differencing', color='green')
axes[1].set_title('Seasonal Differencing')
plt.legend(loc='upper left', fontsize=10)
plt.suptitle('a10 - Drug Sales', fontsize=16)
plt.show()
Seasonal Differencing
Seasonal Differencing

As you can clearly see, the seasonal spikes is intact after applying usual differencing (lag 1). Whereas, it is rectified after seasonal differencing.

Let’s build the SARIMA model using pmdarima‘s auto_arima(). To do that, you need to set seasonal=True, set the frequency m=12 for month wise series and enforce D=1.

# !pip3 install pyramid-arima
import pmdarima as pm

# Seasonal - fit stepwise auto-ARIMA
smodel = pm.auto_arima(data, start_p=1, start_q=1,
                         test='adf',
                         max_p=3, max_q=3, m=12,
                         start_P=0, seasonal=True,
                         d=None, D=1, trace=True,
                         error_action='ignore',  
                         suppress_warnings=True, 
                         stepwise=True)

smodel.summary()
Fit ARIMA: order=(1, 0, 1) seasonal_order=(0, 1, 1, 12); AIC=534.818, BIC=551.105, Fit time=1.742 seconds
Fit ARIMA: order=(0, 0, 0) seasonal_order=(0, 1, 0, 12); AIC=624.061, BIC=630.576, Fit time=0.028 seconds
Fit ARIMA: order=(1, 0, 0) seasonal_order=(1, 1, 0, 12); AIC=596.004, BIC=609.034, Fit time=0.683 seconds
Fit ARIMA: order=(0, 0, 1) seasonal_order=(0, 1, 1, 12); AIC=611.475, BIC=624.505, Fit time=0.709 seconds
Fit ARIMA: order=(1, 0, 1) seasonal_order=(1, 1, 1, 12); AIC=557.501, BIC=577.046, Fit time=3.687 seconds
(...TRUNCATED...)
Fit ARIMA: order=(3, 0, 0) seasonal_order=(1, 1, 1, 12); AIC=554.570, BIC=577.372, Fit time=2.431 seconds
Fit ARIMA: order=(3, 0, 0) seasonal_order=(0, 1, 0, 12); AIC=554.094, BIC=570.381, Fit time=0.220 seconds
Fit ARIMA: order=(3, 0, 0) seasonal_order=(0, 1, 2, 12); AIC=529.502, BIC=552.305, Fit time=2.120 seconds
Fit ARIMA: order=(3, 0, 0) seasonal_order=(1, 1, 2, 12); AIC=nan, BIC=nan, Fit time=nan seconds
Total fit time: 31.613 seconds

The model has estimated the AIC and the P values of the coefficients look significant. Let’s look at the residual diagnostics plot.

The best model SARIMAX(3, 0, 0)x(0, 1, 1, 12) has an AIC of 528.6 and the P Values are significant.

Let’s forecast for the next 24 months.

# Forecast
n_periods = 24
fitted, confint = smodel.predict(n_periods=n_periods, return_conf_int=True)
index_of_fc = pd.date_range(data.index[-1], periods = n_periods, freq='MS')

# make series for plotting purpose
fitted_series = pd.Series(fitted, index=index_of_fc)
lower_series = pd.Series(confint[:, 0], index=index_of_fc)
upper_series = pd.Series(confint[:, 1], index=index_of_fc)

# Plot
plt.plot(data)
plt.plot(fitted_series, color='darkgreen')
plt.fill_between(lower_series.index, 
                 lower_series, 
                 upper_series, 
                 color='k', alpha=.15)

plt.title("SARIMA - Final Forecast of a10 - Drug Sales")
plt.show()
SARIMA - Final Forecasts
SARIMA – Final Forecasts

There you have a nice forecast that captures the expected seasonal demand pattern.

15. How to build SARIMAX Model with exogenous variable

The SARIMA model we built is good. I would stop here typically.

But for the sake of completeness, let’s try and force an external predictor, also called, ‘exogenous variable’ into the model. This model is called the SARIMAX model.

The only requirement to use an exogenous variable is you need to know the value of the variable during the forecast period as well.

For the sake of demonstration, I am going to use the seasonal index from the classical seasonal decomposition on the latest 36 months of data.

Why the seasonal index? Isn’t SARIMA already modeling the seasonality, you ask?

You are correct.

But also, I want to see how the model looks if we force the recent seasonality pattern into the training and forecast.

Secondly, this is a good variable for demo purpose. So you can use this as a template and plug in any of your variables into the code. The seasonal index is a good exogenous variable because it repeats every frequency cycle, 12 months in this case.

So, you will always know what values the seasonal index will hold for the future forecasts.

# Import Data
data = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/a10.csv', parse_dates=['date'], index_col='date')

Let’s compute the seasonal index so that it can be forced as a (exogenous) predictor to the SARIMAX model.

# Compute Seasonal Index
from statsmodels.tsa.seasonal import seasonal_decompose
from dateutil.parser import parse

# multiplicative seasonal component
result_mul = seasonal_decompose(data['value'][-36:],   # 3 years
                                model='multiplicative', 
                                extrapolate_trend='freq')

seasonal_index = result_mul.seasonal[-12:].to_frame()
seasonal_index['month'] = pd.to_datetime(seasonal_index.index).month

# merge with the base data
data['month'] = data.index.month
df = pd.merge(data, seasonal_index, how='left', on='month')
df.columns = ['value', 'month', 'seasonal_index']
df.index = data.index  # reassign the index.

The exogenous variable (seasonal index) is ready. Let’s build the SARIMAX model.

import pmdarima as pm

# SARIMAX Model
sxmodel = pm.auto_arima(df[['value']], exogenous=df[['seasonal_index']],
                           start_p=1, start_q=1,
                           test='adf',
                           max_p=3, max_q=3, m=12,
                           start_P=0, seasonal=True,
                           d=None, D=1, trace=True,
                           error_action='ignore',  
                           suppress_warnings=True, 
                           stepwise=True)

sxmodel.summary()
Fit ARIMA: order=(1, 0, 1) seasonal_order=(0, 1, 1, 12); AIC=536.818, BIC=556.362, Fit time=2.083 seconds
Fit ARIMA: order=(0, 0, 0) seasonal_order=(0, 1, 0, 12); AIC=626.061, BIC=635.834, Fit time=0.033 seconds
Fit ARIMA: order=(1, 0, 0) seasonal_order=(1, 1, 0, 12); AIC=598.004, BIC=614.292, Fit time=0.682 seconds
Fit ARIMA: order=(0, 0, 1) seasonal_order=(0, 1, 1, 12); AIC=613.475, BIC=629.762, Fit time=0.510 seconds
Fit ARIMA: order=(1, 0, 1) seasonal_order=(1, 1, 1, 12); AIC=559.530, BIC=582.332, Fit time=3.129 seconds
(...Truncated...)
Fit ARIMA: order=(3, 0, 0) seasonal_order=(0, 1, 0, 12); AIC=556.094, BIC=575.639, Fit time=0.260 seconds
Fit ARIMA: order=(3, 0, 0) seasonal_order=(0, 1, 2, 12); AIC=531.502, BIC=557.562, Fit time=2.375 seconds
Fit ARIMA: order=(3, 0, 0) seasonal_order=(1, 1, 2, 12); AIC=nan, BIC=nan, Fit time=nan seconds
Total fit time: 30.781 seconds

So, we have the model with the exogenous term. But the coefficient is very small for x1, so the contribution from that variable will be negligible. Let’s forecast it anyway.

We have effectively forced the latest seasonal effect of the latest 3 years into the model instead of the entire history.

Alright let’s forecast into the next 24 months. For this, you need the value of the seasonal index for the next 24 months.

# Forecast
n_periods = 24
fitted, confint = sxmodel.predict(n_periods=n_periods, 
                                  exogenous=np.tile(seasonal_index.value, 2).reshape(-1,1), 
                                  return_conf_int=True)

index_of_fc = pd.date_range(data.index[-1], periods = n_periods, freq='MS')

# make series for plotting purpose
fitted_series = pd.Series(fitted, index=index_of_fc)
lower_series = pd.Series(confint[:, 0], index=index_of_fc)
upper_series = pd.Series(confint[:, 1], index=index_of_fc)

# Plot
plt.plot(data['value'])
plt.plot(fitted_series, color='darkgreen')
plt.fill_between(lower_series.index, 
                 lower_series, 
                 upper_series, 
                 color='k', alpha=.15)

plt.title("SARIMAX Forecast of a10 - Drug Sales")
plt.show()
SARIMAX Forecast
SARIMAX Forecast

16. Practice Exercises

In the AirPassengers dataset, go back 12 months in time and build the SARIMA forecast for the next 12 months.

  1. Is the series stationary? If not what sort of differencing is required?
  2. What is the order of your best model?
  3. What is the AIC of your model?
  4. What is the MAPE achieved in OOT cross-validation?
  5. What is the order of the best model predicted by auto_arima() method?

17. Conclusion

Congrats if you reached this point. Give yourself a BIG hug if you were able to solve the practice exercises.

I really hope you found this useful?

We have covered a lot of concepts starting from the very basics of forecasting, AR, MA, ARIMA, SARIMA and finally the SARIMAX model. If you have any questions please write in the comments section. Meanwhile, I will work on the next article.

Happy Learning!

Time Series Analysis in Python – A Comprehensive Guide with Examples

Time series is a sequence of observations recorded at regular time intervals. This guide walks you through the process of analyzing the characteristics of a given time series in python.

Time Series Analysis in Python – A Comprehensive Guide. Photo by Daniel Ferrandiz.

Contents

[columnize]
  1. What is a Time Series?
  2. How to import Time Series in Python?
  3. What is panel data?
  4. Visualizing a Time Series
  5. Patterns in a Time Series
  6. Additive and multiplicative Time Series
  7. How to decompose a Time Series into its components?
  8. Stationary and non-stationary Time Series
  9. How to make a Time Series stationary?
  10. How to test for stationarity?
  11. What is the difference between white noise and a stationary series?
  12. How to detrend a Time Series?
  13. How to deseasonalize a Time Series?
  14. How to test for seasonality of a Time Series?
  15. How to treat missing values in a Time Series?
  16. What is autocorrelation and partial autocorrelation functions?
  17. How to compute partial autocorrelation function?
  18. Lag Plots
  19. How to estimate the forecastability of a Time Series?
  20. Why and How to smoothen a Time Series?
  21. How to use Granger Causality test to know if one Time Series is helpful in forecasting another?
  22. What Next
[/columnize]

1. What is a Time Series?

Time series is a sequence of observations recorded at regular time intervals.

Depending on the frequency of observations, a time series may typically be hourly, daily, weekly, monthly, quarterly and annual. Sometimes, you might have seconds and minute-wise time series as well, like, number of clicks and user visits every minute etc.

Why even analyze a time series?

Because it is the preparatory step before you develop a forecast of the series.

Besides, time series forecasting has enormous commercial significance because stuff that is important to a business like demand and sales, number of visitors to a website, stock price etc are essentially time series data.

So what does analyzing a time series involve?

Time series analysis involves understanding various aspects about the inherent nature of the series so that you are better informed to create meaningful and accurate forecasts.

2. How to import time series in python?

So how to import time series data?

The data for a time series typically stores in .csv files or other spreadsheet formats and contains two columns: the date and the measured value.

Let’s use the read_csv() in pandas package to read the time series dataset (a csv file on Australian Drug Sales) as a pandas dataframe. Adding the parse_dates=['date'] argument will make the date column to be parsed as a date field.

from dateutil.parser import parse 
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
plt.rcParams.update({'figure.figsize': (10, 7), 'figure.dpi': 120})

# Import as Dataframe
df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/a10.csv', parse_dates=['date'])
df.head()
Dataframe Time Series
Dataframe Time Series

Alternately, you can import it as a pandas Series with the date as index. You just need to specify the index_col argument in the pd.read_csv() to do this.

ser = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/a10.csv', parse_dates=['date'], index_col='date')
ser.head()
Series Timeseries
Series Timeseries

Note, in the series, the ‘value’ column is placed higher than date to imply that it is a series.

3. What is panel data?

Panel data is also a time based dataset.

The difference is that, in addition to time series, it also contains one or more related variables that are measured for the same time periods.

Typically, the columns present in panel data contain explanatory variables that can be helpful in predicting the Y, provided those columns will be available at the future forecasting period.

An example of panel data is shown below.

# dataset source: https://github.com/rouseguy
df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/MarketArrivals.csv')
df = df.loc[df.market=='MUMBAI', :]
df.head()
Panel Data
Panel Data

4. Visualizing a time series

Let’s use matplotlib to visualise the series.

# Time series data source: fpp pacakge in R.
import matplotlib.pyplot as plt
df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/a10.csv', parse_dates=['date'], index_col='date')

# Draw Plot
def plot_df(df, x, y, title="", xlabel='Date', ylabel='Value', dpi=100):
    plt.figure(figsize=(16,5), dpi=dpi)
    plt.plot(x, y, color='tab:red')
    plt.gca().set(title=title, xlabel=xlabel, ylabel=ylabel)
    plt.show()

plot_df(df, x=df.index, y=df.value, title='Monthly anti-diabetic drug sales in Australia from 1992 to 2008.')    
Visualizing Time Series
Visualizing Time Series

Since all values are positive, you can show this on both sides of the Y axis to emphasize the growth.

# Import data
df = pd.read_csv('datasets/AirPassengers.csv', parse_dates=['date'])
x = df['date'].values
y1 = df['value'].values

# Plot
fig, ax = plt.subplots(1, 1, figsize=(16,5), dpi= 120)
plt.fill_between(x, y1=y1, y2=-y1, alpha=0.5, linewidth=2, color='seagreen')
plt.ylim(-800, 800)
plt.title('Air Passengers (Two Side View)', fontsize=16)
plt.hlines(y=0, xmin=np.min(df.date), xmax=np.max(df.date), linewidth=.5)
plt.show()
Air Passengers Data - 2 Side Series
Air Passengers Data – 2 Side Series

Since its a monthly time series and follows a certain repetitive pattern every year, you can plot each year as a separate line in the same plot. This lets you compare the year wise patterns side-by-side.

Seasonal Plot of a Time Series

# Import Data
df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/a10.csv', parse_dates=['date'], index_col='date')
df.reset_index(inplace=True)

# Prepare data
df['year'] = [d.year for d in df.date]
df['month'] = [d.strftime('%b') for d in df.date]
years = df['year'].unique()

# Prep Colors
np.random.seed(100)
mycolors = np.random.choice(list(mpl.colors.XKCD_COLORS.keys()), len(years), replace=False)

# Draw Plot
plt.figure(figsize=(16,12), dpi= 80)
for i, y in enumerate(years):
    if i > 0:        
        plt.plot('month', 'value', data=df.loc[df.year==y, :], color=mycolors[i], label=y)
        plt.text(df.loc[df.year==y, :].shape[0]-.9, df.loc[df.year==y, 'value'][-1:].values[0], y, fontsize=12, color=mycolors[i])

# Decoration
plt.gca().set(xlim=(-0.3, 11), ylim=(2, 30), ylabel='$Drug Sales$', xlabel='$Month$')
plt.yticks(fontsize=12, alpha=.7)
plt.title("Seasonal Plot of Drug Sales Time Series", fontsize=20)
plt.show()
Seasonal Plot of Drug Sales
Seasonal Plot of Drug Sales

There is a steep fall in drug sales every February, rising again in March, falling again in April and so on. Clearly, the pattern repeats within a given year, every year.

However, as years progress, the drug sales increase overall. You can nicely visualize this trend and how it varies each year in a nice year-wise boxplot. Likewise, you can do a month-wise boxplot to visualize the monthly distributions.

Boxplot of Month-wise (Seasonal) and Year-wise (trend) Distribution

You can group the data at seasonal intervals and see how the values are distributed within a given year or month and how it compares over time.

# Import Data
df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/a10.csv', parse_dates=['date'], index_col='date')
df.reset_index(inplace=True)

# Prepare data
df['year'] = [d.year for d in df.date]
df['month'] = [d.strftime('%b') for d in df.date]
years = df['year'].unique()

# Draw Plot
fig, axes = plt.subplots(1, 2, figsize=(20,7), dpi= 80)
sns.boxplot(x='year', y='value', data=df, ax=axes[0])
sns.boxplot(x='month', y='value', data=df.loc[~df.year.isin([1991, 2008]), :])

# Set Title
axes[0].set_title('Year-wise Box Plot\n(The Trend)', fontsize=18); 
axes[1].set_title('Month-wise Box Plot\n(The Seasonality)', fontsize=18)
plt.show()
Yearwise and Monthwise Boxplot
Yearwise and Monthwise Boxplot

The boxplots make the year-wise and month-wise distributions evident. Also, in a month-wise boxplot, the months of December and January clearly has higher drug sales, which can be attributed to the holiday discounts season.

So far, we have seen the similarities to identify the pattern. Now, how to find out any deviations from the usual pattern?

5. Patterns in a time series

Any time series may be split into the following components: Base Level + Trend + Seasonality + Error

A trend is observed when there is an increasing or decreasing slope observed in the time series. Whereas seasonality is observed when there is a distinct repeated pattern observed between regular intervals due to seasonal factors. It could be because of the month of the year, the day of the month, weekdays or even time of the day.

However, It is not mandatory that all time series must have a trend and/or seasonality. A time series may not have a distinct trend but have a seasonality. The opposite can also be true.

So, a time series may be imagined as a combination of the trend, seasonality and the error terms.

fig, axes = plt.subplots(1,3, figsize=(20,4), dpi=100)
pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/guinearice.csv', parse_dates=['date'], index_col='date').plot(title='Trend Only', legend=False, ax=axes[0])

pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/sunspotarea.csv', parse_dates=['date'], index_col='date').plot(title='Seasonality Only', legend=False, ax=axes[1])

pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/AirPassengers.csv', parse_dates=['date'], index_col='date').plot(title='Trend and Seasonality', legend=False, ax=axes[2])
Patterns in Time Series
Patterns in Time Series

Another aspect to consider is the cyclic behaviour. It happens when the rise and fall pattern in the series does not happen in fixed calendar-based intervals. Care should be taken to not confuse ‘cyclic’ effect with ‘seasonal’ effect.

So, How to diffentiate between a ‘cyclic’ vs ‘seasonal’ pattern?

If the patterns are not of fixed calendar based frequencies, then it is cyclic. Because, unlike the seasonality, cyclic effects are typically influenced by the business and other socio-economic factors.

6. Additive and multiplicative time series

Depending on the nature of the trend and seasonality, a time series can be modeled as an additive or multiplicative, wherein, each observation in the series can be expressed as either a sum or a product of the components:

Additive time series:
Value = Base Level + Trend + Seasonality + Error

Multiplicative Time Series:
Value = Base Level x Trend x Seasonality x Error

7. How to decompose a time series into its components?

You can do a classical decomposition of a time series by considering the series as an additive or multiplicative combination of the base level, trend, seasonal index and the residual.

The seasonal_decompose in statsmodels implements this conveniently.

from statsmodels.tsa.seasonal import seasonal_decompose
from dateutil.parser import parse

# Import Data
df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/a10.csv', parse_dates=['date'], index_col='date')

# Multiplicative Decomposition 
result_mul = seasonal_decompose(df['value'], model='multiplicative', extrapolate_trend='freq')

# Additive Decomposition
result_add = seasonal_decompose(df['value'], model='additive', extrapolate_trend='freq')

# Plot
plt.rcParams.update({'figure.figsize': (10,10)})
result_mul.plot().suptitle('Multiplicative Decompose', fontsize=22)
result_add.plot().suptitle('Additive Decompose', fontsize=22)
plt.show()
Additive and Multiplicative Decompose
Additive and Multiplicative Decompose

Setting extrapolate_trend='freq' takes care of any missing values in the trend and residuals at the beginning of the series.

If you look at the residuals of the additive decomposition closely, it has some pattern left over. The multiplicative decomposition, however, looks quite random which is good. So ideally, multiplicative decomposition should be preferred for this particular series.

The numerical output of the trend, seasonal and residual components are stored in the result_mul output itself. Let’s extract them and put it in a dataframe.

# Extract the Components ----
# Actual Values = Product of (Seasonal * Trend * Resid)
df_reconstructed = pd.concat([result_mul.seasonal, result_mul.trend, result_mul.resid, result_mul.observed], axis=1)
df_reconstructed.columns = ['seas', 'trend', 'resid', 'actual_values']
df_reconstructed.head()

If you check, the product of seas, trend and resid columns should exactly equal to the actual_values.

8. Stationary and Non-Stationary Time Series

Stationarity is a property of a time series. A stationary series is one where the values of the series is not a function of time.

That is, the statistical properties of the series like mean, variance and autocorrelation are constant over time. Autocorrelation of the series is nothing but the correlation of the series with its previous values, more on this coming up.

A stationary time series id devoid of seasonal effects as well.

So how to identify if a series is stationary or not? Let’s plot some examples to make it clear:

Stationary and Non-Stationary Time Series
Stationary and Non-Stationary Time Series

The above image is sourced from R’s TSTutorial.

So why does a stationary series matter? why am I even talking about it?

I will come to that in a bit, but understand that it is possible to make nearly any time series stationary by applying a suitable transformation. Most statistical forecasting methods are designed to work on a stationary time series. The first step in the forecasting process is typically to do some transformation to convert a non-stationary series to stationary.

9. How to make a time series stationary?

You can make series stationary by:

  1. Differencing the Series (once or more)
  2. Take the log of the series
  3. Take the nth root of the series
  4. Combination of the above

The most common and convenient method to stationarize the series is by differencing the series at least once until it becomes approximately stationary.

So what is differencing?

If Y_t is the value at time ‘t’, then the first difference of Y = Yt – Yt-1. In simpler terms, differencing the series is nothing but subtracting the next value by the current value.

If the first difference doesn’t make a series stationary, you can go for the second differencing. And so on.

For example, consider the following series: [1, 5, 2, 12, 20]

First differencing gives: [5-1, 2-5, 12-2, 20-12] = [4, -3, 10, 8]

Second differencing gives: [-3-4, -10-3, 8-10] = [-7, -13, -2]

9. Why make a non-stationary series stationary before forecasting?

Forecasting a stationary series is relatively easy and the forecasts are more reliable.

An important reason is, autoregressive forecasting models are essentially linear regression models that utilize the lag(s) of the series itself as predictors.

We know that linear regression works best if the predictors (X variables) are not correlated against each other. So, stationarizing the series solves this problem since it removes any persistent autocorrelation, thereby making the predictors(lags of the series) in the forecasting models nearly independent.

Now that we’ve established that stationarizing the series important, how do you check if a given series is stationary or not?

10. How to test for stationarity?

The stationarity of a series can be established by looking at the plot of the series like we did earlier.

Another method is to split the series into 2 or more contiguous parts and computing the summary statistics like the mean, variance and the autocorrelation. If the stats are quite different, then the series is not likely to be stationary.

Nevertheless, you need a method to quantitatively determine if a given series is stationary or not. This can be done using statistical tests called ‘Unit Root Tests’. There are multiple variations of this, where the tests check if a time series is non-stationary and possess a unit root.

There are multiple implementations of Unit Root tests like:

  1. Augmented Dickey Fuller test (ADH Test)
  2. Kwiatkowski-Phillips-Schmidt-Shin – KPSS test (trend stationary)
  3. Philips Perron test (PP Test)

The most commonly used is the ADF test, where the null hypothesis is the time series possesses a unit root and is non-stationary. So, id the P-Value in ADH test is less than the significance level (0.05), you reject the null hypothesis.

The KPSS test, on the other hand, is used to test for trend stationarity. The null hypothesis and the P-Value interpretation is just the opposite of ADH test. The below code implements these two tests using statsmodels package in python.

from statsmodels.tsa.stattools import adfuller, kpss
df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/a10.csv', parse_dates=['date'])

# ADF Test
result = adfuller(df.value.values, autolag='AIC')
print(f'ADF Statistic: {result[0]}')
print(f'p-value: {result[1]}')
for key, value in result[4].items():
    print('Critial Values:')
    print(f'   {key}, {value}')

# KPSS Test
result = kpss(df.value.values, regression='c')
print('\nKPSS Statistic: %f' % result[0])
print('p-value: %f' % result[1])
for key, value in result[3].items():
    print('Critial Values:')
    print(f'   {key}, {value}')
ADF Statistic: 3.14518568930674
p-value: 1.0
Critial Values:
   1%, -3.465620397124192
Critial Values:
   5%, -2.8770397560752436
Critial Values:
   10%, -2.5750324547306476

KPSS Statistic: 1.313675
p-value: 0.010000
Critial Values:
   10%, 0.347
Critial Values:
   5%, 0.463
Critial Values:
   2.5%, 0.574
Critial Values:
   1%, 0.739

11. What is the difference between white noise and a stationary series?

Like a stationary series, the white noise is also not a function of time, that is its mean and variance does not change over time. But the difference is, the white noise is completely random with a mean of 0.

In white noise there is no pattern whatsoever. If you consider the sound signals in an FM radio as a time series, the blank sound you hear between the channels is white noise.

Mathematically, a sequence of completely random numbers with mean zero is a white noise.

randvals = np.random.randn(1000)
pd.Series(randvals).plot(title='Random White Noise', color='k')
Random White Noise
Random White Noise

12. How to detrend a time series?

Detrending a time series is to remove the trend component from a time series. But how to extract the trend? There are multiple approaches.

  1. Subtract the line of best fit from the time series. The line of best fit may be obtained from a linear regression model with the time steps as the predictor. For more complex trends, you may want to use quadratic terms (x^2) in the model.

  2. Subtract the trend component obtained from time series decomposition we saw earlier.

  3. Subtract the mean

  4. Apply a filter like Baxter-King filter(statsmodels.tsa.filters.bkfilter) or the Hodrick-Prescott Filter (statsmodels.tsa.filters.hpfilter) to remove the moving average trend lines or the cyclical components.

Let’s implement the first two methods.

# Using scipy: Subtract the line of best fit
from scipy import signal
df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/a10.csv', parse_dates=['date'])
detrended = signal.detrend(df.value.values)
plt.plot(detrended)
plt.title('Drug Sales detrended by subtracting the least squares fit', fontsize=16)
Detrend A TimeSeries By Subtracting LeastSquaresFit
Detrend A TimeSeries By Subtracting LeastSquaresFit
# Using statmodels: Subtracting the Trend Component.
from statsmodels.tsa.seasonal import seasonal_decompose
df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/a10.csv', parse_dates=['date'], index_col='date')
result_mul = seasonal_decompose(df['value'], model='multiplicative', extrapolate_trend='freq')
detrended = df.value.values - result_mul.trend
plt.plot(detrended)
plt.title('Drug Sales detrended by subtracting the trend component', fontsize=16)
Detrend By Subtracting Trend Component
Detrend By Subtracting Trend Component

13. How to deseasonalize a time series?

There are multiple approaches to deseasonalize a time series as well. Below are a few:

- 1. Take a moving average with length as the seasonal window. This will smoothen in series in the process.

- 2. Seasonal difference the series (subtract the value of previous season from the current value)

- 3. Divide the series by the seasonal index obtained from STL decomposition

If dividing by the seasonal index does not work well, try taking a log of the series and then do the deseasonalizing. You can later restore to the original scale by taking an exponential.

# Subtracting the Trend Component.
df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/a10.csv', parse_dates=['date'], index_col='date')

# Time Series Decomposition
result_mul = seasonal_decompose(df['value'], model='multiplicative', extrapolate_trend='freq')

# Deseasonalize
deseasonalized = df.value.values / result_mul.seasonal

# Plot
plt.plot(deseasonalized)
plt.title('Drug Sales Deseasonalized', fontsize=16)
plt.plot()
Deseasonalize Time Series
Deseasonalize Time Series

14. How to test for seasonality of a time series?

The common way is to plot the series and check for repeatable patterns in fixed time intervals. So, the types of seasonality is determined by the clock or the calendar:

  1. Hour of day
  2. Day of month
  3. Weekly
  4. Monthly
  5. Yearly

However, if you want a more definitive inspection of the seasonality, use the Autocorrelation Function (ACF) plot. More on the ACF in the upcoming sections. But when there is a strong seasonal pattern, the ACF plot usually reveals definitive repeated spikes at the multiples of the seasonal window.

For example, the drug sales time series is a monthly series with patterns repeating every year. So, you can see spikes at 12th, 24th, 36th.. lines.

I must caution you that in real word datasets such strong patterns is hardly noticed and can get distorted by any noise, so you need a careful eye to capture these patterns.

from pandas.plotting import autocorrelation_plot
df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/a10.csv')

# Draw Plot
plt.rcParams.update({'figure.figsize':(9,5), 'figure.dpi':120})
autocorrelation_plot(df.value.tolist())
Autocorrelation Plot
Autocorrelation Plot

Alternately, if you want a statistical test, the CHTest can determine if seasonal differencing is required to stationarize the series.

15. How to treat missing values in a time series?

Sometimes, your time series will have missing dates/times. That means, the data was not captured or was not available for those periods. It could so happen the measurement was zero on those days, in which case, case you may fill up those periods with zero.

Secondly, when it comes to time series, you should typically NOT replace missing values with the mean of the series, especially if the series is not stationary. What you could do instead for a quick and dirty workaround is to forward-fill the previous value.

However, depending on the nature of the series, you want to try out multiple approaches before concluding. Some effective alternatives to imputation are:

  • Backward Fill
  • Linear Interpolation
  • Quadratic interpolation
  • Mean of nearest neighbors
  • Mean of seasonal couterparts

To measure the imputation performance, I manually introduce missing values to the time series, impute it with above approaches and then measure the mean squared error of the imputed against the actual values.

# # Generate dataset
from scipy.interpolate import interp1d
from sklearn.metrics import mean_squared_error
df_orig = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/a10.csv', parse_dates=['date'], index_col='date').head(100)
df = pd.read_csv('datasets/a10_missings.csv', parse_dates=['date'], index_col='date')

fig, axes = plt.subplots(7, 1, sharex=True, figsize=(10, 12))
plt.rcParams.update({'xtick.bottom' : False})

## 1. Actual -------------------------------
df_orig.plot(title='Actual', ax=axes[0], label='Actual', color='red', style=".-")
df.plot(title='Actual', ax=axes[0], label='Actual', color='green', style=".-")
axes[0].legend(["Missing Data", "Available Data"])

## 2. Forward Fill --------------------------
df_ffill = df.ffill()
error = np.round(mean_squared_error(df_orig['value'], df_ffill['value']), 2)
df_ffill['value'].plot(title='Forward Fill (MSE: ' + str(error) +")", ax=axes[1], label='Forward Fill', style=".-")

## 3. Backward Fill -------------------------
df_bfill = df.bfill()
error = np.round(mean_squared_error(df_orig['value'], df_bfill['value']), 2)
df_bfill['value'].plot(title="Backward Fill (MSE: " + str(error) +")", ax=axes[2], label='Back Fill', color='firebrick', style=".-")

## 4. Linear Interpolation ------------------
df['rownum'] = np.arange(df.shape[0])
df_nona = df.dropna(subset = ['value'])
f = interp1d(df_nona['rownum'], df_nona['value'])
df['linear_fill'] = f(df['rownum'])
error = np.round(mean_squared_error(df_orig['value'], df['linear_fill']), 2)
df['linear_fill'].plot(title="Linear Fill (MSE: " + str(error) +")", ax=axes[3], label='Cubic Fill', color='brown', style=".-")

## 5. Cubic Interpolation --------------------
f2 = interp1d(df_nona['rownum'], df_nona['value'], kind='cubic')
df['cubic_fill'] = f2(df['rownum'])
error = np.round(mean_squared_error(df_orig['value'], df['cubic_fill']), 2)
df['cubic_fill'].plot(title="Cubic Fill (MSE: " + str(error) +")", ax=axes[4], label='Cubic Fill', color='red', style=".-")

# Interpolation References:
# https://docs.scipy.org/doc/scipy/reference/tutorial/interpolate.html
# https://docs.scipy.org/doc/scipy/reference/interpolate.html

## 6. Mean of 'n' Nearest Past Neighbors ------
def knn_mean(ts, n):
    out = np.copy(ts)
    for i, val in enumerate(ts):
        if np.isnan(val):
            n_by_2 = np.ceil(n/2)
            lower = np.max([0, int(i-n_by_2)])
            upper = np.min([len(ts)+1, int(i+n_by_2)])
            ts_near = np.concatenate([ts[lower:i], ts[i:upper]])
            out[i] = np.nanmean(ts_near)
    return out

df['knn_mean'] = knn_mean(df.value.values, 8)
error = np.round(mean_squared_error(df_orig['value'], df['knn_mean']), 2)
df['knn_mean'].plot(title="KNN Mean (MSE: " + str(error) +")", ax=axes[5], label='KNN Mean', color='tomato', alpha=0.5, style=".-")

## 7. Seasonal Mean ----------------------------
def seasonal_mean(ts, n, lr=0.7):
    """
    Compute the mean of corresponding seasonal periods
    ts: 1D array-like of the time series
    n: Seasonal window length of the time series
    """
    out = np.copy(ts)
    for i, val in enumerate(ts):
        if np.isnan(val):
            ts_seas = ts[i-1::-n]  # previous seasons only
            if np.isnan(np.nanmean(ts_seas)):
                ts_seas = np.concatenate([ts[i-1::-n], ts[i::n]])  # previous and forward
            out[i] = np.nanmean(ts_seas) * lr
    return out

df['seasonal_mean'] = seasonal_mean(df.value, n=12, lr=1.25)
error = np.round(mean_squared_error(df_orig['value'], df['seasonal_mean']), 2)
df['seasonal_mean'].plot(title="Seasonal Mean (MSE: " + str(error) +")", ax=axes[6], label='Seasonal Mean', color='blue', alpha=0.5, style=".-")
Missing Value Treatments
Missing Value Treatments

You could also consider the following approaches depending on how accurate you want the imputations to be.

  1. If you have explanatory variables use a prediction model like the random forest or k-Nearest Neighbors to predict it.
  2. If you have enough past observations, forecast the missing values.
  3. If you have enough future observations, backcast the missing values
  4. Forecast of counterparts from previous cycles.

16. What is autocorrelation and partial autocorrelation functions?

Autocorrelation is simply the correlation of a series with its own lags. If a series is significantly autocorrelated, that means, the previous values of the series (lags) may be helpful in predicting the current value.

Partial Autocorrelation also conveys similar information but it conveys the pure correlation of a series and its lag, excluding the correlation contributions from the intermediate lags.

from statsmodels.tsa.stattools import acf, pacf
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf

df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/a10.csv')

# Calculate ACF and PACF upto 50 lags
# acf_50 = acf(df.value, nlags=50)
# pacf_50 = pacf(df.value, nlags=50)

# Draw Plot
fig, axes = plt.subplots(1,2,figsize=(16,3), dpi= 100)
plot_acf(df.value.tolist(), lags=50, ax=axes[0])
plot_pacf(df.value.tolist(), lags=50, ax=axes[1])
ACF and PACF
ACF and PACF

17. How to compute partial autocorrelation function?

So how to compute partial autocorrelation?

The partial autocorrelation of lag (k) of a series is the coefficient of that lag in the autoregression equation of Y. The autoregressive equation of Y is nothing but the linear regression of Y with its own lags as predictors.

For Example, if Y_t is the current series and Y_t-1 is the lag 1 of Y, then the partial autocorrelation of lag 3 (Y_t-3) is the coefficient $\alpha_3$ of Y_t-3 in the following equation:

Autoregression Equation
Autoregression Equation

18. Lag Plots

A Lag plot is a scatter plot of a time series against a lag of itself. It is normally used to check for autocorrelation. If there is any pattern existing in the series like the one you see below, the series is autocorrelated. If there is no such pattern, the series is likely to be random white noise.

In below example on Sunspots area time series, the plots get more and more scattered as the n_lag increases.

from pandas.plotting import lag_plot
plt.rcParams.update({'ytick.left' : False, 'axes.titlepad':10})

# Import
ss = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/sunspotarea.csv')
a10 = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/a10.csv')

# Plot
fig, axes = plt.subplots(1, 4, figsize=(10,3), sharex=True, sharey=True, dpi=100)
for i, ax in enumerate(axes.flatten()[:4]):
    lag_plot(ss.value, lag=i+1, ax=ax, c='firebrick')
    ax.set_title('Lag ' + str(i+1))

fig.suptitle('Lag Plots of Sun Spots Area \n(Points get wide and scattered with increasing lag -> lesser correlation)\n', y=1.15)    

fig, axes = plt.subplots(1, 4, figsize=(10,3), sharex=True, sharey=True, dpi=100)
for i, ax in enumerate(axes.flatten()[:4]):
    lag_plot(a10.value, lag=i+1, ax=ax, c='firebrick')
    ax.set_title('Lag ' + str(i+1))

fig.suptitle('Lag Plots of Drug Sales', y=1.05)    
plt.show()
Lagplots Drugsales
Lagplots Drugsales
Lagplots Sunspots
Lagplots Sunspots

19. How to estimate the forecastability of a time series?

The more regular and repeatable patterns a time series has, the easier it is to forecast. The ‘Approximate Entropy’ can be used to quantify the regularity and unpredictability of fluctuations in a time series.

The higher the approximate entropy, the more difficult it is to forecast it.

Another better alternate is the ‘Sample Entropy’.

Sample Entropy is similar to approximate entropy but is more consistent in estimating the complexity even for smaller time series. For example, a random time series with fewer data points can have a lower ‘approximate entropy’ than a more ‘regular’ time series, whereas, a longer random time series will have a higher ‘approximate entropy’.

Sample Entropy handles this problem nicely. See the demonstration below.

# https://en.wikipedia.org/wiki/Approximate_entropy
ss = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/sunspotarea.csv')
a10 = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/a10.csv')
rand_small = np.random.randint(0, 100, size=36)
rand_big = np.random.randint(0, 100, size=136)

def ApEn(U, m, r):
    """Compute Aproximate entropy"""
    def _maxdist(x_i, x_j):
        return max([abs(ua - va) for ua, va in zip(x_i, x_j)])

    def _phi(m):
        x = [[U[j] for j in range(i, i + m - 1 + 1)] for i in range(N - m + 1)]
        C = [len([1 for x_j in x if _maxdist(x_i, x_j) <= r]) / (N - m + 1.0) for x_i in x]
        return (N - m + 1.0)**(-1) * sum(np.log(C))

    N = len(U)
    return abs(_phi(m+1) - _phi(m))

print(ApEn(ss.value, m=2, r=0.2*np.std(ss.value)))     # 0.651
print(ApEn(a10.value, m=2, r=0.2*np.std(a10.value)))   # 0.537
print(ApEn(rand_small, m=2, r=0.2*np.std(rand_small))) # 0.143
print(ApEn(rand_big, m=2, r=0.2*np.std(rand_big)))     # 0.716
0.6514704970333534
0.5374775224973489
0.0898376940798844
0.7369242960384561
# https://en.wikipedia.org/wiki/Sample_entropy
def SampEn(U, m, r):
    """Compute Sample entropy"""
    def _maxdist(x_i, x_j):
        return max([abs(ua - va) for ua, va in zip(x_i, x_j)])

    def _phi(m):
        x = [[U[j] for j in range(i, i + m - 1 + 1)] for i in range(N - m + 1)]
        C = [len([1 for j in range(len(x)) if i != j and _maxdist(x[i], x[j]) <= r]) for i in range(len(x))]
        return sum(C)

    N = len(U)
    return -np.log(_phi(m+1) / _phi(m))

print(SampEn(ss.value, m=2, r=0.2*np.std(ss.value)))      # 0.78
print(SampEn(a10.value, m=2, r=0.2*np.std(a10.value)))    # 0.41
print(SampEn(rand_small, m=2, r=0.2*np.std(rand_small)))  # 1.79
print(SampEn(rand_big, m=2, r=0.2*np.std(rand_big)))      # 2.42
0.7853311366380039
0.41887013457621214
inf
2.181224235989778

del sys.path[0]

20. Why and How to smoothen a time series?

Smoothening of a time series may be useful in:

  • Reducing the effect of noise in a signal get a fair approximation of the noise-filtered series.
  • The smoothed version of series can be used as a feature to explain the original series itself.
  • Visualize the underlying trend better

So how to smoothen a series? Let’s discuss the following methods:

  1. Take a moving average
  2. Do a LOESS smoothing (Localized Regression)
  3. Do a LOWESS smoothing (Locally Weighted Regression)

Moving average is nothing but the average of a rolling window of defined width. But you must choose the window-width wisely, because, large window-size will over-smooth the series. For example, a window-size equal to the seasonal duration (ex: 12 for a month-wise series), will effectively nullify the seasonal effect.

LOESS, short for ‘LOcalized regrESSion’ fits multiple regressions in the local neighborhood of each point. It is implemented in the statsmodels package, where you can control the degree of smoothing using frac argument which specifies the percentage of data points nearby that should be considered to fit a regression model.

from statsmodels.nonparametric.smoothers_lowess import lowess
plt.rcParams.update({'xtick.bottom' : False, 'axes.titlepad':5})

# Import
df_orig = pd.read_csv('datasets/elecequip.csv', parse_dates=['date'], index_col='date')

# 1. Moving Average
df_ma = df_orig.value.rolling(3, center=True, closed='both').mean()

# 2. Loess Smoothing (5% and 15%)
df_loess_5 = pd.DataFrame(lowess(df_orig.value, np.arange(len(df_orig.value)), frac=0.05)[:, 1], index=df_orig.index, columns=['value'])
df_loess_15 = pd.DataFrame(lowess(df_orig.value, np.arange(len(df_orig.value)), frac=0.15)[:, 1], index=df_orig.index, columns=['value'])

# Plot
fig, axes = plt.subplots(4,1, figsize=(7, 7), sharex=True, dpi=120)
df_orig['value'].plot(ax=axes[0], color='k', title='Original Series')
df_loess_5['value'].plot(ax=axes[1], title='Loess Smoothed 5%')
df_loess_15['value'].plot(ax=axes[2], title='Loess Smoothed 15%')
df_ma.plot(ax=axes[3], title='Moving Average (3)')
fig.suptitle('How to Smoothen a Time Series', y=0.95, fontsize=14)
plt.show()
Smoothen Timeseries
Smoothen Timeseries

How to use Granger Causality test to know if one time series is helpful in forecasting another?

Granger causality test is used to determine if one time series will be useful to forecast another.

How does Granger causality test work?

It is based on the idea that if X causes Y, then the forecast of Y based on previous values of Y AND the previous values of X should outperform the forecast of Y based on previous values of Y alone.

So, understand that Granger causality should not be used to test if a lag of Y causes Y. Instead, it is generally used on exogenous (not Y lag) variables only.

It is nicely implemented in the statsmodel package.

It accepts a 2D array with 2 columns as the main argument. The values are in the first column and the predictor (X) is in the second column.

The Null hypothesis is: the series in the second column, does not Granger cause the series in the first. If the P-Values are less than a significance level (0.05) then you reject the null hypothesis and conclude that the said lag of X is indeed useful.

The second argument maxlag says till how many lags of Y should be included in the test.

from statsmodels.tsa.stattools import grangercausalitytests
df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/a10.csv', parse_dates=['date'])
df['month'] = df.date.dt.month
grangercausalitytests(df[['value', 'month']], maxlag=2)
Granger Causality
number of lags (no zero) 1
ssr based F test:         F=54.7797 , p=0.0000  , df_denom=200, df_num=1
ssr based chi2 test:   chi2=55.6014 , p=0.0000  , df=1
likelihood ratio test: chi2=49.1426 , p=0.0000  , df=1
parameter F test:         F=54.7797 , p=0.0000  , df_denom=200, df_num=1

Granger Causality
number of lags (no zero) 2
ssr based F test:         F=162.6989, p=0.0000  , df_denom=197, df_num=2
ssr based chi2 test:   chi2=333.6567, p=0.0000  , df=2
likelihood ratio test: chi2=196.9956, p=0.0000  , df=2
parameter F test:         F=162.6989, p=0.0000  , df_denom=197, df_num=2

In the above case, the P-Values are Zero for all tests. So the ‘month’ indeed can be used to forecast the Air Passengers.

22. What Next

That’s it for now. We started from the very basics and understood various characteristics of a time series. Once the analysis is done the next step is to begin forecasting.

In the next post, I will walk you through the in-depth process of building time series forecasting models using ARIMA. See you soon.

Matplotlib Tutorial – A Complete Guide to Python Plot w/ Examples

This tutorial explains matplotlib’s way of making plots in simplified parts so you gain the knowledge and a clear understanding of how to build and modify full featured matplotlib plots.

Time Series Analysis in Python – A Comprehensive Guide. Photo by Josiah Ingels.

1. Introduction

Matplotlib is the most popular plotting library in python. Using matplotlib, you can create pretty much any type of plot. However, as your plots get more complex, the learning curve can get steeper.

The goal of this tutorial is to make you understand ‘how plotting with matplotlib works’ and make you comfortable to build full-featured plots with matplotlib.

Contents

[columnize]
  1. Introduction
  2. A Basic Scatterplot
  3. Two scatterplots in same plot
  4. Two scatterplots in different panels
  5. Object Oriented Syntax vs Matlab like Syntax
  6. Axis Ticks Positions and Labels
  7. rcParams, Colors and Plot Styles
  8. Legend
  9. Texts, Arrows and Annotations
  10. Customize subplots layout
  11. How is scatterplot drawn with plt.plot() different from plt.scatter()
  12. Histograms, Boxplots and Time Series
  13. How to Plot with two Y-Axis
  14. Introduction to Seaborn
  15. Conclusion
[/columnize]

2. A Basic Scatterplot

The following piece of code is found in pretty much any python code that has matplotlib plots.

import matplotlib.pyplot as plt
%matplotlib inline

matplotlib.pyplot is usually imported as plt. It is the core object that contains the methods to create all sorts of charts and features in a plot.

The %matplotlib inline is a jupyter notebook specific command that let’s you see the plots in the notbook itself.

Suppose you want to draw a specific type of plot, say a scatterplot, the first thing you want to check out are the methods under plt (type plt and hit tab or type dir(plt) in python prompt).

Let’s begin by making a simple but full-featured scatterplot and take it from there. Let’s see what plt.plot() creates if you an arbitrary sequence of numbers.

import matplotlib.pyplot as plt
%matplotlib inline

# Plot
plt.plot([1,2,3,4,10])
#> [<matplotlib.lines.Line2D at 0x10edbab70>]

Matplotlib line plot

I just gave a list of numbers to plt.plot() and it drew a line chart automatically. It assumed the values of the X-axis to start from zero going up to as many items in the data.

Notice the line matplotlib.lines.Line2D in code output?

That’s because Matplotlib returns the plot object itself besides drawing the plot.

If you only want to see the plot, add plt.show() at the end and execute all the lines in one shot.

Alright, notice instead of the intended scatter plot, plt.plot drew a line plot. That’s because of the default behaviour.

So how to draw a scatterplot instead?

Well to do that, let’s understand a bit more about what arguments plt.plot() expects. The plt.plot accepts 3 basic arguments in the following order: (x, y, format).

This format is a short hand combination of {color}{marker}{line}.

For example, the format 'go-' has 3 characters standing for: ‘green colored dots with solid line’. By omitting the line part (‘-‘) in the end, you will be left with only green dots (‘go’), which makes it draw a scatterplot.

Few commonly used short hand format examples are:
* 'r*--' : ‘red stars with dashed lines’
* 'ks.' : ‘black squares with dotted line’ (‘k’ stands for black)
* 'bD-.' : ‘blue diamonds with dash-dot line’.

For a complete list of colors, markers and linestyles, check out the help(plt.plot) command.

Let’s draw a scaterplot with greendots.

# 'go' stands for green dots
plt.plot([1,2,3,4,5], [1,2,3,4,10], 'go')
plt.show()

Matplotlib scatterplot

3. How to draw two sets of scatterplots in same plot

Good. Now how to plot another set of 5 points of different color in the same figure?

Simply call plt.plot() again, it will add those point to the same picture.

You might wonder, why it does not draw these points in a new panel altogether? I will come to that in the next section.

# Draw two sets of points
plt.plot([1,2,3,4,5], [1,2,3,4,10], 'go')  # green dots
plt.plot([1,2,3,4,5], [2,3,4,5,11], 'b*')  # blue stars
plt.show()

Matplotlib double scatterplot

Looks good. Now let’s add the basic plot features: Title, Legend, X and Y axis labels. How to do that?

The plt object has corresponding methods to add each of this.

plt.plot([1,2,3,4,5], [1,2,3,4,10], 'go', label='GreenDots')
plt.plot([1,2,3,4,5], [2,3,4,5,11], 'b*', label='Bluestars')
plt.title('A Simple Scatterplot')
plt.xlabel('X')
plt.ylabel('Y')
plt.legend(loc='best')  # legend text comes from the plot's label parameter.
plt.show()

Good. Now, how to increase the size of the plot? (The above plot would actually look small on a jupyter notebook)

The easy way to do it is by setting the figsize inside plt.figure() method.

plt.figure(figsize=(10,7)) # 10 is width, 7 is height
plt.plot([1,2,3,4,5], [1,2,3,4,10], 'go', label='GreenDots')  # green dots
plt.plot([1,2,3,4,5], [2,3,4,5,11], 'b*', label='Bluestars')  # blue stars
plt.title('A Simple Scatterplot')  
plt.xlabel('X')
plt.ylabel('Y')
plt.xlim(0, 6)
plt.ylim(0, 12)
plt.legend(loc='best')
plt.show()

Matplotlib scatterplot large

Ok, we have some new lines of code there. What does plt.figure do?

Well, every plot that matplotlib makes is drawn on something called 'figure'. You can think of the figure object as a canvas that holds all the subplots and other plot elements inside it.

And a figure can have one or more subplots inside it called axes, arranged in rows and columns. Every figure has atleast one axes. (Don’t confuse this axes with X and Y axis, they are different.)

4. How to draw 2 scatterplots in different panels

Let’s understand figure and axes in little more detail.

Suppose, I want to draw our two sets of points (green rounds and blue stars) in two separate plots side-by-side instead of the same plot. How would you do that?

You can do that by creating two separate subplots, aka, axes using plt.subplots(1, 2). This creates and returns two objects:
* the figure
* the axes (subplots) inside the figure

Matplotlib Structure

Previously, I called plt.plot() to draw the points. Since there was only one axes by default, it drew the points on that axes itself.

But now, since you want the points drawn on different subplots (axes), you have to call the plot function in the respective axes (ax1 and ax2 in below code) instead of plt.

Notice in below code, I call ax1.plot() and ax2.plot() instead of calling plt.plot() twice.

# Create Figure and Subplots
fig, (ax1, ax2) = plt.subplots(1,2, figsize=(10,4), sharey=True, dpi=120)

# Plot
ax1.plot([1,2,3,4,5], [1,2,3,4,10], 'go')  # greendots
ax2.plot([1,2,3,4,5], [2,3,4,5,11], 'b*')  # bluestart

# Title, X and Y labels, X and Y Lim
ax1.set_title('Scatterplot Greendots'); ax2.set_title('Scatterplot Bluestars')
ax1.set_xlabel('X');  ax2.set_xlabel('X')  # x label
ax1.set_ylabel('Y');  ax2.set_ylabel('Y')  # y label
ax1.set_xlim(0, 6) ;  ax2.set_xlim(0, 6)   # x axis limits
ax1.set_ylim(0, 12);  ax2.set_ylim(0, 12)  # y axis limits

# ax2.yaxis.set_ticks_position('none') 
plt.tight_layout()
plt.show()

Matplotlib double scatterplot

Setting sharey=True in plt.subplots() shares the Y axis between the two subplots.

And dpi=120 increased the number of dots per inch of the plot to make it look more sharp and clear. You will notice a distinct improvement in clarity on increasing the dpi especially in jupyter notebooks.

Thats sounds like a lot of functions to learn. Well it’s quite easy to remember it actually.

The ax1 and ax2 objects, like plt, has equivalent set_title, set_xlabel and set_ylabel functions. Infact, the plt.title() actually calls the current axes set_title() to do the job.

  • plt.xlabel() → ax.set_xlabel()
  • plt.ylabel() → ax.set_ylabel()
  • plt.xlim() → ax.set_xlim()
  • plt.ylim() → ax.set_ylim()
  • plt.title() → ax.set_title()

Alternately, to save keystrokes, you can set multiple things in one go using the ax.set().

ax1.set(title='Scatterplot Greendots', xlabel='X', ylabel='Y', xlim=(0,6), ylim=(0,12))
ax2.set(title='Scatterplot Bluestars', xlabel='X', ylabel='Y', xlim=(0,6), ylim=(0,12))

5. Object Oriented Syntax vs Matlab like Syntax

A known ‘problem’ with learning matplotlib is, it has two coding interfaces:

  1. Matlab like syntax
  2. Object oriented syntax.

This is partly the reason why matplotlib doesn’t have one consistent way of achieving the same given output, making it a bit difficult to understand for new comers.

The syntax you’ve seen so far is the Object-oriented syntax, which I personally prefer and is more intuitive and pythonic to work with.

However, since the original purpose of matplotlib was to recreate the plotting facilities of matlab in python, the matlab-like-syntax is retained and still works.

The matlab syntax is ‘stateful’.

That means, the plt keeps track of what the current axes is. So whatever you draw with plt.{anything} will reflect only on the current subplot.

Practically speaking, the main difference between the two syntaxes is, in matlab-like syntax, all plotting is done using plt methods instead of the respective axes‘s method as in object oriented syntax.

So, how to recreate the above multi-subplots figure (or any other figure for that matter) using matlab-like syntax?

The general procedure is: You manually create one subplot at a time (using plt.subplot() or plt.add_subplot()) and immediately call plt.plot() or plt.{anything} to modify that specific subplot (axes). Whatever method you call using plt will be drawn in the current axes.

The code below shows this in practice.

plt.figure(figsize=(10,4), dpi=120) # 10 is width, 4 is height

# Left hand side plot
plt.subplot(1,2,1)  # (nRows, nColumns, axes number to plot)
plt.plot([1,2,3,4,5], [1,2,3,4,10], 'go')  # green dots
plt.title('Scatterplot Greendots')  
plt.xlabel('X'); plt.ylabel('Y')
plt.xlim(0, 6); plt.ylim(0, 12)

# Right hand side plot
plt.subplot(1,2,2)
plt.plot([1,2,3,4,5], [2,3,4,5,11], 'b*')  # blue stars
plt.title('Scatterplot Bluestars')  
plt.xlabel('X'); plt.ylabel('Y')
plt.xlim(0, 6); plt.ylim(0, 12)
plt.show()

Matplotlib double scatterplot

Let’s breakdown the above piece of code.

In plt.subplot(1,2,1), the first two values, that is (1,2) specifies the number of rows (1) and columns (2) and the third parameter (1) specifies the position of current subplot. The subsequent plt functions, will always draw on this current subplot.

You can get a reference to the current (subplot) axes with plt.gca() and the current figure with plt.gcf(). Likewise, plt.cla() and plt.clf() will clear the current axes and figure respectively.

Alright, compare the above code with the object oriented (OO) version. The OO version might look a but confusing because it has a mix of both ax1 and plt commands.

However, there is a significant advantage with axes approach.

That is, since plt.subplots returns all the axes as separate objects, you can avoid writing repetitive code by looping through the axes.

Always remember: plt.plot() or plt.{anything} will always act on the plot in the current axes, whereas, ax.{anything} will modify the plot inside that specific ax.

# Draw multiple plots using for-loops using object oriented syntax
import numpy as np
from numpy.random import seed, randint
seed(100)

# Create Figure and Subplots
fig, axes = plt.subplots(2,2, figsize=(10,6), sharex=True, sharey=True, dpi=120)

# Define the colors and markers to use
colors = {0:'g', 1:'b', 2:'r', 3:'y'}
markers = {0:'o', 1:'x', 2:'*', 3:'p'}

# Plot each axes
for i, ax in enumerate(axes.ravel()):
    ax.plot(sorted(randint(0,10,10)), sorted(randint(0,10,10)), marker=markers[i], color=colors[i])  
    ax.set_title('Ax: ' + str(i))
    ax.yaxis.set_ticks_position('none')

plt.suptitle('Four Subplots in One Figure', verticalalignment='bottom', fontsize=16)    
plt.tight_layout()
plt.show()

Matplotlib subplots

Did you notice in above plot, the Y-axis does not have ticks?

That’s because I used ax.yaxis.set_ticks_position('none') to turn off the Y-axis ticks. This is another advantage of the object-oriented interface. You can actually get a reference to any specific element of the plot and use its methods to manipulate it.

Can you guess how to turn off the X-axis ticks?

The plt.suptitle() added a main title at figure level title. plt.title() would have done the same for the current subplot (axes).

The verticalalignment='bottom' parameter denotes the hingepoint should be at the bottom of the title text, so that the main title is pushed slightly upwards.

Alright, What you’ve learned so far is the core essence of how to create a plot and manipulate it using matplotlib. Next, let’s see how to get the reference to and modify the other components of the plot

6. How to Modify the Axis Ticks Positions and Labels

There are 3 basic things you will probably ever need in matplotlib when it comes to manipulating axis ticks:
1. How to control the position and tick labels? (using plt.xticks() or ax.setxticks() and ax.setxticklabels())
2. How to control which axis’s ticks (top/bottom/left/right) should be displayed (using plt.tick_params())
3. Functional formatting of tick labels

If you are using ax syntax, you can use ax.set_xticks() and ax.set_xticklabels() to set the positions and label texts respectively. If you are using the plt syntax, you can set both the positions as well as the label text in one call using the plt.xticks().

Actually, if you look at the code of plt.xticks() method (by typing ??plt.xticks in jupyter notebook), it calls ax.set_xticks() and ax.set_xticklabels() to do the job. plt.xticks takes the ticks and labels as required parameters but you can also adjust the label’s fontsize, rotation, ‘horizontalalignment’ and ‘verticalalignment’ of the hinge points on the labels, like I’ve done in the below example.

from matplotlib.ticker import FuncFormatter

def rad_to_degrees(x, pos):
    'converts radians to degrees'
    return round(x * 57.2985, 2)

plt.figure(figsize=(12,7), dpi=100)
X = np.linspace(0,2*np.pi,1000)
plt.plot(X,np.sin(X))
plt.plot(X,np.cos(X))

# 1. Adjust x axis Ticks
plt.xticks(ticks=np.arange(0, 440/57.2985, 90/57.2985), fontsize=12, rotation=30, ha='center', va='top')  # 1 radian = 57.2985 degrees

# 2. Tick Parameters
plt.tick_params(axis='both',bottom=True, top=True, left=True, right=True, direction='in', which='major', grid_color='blue')

# 3. Format tick labels to convert radians to degrees
formatter = FuncFormatter(rad_to_degrees)
plt.gca().xaxis.set_major_formatter(formatter)

plt.grid(linestyle='--', linewidth=0.5, alpha=0.15)
plt.title('Sine and Cosine Waves\n(Notice the ticks are on all 4 sides pointing inwards, radians converted to degrees in x axis)', fontsize=14)
plt.show()

09 Modify Axis Ticks Positions Matplotlib

In above code, plt.tick_params() is used to determine which all axis of the plot (‘top’ / ‘bottom’ / ‘left’ / ‘right’) you want to draw the ticks and which direction (‘in’ / ‘out’) the tick should point to.

the matplotlib.ticker module provides the FuncFormatter to determine how the final tick label should be shown.

7. Understanding the rcParams, Colors and Plot Styles

The look and feel of various components of a matplotlib plot can be set globally using rcParams. The complete list of rcParams can be viewed by typing:

mpl.rc_params()
# RcParams({'_internal.classic_mode': False,
#           'agg.path.chunksize': 0,
#           'animation.avconv_args': [],
#           'animation.avconv_path': 'avconv',
#           'animation.bitrate': -1,
#           'animation.codec': 'h264',
#           ... TRUNCATED LaRge OuTPut ...

You can adjust the params you’d like to change by updating it. The below snippet adjusts the font by setting it to ‘stix’, which looks great on plots by the way.

mpl.rcParams.update({'font.size': 18, 'font.family': 'STIXGeneral', 'mathtext.fontset': 'stix'})

After modifying a plot, you can rollback the rcParams to default setting using:

mpl.rcParams.update(mpl.rcParamsDefault)  # reset to defaults

Matplotlib comes with pre-built styles which you can look by typing:

plt.style.available
# ['seaborn-dark', 'seaborn-darkgrid', 'seaborn-ticks', 'fivethirtyeight',
#  'seaborn-whitegrid', 'classic', '_classic_test', 'fast', 'seaborn-talk',
#  'seaborn-dark-palette', 'seaborn-bright', 'seaborn-pastel', 'grayscale',
#  'seaborn-notebook', 'ggplot', 'seaborn-colorblind', 'seaborn-muted',
#  'seaborn', 'Solarize_Light2', 'seaborn-paper', 'bmh', 'tableau-colorblind10',
#  'seaborn-white', 'dark_background', 'seaborn-poster', 'seaborn-deep']
import matplotlib as mpl
mpl.rcParams.update({'font.size': 18, 'font.family': 'STIXGeneral', 'mathtext.fontset': 'stix'})

def plot_sine_cosine_wave(style='ggplot'):
    plt.style.use(style)
    plt.figure(figsize=(7,4), dpi=80)
    X = np.linspace(0,2*np.pi,1000)
    plt.plot(X,np.sin(X)); plt.plot(X,np.cos(X))
    plt.xticks(ticks=np.arange(0, 440/57.2985, 90/57.2985), labels = [r'$0$',r'$\frac{\pi}{2}$',r'$\pi$',r'$\frac{3\pi}{2}$',r'$2\pi$'])  # 1 radian = 57.2985 degrees
    plt.gca().set(ylim=(-1.25, 1.25), xlim=(-.5, 7))
    plt.title(style, fontsize=18)
    plt.show()

plot_sine_cosine_wave('seaborn-notebook')    
plot_sine_cosine_wave('ggplot')    
plot_sine_cosine_wave('bmh')    

seaborn-notebook

ggplot

bmh

I’ve just shown few of the pre-built styles, the rest of the list is definitely worth a look.

Matplotlib also comes with pre-built colors and palettes. Type the following in your jupyter/python console to check out the available colors.

# View Colors
mpl.colors.CSS4_COLORS  # 148 colors
mpl.colors.XKCD_COLORS  # 949 colors
mpl.colors.BASE_COLORS  # 8 colors
#> {'b': (0, 0, 1),
#>  'g': (0, 0.5, 0),
#>  'r': (1, 0, 0),
#>  'c': (0, 0.75, 0.75),
#>  'm': (0.75, 0, 0.75),
#>  'y': (0.75, 0.75, 0),
#>  'k': (0, 0, 0),
#>  'w': (1, 1, 1)}
# View first 10 Palettes
dir(plt.cm)[:10]
#> ['Accent', 'Accent_r', 'Blues', 'Blues_r',
#>  'BrBG', 'BrBG_r', 'BuGn', 'BuGn_r', 'BuPu', 'BuPu_r']

Matplotlib Colors List

8. How to Customise the Legend

The most common way to make a legend is to define the label parameter for each of the plots and finally call plt.legend().

However, sometimes you might want to construct the legend on your own. In that case, you need to pass the plot items you want to draw the legend for and the legend text as parameters to plt.legend() in the following format:

plt.legend((line1, line2, line3), ('label1', 'label2', 'label3'))

# plt.style.use('seaborn-notebook')
plt.figure(figsize=(10,7), dpi=80)
X = np.linspace(0, 2*np.pi, 1000)
sine = plt.plot(X,np.sin(X)); cosine = plt.plot(X,np.cos(X))
sine_2 = plt.plot(X,np.sin(X+.5)); cosine_2 = plt.plot(X,np.cos(X+.5))
plt.gca().set(ylim=(-1.25, 1.5), xlim=(-.5, 7))
plt.title('Custom Legend Example', fontsize=18)

# Modify legend
plt.legend([sine[0], cosine[0], sine_2[0], cosine_2[0]],   # plot items
           ['sine curve', 'cosine curve', 'sine curve 2', 'cosine curve 2'],  
           frameon=True,                                   # legend border
           framealpha=1,                                   # transparency of border
           ncol=2,                                         # num columns
           shadow=True,                                    # shadow on
           borderpad=1,                                    # thickness of border
           title='Sines and Cosines')                      # title
plt.show()

Customize Legend in Matplotlib

9. How to Add Texts, Arrows and Annotations

plt.text and plt.annotate adds the texts and annotations respectively. If you have to plot multiple texts you need to call plt.text() as many times typically in a for-loop.

Let’s annotate the peaks and troughs adding arrowprops and a bbox for the text.

# Texts, Arrows and Annotations Example
# ref: https://matplotlib.org/users/annotations_guide.html
plt.figure(figsize=(14,7), dpi=120)
X = np.linspace(0, 8*np.pi, 1000)
sine = plt.plot(X,np.sin(X), color='tab:blue');

# 1. Annotate with Arrow Props and bbox
plt.annotate('Peaks', xy=(90/57.2985, 1.0), xytext=(90/57.2985, 1.5),
             bbox=dict(boxstyle='square', fc='green', linewidth=0.1),
             arrowprops=dict(facecolor='green', shrink=0.01, width=0.1), 
             fontsize=12, color='white', horizontalalignment='center')

# 2. Texts at Peaks and Troughs
for angle in [440, 810, 1170]:
    plt.text(angle/57.2985, 1.05, str(angle) + "\ndegrees", transform=plt.gca().transData, horizontalalignment='center', color='green')

for angle in [270, 630, 990, 1350]:
    plt.text(angle/57.2985, -1.3, str(angle) + "\ndegrees", transform=plt.gca().transData, horizontalalignment='center', color='red')    

plt.gca().set(ylim=(-2.0, 2.0), xlim=(-.5, 26))
plt.title('Annotations and Texts Example', fontsize=18)
plt.show()

Matplotlib Annotations

Notice, all the text we plotted above was in relation to the data.

That is, the x and y position in the plt.text() corresponds to the values along the x and y axes. However, sometimes you might work with data of different scales on different subplots and you want to write the texts in the same position on all the subplots.

In such case, instead of manually computing the x and y positions for each axes, you can specify the x and y values in relation to the axes (instead of x and y axis values).

You can do this by setting transform=ax.transData.

The lower left corner of the axes has (x,y) = (0,0) and the top right corner will correspond to (1,1).

The below plot shows the position of texts for the same values of (x,y) = (0.50, 0.02) with respect to the Data(transData), Axes(transAxes) and Figure(transFigure) respectively.

# Texts, Arrows and Annotations Example
plt.figure(figsize=(14,7), dpi=80)
X = np.linspace(0, 8*np.pi, 1000)

# Text Relative to DATA
plt.text(0.50, 0.02, "Text relative to the DATA centered at : (0.50, 0.02)", transform=plt.gca().transData, fontsize=14, ha='center', color='blue')

# Text Relative to AXES
plt.text(0.50, 0.02, "Text relative to the AXES centered at : (0.50, 0.02)", transform=plt.gca().transAxes, fontsize=14, ha='center', color='blue')

# Text Relative to FIGURE
plt.text(0.50, 0.02, "Text relative to the FIGURE centered at : (0.50, 0.02)", transform=plt.gcf().transFigure, fontsize=14, ha='center', color='blue')

plt.gca().set(ylim=(-2.0, 2.0), xlim=(0, 2))
plt.title('Placing Texts Relative to Data, Axes and Figure', fontsize=18)
plt.show()

Texts Relative to Data Axes Matplotlib

10. How to customize matplotlib’s subplots layout

Matplotlib provides two convenient ways to create customized multi-subplots layout.

  • plt.subplot2grid
  • plt.GridSpec

Both plt.subplot2grid and plt.GridSpec lets you draw complex layouts. Below is a nice plt.subplot2grid example.

# Supplot2grid approach
fig = plt.figure()
ax1 = plt.subplot2grid((3,3), (0,0), colspan=2, rowspan=2) # topleft
ax3 = plt.subplot2grid((3,3), (0,2), rowspan=3)            # right
ax4 = plt.subplot2grid((3,3), (2,0))                       # bottom left
ax5 = plt.subplot2grid((3,3), (2,1))                       # bottom right
fig.tight_layout()

Matplotlib Custom Layout with subplot2grid

Using plt.GridSpec, you can use either a plt.subplot() interface which takes part of the grid specified by plt.GridSpec(nrow, ncol) or use the ax = fig.add_subplot(g) where the GridSpec is defined by height_ratios and weight_ratios.

# GridSpec Approach 1
import matplotlib.gridspec as gridspec
fig = plt.figure()
grid = plt.GridSpec(2, 3)  # 2 rows 3 cols
plt.subplot(grid[0, :2])  # top left
plt.subplot(grid[0, 2])   # top right
plt.subplot(grid[1, :1])  # bottom left
plt.subplot(grid[1, 1:])  # bottom right
fig.tight_layout()

Matplotlib Custom Layout - Gridspec

# GridSpec Approach 2
import matplotlib.gridspec as gridspec
fig = plt.figure()
gs = gridspec.GridSpec(2, 2, height_ratios=[2,1], width_ratios=[1,2])
for g in gs:
    ax = fig.add_subplot(g)    
fig.tight_layout()

Matplotlib Custom Layout

The above examples showed layouts where the subplots dont overlap. It is possible to make subplots to overlap. Infact you can draw an axes inside a larger axes using fig.add_axes(). You need to specify the x,y positions relative to the figure and also the width and height of the inner plot.

Below is an example of an inner plot that zooms in to a larger plot.

# Plot inside a plot
plt.style.use('seaborn-whitegrid')
fig, ax = plt.subplots(figsize=(10,6))
x = np.linspace(-0.50, 1., 1000)

# Outer Plot
ax.plot(x, x**2)
ax.plot(x, np.sin(x))
ax.set(xlim=(-0.5, 1.0), ylim=(-0.5,1.2))
fig.tight_layout()

# Inner Plot
inner_ax = fig.add_axes([0.2, 0.55, 0.35, 0.35]) # x, y, width, height
inner_ax.plot(x, x**2)
inner_ax.plot(x, np.sin(x))
inner_ax.set(title='Zoom In', xlim=(-.2, .2), ylim=(-.01, .02), 
             yticks = [-0.01, 0, 0.01, 0.02], xticks=[-0.1,0,.1])
ax.set_title("Plot inside a Plot", fontsize=20)
plt.show()
mpl.rcParams.update(mpl.rcParamsDefault)  # reset to defaults

Plot inside a plot

11. How is scatterplot drawn with plt.plot different from plt.scatter

The difference is plt.plot() does not provide options to change the color and size of point dynamically (based on another array). But plt.scatter() allows you to do that.

By varying the size and color of points, you can create nice looking bubble plots.

Another convenience is you can directly use a pandas dataframe to set the x and y values, provided you specify the source dataframe in the data argument.

You can also set the color 'c' and size 's' of the points from one of the dataframe columns itself.

# Scatterplot with varying size and color of points
import pandas as pd
midwest = pd.read_csv("https://raw.githubusercontent.com/selva86/datasets/master/midwest_filter.csv")

# Plot
fig = plt.figure(figsize=(14, 7), dpi= 80, facecolor='w', edgecolor='k')    
plt.scatter('area', 'poptotal', data=midwest, s='dot_size', c='popdensity', cmap='Reds', edgecolors='black', linewidths=.5)
plt.title("Bubble Plot of PopTotal vs Area\n(color: 'popdensity' & size: 'dot_size' - both are numeric columns in midwest)", fontsize=16)
plt.xlabel('Area', fontsize=18)
plt.ylabel('Poptotal', fontsize=18)
plt.colorbar()
plt.show()     

Bubble plot in Matplotlib - colorbar

# Import data
import pandas as pd
midwest = pd.read_csv("https://raw.githubusercontent.com/selva86/datasets/master/midwest_filter.csv")

# Plot
fig = plt.figure(figsize=(14, 9), dpi= 80, facecolor='w', edgecolor='k')    
colors = plt.cm.tab20.colors
categories = np.unique(midwest['category'])
for i, category in enumerate(categories):
    plt.scatter('area', 'poptotal', data=midwest.loc[midwest.category==category, :], s='dot_size', c=colors[i], label=str(category), edgecolors='black', linewidths=.5)

# Legend for size of points
for dot_size in [100, 300, 1000]:
    plt.scatter([], [], c='k', alpha=0.5, s=dot_size, label=str(dot_size) + ' TotalPop')
plt.legend(loc='upper right', scatterpoints=1, frameon=False, labelspacing=2, title='Saravana Stores', fontsize=8)
plt.title("Bubble Plot of PopTotal vs Area\n(color: 'category' - a categorical column in midwest)", fontsize=18)
plt.xlabel('Area', fontsize=16)
plt.ylabel('Poptotal', fontsize=16)
plt.show()     

Bubbleplot in Matplotlib

# Save the figure
plt.savefig("bubbleplot.png", transparent=True, dpi=120)

12. How to draw Histograms, Boxplots and Time Series

The methods to draw different types of plots are present in pyplot (plt) as well as Axes. The below example shows basic examples of few of the commonly used plot types.

import pandas as pd

# Setup the subplot2grid Layout
fig = plt.figure(figsize=(10, 5))
ax1 = plt.subplot2grid((2,4), (0,0)) 
ax2 = plt.subplot2grid((2,4), (0,1)) 
ax3 = plt.subplot2grid((2,4), (0,2)) 
ax4 = plt.subplot2grid((2,4), (0,3)) 
ax5 = plt.subplot2grid((2,4), (1,0), colspan=2) 
ax6 = plt.subplot2grid((2,4), (1,2)) 
ax7 = plt.subplot2grid((2,4), (1,3)) 

# Input Arrays
n = np.array([0,1,2,3,4,5])
x = np.linspace(0,5,10)
xx = np.linspace(-0.75, 1., 100)

# Scatterplot
ax1.scatter(xx, xx + np.random.randn(len(xx)))
ax1.set_title("Scatter Plot")

# Step Chart
ax2.step(n, n**2, lw=2)
ax2.set_title("Step Plot")

# Bar Chart
ax3.bar(n, n**2, align="center", width=0.5, alpha=0.5)
ax3.set_title("Bar Chart")

# Fill Between
ax4.fill_between(x, x**2, x**3, color="steelblue", alpha=0.5);
ax4.set_title("Fill Between");

# Time Series
dates = pd.date_range('2018-01-01', periods = len(xx))
ax5.plot(dates, xx + np.random.randn(len(xx)))
ax5.set_xticks(dates[::30])
ax5.set_xticklabels(dates.strftime('%Y-%m-%d')[::30])
ax5.set_title("Time Series")

# Box Plot
ax6.boxplot(np.random.randn(len(xx)))
ax6.set_title("Box Plot")

# Histogram
ax7.hist(xx + np.random.randn(len(xx)))
ax7.set_title("Histogram")

fig.tight_layout()

Histogram - Boxplot - Timeseries - Matplotlib

What about more advanced plots?

If you want to see more data analysis oriented examples of a particular plot type, say histogram or time series, the top 50 master plots for data analysis will give you concrete examples of presentation ready plots. This is a very useful tool to have, not only to construct nice looking plots but to draw ideas to what type of plot you want to make for your data.

13. How to Plot with two Y-Axis

Plotting a line chart on the left-hand side axis is straightforward, which you’ve already seen.

So how to draw the second line on the right-hand side y-axis?

The trick is to activate the right hand side Y axis using ax.twinx() to create a second axes.

This second axes will have the Y-axis on the right activated and shares the same x-axis as the original ax. Then, whatever you draw using this second axes will be referenced to the secondary y-axis. The remaining job is to just color the axis and tick labels to match the color of the lines.

# Import Data
df = pd.read_csv("https://github.com/selva86/datasets/raw/master/economics.csv")
x = df['date']; y1 = df['psavert']; y2 = df['unemploy']

# Plot Line1 (Left Y Axis)
fig, ax1 = plt.subplots(1,1,figsize=(16,7), dpi= 80)
ax1.plot(x, y1, color='tab:red')

# Plot Line2 (Right Y Axis)
ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis
ax2.plot(x, y2, color='tab:blue')

# Just Decorations!! -------------------
# ax1 (left y axis)
ax1.set_xlabel('Year', fontsize=20)
ax1.set_ylabel('Personal Savings Rate', color='tab:red', fontsize=20)
ax1.tick_params(axis='y', rotation=0, labelcolor='tab:red' )

# ax2 (right Y axis)
ax2.set_ylabel("# Unemployed (1000's)", color='tab:blue', fontsize=20)
ax2.tick_params(axis='y', labelcolor='tab:blue')
ax2.set_title("Personal Savings Rate vs Unemployed: Plotting in Secondary Y Axis", fontsize=20)
ax2.set_xticks(np.arange(0, len(x), 60))
ax2.set_xticklabels(x[::60], rotation=90, fontdict={'fontsize':10})
plt.show()

Time Series in Secondary Axis

14. Introduction to Seaborn

As the charts get more complex, the more the code you’ve got to write. For example, in matplotlib, there is no direct method to draw a density plot of a scatterplot with line of best fit. You get the idea.

So, what you can do instead is to use a higher level package like seaborn, and use one of its prebuilt functions to draw the plot.

We are not going in-depth into seaborn. But let’s see how to get started and where to find what you want. A lot of seaborn’s plots are suitable for data analysis and the library works seamlessly with pandas dataframes.

seaborn is typically imported as sns. Like matplotlib it comes with its own set of pre-built styles and palettes.

import seaborn as sns
sns.set_style("white")

# Import Data
df = pd.read_csv("https://github.com/selva86/datasets/raw/master/mpg_ggplot2.csv")

# Draw Plot
plt.figure(figsize=(16,10), dpi= 80)
sns.kdeplot(df.loc[df['cyl'] == 4, "cty"], shade=True, color="g", label="Cyl=4", alpha=.7)
sns.kdeplot(df.loc[df['cyl'] == 6, "cty"], shade=True, color="dodgerblue", label="Cyl=6", alpha=.7)
sns.kdeplot(df.loc[df['cyl'] == 8, "cty"], shade=True, color="orange", label="Cyl=8", alpha=.7)

# Decoration
plt.title('Density Plot of City Mileage by n_Cylinders', fontsize=22)
plt.legend()
plt.show()

Density Plots in Matplotlib

# Load Dataset
df = sns.load_dataset('iris')

# Plot
plt.figure(figsize=(10,8), dpi= 80)
sns.pairplot(df, kind="reg", hue="species")
plt.show()
<Figure size 800x640 with 0 Axes>

Pairplot - Seaborn

This is just to give a hint of what’s possible with seaborn. Maybe I will write a separate post on it. However, the official seaborn page has good examples for you to start with.

15. Conclusion

Congratulations if you reached this far. Because we literally started from scratch and covered the essential topics to making matplotlib plots.

We covered the syntax and overall structure of creating matplotlib plots, saw how to modify various components of a plot, customized subplots layout, plots styling, colors, palettes, draw different plot types etc.

If you want to get more practice, try taking up couple of plots listed in the top 50 plots starting with correlation plots and try recreating it.

Until next time!

Topic modeling visualization – How to present the results of LDA models?

In this post, we discuss techniques to visualize the output and results from topic model (LDA) based on the gensim package.

Topic modeling visualization – How to present the results of LDA models?

Contents

[columnize]
  1. Introduction
  2. Import NewsGroups Dataset
  3. Tokenize Sentences and Clean
  4. Build the Bigram, Trigram Models and Lemmatize
  5. Build the Topic Model
  6. Presenting the Results

  7. What is the Dominant topic and its percentage contribution in each document?
  8. The most representative sentences for each topic
  9. Frequency Distribution of Word Counts in Documents
  10. Word Clouds of Top N Keywords in Each Topic
  11. Word Counts of Topic Keywords
  12. Sentence Chart Colored by Topic
  13. What are the most discussed topics in the documents?
  14. t-SNE Clustering Chart
  15. pyLDAVis
  16. Conclusion
[/columnize]

1. Introduction

In topic modeling with gensim, we followed a structured workflow to build an insightful topic model based on the Latent Dirichlet Allocation (LDA) algorithm.

In this post, we will build the topic model using gensim’s native LdaModel and explore multiple strategies to effectively visualize the results using matplotlib plots.

I will be using a portion of the 20 Newsgroups dataset since the focus is more on approaches to visualizing the results.

Let’s begin by importing the packages and the 20 News Groups dataset.

import sys
# !{sys.executable} -m spacy download en
import re, numpy as np, pandas as pd
from pprint import pprint

# Gensim
import gensim, spacy, logging, warnings
import gensim.corpora as corpora
from gensim.utils import lemmatize, simple_preprocess
from gensim.models import CoherenceModel
import matplotlib.pyplot as plt

# NLTK Stop words
from nltk.corpus import stopwords
stop_words = stopwords.words('english')
stop_words.extend(['from', 'subject', 're', 'edu', 'use', 'not', 'would', 'say', 'could', '_', 'be', 'know', 'good', 'go', 'get', 'do', 'done', 'try', 'many', 'some', 'nice', 'thank', 'think', 'see', 'rather', 'easy', 'easily', 'lot', 'lack', 'make', 'want', 'seem', 'run', 'need', 'even', 'right', 'line', 'even', 'also', 'may', 'take', 'come'])

%matplotlib inline
warnings.filterwarnings("ignore",category=DeprecationWarning)
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.ERROR)

2. Import NewsGroups Dataset

Let’s import the news groups dataset and retain only 4 of the target_names categories.

# Import Dataset
df = pd.read_json('https://raw.githubusercontent.com/selva86/datasets/master/newsgroups.json')
df = df.loc[df.target_names.isin(['soc.religion.christian', 'rec.sport.hockey', 'talk.politics.mideast', 'rec.motorcycles']) , :]
print(df.shape)  #> (2361, 3)
df.head()

3. Tokenize Sentences and Clean

Removing the emails, new line characters, single quotes and finally split the sentence into a list of words using gensim’s simple_preprocess(). Setting the deacc=True option removes punctuations.

def sent_to_words(sentences):
    for sent in sentences:
        sent = re.sub('\S*@\S*\s?', '', sent)  # remove emails
        sent = re.sub('\s+', ' ', sent)  # remove newline chars
        sent = re.sub("\'", "", sent)  # remove single quotes
        sent = gensim.utils.simple_preprocess(str(sent), deacc=True) 
        yield(sent)  

# Convert to list
data = df.content.values.tolist()
data_words = list(sent_to_words(data))
print(data_words[:1])
# [['from', 'irwin', 'arnstein', 'subject', 're', 'recommendation', 'on', 'duc', 'summary', 'whats', 'it', 'worth', 'distribution', 'usa', 'expires', 'sat', 'may', 'gmt', ...trucated...]]

4. Build the Bigram, Trigram Models and Lemmatize

Let’s form the bigram and trigrams using the Phrases model. This is passed to Phraser() for efficiency in speed of execution.

Next, lemmatize each word to its root form, keeping only nouns, adjectives, verbs and adverbs.

We keep only these POS tags because they are the ones contributing the most to the meaning of the sentences. Here, I use spacy for lemmatization.

# Build the bigram and trigram models
bigram = gensim.models.Phrases(data_words, min_count=5, threshold=100) # higher threshold fewer phrases.
trigram = gensim.models.Phrases(bigram[data_words], threshold=100)  
bigram_mod = gensim.models.phrases.Phraser(bigram)
trigram_mod = gensim.models.phrases.Phraser(trigram)

# !python3 -m spacy download en  # run in terminal once
def process_words(texts, stop_words=stop_words, allowed_postags=['NOUN', 'ADJ', 'VERB', 'ADV']):
    """Remove Stopwords, Form Bigrams, Trigrams and Lemmatization"""
    texts = [[word for word in simple_preprocess(str(doc)) if word not in stop_words] for doc in texts]
    texts = [bigram_mod[doc] for doc in texts]
    texts = [trigram_mod[bigram_mod[doc]] for doc in texts]
    texts_out = []
    nlp = spacy.load('en', disable=['parser', 'ner'])
    for sent in texts:
        doc = nlp(" ".join(sent)) 
        texts_out.append([token.lemma_ for token in doc if token.pos_ in allowed_postags])
    # remove stopwords once more after lemmatization
    texts_out = [[word for word in simple_preprocess(str(doc)) if word not in stop_words] for doc in texts_out]    
    return texts_out

data_ready = process_words(data_words)  # processed Text Data!

5. Build the Topic Model

To build the LDA topic model using LdaModel(), you need the corpus and the dictionary. Let’s create them first and then build the model. The trained topics (keywords and weights) are printed below as well.

If you examine the topic key words, they are nicely segregate and collectively represent the topics we initially chose: Christianity, Hockey, MidEast and Motorcycles. Nice!

# Create Dictionary
id2word = corpora.Dictionary(data_ready)

# Create Corpus: Term Document Frequency
corpus = [id2word.doc2bow(text) for text in data_ready]

# Build LDA model
lda_model = gensim.models.ldamodel.LdaModel(corpus=corpus,
                                           id2word=id2word,
                                           num_topics=4, 
                                           random_state=100,
                                           update_every=1,
                                           chunksize=10,
                                           passes=10,
                                           alpha='symmetric',
                                           iterations=100,
                                           per_word_topics=True)

pprint(lda_model.print_topics())
#> [(0,
#>   '0.017*"write" + 0.015*"people" + 0.014*"organization" + 0.014*"article" + '
#>   '0.013*"time" + 0.008*"give" + 0.008*"first" + 0.007*"tell" + 0.007*"new" + '
#>   '0.007*"question"'),
#>  (1,
#>   '0.008*"christian" + 0.008*"believe" + 0.007*"god" + 0.007*"law" + '
#>   '0.006*"state" + 0.006*"israel" + 0.006*"israeli" + 0.005*"exist" + '
#>   '0.005*"way" + 0.004*"bible"'),
#>  (2,
#>   '0.024*"armenian" + 0.012*"bike" + 0.006*"kill" + 0.006*"work" + '
#>   '0.005*"well" + 0.005*"year" + 0.005*"sumgait" + 0.005*"soldier" + '
#>   '0.004*"way" + 0.004*"ride"'),
#>  (3,
#>   '0.019*"team" + 0.019*"game" + 0.013*"hockey" + 0.010*"player" + '
#>   '0.009*"play" + 0.009*"win" + 0.009*"nhl" + 0.009*"year" + 0.009*"hawk" + '
#>   '0.009*"season"')]

6. What is the Dominant topic and its percentage contribution in each document

In LDA models, each document is composed of multiple topics. But, typically only one of the topics is dominant. The below code extracts this dominant topic for each sentence and shows the weight of the topic and the keywords in a nicely formatted output.

This way, you will know which document belongs predominantly to which topic.

def format_topics_sentences(ldamodel=None, corpus=corpus, texts=data):
    # Init output
    sent_topics_df = pd.DataFrame()

    # Get main topic in each document
    for i, row_list in enumerate(ldamodel[corpus]):
        row = row_list[0] if ldamodel.per_word_topics else row_list            
        # print(row)
        row = sorted(row, key=lambda x: (x[1]), reverse=True)
        # Get the Dominant topic, Perc Contribution and Keywords for each document
        for j, (topic_num, prop_topic) in enumerate(row):
            if j == 0:  # => dominant topic
                wp = ldamodel.show_topic(topic_num)
                topic_keywords = ", ".join([word for word, prop in wp])
                sent_topics_df = sent_topics_df.append(pd.Series([int(topic_num), round(prop_topic,4), topic_keywords]), ignore_index=True)
            else:
                break
    sent_topics_df.columns = ['Dominant_Topic', 'Perc_Contribution', 'Topic_Keywords']

    # Add original text to the end of the output
    contents = pd.Series(texts)
    sent_topics_df = pd.concat([sent_topics_df, contents], axis=1)
    return(sent_topics_df)


df_topic_sents_keywords = format_topics_sentences(ldamodel=lda_model, corpus=corpus, texts=data_ready)

# Format
df_dominant_topic = df_topic_sents_keywords.reset_index()
df_dominant_topic.columns = ['Document_No', 'Dominant_Topic', 'Topic_Perc_Contrib', 'Keywords', 'Text']
df_dominant_topic.head(10)

7. The most representative sentence for each topic

Sometimes you want to get samples of sentences that most represent a given topic. This code gets the most exemplar sentence for each topic.

# Display setting to show more characters in column
pd.options.display.max_colwidth = 100

sent_topics_sorteddf_mallet = pd.DataFrame()
sent_topics_outdf_grpd = df_topic_sents_keywords.groupby('Dominant_Topic')

for i, grp in sent_topics_outdf_grpd:
    sent_topics_sorteddf_mallet = pd.concat([sent_topics_sorteddf_mallet, 
                                             grp.sort_values(['Perc_Contribution'], ascending=False).head(1)], 
                                            axis=0)

# Reset Index    
sent_topics_sorteddf_mallet.reset_index(drop=True, inplace=True)

# Format
sent_topics_sorteddf_mallet.columns = ['Topic_Num', "Topic_Perc_Contrib", "Keywords", "Representative Text"]

# Show
sent_topics_sorteddf_mallet.head(10)

8. Frequency Distribution of Word Counts in Documents

When working with a large number of documents, you want to know how big the documents are as a whole and by topic. Let’s plot the document word counts distribution.

doc_lens = [len(d) for d in df_dominant_topic.Text]

# Plot
plt.figure(figsize=(16,7), dpi=160)
plt.hist(doc_lens, bins = 1000, color='navy')
plt.text(750, 100, "Mean   : " + str(round(np.mean(doc_lens))))
plt.text(750,  90, "Median : " + str(round(np.median(doc_lens))))
plt.text(750,  80, "Stdev   : " + str(round(np.std(doc_lens))))
plt.text(750,  70, "1%ile    : " + str(round(np.quantile(doc_lens, q=0.01))))
plt.text(750,  60, "99%ile  : " + str(round(np.quantile(doc_lens, q=0.99))))

plt.gca().set(xlim=(0, 1000), ylabel='Number of Documents', xlabel='Document Word Count')
plt.tick_params(size=16)
plt.xticks(np.linspace(0,1000,9))
plt.title('Distribution of Document Word Counts', fontdict=dict(size=22))
plt.show()

import seaborn as sns
import matplotlib.colors as mcolors
cols = [color for name, color in mcolors.TABLEAU_COLORS.items()]  # more colors: 'mcolors.XKCD_COLORS'

fig, axes = plt.subplots(2,2,figsize=(16,14), dpi=160, sharex=True, sharey=True)

for i, ax in enumerate(axes.flatten()):    
    df_dominant_topic_sub = df_dominant_topic.loc[df_dominant_topic.Dominant_Topic == i, :]
    doc_lens = [len(d) for d in df_dominant_topic_sub.Text]
    ax.hist(doc_lens, bins = 1000, color=cols[i])
    ax.tick_params(axis='y', labelcolor=cols[i], color=cols[i])
    sns.kdeplot(doc_lens, color="black", shade=False, ax=ax.twinx())
    ax.set(xlim=(0, 1000), xlabel='Document Word Count')
    ax.set_ylabel('Number of Documents', color=cols[i])
    ax.set_title('Topic: '+str(i), fontdict=dict(size=16, color=cols[i]))

fig.tight_layout()
fig.subplots_adjust(top=0.90)
plt.xticks(np.linspace(0,1000,9))
fig.suptitle('Distribution of Document Word Counts by Dominant Topic', fontsize=22)
plt.show()

9. Word Clouds of Top N Keywords in Each Topic

Though you’ve already seen what are the topic keywords in each topic, a word cloud with the size of the words proportional to the weight is a pleasant sight. The coloring of the topics I’ve taken here is followed in the subsequent plots as well.

# 1. Wordcloud of Top N words in each topic
from matplotlib import pyplot as plt
from wordcloud import WordCloud, STOPWORDS
import matplotlib.colors as mcolors

cols = [color for name, color in mcolors.TABLEAU_COLORS.items()]  # more colors: 'mcolors.XKCD_COLORS'

cloud = WordCloud(stopwords=stop_words,
                  background_color='white',
                  width=2500,
                  height=1800,
                  max_words=10,
                  colormap='tab10',
                  color_func=lambda *args, **kwargs: cols[i],
                  prefer_horizontal=1.0)

topics = lda_model.show_topics(formatted=False)

fig, axes = plt.subplots(2, 2, figsize=(10,10), sharex=True, sharey=True)

for i, ax in enumerate(axes.flatten()):
    fig.add_subplot(ax)
    topic_words = dict(topics[i][1])
    cloud.generate_from_frequencies(topic_words, max_font_size=300)
    plt.gca().imshow(cloud)
    plt.gca().set_title('Topic ' + str(i), fontdict=dict(size=16))
    plt.gca().axis('off')


plt.subplots_adjust(wspace=0, hspace=0)
plt.axis('off')
plt.margins(x=0, y=0)
plt.tight_layout()
plt.show()

10. Word Counts of Topic Keywords

When it comes to the keywords in the topics, the importance (weights) of the keywords matters. Along with that, how frequently the words have appeared in the documents is also interesting to look.

Let’s plot the word counts and the weights of each keyword in the same chart.

You want to keep an eye out on the words that occur in multiple topics and the ones whose relative frequency is more than the weight. Often such words turn out to be less important. The chart I’ve drawn below is a result of adding several such words to the stop words list in the beginning and re-running the training process.

from collections import Counter
topics = lda_model.show_topics(formatted=False)
data_flat = [w for w_list in data_ready for w in w_list]
counter = Counter(data_flat)

out = []
for i, topic in topics:
    for word, weight in topic:
        out.append([word, i , weight, counter[word]])

df = pd.DataFrame(out, columns=['word', 'topic_id', 'importance', 'word_count'])        

# Plot Word Count and Weights of Topic Keywords
fig, axes = plt.subplots(2, 2, figsize=(16,10), sharey=True, dpi=160)
cols = [color for name, color in mcolors.TABLEAU_COLORS.items()]
for i, ax in enumerate(axes.flatten()):
    ax.bar(x='word', height="word_count", data=df.loc[df.topic_id==i, :], color=cols[i], width=0.5, alpha=0.3, label='Word Count')
    ax_twin = ax.twinx()
    ax_twin.bar(x='word', height="importance", data=df.loc[df.topic_id==i, :], color=cols[i], width=0.2, label='Weights')
    ax.set_ylabel('Word Count', color=cols[i])
    ax_twin.set_ylim(0, 0.030); ax.set_ylim(0, 3500)
    ax.set_title('Topic: ' + str(i), color=cols[i], fontsize=16)
    ax.tick_params(axis='y', left=False)
    ax.set_xticklabels(df.loc[df.topic_id==i, 'word'], rotation=30, horizontalalignment= 'right')
    ax.legend(loc='upper left'); ax_twin.legend(loc='upper right')

fig.tight_layout(w_pad=2)    
fig.suptitle('Word Count and Importance of Topic Keywords', fontsize=22, y=1.05)    
plt.show()

11. Sentence Chart Colored by Topic

Each word in the document is representative of one of the 4 topics. Let’s color each word in the given documents by the topic id it is attributed to.
The color of the enclosing rectangle is the topic assigned to the document.

# Sentence Coloring of N Sentences
from matplotlib.patches import Rectangle

def sentences_chart(lda_model=lda_model, corpus=corpus, start = 0, end = 13):
    corp = corpus[start:end]
    mycolors = [color for name, color in mcolors.TABLEAU_COLORS.items()]

    fig, axes = plt.subplots(end-start, 1, figsize=(20, (end-start)*0.95), dpi=160)       
    axes[0].axis('off')
    for i, ax in enumerate(axes):
        if i > 0:
            corp_cur = corp[i-1] 
            topic_percs, wordid_topics, wordid_phivalues = lda_model[corp_cur]
            word_dominanttopic = [(lda_model.id2word[wd], topic[0]) for wd, topic in wordid_topics]    
            ax.text(0.01, 0.5, "Doc " + str(i-1) + ": ", verticalalignment='center',
                    fontsize=16, color='black', transform=ax.transAxes, fontweight=700)

            # Draw Rectange
            topic_percs_sorted = sorted(topic_percs, key=lambda x: (x[1]), reverse=True)
            ax.add_patch(Rectangle((0.0, 0.05), 0.99, 0.90, fill=None, alpha=1, 
                                   color=mycolors[topic_percs_sorted[0][0]], linewidth=2))

            word_pos = 0.06
            for j, (word, topics) in enumerate(word_dominanttopic):
                if j < 14:
                    ax.text(word_pos, 0.5, word,
                            horizontalalignment='left',
                            verticalalignment='center',
                            fontsize=16, color=mycolors[topics],
                            transform=ax.transAxes, fontweight=700)
                    word_pos += .009 * len(word)  # to move the word for the next iter
                    ax.axis('off')
            ax.text(word_pos, 0.5, '. . .',
                    horizontalalignment='left',
                    verticalalignment='center',
                    fontsize=16, color='black',
                    transform=ax.transAxes)       

    plt.subplots_adjust(wspace=0, hspace=0)
    plt.suptitle('Sentence Topic Coloring for Documents: ' + str(start) + ' to ' + str(end-2), fontsize=22, y=0.95, fontweight=700)
    plt.tight_layout()
    plt.show()

sentences_chart()    

12. What are the most discussed topics in the documents?

Let’s compute the total number of documents attributed to each topic.

# Sentence Coloring of N Sentences
def topics_per_document(model, corpus, start=0, end=1):
    corpus_sel = corpus[start:end]
    dominant_topics = []
    topic_percentages = []
    for i, corp in enumerate(corpus_sel):
        topic_percs, wordid_topics, wordid_phivalues = model[corp]
        dominant_topic = sorted(topic_percs, key = lambda x: x[1], reverse=True)[0][0]
        dominant_topics.append((i, dominant_topic))
        topic_percentages.append(topic_percs)
    return(dominant_topics, topic_percentages)

dominant_topics, topic_percentages = topics_per_document(model=lda_model, corpus=corpus, end=-1)            

# Distribution of Dominant Topics in Each Document
df = pd.DataFrame(dominant_topics, columns=['Document_Id', 'Dominant_Topic'])
dominant_topic_in_each_doc = df.groupby('Dominant_Topic').size()
df_dominant_topic_in_each_doc = dominant_topic_in_each_doc.to_frame(name='count').reset_index()

# Total Topic Distribution by actual weight
topic_weightage_by_doc = pd.DataFrame([dict(t) for t in topic_percentages])
df_topic_weightage_by_doc = topic_weightage_by_doc.sum().to_frame(name='count').reset_index()

# Top 3 Keywords for each Topic
topic_top3words = [(i, topic) for i, topics in lda_model.show_topics(formatted=False) 
                                 for j, (topic, wt) in enumerate(topics) if j < 3]

df_top3words_stacked = pd.DataFrame(topic_top3words, columns=['topic_id', 'words'])
df_top3words = df_top3words_stacked.groupby('topic_id').agg(', \n'.join)
df_top3words.reset_index(level=0,inplace=True)

Let’s make two plots:

  1. The number of documents for each topic by assigning the document to the topic that has the most weight in that document.
  2. The number of documents for each topic by by summing up the actual weight contribution of each topic to respective documents.
from matplotlib.ticker import FuncFormatter

# Plot
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4), dpi=120, sharey=True)

# Topic Distribution by Dominant Topics
ax1.bar(x='Dominant_Topic', height='count', data=df_dominant_topic_in_each_doc, width=.5, color='firebrick')
ax1.set_xticks(range(df_dominant_topic_in_each_doc.Dominant_Topic.unique().__len__()))
tick_formatter = FuncFormatter(lambda x, pos: 'Topic ' + str(x)+ '\n' + df_top3words.loc[df_top3words.topic_id==x, 'words'].values[0])
ax1.xaxis.set_major_formatter(tick_formatter)
ax1.set_title('Number of Documents by Dominant Topic', fontdict=dict(size=10))
ax1.set_ylabel('Number of Documents')
ax1.set_ylim(0, 1000)

# Topic Distribution by Topic Weights
ax2.bar(x='index', height='count', data=df_topic_weightage_by_doc, width=.5, color='steelblue')
ax2.set_xticks(range(df_topic_weightage_by_doc.index.unique().__len__()))
ax2.xaxis.set_major_formatter(tick_formatter)
ax2.set_title('Number of Documents by Topic Weightage', fontdict=dict(size=10))

plt.show()

13. t-SNE Clustering Chart

Let’s visualize the clusters of documents in a 2D space using t-SNE (t-distributed stochastic neighbor embedding) algorithm.

# Get topic weights and dominant topics ------------
from sklearn.manifold import TSNE
from bokeh.plotting import figure, output_file, show
from bokeh.models import Label
from bokeh.io import output_notebook

# Get topic weights
topic_weights = []
for i, row_list in enumerate(lda_model[corpus]):
    topic_weights.append([w for i, w in row_list[0]])

# Array of topic weights    
arr = pd.DataFrame(topic_weights).fillna(0).values

# Keep the well separated points (optional)
arr = arr[np.amax(arr, axis=1) > 0.35]

# Dominant topic number in each doc
topic_num = np.argmax(arr, axis=1)

# tSNE Dimension Reduction
tsne_model = TSNE(n_components=2, verbose=1, random_state=0, angle=.99, init='pca')
tsne_lda = tsne_model.fit_transform(arr)

# Plot the Topic Clusters using Bokeh
output_notebook()
n_topics = 4
mycolors = np.array([color for name, color in mcolors.TABLEAU_COLORS.items()])
plot = figure(title="t-SNE Clustering of {} LDA Topics".format(n_topics), 
              plot_width=900, plot_height=700)
plot.scatter(x=tsne_lda[:,0], y=tsne_lda[:,1], color=mycolors[topic_num])
show(plot)

14. pyLDAVis

Finally, pyLDAVis is the most commonly used and a nice way to visualise the information contained in a topic model. Below is the implementation for LdaModel().

import pyLDAvis.gensim
pyLDAvis.enable_notebook()
vis = pyLDAvis.gensim.prepare(lda_model, corpus, dictionary=lda_model.id2word)
vis

15. Conclusion

We started from scratch by importing, cleaning and processing the newsgroups dataset to build the LDA model. Then we saw multiple ways to visualize the outputs of topic models including the word clouds and sentence coloring, which intuitively tells you what topic is dominant in each topic. A t-SNE clustering and the pyLDAVis are provide more details into the clustering of the topics.

Where next? If you are familiar with scikit learn, you can build and grid search topic models using scikit learn as well.

Top 50 matplotlib Visualizations – The Master Plots (with full python code)

A compilation of the Top 50 matplotlib plots most useful in data analysis and visualization. This list lets you choose what visualization to show for what situation using python’s matplotlib and seaborn library.

Introduction

The charts are grouped based on the 7 different purposes of your visualization objective. For example, if you want to picturize the relationship between 2 variables, check out the plots under the ‘Correlation’ section. Or if you want to show how a value changed over time, look under the ‘Change’ section and so on.

An effective chart is one which:

  • Conveys the right and necessary information without distorting facts.
  • Simple in design, you dont have to strain in order to get it.
  • Aesthetics support the information rather than overshadow it.
  • Not overloaded with information.
  • Contents

    (right click and open in new page if the links dont work)
    [columnize]

    Correlation

    1. Scatter plot
    2. Bubble plot with Encircling
    3. Scatter plot with line of best fit
    4. Jittering with stripplot
    5. Counts Plot
    6. Marginal Histogram
    7. Marginal Boxplot
    8. Correlogram
    9. Pairwise Plot

    Deviation

    1. Diverging Bars
    2. Diverging Texts
    3. Diverging Dot Plot
    4. Diverging Lollipop Chart with Markers
    5. Area Chart

    Ranking

    1. Ordered Bar Chart
    2. Lollipop Chart
    3. Dot Plot
    4. Slope Chart
    5. Dumbbell Plot

    Distribution

    1. Histogram for Continuous Variable
    2. Histogram for Categorical Variable
    3. Density Plot
    4. Density Curves with Histogram
    5. Joy Plot
    6. Distributed Dot Plot
    7. Box Plot
    8. Dot + Box Plot
    9. Violin Plot
    10. Population Pyramid
    11. Categorical Plots

    Composition

    1. Waffle Chart
    2. Pie Chart
    3. Treemap
    4. Bar Chart

    Change

    1. Time Series Plot
    2. Time Series with Peaks and Troughs Annotated
    3. Autocorrelation Plot
    4. Cross Correlation Plot
    5. Time Series Decomposition Plot
    6. Multiple Time Series
    7. Plotting with different scales using secondary Y axis
    8. Time Series with Error Bands
    9. Stacked Area Chart
    10. Area Chart Unstacked
    11. Calendar Heat Map
    12. Seasonal Plot

    Groups

    1. Dendrogram
    2. Cluster Plot
    3. Andrews Curve
    4. Parallel Coordinates
    [/columnize]

    Setup

    Run this once before the plot’s code. The individual charts, however, may redefine its own aesthetics.

    # !pip install brewer2mpl
    import numpy as np
    import pandas as pd
    import matplotlib as mpl
    import matplotlib.pyplot as plt
    import seaborn as sns
    import warnings; warnings.filterwarnings(action='once')
    
    large = 22; med = 16; small = 12
    params = {'axes.titlesize': large,
              'legend.fontsize': med,
              'figure.figsize': (16, 10),
              'axes.labelsize': med,
              'axes.titlesize': med,
              'xtick.labelsize': med,
              'ytick.labelsize': med,
              'figure.titlesize': large}
    plt.rcParams.update(params)
    plt.style.use('seaborn-whitegrid')
    sns.set_style("white")
    %matplotlib inline
    
    # Version
    print(mpl.__version__)  #> 3.0.0
    print(sns.__version__)  #> 0.9.0
    

    Correlation

    The plots under correlation is used to visualize the relationship between 2 or more variables. That is, how does one variable change with respect to another.

    1. Scatter plot

    Scatteplot is a classic and fundamental plot used to study the relationship between two variables. If you have multiple groups in your data you may want to visualise each group in a different color. In matplotlib, you can conveniently do this using plt.scatterplot().

    Show Code
    # Import dataset 
    midwest = pd.read_csv("https://raw.githubusercontent.com/selva86/datasets/master/midwest_filter.csv")
    
    # Prepare Data 
    # Create as many colors as there are unique midwest['category']
    categories = np.unique(midwest['category'])
    colors = [plt.cm.tab10(i/float(len(categories)-1)) for i in range(len(categories))]
    
    # Draw Plot for Each Category
    plt.figure(figsize=(16, 10), dpi= 80, facecolor='w', edgecolor='k')
    
    for i, category in enumerate(categories):
        plt.scatter('area', 'poptotal', 
                    data=midwest.loc[midwest.category==category, :], 
                    s=20, c=colors[i], label=str(category))
    
    # Decorations
    plt.gca().set(xlim=(0.0, 0.1), ylim=(0, 90000),
                  xlabel='Area', ylabel='Population')
    
    plt.xticks(fontsize=12); plt.yticks(fontsize=12)
    plt.title("Scatterplot of Midwest Area vs Population", fontsize=22)
    plt.legend(fontsize=12)    
    plt.show()    
    

    Scatterplot Matplotlib

    2. Bubble plot with Encircling

    Sometimes you want to show a group of points within a boundary to emphasize their importance. In this example, you get the records from the dataframe that should be encircled and pass it to the encircle() described in the code below.

    Show Code
    from matplotlib import patches
    from scipy.spatial import ConvexHull
    import warnings; warnings.simplefilter('ignore')
    sns.set_style("white")
    
    # Step 1: Prepare Data
    midwest = pd.read_csv("https://raw.githubusercontent.com/selva86/datasets/master/midwest_filter.csv")
    
    # As many colors as there are unique midwest['category']
    categories = np.unique(midwest['category'])
    colors = [plt.cm.tab10(i/float(len(categories)-1)) for i in range(len(categories))]
    
    # Step 2: Draw Scatterplot with unique color for each category
    fig = plt.figure(figsize=(16, 10), dpi= 80, facecolor='w', edgecolor='k')    
    
    for i, category in enumerate(categories):
        plt.scatter('area', 'poptotal', data=midwest.loc[midwest.category==category, :], s='dot_size', c=colors[i], label=str(category), edgecolors='black', linewidths=.5)
    
    # Step 3: Encircling
    # https://stackoverflow.com/questions/44575681/how-do-i-encircle-different-data-sets-in-scatter-plot
    def encircle(x,y, ax=None, **kw):
        if not ax: ax=plt.gca()
        p = np.c_[x,y]
        hull = ConvexHull(p)
        poly = plt.Polygon(p[hull.vertices,:], **kw)
        ax.add_patch(poly)
    
    # Select data to be encircled
    midwest_encircle_data = midwest.loc[midwest.state=='IN', :]                         
    
    # Draw polygon surrounding vertices    
    encircle(midwest_encircle_data.area, midwest_encircle_data.poptotal, ec="k", fc="gold", alpha=0.1)
    encircle(midwest_encircle_data.area, midwest_encircle_data.poptotal, ec="firebrick", fc="none", linewidth=1.5)
    
    # Step 4: Decorations
    plt.gca().set(xlim=(0.0, 0.1), ylim=(0, 90000),
                  xlabel='Area', ylabel='Population')
    
    plt.xticks(fontsize=12); plt.yticks(fontsize=12)
    plt.title("Bubble Plot with Encircling", fontsize=22)
    plt.legend(fontsize=12)    
    plt.show()    
    

    Bubble Plot in Matplotlib

    3. Scatter plot with linear regression line of best fit

    If you want to understand how two variables change with respect to each other, the line of best fit is the way to go. The below plot shows how the line of best fit differs amongst various groups in the data. To disable the groupings and to just draw one line-of-best-fit for the entire dataset, remove the hue='cyl' parameter from the sns.lmplot() call below.

    Show Code
    # Import Data
    df = pd.read_csv("https://raw.githubusercontent.com/selva86/datasets/master/mpg_ggplot2.csv")
    df_select = df.loc[df.cyl.isin([4,8]), :]
    
    # Plot
    sns.set_style("white")
    gridobj = sns.lmplot(x="displ", y="hwy", hue="cyl", data=df_select, 
                         height=7, aspect=1.6, robust=True, palette='tab10', 
                         scatter_kws=dict(s=60, linewidths=.7, edgecolors='black'))
    
    # Decorations
    gridobj.set(xlim=(0.5, 7.5), ylim=(0, 50))
    plt.title("Scatterplot with line of best fit grouped by number of cylinders", fontsize=20)
    plt.show()
    

    Each regression line in its own column

    Alternately, you can show the best fit line for each group in its own column. You cando this by setting the col=groupingcolumn parameter inside the sns.lmplot().

    Show Code
    # Import Data
    df = pd.read_csv("https://raw.githubusercontent.com/selva86/datasets/master/mpg_ggplot2.csv")
    df_select = df.loc[df.cyl.isin([4,8]), :]
    
    # Each line in its own column
    sns.set_style("white")
    gridobj = sns.lmplot(x="displ", y="hwy", 
                         data=df_select, 
                         height=7, 
                         robust=True, 
                         palette='Set1', 
                         col="cyl",
                         scatter_kws=dict(s=60, linewidths=.7, edgecolors='black'))
    
    # Decorations
    gridobj.set(xlim=(0.5, 7.5), ylim=(0, 50))
    plt.show()
    

    4. Jittering with stripplot

    Often multiple datapoints have exactly the same X and Y values. As a result, multiple points get plotted over each other and hide. To avoid this, jitter the points slightly so you can visually see them. This is convenient to do using seaborn’s stripplot().

    Show Code
    # Import Data
    df = pd.read_csv("https://raw.githubusercontent.com/selva86/datasets/master/mpg_ggplot2.csv")
    
    # Draw Stripplot
    fig, ax = plt.subplots(figsize=(16,10), dpi= 80)    
    sns.stripplot(df.cty, df.hwy, jitter=0.25, size=8, ax=ax, linewidth=.5)
    
    # Decorations
    plt.title('Use jittered plots to avoid overlapping of points', fontsize=22)
    plt.show()
    

    5. Counts Plot

    Another option to avoid the problem of points overlap is the increase the size of the dot depending on how many points lie in that spot. So, larger the size of the point more is the concentration of points around that.

    # Import Data
    df = pd.read_csv("https://raw.githubusercontent.com/selva86/datasets/master/mpg_ggplot2.csv")
    df_counts = df.groupby(['hwy', 'cty']).size().reset_index(name='counts')
    
    # Draw Stripplot
    fig, ax = plt.subplots(figsize=(16,10), dpi= 80)    
    sns.stripplot(df_counts.cty, df_counts.hwy, size=df_counts.counts*2, ax=ax)
    
    # Decorations
    plt.title('Counts Plot - Size of circle is bigger as more points overlap', fontsize=22)
    plt.show()
    

    6. Marginal Histogram

    Marginal histograms have a histogram along the X and Y axis variables. This is used to visualize the relationship between the X and Y along with the univariate distribution of the X and the Y individually. This plot if often used in exploratory data analysis (EDA).

    Show Code
    # Import Data
    df = pd.read_csv("https://raw.githubusercontent.com/selva86/datasets/master/mpg_ggplot2.csv")
    
    # Create Fig and gridspec
    fig = plt.figure(figsize=(16, 10), dpi= 80)
    grid = plt.GridSpec(4, 4, hspace=0.5, wspace=0.2)
    
    # Define the axes
    ax_main = fig.add_subplot(grid[:-1, :-1])
    ax_right = fig.add_subplot(grid[:-1, -1], xticklabels=[], yticklabels=[])
    ax_bottom = fig.add_subplot(grid[-1, 0:-1], xticklabels=[], yticklabels=[])
    
    # Scatterplot on main ax
    ax_main.scatter('displ', 'hwy', s=df.cty*4, c=df.manufacturer.astype('category').cat.codes, alpha=.9, data=df, cmap="tab10", edgecolors='gray', linewidths=.5)
    
    # histogram on the right
    ax_bottom.hist(df.displ, 40, histtype='stepfilled', orientation='vertical', color='deeppink')
    ax_bottom.invert_yaxis()
    
    # histogram in the bottom
    ax_right.hist(df.hwy, 40, histtype='stepfilled', orientation='horizontal', color='deeppink')
    
    # Decorations
    ax_main.set(title='Scatterplot with Histograms \n displ vs hwy', xlabel='displ', ylabel='hwy')
    ax_main.title.set_fontsize(20)
    for item in ([ax_main.xaxis.label, ax_main.yaxis.label] + ax_main.get_xticklabels() + ax_main.get_yticklabels()):
        item.set_fontsize(14)
    
    xlabels = ax_main.get_xticks().tolist()
    ax_main.set_xticklabels(xlabels)
    plt.show()
    

    7. Marginal Boxplot

    Marginal boxplot serves a similar purpose as marginal histogram. However, the boxplot helps to pinpoint the median, 25th and 75th percentiles of the X and the Y.

    Show Code
    # Import Data
    df = pd.read_csv("https://raw.githubusercontent.com/selva86/datasets/master/mpg_ggplot2.csv")
    
    # Create Fig and gridspec
    fig = plt.figure(figsize=(16, 10), dpi= 80)
    grid = plt.GridSpec(4, 4, hspace=0.5, wspace=0.2)
    
    # Define the axes
    ax_main = fig.add_subplot(grid[:-1, :-1])
    ax_right = fig.add_subplot(grid[:-1, -1], xticklabels=[], yticklabels=[])
    ax_bottom = fig.add_subplot(grid[-1, 0:-1], xticklabels=[], yticklabels=[])
    
    # Scatterplot on main ax
    ax_main.scatter('displ', 'hwy', s=df.cty*5, c=df.manufacturer.astype('category').cat.codes, alpha=.9, data=df, cmap="Set1", edgecolors='black', linewidths=.5)
    
    # Add a graph in each part
    sns.boxplot(df.hwy, ax=ax_right, orient="v")
    sns.boxplot(df.displ, ax=ax_bottom, orient="h")
    
    # Decorations ------------------
    # Remove x axis name for the boxplot
    ax_bottom.set(xlabel='')
    ax_right.set(ylabel='')
    
    # Main Title, Xlabel and YLabel
    ax_main.set(title='Scatterplot with Histograms \n displ vs hwy', xlabel='displ', ylabel='hwy')
    
    # Set font size of different components
    ax_main.title.set_fontsize(20)
    for item in ([ax_main.xaxis.label, ax_main.yaxis.label] + ax_main.get_xticklabels() + ax_main.get_yticklabels()):
        item.set_fontsize(14)
    
    plt.show()
    

    8. Correllogram

    Correlogram is used to visually see the correlation metric between all possible pairs of numeric variables in a given dataframe (or 2D array).

    # Import Dataset
    df = pd.read_csv("https://github.com/selva86/datasets/raw/master/mtcars.csv")
    
    # Plot
    plt.figure(figsize=(12,10), dpi= 80)
    sns.heatmap(df.corr(), xticklabels=df.corr().columns, yticklabels=df.corr().columns, cmap='RdYlGn', center=0, annot=True)
    
    # Decorations
    plt.title('Correlogram of mtcars', fontsize=22)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.show()
    

    9. Pairwise Plot

    Pairwise plot is a favorite in exploratory analysis to understand the relationship between all possible pairs of numeric variables. It is a must have tool for bivariate analysis.

    # Load Dataset
    df = sns.load_dataset('iris')
    
    # Plot
    plt.figure(figsize=(10,8), dpi= 80)
    sns.pairplot(df, kind="scatter", hue="species", plot_kws=dict(s=80, edgecolor="white", linewidth=2.5))
    plt.show()
    

    # Load Dataset
    df = sns.load_dataset('iris')
    
    # Plot
    plt.figure(figsize=(10,8), dpi= 80)
    sns.pairplot(df, kind="reg", hue="species")
    plt.show()
    

    Deviation

    10. Diverging Bars

    If you want to see how the items are varying based on a single metric and visualize the order and amount of this variance, the diverging bars is a great tool. It helps to quickly differentiate the performance of groups in your data and is quite intuitive and instantly conveys the point.

    # Prepare Data
    df = pd.read_csv("https://github.com/selva86/datasets/raw/master/mtcars.csv")
    x = df.loc[:, ['mpg']]
    df['mpg_z'] = (x - x.mean())/x.std()
    df['colors'] = ['red' if x < 0 else 'green' for x in df['mpg_z']]
    df.sort_values('mpg_z', inplace=True)
    df.reset_index(inplace=True)
    
    # Draw plot
    plt.figure(figsize=(14,10), dpi= 80)
    plt.hlines(y=df.index, xmin=0, xmax=df.mpg_z, color=df.colors, alpha=0.4, linewidth=5)
    
    # Decorations
    plt.gca().set(ylabel='$Model$', xlabel='$Mileage$')
    plt.yticks(df.index, df.cars, fontsize=12)
    plt.title('Diverging Bars of Car Mileage', fontdict={'size':20})
    plt.grid(linestyle='--', alpha=0.5)
    plt.show()
    

    11. Diverging Texts

    Diverging texts is similar to diverging bars and it preferred if you want to show the value of each items within the chart in a nice and presentable way.

    # Prepare Data
    df = pd.read_csv("https://github.com/selva86/datasets/raw/master/mtcars.csv")
    x = df.loc[:, ['mpg']]
    df['mpg_z'] = (x - x.mean())/x.std()
    df['colors'] = ['red' if x < 0 else 'green' for x in df['mpg_z']]
    df.sort_values('mpg_z', inplace=True)
    df.reset_index(inplace=True)
    
    # Draw plot
    plt.figure(figsize=(14,14), dpi= 80)
    plt.hlines(y=df.index, xmin=0, xmax=df.mpg_z)
    for x, y, tex in zip(df.mpg_z, df.index, df.mpg_z):
        t = plt.text(x, y, round(tex, 2), horizontalalignment='right' if x < 0 else 'left', 
                     verticalalignment='center', fontdict={'color':'red' if x < 0 else 'green', 'size':14})
    
    # Decorations    
    plt.yticks(df.index, df.cars, fontsize=12)
    plt.title('Diverging Text Bars of Car Mileage', fontdict={'size':20})
    plt.grid(linestyle='--', alpha=0.5)
    plt.xlim(-2.5, 2.5)
    plt.show()
    

    12. Diverging Dot Plot

    Divering dot plot is also similar to the diverging bars. However compared to diverging bars, the absence of bars reduces the amount of contrast and disparity between the groups.

    Show Code
    # Prepare Data
    df = pd.read_csv("https://github.com/selva86/datasets/raw/master/mtcars.csv")
    x = df.loc[:, ['mpg']]
    df['mpg_z'] = (x - x.mean())/x.std()
    df['colors'] = ['red' if x < 0 else 'darkgreen' for x in df['mpg_z']]
    df.sort_values('mpg_z', inplace=True)
    df.reset_index(inplace=True)
    
    # Draw plot
    plt.figure(figsize=(14,16), dpi= 80)
    plt.scatter(df.mpg_z, df.index, s=450, alpha=.6, color=df.colors)
    for x, y, tex in zip(df.mpg_z, df.index, df.mpg_z):
        t = plt.text(x, y, round(tex, 1), horizontalalignment='center', 
                     verticalalignment='center', fontdict={'color':'white'})
    
    # Decorations
    # Lighten borders
    plt.gca().spines["top"].set_alpha(.3)
    plt.gca().spines["bottom"].set_alpha(.3)
    plt.gca().spines["right"].set_alpha(.3)
    plt.gca().spines["left"].set_alpha(.3)
    
    plt.yticks(df.index, df.cars)
    plt.title('Diverging Dotplot of Car Mileage', fontdict={'size':20})
    plt.xlabel('$Mileage$')
    plt.grid(linestyle='--', alpha=0.5)
    plt.xlim(-2.5, 2.5)
    plt.show()
    

    13. Diverging Lollipop Chart with Markers

    Lollipop with markers provides a flexible way of visualizing the divergence by laying emphasis on any significant datapoints you want to bring attention to and give reasoning within the chart appropriately.

    Show Code
    # Prepare Data
    df = pd.read_csv("https://github.com/selva86/datasets/raw/master/mtcars.csv")
    x = df.loc[:, ['mpg']]
    df['mpg_z'] = (x - x.mean())/x.std()
    df['colors'] = 'black'
    
    # color fiat differently
    df.loc[df.cars == 'Fiat X1-9', 'colors'] = 'darkorange'
    df.sort_values('mpg_z', inplace=True)
    df.reset_index(inplace=True)
    
    
    # Draw plot
    import matplotlib.patches as patches
    
    plt.figure(figsize=(14,16), dpi= 80)
    plt.hlines(y=df.index, xmin=0, xmax=df.mpg_z, color=df.colors, alpha=0.4, linewidth=1)
    plt.scatter(df.mpg_z, df.index, color=df.colors, s=[600 if x == 'Fiat X1-9' else 300 for x in df.cars], alpha=0.6)
    plt.yticks(df.index, df.cars)
    plt.xticks(fontsize=12)
    
    # Annotate
    plt.annotate('Mercedes Models', xy=(0.0, 11.0), xytext=(1.0, 11), xycoords='data', 
                fontsize=15, ha='center', va='center',
                bbox=dict(boxstyle='square', fc='firebrick'),
                arrowprops=dict(arrowstyle='-[, widthB=2.0, lengthB=1.5', lw=2.0, color='steelblue'), color='white')
    
    # Add Patches
    p1 = patches.Rectangle((-2.0, -1), width=.3, height=3, alpha=.2, facecolor='red')
    p2 = patches.Rectangle((1.5, 27), width=.8, height=5, alpha=.2, facecolor='green')
    plt.gca().add_patch(p1)
    plt.gca().add_patch(p2)
    
    # Decorate
    plt.title('Diverging Bars of Car Mileage', fontdict={'size':20})
    plt.grid(linestyle='--', alpha=0.5)
    plt.show()
    

    14. Area Chart

    By coloring the area between the axis and the lines, the area chart throws more emphasis not just on the peaks and troughs but also the duration of the highs and lows. The longer the duration of the highs, the larger is the area under the line.

    Show Code
    import numpy as np
    import pandas as pd
    
    # Prepare Data
    df = pd.read_csv("https://github.com/selva86/datasets/raw/master/economics.csv", parse_dates=['date']).head(100)
    x = np.arange(df.shape[0])
    y_returns = (df.psavert.diff().fillna(0)/df.psavert.shift(1)).fillna(0) * 100
    
    # Plot
    plt.figure(figsize=(16,10), dpi= 80)
    plt.fill_between(x[1:], y_returns[1:], 0, where=y_returns[1:] >= 0, facecolor='green', interpolate=True, alpha=0.7)
    plt.fill_between(x[1:], y_returns[1:], 0, where=y_returns[1:] <= 0, facecolor='red', interpolate=True, alpha=0.7)
    
    # Annotate
    plt.annotate('Peak \n1975', xy=(94.0, 21.0), xytext=(88.0, 28),
                 bbox=dict(boxstyle='square', fc='firebrick'),
                 arrowprops=dict(facecolor='steelblue', shrink=0.05), fontsize=15, color='white')
    
    
    # Decorations
    xtickvals = [str(m)[:3].upper()+"-"+str(y) for y,m in zip(df.date.dt.year, df.date.dt.month_name())]
    plt.gca().set_xticks(x[::6])
    plt.gca().set_xticklabels(xtickvals[::6], rotation=90, fontdict={'horizontalalignment': 'center', 'verticalalignment': 'center_baseline'})
    plt.ylim(-35,35)
    plt.xlim(1,100)
    plt.title("Month Economics Return %", fontsize=22)
    plt.ylabel('Monthly returns %')
    plt.grid(alpha=0.5)
    plt.show()
    

    Ranking

    15. Ordered Bar Chart

    Ordered bar chart conveys the rank order of the items effectively. But adding the value of the metric above the chart, the user gets the precise information from the chart itself.

    Show Code
    # Prepare Data
    df_raw = pd.read_csv("https://github.com/selva86/datasets/raw/master/mpg_ggplot2.csv")
    df = df_raw[['cty', 'manufacturer']].groupby('manufacturer').apply(lambda x: x.mean())
    df.sort_values('cty', inplace=True)
    df.reset_index(inplace=True)
    
    # Draw plot
    import matplotlib.patches as patches
    
    fig, ax = plt.subplots(figsize=(16,10), facecolor='white', dpi= 80)
    ax.vlines(x=df.index, ymin=0, ymax=df.cty, color='firebrick', alpha=0.7, linewidth=20)
    
    # Annotate Text
    for i, cty in enumerate(df.cty):
        ax.text(i, cty+0.5, round(cty, 1), horizontalalignment='center')
    
    
    # Title, Label, Ticks and Ylim
    ax.set_title('Bar Chart for Highway Mileage', fontdict={'size':22})
    ax.set(ylabel='Miles Per Gallon', ylim=(0, 30))
    plt.xticks(df.index, df.manufacturer.str.upper(), rotation=60, horizontalalignment='right', fontsize=12)
    
    # Add patches to color the X axis labels
    p1 = patches.Rectangle((.57, -0.005), width=.33, height=.13, alpha=.1, facecolor='green', transform=fig.transFigure)
    p2 = patches.Rectangle((.124, -0.005), width=.446, height=.13, alpha=.1, facecolor='red', transform=fig.transFigure)
    fig.add_artist(p1)
    fig.add_artist(p2)
    plt.show()
    

    16. Lollipop Chart

    Lollipop chart serves a similar purpose as a ordered bar chart in a visually pleasing way.

    Show Code
    # Prepare Data
    df_raw = pd.read_csv("https://github.com/selva86/datasets/raw/master/mpg_ggplot2.csv")
    df = df_raw[['cty', 'manufacturer']].groupby('manufacturer').apply(lambda x: x.mean())
    df.sort_values('cty', inplace=True)
    df.reset_index(inplace=True)
    
    # Draw plot
    fig, ax = plt.subplots(figsize=(16,10), dpi= 80)
    ax.vlines(x=df.index, ymin=0, ymax=df.cty, color='firebrick', alpha=0.7, linewidth=2)
    ax.scatter(x=df.index, y=df.cty, s=75, color='firebrick', alpha=0.7)
    
    # Title, Label, Ticks and Ylim
    ax.set_title('Lollipop Chart for Highway Mileage', fontdict={'size':22})
    ax.set_ylabel('Miles Per Gallon')
    ax.set_xticks(df.index)
    ax.set_xticklabels(df.manufacturer.str.upper(), rotation=60, fontdict={'horizontalalignment': 'right', 'size':12})
    ax.set_ylim(0, 30)
    
    # Annotate
    for row in df.itertuples():
        ax.text(row.Index, row.cty+.5, s=round(row.cty, 2), horizontalalignment= 'center', verticalalignment='bottom', fontsize=14)
    
    plt.show()
    

    17. Dot Plot

    The dot plot conveys the rank order of the items. And since it is aligned along the horizontal axis, you can visualize how far the points are from each other more easily.

    # Prepare Data
    df_raw = pd.read_csv("https://github.com/selva86/datasets/raw/master/mpg_ggplot2.csv")
    df = df_raw[['cty', 'manufacturer']].groupby('manufacturer').apply(lambda x: x.mean())
    df.sort_values('cty', inplace=True)
    df.reset_index(inplace=True)
    
    # Draw plot
    fig, ax = plt.subplots(figsize=(16,10), dpi= 80)
    ax.hlines(y=df.index, xmin=11, xmax=26, color='gray', alpha=0.7, linewidth=1, linestyles='dashdot')
    ax.scatter(y=df.index, x=df.cty, s=75, color='firebrick', alpha=0.7)
    
    # Title, Label, Ticks and Ylim
    ax.set_title('Dot Plot for Highway Mileage', fontdict={'size':22})
    ax.set_xlabel('Miles Per Gallon')
    ax.set_yticks(df.index)
    ax.set_yticklabels(df.manufacturer.str.title(), fontdict={'horizontalalignment': 'right'})
    ax.set_xlim(10, 27)
    plt.show()
    

    18. Slope Chart

    Slope chart is most suitable for comparing the ‘Before’ and ‘After’ positions of a given person/item.

    Show Code
    import matplotlib.lines as mlines
    # Import Data
    df = pd.read_csv("https://raw.githubusercontent.com/selva86/datasets/master/gdppercap.csv")
    
    left_label = [str(c) + ', '+ str(round(y)) for c, y in zip(df.continent, df['1952'])]
    right_label = [str(c) + ', '+ str(round(y)) for c, y in zip(df.continent, df['1957'])]
    klass = ['red' if (y1-y2) < 0 else 'green' for y1, y2 in zip(df['1952'], df['1957'])]
    
    # draw line
    # https://stackoverflow.com/questions/36470343/how-to-draw-a-line-with-matplotlib/36479941
    def newline(p1, p2, color='black'):
        ax = plt.gca()
        l = mlines.Line2D([p1[0],p2[0]], [p1[1],p2[1]], color='red' if p1[1]-p2[1] > 0 else 'green', marker='o', markersize=6)
        ax.add_line(l)
        return l
    
    fig, ax = plt.subplots(1,1,figsize=(14,14), dpi= 80)
    
    # Vertical Lines
    ax.vlines(x=1, ymin=500, ymax=13000, color='black', alpha=0.7, linewidth=1, linestyles='dotted')
    ax.vlines(x=3, ymin=500, ymax=13000, color='black', alpha=0.7, linewidth=1, linestyles='dotted')
    
    # Points
    ax.scatter(y=df['1952'], x=np.repeat(1, df.shape[0]), s=10, color='black', alpha=0.7)
    ax.scatter(y=df['1957'], x=np.repeat(3, df.shape[0]), s=10, color='black', alpha=0.7)
    
    # Line Segmentsand Annotation
    for p1, p2, c in zip(df['1952'], df['1957'], df['continent']):
        newline([1,p1], [3,p2])
        ax.text(1-0.05, p1, c + ', ' + str(round(p1)), horizontalalignment='right', verticalalignment='center', fontdict={'size':14})
        ax.text(3+0.05, p2, c + ', ' + str(round(p2)), horizontalalignment='left', verticalalignment='center', fontdict={'size':14})
    
    # 'Before' and 'After' Annotations
    ax.text(1-0.05, 13000, 'BEFORE', horizontalalignment='right', verticalalignment='center', fontdict={'size':18, 'weight':700})
    ax.text(3+0.05, 13000, 'AFTER', horizontalalignment='left', verticalalignment='center', fontdict={'size':18, 'weight':700})
    
    # Decoration
    ax.set_title("Slopechart: Comparing GDP Per Capita between 1952 vs 1957", fontdict={'size':22})
    ax.set(xlim=(0,4), ylim=(0,14000), ylabel='Mean GDP Per Capita')
    ax.set_xticks([1,3])
    ax.set_xticklabels(["1952", "1957"])
    plt.yticks(np.arange(500, 13000, 2000), fontsize=12)
    
    # Lighten borders
    plt.gca().spines["top"].set_alpha(.0)
    plt.gca().spines["bottom"].set_alpha(.0)
    plt.gca().spines["right"].set_alpha(.0)
    plt.gca().spines["left"].set_alpha(.0)
    plt.show()
    

    19. Dumbbell Plot

    Dumbbell plot conveys the ‘before’ and ‘after’ positions of various items along with the rank ordering of the items. Its very useful if you want to visualize the effect of a particular project / initiative on different objects.

    Show Code
    import matplotlib.lines as mlines
    
    # Import Data
    df = pd.read_csv("https://raw.githubusercontent.com/selva86/datasets/master/health.csv")
    df.sort_values('pct_2014', inplace=True)
    df.reset_index(inplace=True)
    
    # Func to draw line segment
    def newline(p1, p2, color='black'):
        ax = plt.gca()
        l = mlines.Line2D([p1[0],p2[0]], [p1[1],p2[1]], color='skyblue')
        ax.add_line(l)
        return l
    
    # Figure and Axes
    fig, ax = plt.subplots(1,1,figsize=(14,14), facecolor='#f7f7f7', dpi= 80)
    
    # Vertical Lines
    ax.vlines(x=.05, ymin=0, ymax=26, color='black', alpha=1, linewidth=1, linestyles='dotted')
    ax.vlines(x=.10, ymin=0, ymax=26, color='black', alpha=1, linewidth=1, linestyles='dotted')
    ax.vlines(x=.15, ymin=0, ymax=26, color='black', alpha=1, linewidth=1, linestyles='dotted')
    ax.vlines(x=.20, ymin=0, ymax=26, color='black', alpha=1, linewidth=1, linestyles='dotted')
    
    # Points
    ax.scatter(y=df['index'], x=df['pct_2013'], s=50, color='#0e668b', alpha=0.7)
    ax.scatter(y=df['index'], x=df['pct_2014'], s=50, color='#a3c4dc', alpha=0.7)
    
    # Line Segments
    for i, p1, p2 in zip(df['index'], df['pct_2013'], df['pct_2014']):
        newline([p1, i], [p2, i])
    
    # Decoration
    ax.set_facecolor('#f7f7f7')
    ax.set_title("Dumbell Chart: Pct Change - 2013 vs 2014", fontdict={'size':22})
    ax.set(xlim=(0,.25), ylim=(-1, 27), ylabel='Mean GDP Per Capita')
    ax.set_xticks([.05, .1, .15, .20])
    ax.set_xticklabels(['5%', '15%', '20%', '25%'])
    ax.set_xticklabels(['5%', '15%', '20%', '25%'])    
    plt.show()
    

    Distribution

    20. Histogram for Continuous Variable

    Histogram shows the frequency distribution of a given variable. The below representation groups the frequency bars based on a categorical variable giving a greater insight about the continuous variable and the categorical variable in tandem.

    Show Code
    # Import Data
    df = pd.read_csv("https://github.com/selva86/datasets/raw/master/mpg_ggplot2.csv")
    
    # Prepare data
    x_var = 'displ'
    groupby_var = 'class'
    df_agg = df.loc[:, [x_var, groupby_var]].groupby(groupby_var)
    vals = [df[x_var].values.tolist() for i, df in df_agg]
    
    # Draw
    plt.figure(figsize=(16,9), dpi= 80)
    colors = [plt.cm.Spectral(i/float(len(vals)-1)) for i in range(len(vals))]
    n, bins, patches = plt.hist(vals, 30, stacked=True, density=False, color=colors[:len(vals)])
    
    # Decoration
    plt.legend({group:col for group, col in zip(np.unique(df[groupby_var]).tolist(), colors[:len(vals)])})
    plt.title(f"Stacked Histogram of ${x_var}$ colored by ${groupby_var}$", fontsize=22)
    plt.xlabel(x_var)
    plt.ylabel("Frequency")
    plt.ylim(0, 25)
    plt.xticks(ticks=bins[::3], labels=[round(b,1) for b in bins[::3]])
    plt.show()
    

    21. Histogram for Categorical Variable

    The histogram of a categorical variable shows the frequency distribution of a that variable. By coloring the bars, you can visualize the distribution in connection with another categorical variable representing the colors.

    Show Code
    # Import Data
    df = pd.read_csv("https://github.com/selva86/datasets/raw/master/mpg_ggplot2.csv")
    
    # Prepare data
    x_var = 'manufacturer'
    groupby_var = 'class'
    df_agg = df.loc[:, [x_var, groupby_var]].groupby(groupby_var)
    vals = [df[x_var].values.tolist() for i, df in df_agg]
    
    # Draw
    plt.figure(figsize=(16,9), dpi= 80)
    colors = [plt.cm.Spectral(i/float(len(vals)-1)) for i in range(len(vals))]
    n, bins, patches = plt.hist(vals, df[x_var].unique().__len__(), stacked=True, density=False, color=colors[:len(vals)])
    
    # Decoration
    plt.legend({group:col for group, col in zip(np.unique(df[groupby_var]).tolist(), colors[:len(vals)])})
    plt.title(f"Stacked Histogram of ${x_var}$ colored by ${groupby_var}$", fontsize=22)
    plt.xlabel(x_var)
    plt.ylabel("Frequency")
    plt.ylim(0, 40)
    plt.xticks(ticks=bins, labels=np.unique(df[x_var]).tolist(), rotation=90, horizontalalignment='left')
    plt.show()
    

    22. Density Plot

    Density plots are a commonly used tool visualise the distribution of a continuous variable. By grouping them by the ‘response’ variable, you can inspect the relationship between the X and the Y. The below case if for representational purpose to describe how the distribution of city mileage varies with respect the number of cylinders.

    Show Code
    # Import Data
    df = pd.read_csv("https://github.com/selva86/datasets/raw/master/mpg_ggplot2.csv")
    
    # Draw Plot
    plt.figure(figsize=(16,10), dpi= 80)
    sns.kdeplot(df.loc[df['cyl'] == 4, "cty"], shade=True, color="g", label="Cyl=4", alpha=.7)
    sns.kdeplot(df.loc[df['cyl'] == 5, "cty"], shade=True, color="deeppink", label="Cyl=5", alpha=.7)
    sns.kdeplot(df.loc[df['cyl'] == 6, "cty"], shade=True, color="dodgerblue", label="Cyl=6", alpha=.7)
    sns.kdeplot(df.loc[df['cyl'] == 8, "cty"], shade=True, color="orange", label="Cyl=8", alpha=.7)
    
    # Decoration
    plt.title('Density Plot of City Mileage by n_Cylinders', fontsize=22)
    plt.legend()
    plt.show()
    
    s

    23. Density Curves with Histogram

    Density curve with histogram brings together the collective information conveyed by the two plots so you can have them both in a single figure instead of two.

    Show Code
    # Import Data
    df = pd.read_csv("https://github.com/selva86/datasets/raw/master/mpg_ggplot2.csv")
    
    # Draw Plot
    plt.figure(figsize=(13,10), dpi= 80)
    sns.distplot(df.loc[df['class'] == 'compact', "cty"], color="dodgerblue", label="Compact", hist_kws={'alpha':.7}, kde_kws={'linewidth':3})
    sns.distplot(df.loc[df['class'] == 'suv', "cty"], color="orange", label="SUV", hist_kws={'alpha':.7}, kde_kws={'linewidth':3})
    sns.distplot(df.loc[df['class'] == 'minivan', "cty"], color="g", label="minivan", hist_kws={'alpha':.7}, kde_kws={'linewidth':3})
    plt.ylim(0, 0.35)
    
    # Decoration
    plt.title('Density Plot of City Mileage by Vehicle Type', fontsize=22)
    plt.legend()
    plt.show()
    

    24. Joy Plot

    Joy Plot allows the density curves of different groups to overlap, it is a great way to visualize the distribution of a larger number of groups in relation to each other. It looks pleasing to the eye and conveys just the right information clearly. It can be easily built using the joypy package which is based on matplotlib.

    # !pip install joypy
    # Import Data
    mpg = pd.read_csv("https://github.com/selva86/datasets/raw/master/mpg_ggplot2.csv")
    
    # Draw Plot
    plt.figure(figsize=(16,10), dpi= 80)
    fig, axes = joypy.joyplot(mpg, column=['hwy', 'cty'], by="class", ylim='own', figsize=(14,10))
    
    # Decoration
    plt.title('Joy Plot of City and Highway Mileage by Class', fontsize=22)
    plt.show()
    

    25. Distributed Dot Plot

    Distributed dot plot shows the univariate distribution of points segmented by groups. The darker the points, more is the concentration of data points in that region. By coloring the median differently, the real positioning of the groups becomes apparent instantly.

    Show Code
    import matplotlib.patches as mpatches
    
    # Prepare Data
    df_raw = pd.read_csv("https://github.com/selva86/datasets/raw/master/mpg_ggplot2.csv")
    cyl_colors = {4:'tab:red', 5:'tab:green', 6:'tab:blue', 8:'tab:orange'}
    df_raw['cyl_color'] = df_raw.cyl.map(cyl_colors)
    
    # Mean and Median city mileage by make
    df = df_raw[['cty', 'manufacturer']].groupby('manufacturer').apply(lambda x: x.mean())
    df.sort_values('cty', ascending=False, inplace=True)
    df.reset_index(inplace=True)
    df_median = df_raw[['cty', 'manufacturer']].groupby('manufacturer').apply(lambda x: x.median())
    
    # Draw horizontal lines
    fig, ax = plt.subplots(figsize=(16,10), dpi= 80)
    ax.hlines(y=df.index, xmin=0, xmax=40, color='gray', alpha=0.5, linewidth=.5, linestyles='dashdot')
    
    # Draw the Dots
    for i, make in enumerate(df.manufacturer):
        df_make = df_raw.loc[df_raw.manufacturer==make, :]
        ax.scatter(y=np.repeat(i, df_make.shape[0]), x='cty', data=df_make, s=75, edgecolors='gray', c='w', alpha=0.5)
        ax.scatter(y=i, x='cty', data=df_median.loc[df_median.index==make, :], s=75, c='firebrick')
    
    # Annotate    
    ax.text(33, 13, "$red \; dots \; are \; the \: median$", fontdict={'size':12}, color='firebrick')
    
    # Decorations
    red_patch = plt.plot([],[], marker="o", ms=10, ls="", mec=None, color='firebrick', label="Median")
    plt.legend(handles=red_patch)
    ax.set_title('Distribution of City Mileage by Make', fontdict={'size':22})
    ax.set_xlabel('Miles Per Gallon (City)', alpha=0.7)
    ax.set_yticks(df.index)
    ax.set_yticklabels(df.manufacturer.str.title(), fontdict={'horizontalalignment': 'right'}, alpha=0.7)
    ax.set_xlim(1, 40)
    plt.xticks(alpha=0.7)
    plt.gca().spines["top"].set_visible(False)    
    plt.gca().spines["bottom"].set_visible(False)    
    plt.gca().spines["right"].set_visible(False)    
    plt.gca().spines["left"].set_visible(False)   
    plt.grid(axis='both', alpha=.4, linewidth=.1)
    plt.show()
    

    26. Box Plot

    Box plots are a great way to visualize the distribution, keeping the median, 25th 75th quartiles and the outliers in mind. However, you need to be careful about interpreting the size the boxes which can potentially distort the number of points contained within that group. So, manually providing the number of observations in each box can help overcome this drawback.

    For example, the first two boxes on the left have boxes of the same size even though they have 5 and 47 obs respectively. So writing the number of observations in that group becomes necessary.

    # Import Data
    df = pd.read_csv("https://github.com/selva86/datasets/raw/master/mpg_ggplot2.csv")
    
    # Draw Plot
    plt.figure(figsize=(13,10), dpi= 80)
    sns.boxplot(x='class', y='hwy', data=df, notch=False)
    
    # Add N Obs inside boxplot (optional)
    def add_n_obs(df,group_col,y):
        medians_dict = {grp[0]:grp[1][y].median() for grp in df.groupby(group_col)}
        xticklabels = [x.get_text() for x in plt.gca().get_xticklabels()]
        n_obs = df.groupby(group_col)[y].size().values
        for (x, xticklabel), n_ob in zip(enumerate(xticklabels), n_obs):
            plt.text(x, medians_dict[xticklabel]*1.01, "#obs : "+str(n_ob), horizontalalignment='center', fontdict={'size':14}, color='white')
    
    add_n_obs(df,group_col='class',y='hwy')    
    
    # Decoration
    plt.title('Box Plot of Highway Mileage by Vehicle Class', fontsize=22)
    plt.ylim(10, 40)
    plt.show()
    

    27. Dot + Box Plot

    Dot + Box plot Conveys similar information as a boxplot split in groups. The dots, in addition, gives a sense of how many data points lie within each group.

    # Import Data
    df = pd.read_csv("https://github.com/selva86/datasets/raw/master/mpg_ggplot2.csv")
    
    # Draw Plot
    plt.figure(figsize=(13,10), dpi= 80)
    sns.boxplot(x='class', y='hwy', data=df, hue='cyl')
    sns.stripplot(x='class', y='hwy', data=df, color='black', size=3, jitter=1)
    
    for i in range(len(df['class'].unique())-1):
        plt.vlines(i+.5, 10, 45, linestyles='solid', colors='gray', alpha=0.2)
    
    # Decoration
    plt.title('Box Plot of Highway Mileage by Vehicle Class', fontsize=22)
    plt.legend(title='Cylinders')
    plt.show()
    

    28. Violin Plot

    Violin plot is a visually pleasing alternative to box plots. The shape or area of the violin depends on the number of observations it holds. However, the violin plots can be harder to read and it not commonly used in professional settings.

    # Import Data
    df = pd.read_csv("https://github.com/selva86/datasets/raw/master/mpg_ggplot2.csv")
    
    # Draw Plot
    plt.figure(figsize=(13,10), dpi= 80)
    sns.violinplot(x='class', y='hwy', data=df, scale='width', inner='quartile')
    
    # Decoration
    plt.title('Violin Plot of Highway Mileage by Vehicle Class', fontsize=22)
    plt.show()
    

    29. Population Pyramid

    Population pyramid can be used to show either the distribution of the groups ordered by the volumne. Or it can also be used to show the stage-by-stage filtering of the population as it is used below to show how many people pass through each stage of a marketing funnel.

    # Read data
    df = pd.read_csv("https://raw.githubusercontent.com/selva86/datasets/master/email_campaign_funnel.csv")
    
    # Draw Plot
    plt.figure(figsize=(13,10), dpi= 80)
    group_col = 'Gender'
    order_of_bars = df.Stage.unique()[::-1]
    colors = [plt.cm.Spectral(i/float(len(df[group_col].unique())-1)) for i in range(len(df[group_col].unique()))]
    
    for c, group in zip(colors, df[group_col].unique()):
        sns.barplot(x='Users', y='Stage', data=df.loc[df[group_col]==group, :], order=order_of_bars, color=c, label=group)
    
    # Decorations    
    plt.xlabel("$Users$")
    plt.ylabel("Stage of Purchase")
    plt.yticks(fontsize=12)
    plt.title("Population Pyramid of the Marketing Funnel", fontsize=22)
    plt.legend()
    plt.show()
    

    30. Categorical Plots

    Categorical plots provided by the seaborn library can be used to visualize the counts distribution of 2 ore more categorical variables in relation to each other.

    # Load Dataset
    titanic = sns.load_dataset("titanic")
    
    # Plot
    g = sns.catplot("alive", col="deck", col_wrap=4,
                    data=titanic[titanic.deck.notnull()],
                    kind="count", height=3.5, aspect=.8, 
                    palette='tab20')
    
    fig.suptitle('sf')
    plt.show()
    

    # Load Dataset
    titanic = sns.load_dataset("titanic")
    
    # Plot
    sns.catplot(x="age", y="embark_town",
                hue="sex", col="class",
                data=titanic[titanic.embark_town.notnull()],
                orient="h", height=5, aspect=1, palette="tab10",
                kind="violin", dodge=True, cut=0, bw=.2)
    

    Composition

    31. Waffle Chart

    The waffle chart can be created using the pywaffle package and is used to show the compositions of groups in a larger population.

    Show Code
    #! pip install pywaffle
    # Reference: https://stackoverflow.com/questions/41400136/how-to-do-waffle-charts-in-python-square-piechart
    from pywaffle import Waffle
    
    # Import
    df_raw = pd.read_csv("https://github.com/selva86/datasets/raw/master/mpg_ggplot2.csv")
    
    # Prepare Data
    df = df_raw.groupby('class').size().reset_index(name='counts')
    n_categories = df.shape[0]
    colors = [plt.cm.inferno_r(i/float(n_categories)) for i in range(n_categories)]
    
    # Draw Plot and Decorate
    fig = plt.figure(
        FigureClass=Waffle,
        plots={
            '111': {
                'values': df['counts'],
                'labels': ["{0} ({1})".format(n[0], n[1]) for n in df[['class', 'counts']].itertuples()],
                'legend': {'loc': 'upper left', 'bbox_to_anchor': (1.05, 1), 'fontsize': 12},
                'title': {'label': '# Vehicles by Class', 'loc': 'center', 'fontsize':18}
            },
        },
        rows=7,
        colors=colors,
        figsize=(16, 9)
    )
    

    Show Code
    #! pip install pywaffle
    from pywaffle import Waffle
    
    # Import
    # df_raw = pd.read_csv("https://github.com/selva86/datasets/raw/master/mpg_ggplot2.csv")
    
    # Prepare Data
    # By Class Data
    df_class = df_raw.groupby('class').size().reset_index(name='counts_class')
    n_categories = df_class.shape[0]
    colors_class = [plt.cm.Set3(i/float(n_categories)) for i in range(n_categories)]
    
    # By Cylinders Data
    df_cyl = df_raw.groupby('cyl').size().reset_index(name='counts_cyl')
    n_categories = df_cyl.shape[0]
    colors_cyl = [plt.cm.Spectral(i/float(n_categories)) for i in range(n_categories)]
    
    # By Make Data
    df_make = df_raw.groupby('manufacturer').size().reset_index(name='counts_make')
    n_categories = df_make.shape[0]
    colors_make = [plt.cm.tab20b(i/float(n_categories)) for i in range(n_categories)]
    
    
    # Draw Plot and Decorate
    fig = plt.figure(
        FigureClass=Waffle,
        plots={
            '311': {
                'values': df_class['counts_class'],
                'labels': ["{1}".format(n[0], n[1]) for n in df_class[['class', 'counts_class']].itertuples()],
                'legend': {'loc': 'upper left', 'bbox_to_anchor': (1.05, 1), 'fontsize': 12, 'title':'Class'},
                'title': {'label': '# Vehicles by Class', 'loc': 'center', 'fontsize':18},
                'colors': colors_class
            },
            '312': {
                'values': df_cyl['counts_cyl'],
                'labels': ["{1}".format(n[0], n[1]) for n in df_cyl[['cyl', 'counts_cyl']].itertuples()],
                'legend': {'loc': 'upper left', 'bbox_to_anchor': (1.05, 1), 'fontsize': 12, 'title':'Cyl'},
                'title': {'label': '# Vehicles by Cyl', 'loc': 'center', 'fontsize':18},
                'colors': colors_cyl
            },
            '313': {
                'values': df_make['counts_make'],
                'labels': ["{1}".format(n[0], n[1]) for n in df_make[['manufacturer', 'counts_make']].itertuples()],
                'legend': {'loc': 'upper left', 'bbox_to_anchor': (1.05, 1), 'fontsize': 12, 'title':'Manufacturer'},
                'title': {'label': '# Vehicles by Make', 'loc': 'center', 'fontsize':18},
                'colors': colors_make
            }
        },
        rows=9,
        figsize=(16, 14)
    )
    

    32. Pie Chart

    Pie chart is a classic way to show the composition of groups. However, its not generally advisable to use nowadays because the area of the pie portions can sometimes become misleading. So, if you are to use pie chart, its highly recommended to explicitly write down the percentage or numbers for each portion of the pie.

    # Import
    df_raw = pd.read_csv("https://github.com/selva86/datasets/raw/master/mpg_ggplot2.csv")
    
    # Prepare Data
    df = df_raw.groupby('class').size()
    
    # Make the plot with pandas
    df.plot(kind='pie', subplots=True, figsize=(8, 8), dpi= 80)
    plt.title("Pie Chart of Vehicle Class - Bad")
    plt.ylabel("")
    plt.show()
    

    Show Code
    # Import
    df_raw = pd.read_csv("https://github.com/selva86/datasets/raw/master/mpg_ggplot2.csv")
    
    # Prepare Data
    df = df_raw.groupby('class').size().reset_index(name='counts')
    
    # Draw Plot
    fig, ax = plt.subplots(figsize=(12, 7), subplot_kw=dict(aspect="equal"), dpi= 80)
    
    data = df['counts']
    categories = df['class']
    explode = [0,0,0,0,0,0.1,0]
    
    def func(pct, allvals):
        absolute = int(pct/100.*np.sum(allvals))
        return "{:.1f}% ({:d} )".format(pct, absolute)
    
    wedges, texts, autotexts = ax.pie(data, 
                                      autopct=lambda pct: func(pct, data),
                                      textprops=dict(color="w"), 
                                      colors=plt.cm.Dark2.colors,
                                     startangle=140,
                                     explode=explode)
    
    # Decoration
    ax.legend(wedges, categories, title="Vehicle Class", loc="center left", bbox_to_anchor=(1, 0, 0.5, 1))
    plt.setp(autotexts, size=10, weight=700)
    ax.set_title("Class of Vehicles: Pie Chart")
    plt.show()
    

    33. Treemap

    Tree map is similar to a pie chart and it does a better work without misleading the contributions by each group.

    # pip install squarify
    import squarify 
    
    # Import Data
    df_raw = pd.read_csv("https://github.com/selva86/datasets/raw/master/mpg_ggplot2.csv")
    
    # Prepare Data
    df = df_raw.groupby('class').size().reset_index(name='counts')
    labels = df.apply(lambda x: str(x[0]) + "\n (" + str(x[1]) + ")", axis=1)
    sizes = df['counts'].values.tolist()
    colors = [plt.cm.Spectral(i/float(len(labels))) for i in range(len(labels))]
    
    # Draw Plot
    plt.figure(figsize=(12,8), dpi= 80)
    squarify.plot(sizes=sizes, label=labels, color=colors, alpha=.8)
    
    # Decorate
    plt.title('Treemap of Vechile Class')
    plt.axis('off')
    plt.show()
    

    34. Bar Chart

    Bar chart is a classic way of visualizing items based on counts or any given metric. In below chart, I have used a different color for each item, but you might typically want to pick one color for all items unless you to color them by groups. The color names get stored inside all_colors in the code below. You can change the color of the bars by setting the color parameter in plt.plot().

    Show Code
    import random
    
    # Import Data
    df_raw = pd.read_csv("https://github.com/selva86/datasets/raw/master/mpg_ggplot2.csv")
    
    # Prepare Data
    df = df_raw.groupby('manufacturer').size().reset_index(name='counts')
    n = df['manufacturer'].unique().__len__()+1
    all_colors = list(plt.cm.colors.cnames.keys())
    random.seed(100)
    c = random.choices(all_colors, k=n)
    
    # Plot Bars
    plt.figure(figsize=(16,10), dpi= 80)
    plt.bar(df['manufacturer'], df['counts'], color=c, width=.5)
    for i, val in enumerate(df['counts'].values):
        plt.text(i, val, float(val), horizontalalignment='center', verticalalignment='bottom', fontdict={'fontweight':500, 'size':12})
    
    # Decoration
    plt.gca().set_xticklabels(df['manufacturer'], rotation=60, horizontalalignment= 'right')
    plt.title("Number of Vehicles by Manaufacturers", fontsize=22)
    plt.ylabel('# Vehicles')
    plt.ylim(0, 45)
    plt.show()
    

    Change

    35. Time Series Plot

    Time series plot is used to visualise how a given metric changes over time. Here you can see how the Air Passenger traffic changed between 1949 and 1969.

    Show Code
    # Import Data
    df = pd.read_csv('https://github.com/selva86/datasets/raw/master/AirPassengers.csv')
    
    # Draw Plot
    plt.figure(figsize=(16,10), dpi= 80)
    plt.plot('date', 'traffic', data=df, color='tab:red')
    
    # Decoration
    plt.ylim(50, 750)
    xtick_location = df.index.tolist()[::12]
    xtick_labels = [x[-4:] for x in df.date.tolist()[::12]]
    plt.xticks(ticks=xtick_location, labels=xtick_labels, rotation=0, fontsize=12, horizontalalignment='center', alpha=.7)
    plt.yticks(fontsize=12, alpha=.7)
    plt.title("Air Passengers Traffic (1949 - 1969)", fontsize=22)
    plt.grid(axis='both', alpha=.3)
    
    # Remove borders
    plt.gca().spines["top"].set_alpha(0.0)    
    plt.gca().spines["bottom"].set_alpha(0.3)
    plt.gca().spines["right"].set_alpha(0.0)    
    plt.gca().spines["left"].set_alpha(0.3)   
    plt.show()
    

    36. Time Series with Peaks and Troughs Annotated

    The below time series plots all the the peaks and troughs and annotates the occurence of selected special events.

    Show Code
    # Import Data
    df = pd.read_csv('https://github.com/selva86/datasets/raw/master/AirPassengers.csv')
    
    # Get the Peaks and Troughs
    data = df['traffic'].values
    doublediff = np.diff(np.sign(np.diff(data)))
    peak_locations = np.where(doublediff == -2)[0] + 1
    
    doublediff2 = np.diff(np.sign(np.diff(-1*data)))
    trough_locations = np.where(doublediff2 == -2)[0] + 1
    
    # Draw Plot
    plt.figure(figsize=(16,10), dpi= 80)
    plt.plot('date', 'traffic', data=df, color='tab:blue', label='Air Traffic')
    plt.scatter(df.date[peak_locations], df.traffic[peak_locations], marker=mpl.markers.CARETUPBASE, color='tab:green', s=100, label='Peaks')
    plt.scatter(df.date[trough_locations], df.traffic[trough_locations], marker=mpl.markers.CARETDOWNBASE, color='tab:red', s=100, label='Troughs')
    
    # Annotate
    for t, p in zip(trough_locations[1::5], peak_locations[::3]):
        plt.text(df.date[p], df.traffic[p]+15, df.date[p], horizontalalignment='center', color='darkgreen')
        plt.text(df.date[t], df.traffic[t]-35, df.date[t], horizontalalignment='center', color='darkred')
    
    # Decoration
    plt.ylim(50,750)
    xtick_location = df.index.tolist()[::6]
    xtick_labels = df.date.tolist()[::6]
    plt.xticks(ticks=xtick_location, labels=xtick_labels, rotation=90, fontsize=12, alpha=.7)
    plt.title("Peak and Troughs of Air Passengers Traffic (1949 - 1969)", fontsize=22)
    plt.yticks(fontsize=12, alpha=.7)
    
    # Lighten borders
    plt.gca().spines["top"].set_alpha(.0)
    plt.gca().spines["bottom"].set_alpha(.3)
    plt.gca().spines["right"].set_alpha(.0)
    plt.gca().spines["left"].set_alpha(.3)
    
    plt.legend(loc='upper left')
    plt.grid(axis='y', alpha=.3)
    plt.show()
    

    37. Autocorrelation (ACF) and Partial Autocorrelation (PACF) Plot

    The ACF plot shows the correlation of the time series with its own lags. Each vertical line (on the autocorrelation plot) represents the correlation between the series and its lag starting from lag 0. The blue shaded region in the plot is the significance level. Those lags that lie above the blue line are the significant lags.

    So how to interpret this?

    For AirPassengers, we see upto 14 lags have crossed the blue line and so are significant. This means, the Air Passengers traffic seen upto 14 years back has an influence on the traffic seen today.

    PACF on the other had shows the autocorrelation of any given lag (of time series) against the current series, but with the contributions of the lags-inbetween removed.

    from statsmodels.graphics.tsaplots import plot_acf, plot_pacf
    
    # Import Data
    df = pd.read_csv('https://github.com/selva86/datasets/raw/master/AirPassengers.csv')
    
    # Draw Plot
    fig, (ax1, ax2) = plt.subplots(1, 2,figsize=(16,6), dpi= 80)
    plot_acf(df.traffic.tolist(), ax=ax1, lags=50)
    plot_pacf(df.traffic.tolist(), ax=ax2, lags=20)
    
    # Decorate
    # lighten the borders
    ax1.spines["top"].set_alpha(.3); ax2.spines["top"].set_alpha(.3)
    ax1.spines["bottom"].set_alpha(.3); ax2.spines["bottom"].set_alpha(.3)
    ax1.spines["right"].set_alpha(.3); ax2.spines["right"].set_alpha(.3)
    ax1.spines["left"].set_alpha(.3); ax2.spines["left"].set_alpha(.3)
    
    # font size of tick labels
    ax1.tick_params(axis='both', labelsize=12)
    ax2.tick_params(axis='both', labelsize=12)
    plt.show()
    

    38. Cross Correlation plot

    Cross correlation plot shows the lags of two time series with each other.

    Show Code
    import statsmodels.tsa.stattools as stattools
    
    # Import Data
    df = pd.read_csv('https://github.com/selva86/datasets/raw/master/mortality.csv')
    x = df['mdeaths']
    y = df['fdeaths']
    
    # Compute Cross Correlations
    ccs = stattools.ccf(x, y)[:100]
    nlags = len(ccs)
    
    # Compute the Significance level
    # ref: https://stats.stackexchange.com/questions/3115/cross-correlation-significance-in-r/3128#3128
    conf_level = 2 / np.sqrt(nlags)
    
    # Draw Plot
    plt.figure(figsize=(12,7), dpi= 80)
    
    plt.hlines(0, xmin=0, xmax=100, color='gray')  # 0 axis
    plt.hlines(conf_level, xmin=0, xmax=100, color='gray')
    plt.hlines(-conf_level, xmin=0, xmax=100, color='gray')
    
    plt.bar(x=np.arange(len(ccs)), height=ccs, width=.3)
    
    # Decoration
    plt.title('$Cross\; Correlation\; Plot:\; mdeaths\; vs\; fdeaths$', fontsize=22)
    plt.xlim(0,len(ccs))
    plt.show()
    

    39. Time Series Decomposition Plot

    Time series decomposition plot shows the break down of the time series into trend, seasonal and residual components.

    from statsmodels.tsa.seasonal import seasonal_decompose
    from dateutil.parser import parse
    
    # Import Data
    df = pd.read_csv('https://github.com/selva86/datasets/raw/master/AirPassengers.csv')
    dates = pd.DatetimeIndex([parse(d).strftime('%Y-%m-01') for d in df['date']])
    df.set_index(dates, inplace=True)
    
    # Decompose 
    result = seasonal_decompose(df['traffic'], model='multiplicative')
    
    # Plot
    plt.rcParams.update({'figure.figsize': (10,10)})
    result.plot().suptitle('Time Series Decomposition of Air Passengers')
    plt.show()
    

    40. Multiple Time Series

    You can plot multiple time series that measures the same value on the same chart as shown below.

    Show Code
    # Import Data
    df = pd.read_csv('https://github.com/selva86/datasets/raw/master/mortality.csv')
    
    # Define the upper limit, lower limit, interval of Y axis and colors
    y_LL = 100
    y_UL = int(df.iloc[:, 1:].max().max()*1.1)
    y_interval = 400
    mycolors = ['tab:red', 'tab:blue', 'tab:green', 'tab:orange']    
    
    # Draw Plot and Annotate
    fig, ax = plt.subplots(1,1,figsize=(16, 9), dpi= 80)    
    
    columns = df.columns[1:]  
    for i, column in enumerate(columns):    
        plt.plot(df.date.values, df[column].values, lw=1.5, color=mycolors[i])    
        plt.text(df.shape[0]+1, df[column].values[-1], column, fontsize=14, color=mycolors[i])
    
    # Draw Tick lines  
    for y in range(y_LL, y_UL, y_interval):    
        plt.hlines(y, xmin=0, xmax=71, colors='black', alpha=0.3, linestyles="--", lw=0.5)
    
    # Decorations    
    plt.tick_params(axis="both", which="both", bottom=False, top=False,    
                    labelbottom=True, left=False, right=False, labelleft=True)        
    
    # Lighten borders
    plt.gca().spines["top"].set_alpha(.3)
    plt.gca().spines["bottom"].set_alpha(.3)
    plt.gca().spines["right"].set_alpha(.3)
    plt.gca().spines["left"].set_alpha(.3)
    
    plt.title('Number of Deaths from Lung Diseases in the UK (1974-1979)', fontsize=22)
    plt.yticks(range(y_LL, y_UL, y_interval), [str(y) for y in range(y_LL, y_UL, y_interval)], fontsize=12)    
    plt.xticks(range(0, df.shape[0], 12), df.date.values[::12], horizontalalignment='left', fontsize=12)    
    plt.ylim(y_LL, y_UL)    
    plt.xlim(-2, 80)    
    plt.show()
    

    41. Plotting with different scales using secondary Y axis

    If you want to show two time series that measures two different quantities at the same point in time, you can plot the second series againt the secondary Y axis on the right.

    Show Code
    # Import Data
    df = pd.read_csv("https://github.com/selva86/datasets/raw/master/economics.csv")
    
    x = df['date']
    y1 = df['psavert']
    y2 = df['unemploy']
    
    # Plot Line1 (Left Y Axis)
    fig, ax1 = plt.subplots(1,1,figsize=(16,9), dpi= 80)
    ax1.plot(x, y1, color='tab:red')
    
    # Plot Line2 (Right Y Axis)
    ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis
    ax2.plot(x, y2, color='tab:blue')
    
    # Decorations
    # ax1 (left Y axis)
    ax1.set_xlabel('Year', fontsize=20)
    ax1.tick_params(axis='x', rotation=0, labelsize=12)
    ax1.set_ylabel('Personal Savings Rate', color='tab:red', fontsize=20)
    ax1.tick_params(axis='y', rotation=0, labelcolor='tab:red' )
    ax1.grid(alpha=.4)
    
    # ax2 (right Y axis)
    ax2.set_ylabel("# Unemployed (1000's)", color='tab:blue', fontsize=20)
    ax2.tick_params(axis='y', labelcolor='tab:blue')
    ax2.set_xticks(np.arange(0, len(x), 60))
    ax2.set_xticklabels(x[::60], rotation=90, fontdict={'fontsize':10})
    ax2.set_title("Personal Savings Rate vs Unemployed: Plotting in Secondary Y Axis", fontsize=22)
    fig.tight_layout()
    plt.show()
    

    42. Time Series with Error Bands

    Time series with error bands can be constructed if you have a time series dataset with multiple observations for each time point (date / timestamp). Below you can see a couple of examples based on the orders coming in at various times of the day. And another example on the number of orders arriving over a duration of 45 days.

    In this approach, the mean of the number of orders is denoted by the white line. And a 95% confidence bands are computed and drawn around the mean.

    Show Code
    from scipy.stats import sem
    
    # Import Data
    df = pd.read_csv("https://raw.githubusercontent.com/selva86/datasets/master/user_orders_hourofday.csv")
    df_mean = df.groupby('order_hour_of_day').quantity.mean()
    df_se = df.groupby('order_hour_of_day').quantity.apply(sem).mul(1.96)
    
    # Plot
    plt.figure(figsize=(16,10), dpi= 80)
    plt.ylabel("# Orders", fontsize=16)  
    x = df_mean.index
    plt.plot(x, df_mean, color="white", lw=2) 
    plt.fill_between(x, df_mean - df_se, df_mean + df_se, color="#3F5D7D")  
    
    # Decorations
    # Lighten borders
    plt.gca().spines["top"].set_alpha(0)
    plt.gca().spines["bottom"].set_alpha(1)
    plt.gca().spines["right"].set_alpha(0)
    plt.gca().spines["left"].set_alpha(1)
    plt.xticks(x[::2], [str(d) for d in x[::2]] , fontsize=12)
    plt.title("User Orders by Hour of Day (95% confidence)", fontsize=22)
    plt.xlabel("Hour of Day")
    
    s, e = plt.gca().get_xlim()
    plt.xlim(s, e)
    
    # Draw Horizontal Tick lines  
    for y in range(8, 20, 2):    
        plt.hlines(y, xmin=s, xmax=e, colors='black', alpha=0.5, linestyles="--", lw=0.5)
    
    plt.show()
    

    Show Code
    "Data Source: https://www.kaggle.com/olistbr/brazilian-ecommerce#olist_orders_dataset.csv"
    from dateutil.parser import parse
    from scipy.stats import sem
    
    # Import Data
    df_raw = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/orders_45d.csv', 
                         parse_dates=['purchase_time', 'purchase_date'])
    
    # Prepare Data: Daily Mean and SE Bands
    df_mean = df_raw.groupby('purchase_date').quantity.mean()
    df_se = df_raw.groupby('purchase_date').quantity.apply(sem).mul(1.96)
    
    # Plot
    plt.figure(figsize=(16,10), dpi= 80)
    plt.ylabel("# Daily Orders", fontsize=16)  
    x = [d.date().strftime('%Y-%m-%d') for d in df_mean.index]
    plt.plot(x, df_mean, color="white", lw=2) 
    plt.fill_between(x, df_mean - df_se, df_mean + df_se, color="#3F5D7D")  
    
    # Decorations
    # Lighten borders
    plt.gca().spines["top"].set_alpha(0)
    plt.gca().spines["bottom"].set_alpha(1)
    plt.gca().spines["right"].set_alpha(0)
    plt.gca().spines["left"].set_alpha(1)
    plt.xticks(x[::6], [str(d) for d in x[::6]] , fontsize=12)
    plt.title("Daily Order Quantity of Brazilian Retail with Error Bands (95% confidence)", fontsize=20)
    
    # Axis limits
    s, e = plt.gca().get_xlim()
    plt.xlim(s, e-2)
    plt.ylim(4, 10)
    
    # Draw Horizontal Tick lines  
    for y in range(5, 10, 1):    
        plt.hlines(y, xmin=s, xmax=e, colors='black', alpha=0.5, linestyles="--", lw=0.5)
    
    plt.show()
    

    43. Stacked Area Chart

    Stacked area chart gives an visual representation of the extent of contribution from multiple time series so that it is easy to compare against each other.

    Show Code
    # Import Data
    df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/nightvisitors.csv')
    
    # Decide Colors 
    mycolors = ['tab:red', 'tab:blue', 'tab:green', 'tab:orange', 'tab:brown', 'tab:grey', 'tab:pink', 'tab:olive']      
    
    # Draw Plot and Annotate
    fig, ax = plt.subplots(1,1,figsize=(16, 9), dpi= 80)
    columns = df.columns[1:]
    labs = columns.values.tolist()
    
    # Prepare data
    x  = df['yearmon'].values.tolist()
    y0 = df[columns[0]].values.tolist()
    y1 = df[columns[1]].values.tolist()
    y2 = df[columns[2]].values.tolist()
    y3 = df[columns[3]].values.tolist()
    y4 = df[columns[4]].values.tolist()
    y5 = df[columns[5]].values.tolist()
    y6 = df[columns[6]].values.tolist()
    y7 = df[columns[7]].values.tolist()
    y = np.vstack([y0, y2, y4, y6, y7, y5, y1, y3])
    
    # Plot for each column
    labs = columns.values.tolist()
    ax = plt.gca()
    ax.stackplot(x, y, labels=labs, colors=mycolors, alpha=0.8)
    
    # Decorations
    ax.set_title('Night Visitors in Australian Regions', fontsize=18)
    ax.set(ylim=[0, 100000])
    ax.legend(fontsize=10, ncol=4)
    plt.xticks(x[::5], fontsize=10, horizontalalignment='center')
    plt.yticks(np.arange(10000, 100000, 20000), fontsize=10)
    plt.xlim(x[0], x[-1])
    
    # Lighten borders
    plt.gca().spines["top"].set_alpha(0)
    plt.gca().spines["bottom"].set_alpha(.3)
    plt.gca().spines["right"].set_alpha(0)
    plt.gca().spines["left"].set_alpha(.3)
    
    plt.show()
    

    44. Area Chart UnStacked

    An unstacked area chart is used to visualize the progress (ups and downs) of two or more series with respect to each other. In the chart below, you can clearly see how the personal savings rate comes down as the median duration of unemployment increases. The unstacked area chart brings out this phenomenon nicely.

    Show Code
    # Import Data
    df = pd.read_csv("https://github.com/selva86/datasets/raw/master/economics.csv")
    
    # Prepare Data
    x = df['date'].values.tolist()
    y1 = df['psavert'].values.tolist()
    y2 = df['uempmed'].values.tolist()
    mycolors = ['tab:red', 'tab:blue', 'tab:green', 'tab:orange', 'tab:brown', 'tab:grey', 'tab:pink', 'tab:olive']      
    columns = ['psavert', 'uempmed']
    
    # Draw Plot 
    fig, ax = plt.subplots(1, 1, figsize=(16,9), dpi= 80)
    ax.fill_between(x, y1=y1, y2=0, label=columns[1], alpha=0.5, color=mycolors[1], linewidth=2)
    ax.fill_between(x, y1=y2, y2=0, label=columns[0], alpha=0.5, color=mycolors[0], linewidth=2)
    
    # Decorations
    ax.set_title('Personal Savings Rate vs Median Duration of Unemployment', fontsize=18)
    ax.set(ylim=[0, 30])
    ax.legend(loc='best', fontsize=12)
    plt.xticks(x[::50], fontsize=10, horizontalalignment='center')
    plt.yticks(np.arange(2.5, 30.0, 2.5), fontsize=10)
    plt.xlim(-10, x[-1])
    
    # Draw Tick lines  
    for y in np.arange(2.5, 30.0, 2.5):    
        plt.hlines(y, xmin=0, xmax=len(x), colors='black', alpha=0.3, linestyles="--", lw=0.5)
    
    # Lighten borders
    plt.gca().spines["top"].set_alpha(0)
    plt.gca().spines["bottom"].set_alpha(.3)
    plt.gca().spines["right"].set_alpha(0)
    plt.gca().spines["left"].set_alpha(.3)
    plt.show()
    

    45. Calendar Heat Map

    Calendar map is an alternate and a less preferred option to visualise time based data compared to a time series. Though can be visually appealing, the numeric values are not quite evident. It is however effective in picturising the extreme values and holiday effects nicely.

    import matplotlib as mpl
    import calmap
    
    # Import Data
    df = pd.read_csv("https://raw.githubusercontent.com/selva86/datasets/master/yahoo.csv", parse_dates=['date'])
    df.set_index('date', inplace=True)
    
    # Plot
    plt.figure(figsize=(16,10), dpi= 80)
    calmap.calendarplot(df['2014']['VIX.Close'], fig_kws={'figsize': (16,10)}, yearlabel_kws={'color':'black', 'fontsize':14}, subplot_kws={'title':'Yahoo Stock Prices'})
    plt.show()
    

    46. Seasonal Plot

    The seasonal plot can be used to compare how the time series performed at same day in the previous season (year / month / week etc).

    Show Code
    from dateutil.parser import parse 
    
    # Import Data
    df = pd.read_csv('https://github.com/selva86/datasets/raw/master/AirPassengers.csv')
    
    # Prepare data
    df['year'] = [parse(d).year for d in df.date]
    df['month'] = [parse(d).strftime('%b') for d in df.date]
    years = df['year'].unique()
    
    # Draw Plot
    mycolors = ['tab:red', 'tab:blue', 'tab:green', 'tab:orange', 'tab:brown', 'tab:grey', 'tab:pink', 'tab:olive', 'deeppink', 'steelblue', 'firebrick', 'mediumseagreen']      
    plt.figure(figsize=(16,10), dpi= 80)
    
    for i, y in enumerate(years):
        plt.plot('month', 'traffic', data=df.loc[df.year==y, :], color=mycolors[i], label=y)
        plt.text(df.loc[df.year==y, :].shape[0]-.9, df.loc[df.year==y, 'traffic'][-1:].values[0], y, fontsize=12, color=mycolors[i])
    
    # Decoration
    plt.ylim(50,750)
    plt.xlim(-0.3, 11)
    plt.ylabel('$Air Traffic$')
    plt.yticks(fontsize=12, alpha=.7)
    plt.title("Monthly Seasonal Plot: Air Passengers Traffic (1949 - 1969)", fontsize=22)
    plt.grid(axis='y', alpha=.3)
    
    # Remove borders
    plt.gca().spines["top"].set_alpha(0.0)    
    plt.gca().spines["bottom"].set_alpha(0.5)
    plt.gca().spines["right"].set_alpha(0.0)    
    plt.gca().spines["left"].set_alpha(0.5)   
    # plt.legend(loc='upper right', ncol=2, fontsize=12)
    plt.show()
    

    Groups

    47. Dendrogram

    A Dendrogram groups similar points together based on a given distance metric and organizes them in tree like links based on the point’s similarity.

    import scipy.cluster.hierarchy as shc
    
    # Import Data
    df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/USArrests.csv')
    
    # Plot
    plt.figure(figsize=(16, 10), dpi= 80)  
    plt.title("USArrests Dendograms", fontsize=22)  
    dend = shc.dendrogram(shc.linkage(df[['Murder', 'Assault', 'UrbanPop', 'Rape']], method='ward'), labels=df.State.values, color_threshold=100)  
    plt.xticks(fontsize=12)
    plt.show()
    

    48. Cluster Plot

    Cluster Plot canbe used to demarcate points that belong to the same cluster. Below is a representational example to group the US states into 5 groups based on the USArrests dataset. This cluster plot uses the ‘murder’ and ‘assault’ columns as X and Y axis. Alternately you can use the first to principal components as rthe X and Y axis.

    Show Code
    from sklearn.cluster import AgglomerativeClustering
    from scipy.spatial import ConvexHull
    
    # Import Data
    df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/USArrests.csv')
    
    # Agglomerative Clustering
    cluster = AgglomerativeClustering(n_clusters=5, affinity='euclidean', linkage='ward')  
    cluster.fit_predict(df[['Murder', 'Assault', 'UrbanPop', 'Rape']])  
    
    # Plot
    plt.figure(figsize=(14, 10), dpi= 80)  
    plt.scatter(df.iloc[:,0], df.iloc[:,1], c=cluster.labels_, cmap='tab10')  
    
    # Encircle
    def encircle(x,y, ax=None, **kw):
        if not ax: ax=plt.gca()
        p = np.c_[x,y]
        hull = ConvexHull(p)
        poly = plt.Polygon(p[hull.vertices,:], **kw)
        ax.add_patch(poly)
    
    # Draw polygon surrounding vertices    
    encircle(df.loc[cluster.labels_ == 0, 'Murder'], df.loc[cluster.labels_ == 0, 'Assault'], ec="k", fc="gold", alpha=0.2, linewidth=0)
    encircle(df.loc[cluster.labels_ == 1, 'Murder'], df.loc[cluster.labels_ == 1, 'Assault'], ec="k", fc="tab:blue", alpha=0.2, linewidth=0)
    encircle(df.loc[cluster.labels_ == 2, 'Murder'], df.loc[cluster.labels_ == 2, 'Assault'], ec="k", fc="tab:red", alpha=0.2, linewidth=0)
    encircle(df.loc[cluster.labels_ == 3, 'Murder'], df.loc[cluster.labels_ == 3, 'Assault'], ec="k", fc="tab:green", alpha=0.2, linewidth=0)
    encircle(df.loc[cluster.labels_ == 4, 'Murder'], df.loc[cluster.labels_ == 4, 'Assault'], ec="k", fc="tab:orange", alpha=0.2, linewidth=0)
    
    # Decorations
    plt.xlabel('Murder'); plt.xticks(fontsize=12)
    plt.ylabel('Assault'); plt.yticks(fontsize=12)
    plt.title('Agglomerative Clustering of USArrests (5 Groups)', fontsize=22)
    plt.show()
    

    49. Andrews Curve

    Andrews Curve helps visualize if there are inherent groupings of the numerical features based on a given grouping. If the features (columns in the dataset) doesn’t help discriminate the group (cyl), then the lines will not be well segregated as you see below.

    from pandas.plotting import andrews_curves
    
    # Import
    df = pd.read_csv("https://github.com/selva86/datasets/raw/master/mtcars.csv")
    df.drop(['cars', 'carname'], axis=1, inplace=True)
    
    # Plot
    plt.figure(figsize=(12,9), dpi= 80)
    andrews_curves(df, 'cyl', colormap='Set1')
    
    # Lighten borders
    plt.gca().spines["top"].set_alpha(0)
    plt.gca().spines["bottom"].set_alpha(.3)
    plt.gca().spines["right"].set_alpha(0)
    plt.gca().spines["left"].set_alpha(.3)
    
    plt.title('Andrews Curves of mtcars', fontsize=22)
    plt.xlim(-3,3)
    plt.grid(alpha=0.3)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.show()
    

    50. Parallel Coordinates

    Parallel coordinates helps to visualize if a feature helps to segregate the groups effectively. If a segregation is effected, that feature is likely going to be very useful in predicting that group.

    from pandas.plotting import parallel_coordinates
    
    # Import Data
    df_final = pd.read_csv("https://raw.githubusercontent.com/selva86/datasets/master/diamonds_filter.csv")
    
    # Plot
    plt.figure(figsize=(12,9), dpi= 80)
    parallel_coordinates(df_final, 'cut', colormap='Dark2')
    
    # Lighten borders
    plt.gca().spines["top"].set_alpha(0)
    plt.gca().spines["bottom"].set_alpha(.3)
    plt.gca().spines["right"].set_alpha(0)
    plt.gca().spines["left"].set_alpha(.3)
    
    plt.title('Parallel Coordinated of Diamonds', fontsize=22)
    plt.grid(alpha=0.3)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.show()
    

    That’s all for now! If you encounter some error or bug please notify here.

    List Comprehensions in Python – My Simplified Guide

    List comprehensions is a pythonic way of expressing a ‘For Loop’ that appends to a list in a single line of code. It is an intuitive, easy-to-read and a very convenient way of creating lists. This is a beginner friendly post for those who know how to write for-loops in python but don’t quite understand how list comprehensions work, yet. If you are already familiar with list comprehensions and want to get some practice, try solving the practice exercises at the bottom.

    Contents

    1. Introduction
    2. Typical format of List Comprehensions
    — Type 1: Simple For-loop
    — Type 2: For-loop with conditional filtering
    — Type 3: for-loop with ‘if’ and ‘else’ condition
    — Type 4: Multiple for-loops
    — Type 5: Paired outputs
    — Type 6: Dictionary Comprehensions
    — Type 7: Tokenizing sentences into list of words
    3. Practice Exercises (increasing level of difficulty)
    4. Conclusion

    1. Introduction

    List comprehensions is a pythonic way of expressing a ‘for-loop’ that appends to a list in a single line of code.

    So how does a list comprehension look? Let’s write one to create a list of even numbers between 0 and 9:

    [i for i in range(10) if i%2 == 0]
    #> [0, 2, 4, 6, 8]
    

    That was easy to read and quite intuitive to understand. Yes?

    And below is the for-loop equivalent for the same logic:

    result = []
    for i in range(10):
        if i%2 == 0:
            result.append(i)
    

    I prefer list comprehensions because it’s easier to read, requires lesser keystrokes and it usually runs faster as well.

    The best way to learn list comprehensions is by studying examples of converting for-loops and practicing sample problems.

    In this post, you will see how to compose list comprehensions and obtain different types of outputs. Upon solving the practice exercises that comes at the bottom of the post, you will intuitively understand how to construct list comprehensions and make it a habit.

    2. Typical format of List Comprehensions

    A list comprehension typically has 3 components:

    • The output (which can be string, number, list or any object you want to put in the list.)
    • For Statements
    • Conditional filtering (optional).

    Below is a typical format of a list comprehension.

    List Comprehension Python - General Format

    However, this format is not a golden rule.

    Because there can be logics that can have multiple ‘for-statements’ and ‘if conditions’ and they can change positions as well. The only thing that does not change however is the position of the output value, which always comes at the beginning.

    Next, Let’s see examples of 7 different types of problems where you can use list comprehensions instead of for-loops.

    Example Type 1: Simple for-loop

    Problem Statement: Square each number in mylist and store the result as a list.

    The ‘For Loop’ iterates over each number, squares the number and appends to a list.

    For Loop Version:

    mylist = [1,2,3,4,5]
    
    # For Loop Version
    result = []
    for i in mylist:
        result.append(i**2)
    
    print(result)
    #> [1, 4, 9, 16, 25]
    

    How to convert this to a list comprehension? Take the output in the same line as the for condition and enclose the whole thing in a pair of [ .. ].

    List comprehension solution:

    result = [i**2 for i in [1,2,3,4,5]]
    print(result)
    #> [1, 4, 9, 16, 25]
    

    Example Type 2: for-loop with conditional filtering

    What if you have an if condition in the for loop? Say, you want to square only the even numbers:

    Problem statement: Square only the even numbers in mylist and store the result in a list.

    For Loop Version:

    mylist = [1,2,3,4,5]
    
    # For Loop Version
    result = []
    for i in mylist:
        if i%2==0:
            result.append(i**2)
    
    print(result)
    #> [4, 16]
    

    In list comprehension, we add the ‘if condition’ after the for-loop if you want to filter the items.

    List Comprehension solution:

    # List Comprehension Version
    result = [i**2 for i in [1,2,3,4,5] if i%2==0]
    print(result)
    #> [4, 16]
    

    Example Type 3: for-loop with ‘if’ and ‘else’ condition

    Let’s see a case where you have an ‘if-else’ condition in the for-loop.

    Problem Statement: In mylist, square the number if its even, else, cube it.

    For Loop Version:

    mylist = [1,2,3,4,5]
    # For Loop Version
    result = []
    for i in mylist:
        if i%2==0:
            result.append(i**2)
        else:
            result.append(i**3)
    
    print(result)
    #> [1, 4, 27, 16, 125]
    

    In previous example, we wanted to filter the even numbers. But in this case, there is no filtering. So put the if and else before the for-loop itself.

    List Comprehension solution:

    [i**2 if i%2==0 else i**3 for i in [1,2,3,4,5]]
    #> [1, 4, 27, 16, 125]
    

    Example Type 4: Multiple for-loops

    Now let’s see a slightly complicated example that involves two For-Loops.

    Problem Statement: Flatten the matrix mat (a list of lists) keeping only the even numbers.

    For Loop Version:

    # For Loop Version
    mat = [[1,2,3,4], [5,6,7,8], [9,10,11,12], [13,14,15,16]]
    result = []
    for row in mat:
        for i in row:
            if i%2 == 0:
                result.append(i)
    
    print(result)
    #> [2, 4, 6, 8, 10, 12, 14, 16]
    

    Can you imagine what the equivalent list comprehension version would look like? It’s nearly the same as writing the lines of the for-loop one after the other.

    List Comprehension solution=:

    # List Comprehension version
    [i for row in mat for i in row if i%2==0]
    #> [2, 4, 6, 8, 10, 12, 14, 16]
    

    Hope you are getting a feel of list comprehensions. Let’s do one more example.

    Example Type 5: Paired outputs

    Problem Statement: For each number in list_b, get the number and its position in mylist as a list of tuples.

    For-Loop Version:

    mylist = [9, 3, 6, 1, 5, 0, 8, 2, 4, 7]
    list_b = [6, 4, 6, 1, 2, 2]
    
    result = []
    for i in list_b:
        result.append((i, mylist.index(i)))
    
    print(result)
    #> [(6, 2), (4, 8), (6, 2), (1, 3), (2, 7), (2, 7)]
    

    List Comprehension solution:

    , In this case, the output has 2 items instead of one. So pair both of them as a tuple and place it before the for statement.

    [(i, mylist.index(i)) for i in list_b]
    #> [(6, 2), (4, 8), (6, 2), (1, 3), (2, 7), (2, 7)]
    

    Example Type 6: Dictionary Comprehensions

    Same problem as previous example but output is a dictionary instead of a list of tuples.

    Problem Statement: For each number in list_b, get the number and its position in mylist as a dict.

    For Loop Version:

    mylist = [9, 3, 6, 1, 5, 0, 8, 2, 4, 7]
    list_b = [6, 4, 6, 1, 2, 2]
    
    result = {}
    for i in list_b:
        result[i]=mylist.index(i)
    
    print(result)
    #> {6: 2, 4: 8, 1: 3, 2: 7}
    

    List Comprehension solution:

    To make a dictionary output, you just need to replace the square brackets with curly brackets. And use a : instead of a comma between the pairs.

    {i: mylist.index(i) for i in list_b}
    #> {6: 2, 4: 8, 1: 3, 2: 7}
    

    Example Type 7: Tokenizing sentences into list of words

    This is a slightly different way of applying list comprehension.

    Problem Statement: The goal is to tokenize the following 5 sentences into words, excluding the stop words.

    Input:

    sentences = ["a new world record was set", 
                 "in the holy city of ayodhya", 
                 "on the eve of diwali on tuesday", 
                 "with over three lakh diya or earthen lamps", 
                 "lit up simultaneously on the banks of the sarayu river"]
    
    stopwords = ['for', 'a', 'of', 'the', 'and', 'to', 'in', 'on', 'with']
    

    For Loop Version:

    # For Loop Version
    results = []    
    for sentence in sentences:
        sentence_tokens = []
        for word in sentence.split(' '):
            if word not in stopwords:
                sentence_tokens.append(word)
        results.append(sentence_tokens)
    
    print(results)
    #> [['new', 'world', 'record', 'was', 'set'], 
    #> ['holy', 'city', 'ayodhya'], 
    #> ['eve', 'diwali', 'tuesday'], 
    #> ['over', 'three', 'lakh', 'diya', 'or', 'earthen', 'lamps'], 
    #> ['lit', 'up', 'simultaneously', 'banks', 'sarayu', 'river']]
    

    Before reading ahead, can you try creating the equivalent list comprehension version? (..pause and try it out in your py console..).

    List Comprehension solution:

    If you wanted to flatten out the words in the sentences, then the solution would have been something like this:

    results = [word for sentence in sentences for word in sentence.split(' ') if word not in stopwords]
    print(results)
    #> ['new', 'world', 'record', 'was', 'set', 'holy', 'city', 'ayodhya', 'eve', 'diwali', 'tuesday', 'over', 'three', 'lakh', 'diya', 'or', 'earthen', 'lamps', 'lit', 'up', 'simultaneously', 'banks', 'sarayu', 'river']
    

    But we want to distinguish which words belong to which sentence, that is the original grouping of sentences should remain intact as a list.

    To achieve this, the entire second unit of for-loop, that is, the [word for word in sentence.split(' ') if word not in stopwords] part should be considered as an output and therefore will go at the beginning of the list comprehension.

    # List Comprehension Version
    [[word for word in sentence.split(' ') if word not in stopwords] for sentence in sentences]
    #> [['new', 'world', 'record', 'was', 'set'],
    #>  ['holy', 'city', 'ayodhya'],
    #>  ['eve', 'diwali', 'tuesday'],
    #>  ['over', 'three', 'lakh', 'diya', 'or', 'earthen', 'lamps'],
    #>  ['lit', 'up', 'simultaneously', 'banks', 'sarayu', 'river']]
    

    3. Practice Exercises (increasing level of difficulty)

    Question 1. Given a 1D list, negate all elements which are between 3 and 8, using list comprehensions

    # Input
    mylist = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    
    # Desired Output
    [1, 2, -3, -4, -5, -6, -7, -8, 9, 10]
    
    Show Answer
    mylist = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    [-i if  3 <= i <= 8 else i for i in mylist]
    

     

    Question 2: Make a dictionary of the 26 english alphabets mapping each with the corresponding integer.

    # Desired output
    {'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6,
     'g': 7, 'h': 8, 'i': 9, 'j': 10, 'k': 11, 'l': 12,
     'm': 13, 'n': 14, 'o': 15, 'p': 16, 'q': 17, 'r': 18,
     's': 19, 't': 20, 'u': 21, 'v': 22, 'w': 23, 'x': 24,
     'y': 25, 'z': 26}
    
    Show Answer
    import string
    {a:i+1 for a,i in zip(string.ascii_letters[:26], range(26))}
    

     

    Question 3: Replace all alphabets in the string ‘Lee Quan Yew’, by substituting the alphabet with the corresponding numbers, like 1 for ‘a’, 2 for ‘b’ and so on.

    Desired Output:

    [12, 5, 5, ' ', 17, 21, 1, 14, ' ', 25, 5, 23]
    
    Show Answer
    import string
    d = {a:i+1 for a,i in zip(string.ascii_lowercase, range(26))}
    [d.get(a.lower(), ' ') for a in 'Lee Quan Yew']
    

     

    Question 4: Get the unique list of words from the following sentences, excluding any stopwords.

    sentences = ["The Hubble Space telescope has spotted", 
                 "a formation of galaxies that resembles", 
                 "a smiling face in the sky"]
    
    # Desired output:
    {'face', 'formation', 'galaxies', 'has', 'hubble', 'resembles',
     'sky', 'smiling', 'space', 'spotted', 'telescope', 'that', 'the'}
    
    Show Answer
    {word.lower() for sentence in sentences for word in sentence.split(' ') if word not in stopwords}
    

     

    Question 5: Tokenize the following sentences excluding all stopwords and punctuations.

    sentences = ["The Hubble Space telescope has spotted", 
                 "a formation of galaxies that resembles", 
                 "a smiling face in the sky", 
                 "The image taken with the Wide Field Camera", 
                 "shows a patch of space filled with galaxies", 
                 "of all shapes, colours and sizes"]
    
    stopwords = ['for', 'a', 'of', 'the', 'and', 'to', 'in', 'on', 'with']
    
    # Desired Output
    #> [['the', 'hubble', 'space', 'telescope', 'has', 'spotted'],
    #>  ['formation', 'galaxies', 'that', 'resembles'],
    #>  ['smiling', 'face', 'sky'],
    #>  ['the', 'image', 'taken', 'wide', 'field', 'camera'],
    #>  ['shows', 'patch', 'space', 'filled', 'galaxies'],
    #>  ['all', 'shapes,', 'colours', 'sizes']]
    
    Show Answer
    [[word.lower() for word in sentence.split(' ') if word not in stopwords] for sentence in sentences]
    

     

    Question 6: Create a list of (word:id) pairs for all words in the following sentences, where id is the sentence index.

    # Input
    sentences = ["The Hubble Space telescope has spotted", 
                 "a formation of galaxies that resembles", 
                 "a smiling face in the sky"]
    
    # Desired output:
    [('the', 0), ('hubble', 0), ('space', 0), ('telescope', 0), ('has', 0), ('spotted', 0),
     ('a', 1), ('formation', 1), ('of', 1), ('galaxies', 1), ('that', 1), ('resembles', 1),
     ('a', 2), ('smiling', 2), ('face', 2), ('in', 2), ('the', 2), ('sky', 2)]
    
    Show Answer
    [(word.lower(), i) for i, sentence in enumerate(sentences) for word in sentence.split(' ')]
    

     

    Question 7: Print the inner positions of the 64 squares in a chess board, replacing the boundary squares with the string ‘—-‘.

    # Desired Output:
    [['----', '----', '----', '----', '----', '----', '----', '----'],
     ['----', (1, 1), (2, 1), (3, 1), (4, 1), (5, 1), (6, 1), '----'],
     ['----', (1, 2), (2, 2), (3, 2), (4, 2), (5, 2), (6, 2), '----'],
     ['----', (1, 3), (2, 3), (3, 3), (4, 3), (5, 3), (6, 3), '----'],
     ['----', (1, 4), (2, 4), (3, 4), (4, 4), (5, 4), (6, 4), '----'],
     ['----', (1, 5), (2, 5), (3, 5), (4, 5), (5, 5), (6, 5), '----'],
     ['----', (1, 6), (2, 6), (3, 6), (4, 6), (5, 6), (6, 6), '----'],
     ['----', '----', '----', '----', '----', '----', '----', '----']]
    
    Show Answer
    [[(i,j) if (i not in (0, 7)) and (j not in (0, 7)) else ('----') for i in range(8)] for j in range(8)]
    

    4. Conclusion

    Today, we covered what is list comprehension and how to create various types of list comprehensions through examples.

    I hope you were able to solve the exercises and feel more comfortable working with list comprehensions. If you did, Congratulations and I will see you in the next one.

    Python @Property Explained – How to Use and When? (Full Examples)

    A python @property decorator lets a method to be accessed as an attribute instead of as a method with a '()'. Today, you will gain an understanding of when it is really needed, in what situations you can use it and how to actually use it.

    Contents

    1. Introduction
    2. What does @property do?
    3. When to use @property?
    4. The setter method – When and How to write one?
    5. The deleter method – When and How to write one?
    6. Conclusion

    Python @property – A Simplified Guide

    1. Introduction

    In well-written python code, you might have noticed a @property decorator just before the method definition.

    In this guide, you will understand clearly what exactly the python @property does, when to use it and how to use it. This guide, however, assumes that you have a basic idea about what python classes are. Because the @property is typically used inside one.

    2. What does @property do?

    So, what does the @property do?

    The @property lets a method to be accessed as an attribute instead of as a method with a '()'. But why is it really needed and in what situations can you use it?

    To understand this, let’s create a Person class that contains the first, last and fullname of the person as attributes and has an email() method that provides the person’s email.

    class Person():
    
        def __init__(self, firstname, lastname):
            self.first = firstname
            self.last = lastname
            self.fullname = self.first + ' '+ self.last
    
        def email(self):
            return '{}.{}@email.com'.format(self.first, self.last)
    

    Let’s create an instance of the Person ‘selva prabhakaran’ and print the attributes.

    # Create a Person object
    person = Person('selva', 'prabhakaran')
    print(person.first)  #> selva
    print(person.last)  #> prabhakaran
    print(person.fullname)  #> selva prabhakaran
    print(person.email())  #> [email protected]
    

    3. When to use @property?

    So far so good.

    Now, somehow you decide to change the last name of the person.

    Here is a fun fact about python classes: If you change the value of an attribute inside a class, the other attributes that are derived from the attribute you just changed don’t automatically update.

    For example: By changing the self.last name you might expect the self.full attribute, which is derived from self.last to update. But unexpectedly it doesn’t. This can provide potentially misleading information about the person.

    However, notice the email() works as intended, eventhough it is derived from self.last.

    # Changing the `last` name does not change `self.full` name, but email() works
    person.last = 'prasanna'
    print(person.last)  #> prasanna
    print(person.fullname)  #> selva prabhakaran
    print(person.email())  #> [email protected]
    

    So, a probable solution would be to convert the self.fullname attribute to a fullname() method, so it will provide correct value like the email() method did. Let’s do it.

    # Converting fullname to a method provides the right fullname
    # But it breaks old code that used the fullname attribute without the `()`
    class Person():
    
        def __init__(self, firstname, lastname):
            self.first = firstname
            self.last = lastname
    
        def fullname(self):
            return self.first + ' '+ self.last
    
        def email(self):
            return '{}.{}@email.com'.format(self.first, self.last)
    
    person = Person('selva', 'prabhakaran')
    print(person.fullname())  #> selva prabhakaran
    
    # change last name to Prasanna
    person.last = 'prasanna'
    
    print(person.fullname())  #> selva prasanna
    

    Now the convert to method solution works.

    But there is a problem.

    Since we are using person.fullname() method with a '()' instead of person.fullname as attribute, it will break whatever code that used the self.fullname attribute. If you are building a product/tool, the chances are, other developers and users of your module used it at some point and all their code will break as well.

    So a better solution (without breaking your user’s code) is to convert the method as a property by adding a @property decorator before the method’s definition. By doing this, the fullname() method can be accessed as an attribute instead of as a method with '()'. See example below.

    # Adding @property provides the right fullname and does not break code!
    class Person():
    
        def __init__(self, firstname, lastname):
            self.first = firstname
            self.last = lastname
    
        @property    
        def fullname(self):
            return self.first + ' '+ self.last
    
        def email(self):
            return '{}.{}@email.com'.format(self.first, self.last)
    
    # Init a Person 
    person = Person('selva', 'prabhakaran')
    print(person.fullname)  #> selva prabhakaran
    
    # Change last name to Prasanna
    person.last = 'prasanna'
    
    # Print fullname
    print(person.fullname)  # selva prasanna
    

    4. The setter method – When to use it and How to write one?

    Now you are able to access the fullname like an attribute.

    However there is one final problem.

    Your users are going to want to change the fullname property at some point. And by setting it, they expect it will change the values of the first and last names from which fullname was derived in the first place.

    But unfortunately, trying to set the value of fullname throws an AttributeError.

    person.fullname = 'raja rajan'
    
    #> ---------------------------------------------------------------------------
    #> AttributeError                            Traceback (most recent call last)
    #> <ipython-input-36-67cde7461cfc> in <module>
    #> ----> 1 person.fullname = 'raja rajan'
    
    #> AttributeError: can't set attribute
    

    How to tackle this?

    We define an equivalent setter method that will be called everytime a user sets a value to this property.

    Inside this setter method, you can modify the values of variables that should be changed when the value of fullname is set/changed.

    However, there are a couple of conventions you need to follow when defining a setter method:

    1. The setter method should have the same name as the equivalent method that @property decorates.
    2. It accepts as argument the value that user sets to the property.

    Finally you need to add a @{methodname}.setter decorator just before the method definition.

    Once you add the @{methodname}.setter decorator to it, this method will be called everytime the property (fullname in this case) is set or changed. See below.

    class Person():
    
        def __init__(self, firstname, lastname):
            self.first = firstname
            self.last = lastname
    
        @property    
        def fullname(self):
            return self.first + ' '+ self.last
    
        @fullname.setter
        def fullname(self, name):
            firstname, lastname = name.split()
            self.first = firstname
            self.last = lastname
    
        def email(self):
            return '{}.{}@email.com'.format(self.first, self.last)
    
    # Init a Person 
    person = Person('selva', 'prabhakaran')
    print(person.fullname)  #> selva prabhakaran
    print(person.first)  #> selva
    print(person.last)  #> prabhakaran
    
    # Setting fullname calls the setter method and updates person.first and person.last
    person.fullname = 'velu pillai'
    
    # Print the changed values of `first` and `last`
    print(person.fullname) #> velu pillai
    print(person.first)  #> pillai
    print(person.last)  #> pillai
    

    There you go. We set a new value to person.fullname, the person.first and person.last updated as well. Our Person class will now automatically update the derived attributes (property) when one of the base attribute changes and vice versa.

    5. The deleter method

    Similar to the setter, the deleter’s method defines what happens when a property is deleted.

    You can create the deleter method by defining a method of the same name and adding a @{methodname}.deleter decorator. See the implementation below.

    class Person():
        
        def __init__(self, firstname, lastname):
            self.first = firstname
            self.last = lastname
            
        @property    
        def fullname(self):
            return self.first + ' '+ self.last
        
        @fullname.setter
        def fullname(self, name):
            firstname, lastname = name.split()
            self.first = firstname
            self.last = lastname
            
        @fullname.deleter
        def fullname(self):
            self.first = None
            self.last = None        
            
        def email(self):
            return '{}.{}@email.com'.format(self.first, self.last)
        
    # Init a Person 
    person = Person('selva', 'prabhakaran')
    print(person.fullname)  #> selva prabhakaran
    
    # Deleting fullname calls the deleter method, which erases self.first and self.last
    del person.fullname 
    
    # Print the changed values of `first` and `last`
    print(person.first)  #> None
    print(person.last)  #> None
    

    In above case, the person.first and person.last attribute return None, once the fullname is deleted.

    6. Conclusion

    So, to summarize:

    1. When to use @property decorator?
      When an attribute is derived from other attributes in the class, so the derived attribute will update whenever the source attributes is changed.
    2. How to make a @property?
      Make an attribute as property by defining it as a function and add the @property decorator before the fn definition.
    3. When to define a setter method for the property?
      Typically, if you want to update the source attributes whenever the property is set. It lets you define any other changes as well.

    Hope the purpose of @property is clear and you now know when and how to use it. If you did, congratulations! I will meet you in the next one.

    How Naive Bayes Algorithm Works? (with example and full code)

    Naive Bayes is a probabilistic machine learning algorithm based on the Bayes Theorem, used in a wide variety of classification tasks. In this post, you will gain a clear and complete understanding of the Naive Bayes algorithm and all necessary concepts so that there is no room for doubts or gap in understanding.

    Contents

    [columnize] 1. Introduction
    2. What is Conditional Probability?
    3. The Bayes Rule
    4. The Naive Bayes
    5. Naive Bayes Example by Hand
    6. What is Laplace Correction?
    7. What is Gaussian Naive Bayes?
    8. Building a Naive Bayes Classifier in R
    9. Building Naive Bayes Classifier in Python
    10. Practice Exercise: Predict Human Activity Recognition (HAR)
    11. Tips to improve the model
    [/columnize]

    1. Introduction

    Naive Bayes is a probabilistic machine learning algorithm that can be used in a wide variety of classification tasks. Typical applications include filtering spam, classifying documents, sentiment prediction etc. It is based on the works of Rev. Thomas Bayes (1702–61) and hence the name.

    But why is it called ‘Naive’?

    The name naive is used because it assumes the features that go into the model is independent of each other. That is changing the value of one feature, does not directly influence or change the value of any of the other features used in the algorithm.

    Alright. By the sounds of it, Naive Bayes does seem to be a simple yet powerful algorithm. But why is it so popular?

    That’s because there is a significant advantage with NB. Since it is a probabilistic model, the algorithm can be coded up easily and the predictions made real quick. Real-time quick. Because of this, it is easily scalable and is trditionally the algorithm of choice for real-world applications (apps) that are required to respond to user’s requests instantaneously.

    But before you go into Naive Bayes, you need to understand what ‘Conditional Probability’ is and what is the ‘Bayes Rule’.

    And by the end of this tutorial, you will know:

    • How exactly Naive Bayes Classifier works step-by-step
    • What is Gaussian Naive Bayes, when is it used and how it works?
    • How to code it up in R and Python
    • How to improve your Naive Bayes models?

    Cool? Let’s begin.

    2. What is Conditional Probability?

    Lets start from the basics by understanding conditional probability.

    Coin Toss and Fair Dice Example

    When you flip a fair coin, there is an equal chance of getting either heads or tails. So you can say the probability of getting heads is 50%.

    Similarly what would be the probability of getting a 1 when you roll a dice with 6 faces? Assuming the dice is fair, the probability of 1/6 = 0.166.

    Alright, one final example with playing cards.

    Playing Cards Example

    If you pick a card from the deck, can you guess the probability of getting a queen given the card is a spade?

    Well, I have already set a condition that the card is a spade. So, the denominator (eligible population) is 13 and not 52. And since there is only one queen in spades, the probability it is a queen given the card is a spade is 1/13 = 0.077

    This is a classic example of conditional probability. So, when you say the conditional probability of A given B, it denotes the probability of A occurring given that B has already occurred.

    Mathematically, Conditional probability of A given B can be computed as: P(A|B) = P(A AND B) / P(B)

    School Example

    Let’s see a slightly complicated example. Consider a school with a total population of 100 persons. These 100 persons can be seen either as ‘Students’ and ‘Teachers’ or as a population of ‘Males’ and ‘Females’.

    With below tabulation of the 100 people, what is the conditional probability that a certain member of the school is a ‘Teacher’ given that he is a ‘Man’?

    To calculate this, you may intuitively filter the sub-population of 60 males and focus on the 12 (male) teachers.

    So the required conditional probability P(Teacher | Male) = 12 / 60 = 0.2.

    This can be represented as the intersection of Teacher (A) and Male (B) divided by Male (B). Likewise, the conditional probability of B given A can be computed. The Bayes Rule that we use for Naive Bayes, can be derived from these two notations.

    3. The Bayes Rule

    The Bayes Rule is a way of going from P(X|Y), known from the training dataset, to find P(Y|X).

    To do this, we replace A and B in the above formula, with the feature X and response Y.

    For observations in test or scoring data, the X would be known while Y is unknown. And for each row of the test dataset, you want to compute the probability of Y given the X has already happened.

    What happens if Y has more than 2 categories? we compute the probability of each class of Y and let the highest win.

    4. The Naive Bayes

    The Bayes Rule provides the formula for the probability of Y given X. But, in real-world problems, you typically have multiple X variables.

    When the features are independent, we can extend the Bayes Rule to what is called Naive Bayes.

    It is called ‘Naive’ because of the naive assumption that the X’s are independent of each other. Regardless of its name, it’s a powerful formula.

    In technical jargon, the left-hand-side (LHS) of the equation is understood as the posterior probability or simply the posterior

    The RHS has 2 terms in the numerator.

    The first term is called the ‘Likelihood of Evidence’. It is nothing but the conditional probability of each X’s given Y is of particular class ‘c’.

    Since all the X’s are assumed to be independent of each other, you can just multiply the ‘likelihoods’ of all the X’s and called it the ‘Probability of likelihood of evidence’. This is known from the training dataset by filtering records where Y=c.

    The second term is called the prior which is the overall probability of Y=c, where c is a class of Y. In simpler terms, Prior = count(Y=c) / n_Records.

    An example is better than an hour of theory. So let’s see one.

    5. Naive Bayes Example by Hand

    Say you have 1000 fruits which could be either ‘banana’, ‘orange’ or ‘other’. These are the 3 possible classes of the Y variable.

    We have data for the following X variables, all of which are binary (1 or 0).

    • Long
    • Sweet
    • Yellow

    The first few rows of the training dataset look like this:

    FruitLong (x1)Sweet (x2)Yellow (x3)
    Orange010
    Banana101
    Banana111
    Other110
    ........

    For the sake of computing the probabilities, let’s aggregate the training data to form a counts table like this.

    So the objective of the classifier is to predict if a given fruit is a ‘Banana’ or ‘Orange’ or ‘Other’ when only the 3 features (long, sweet and yellow) are known.

    Let’s say you are given a fruit that is: Long, Sweet and Yellow, can you predict what fruit it is?

    This is the same of predicting the Y when only the X variables in testing data are known. Let’s solve it by hand using Naive Bayes.

    The idea is to compute the 3 probabilities, that is the probability of the fruit being a banana, orange or other. Whichever fruit type gets the highest probability wins.

    All the information to calculate these probabilities is present in the above tabulation.

    Step 1: Compute the ‘Prior’ probabilities for each of the class of fruits.

    That is, the proportion of each fruit class out of all the fruits from the population. You can provide the ‘Priors’ from prior information about the population. Otherwise, it can be computed from the training data.

    For this case, let’s compute from the training data. Out of 1000 records in training data, you have 500 Bananas, 300 Oranges and 200 Others. So the respective priors are 0.5, 0.3 and 0.2.

    P(Y=Banana) = 500 / 1000 = 0.50

    P(Y=Orange) = 300 / 1000 = 0.30

    P(Y=Other) = 200 / 1000 = 0.20

    Step 2: Compute the probability of evidence that goes in the denominator.

    This is nothing but the product of P of Xs for all X. This is an optional step because the denominator is the same for all the classes and so will not affect the probabilities.

    P(x1=Long) = 500 / 1000 = 0.50

    P(x2=Sweet) = 650 / 1000 = 0.65

    P(x3=Yellow) = 800 / 1000 = 0.80

    Step 3: Compute the probability of likelihood of evidences that goes in the numerator.

    It is the product of conditional probabilities of the 3 features. If you refer back to the formula, it says P(X1 |Y=k). Here X1 is ‘Long’ and k is ‘Banana’. That means the probability the fruit is ‘Long’ given that it is a Banana. In the above table, you have 500 Bananas. Out of that 400 is long. So, P(Long | Banana) = 400/500 = 0.8.

    Here, I have done it for Banana alone.

    Probability of Likelihood for Banana

    P(x1=Long | Y=Banana) = 400 / 500 = 0.80

    P(x2=Sweet | Y=Banana) = 350 / 500 = 0.70

    P(x3=Yellow | Y=Banana) = 450 / 500 = 0.90

    So, the overall probability of Likelihood of evidence for Banana = 0.8 * 0.7 * 0.9 = 0.504

    Step 4: Substitute all the 3 equations into the Naive Bayes formula, to get the probability that it is a banana.


    Similarly, you can compute the probabilities for ‘Orange’ and ‘Other fruit’. The denominator is the same for all 3 cases, so it’s optional to compute.

    Clearly, Banana gets the highest probability, so that will be our predicted class.

    6. What is Laplace Correction?

    The value of P(Orange | Long, Sweet and Yellow) was zero in the above example, because, P(Long | Orange) was zero. That is, there were no ‘Long’ oranges in the training data.

    It makes sense, but when you have a model with many features, the entire probability will become zero because one of the feature’s value was zero. To avoid this, we increase the count of the variable with zero to a small value (usually 1) in the numerator, so that the overall probability doesn’t become zero.

    This correction is called ‘Laplace Correction’. Most Naive Bayes model implementations accept this or an equivalent form of correction as a parameter.

    7. What is Gaussian Naive Bayes?

    So far we’ve seen the computations when the X’s are categorical. But how to compute the probabilities when X is a continuous variable?

    If we assume that the X follows a particular distribution, then you can plug in the probability density function of that distribution to compute the probability of likelihoods.

    If you assume the X’s follow a Normal (aka Gaussian) Distribution, which is fairly common, we substitute the corresponding probability density of a Normal distribution and call it the Gaussian Naive Bayes. You need just the mean and variance of the X to compute this formula.

    where mu and sigma are the mean and variance of the continuous X computed for a given class ‘c’ (of Y).

    To make the features more Gaussian like, you might consider transforming the variable using something like the Box-Cox to achieve this.

    That’s it. Now, let’s build a Naive Bayes classifier.

    8. Building a Naive Bayes Classifier in R

    Understanding Naive Bayes was the (slightly) tricky part. Implementing it is fairly straightforward.

    In R, Naive Bayes classifier is implemented in packages such as e1071, klaR and bnlearn. In Python, it is implemented in scikit learn.

    For sake of demonstration, let’s use the standard iris dataset to predict the Species of flower using 4 different features: Sepal.Length, Sepal.Width, Petal.Length, Petal.Width

    # Import Data
    training <- read.csv('https://raw.githubusercontent.com/selva86/datasets/master/iris_train.csv')
    test <- read.csv('https://raw.githubusercontent.com/selva86/datasets/master/iris_test.csv')
    

    The training data is now contained in training and test data in test dataframe. Lets load the klaR package and build the naive bayes model.

    # Using klaR for Naive Bayes
    library(klaR)
    nb_mod <- NaiveBayes(Species ~ ., data=training)
    pred <- predict(nb_mod, test)
    

    Lets see the confusion matrix.

    # Confusion Matrix
    tab <- table(pred$class, test$Species)
    caret::confusionMatrix(tab)  
    #> Confusion Matrix and Statistics
    
    #>              setosa versicolor virginica
    #>   setosa         15          0         0
    #>   versicolor      0         11         0
    #>   virginica       0          4        15
    
    #> Overall Statistics
    
    #>                Accuracy : 0.9111          
    #>                  95% CI : (0.7878, 0.9752)
    #>     No Information Rate : 0.3333          
    #>     P-Value [Acc > NIR] : 8.467e-16       
    
    #>                   Kappa : 0.8667          
    #>  Mcnemar's Test P-Value : NA              
    
    #> Statistics by Class:
    
    #>                      Class: setosa Class: versicolor Class: virginica
    #> Sensitivity                 1.0000            0.7333           1.0000
    #> Specificity                 1.0000            1.0000           0.8667
    #> Pos Pred Value              1.0000            1.0000           0.7895
    #> Neg Pred Value              1.0000            0.8824           1.0000
    #> Prevalence                  0.3333            0.3333           0.3333
    #> Detection Rate              0.3333            0.2444           0.3333
    #> Detection Prevalence        0.3333            0.2444           0.4222
    #> Balanced Accuracy           1.0000            0.8667           0.9333
    
    # Plot density of each feature using nb_mod
    opar = par(mfrow=c(2, 2), mar=c(4,0,0,0))
    plot(nb_mod, main="")  
    par(opar)
    

    # Plot the Confusion Matrix
    library(ggplot2)
    test$pred <- pred$class
    ggplot(test, aes(Species, pred, color = Species)) +
      geom_jitter(width = 0.2, height = 0.1, size=2) +
      labs(title="Confusion Matrix", 
           subtitle="Predicted vs. Observed from Iris dataset", 
           y="Predicted", 
           x="Truth",
           caption="machinelearningplus.com")
    

    9. Building Naive Bayes Classifier in Python

    # Import packages
    from sklearn.naive_bayes import GaussianNB
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import confusion_matrix
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    import seaborn as sns; sns.set()
    
    # Import data
    training = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/iris_train.csv')
    test = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/iris_test.csv')
    
    
    # Create the X, Y, Training and Test
    xtrain = training.drop('Species', axis=1)
    ytrain = training.loc[:, 'Species']
    xtest = test.drop('Species', axis=1)
    ytest = test.loc[:, 'Species']
    
    
    # Init the Gaussian Classifier
    model = GaussianNB()
    
    # Train the model 
    model.fit(xtrain, ytrain)
    
    # Predict Output 
    pred = model.predict(xtest)
    
    # Plot Confusion Matrix
    mat = confusion_matrix(pred, ytest)
    names = np.unique(pred)
    sns.heatmap(mat, square=True, annot=True, fmt='d', cbar=False,
                xticklabels=names, yticklabels=names)
    plt.xlabel('Truth')
    plt.ylabel('Predicted')
    

    10. Practice Exercise: Predict Human Activity Recognition (HAR)

    The objective of this practice exercise is to predict current human activity based on phisiological activity measurements from 53 different features based in the HAR dataset. The training and test datasets are provided.

    Build a Naive Bayes model, predict on the test dataset and compute the confusion matrix.

    Show R Solution
    # R Solution
    library(caret)
    training <- read.csv("har_train.csv", na.strings = c("NA", ""))
    test <- read.csv("har_validate.csv", na.strings = c("NA", ""))
    
    # Train Naive Bayes using klaR
    library(klaR)
    
    # Train
    nb_mod <- NaiveBayes(classe ~ ., data=training, fL=1, usekernel = T) . # with kernel and laplace correction = 1
    
    # Predict
    pred <- suppressWarnings(predict(nb_mod, test))
    
    # Confusion Matrix
    tab <- table(pred$class, test$classe)
    caret::confusionMatrix(tab)  # Computes the accuracy metrics for all individual classes.
    
    # Plot the Confusion Matrix
    library(ggplot2)
    test$pred <- pred$class
    ggplot(test, aes(classe, pred, color = classe)) +
      geom_jitter(width = 0.3, height = 0.3, size=1) +
      labs(title="Confusion Matrix", 
           subtitle="Predicted vs. Observed from HAR dataset", 
           y="Predicted", 
           x="Truth", 
           caption="machinelearningplus.com")
    
    Confusion Matrix and Statistics
    
    
           A    B    C    D    E
      A 1222   80   66   67   27
      B   80  799   70    5  106
      C  167  170  828  175   54
      D  197   82   60  680   52
      E    8    8    2   37  843
    
    Overall Statistics
    
                   Accuracy : 0.7429         
                     95% CI : (0.7315, 0.754)
        No Information Rate : 0.2845         
        P-Value [Acc > NIR] : < 2.2e-16      
    
                      Kappa : 0.6767         
     Mcnemar's Test P-Value : < 2.2e-16      
    
    Statistics by Class:
    
                         Class: A Class: B Class: C Class: D Class: E
    Sensitivity            0.7300   0.7015   0.8070   0.7054   0.7791
    Specificity            0.9430   0.9450   0.8835   0.9205   0.9885
    Pos Pred Value         0.8358   0.7538   0.5940   0.6349   0.9388
    Neg Pred Value         0.8978   0.9295   0.9559   0.9410   0.9521
    Prevalence             0.2845   0.1935   0.1743   0.1638   0.1839
    Detection Rate         0.2076   0.1358   0.1407   0.1155   0.1432
    Detection Prevalence   0.2484   0.1801   0.2369   0.1820   0.1526
    Balanced Accuracy      0.8365   0.8232   0.8453   0.8130   0.8838
    

    png

    Show Python Solution
    # Python Solution 
    # Import packages
    from sklearn.naive_bayes import GaussianNB
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import confusion_matrix
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    import seaborn as sns; sns.set()
    
    # Import data
    training = pd.read_csv("har_train.csv")
    test =  pd.read_csv("har_validate.csv")
    
    # Create the X and Y
    xtrain = training.drop('classe', axis=1)
    ytrain = training.loc[:, 'classe']
    
    xtest = test.drop('classe', axis=1)
    ytest = test.loc[:, 'classe']
    
    # Init the Gaussian Classifier
    model = GaussianNB()
    
    # Train the model
    model.fit(xtrain, ytrain)
    
    # Predict Output 
    pred = model.predict(xtest)
    print(pred[:5])
    
    # Plot Confusion Matrix
    mat = confusion_matrix(pred, ytest)
    names = np.unique(pred)
    sns.heatmap(mat, square=True, annot=True, fmt='d', cbar=False,
                xticklabels=names, yticklabels=names)
    plt.xlabel('Truth')
    plt.ylabel('Predicted')
    

    11. Tips to improve the model

    1. Try transforming the variables using transformations like BoxCox or YeoJohnson to make the features near Normal.
    2. Try applying Laplace correction to handle records with zeros values in X variables.
    3. Check for correlated features and try removing the highly correlated ones. Naive Bayes is based on the assumption that the features are independent.
    4. Feature engineering. Combining features (a product) to form new ones that makes intuitive sense might help.
    5. Try providing more realistic prior probabilities to the algorithm based on knowledge from business, instead of letting the algo calculate the priors based on the training sample.

    For this case, ensemble methods like bagging, boosting will help a lot by reducing the variance.

    Parallel Processing in Python – A Practical Guide with Examples

    Parallel processing is a mode of operation where the task is executed simultaneously in multiple processors in the same computer. It is meant to reduce the overall processing time. In this tutorial, you’ll understand the procedure to parallelize any typical logic using python’s multiprocessing module.

    Contents

    [columnize] 1. Introduction
    2. How many maximum parallel processes can you run?
    3. What is Synchronous and Asynchronous execution?
    4. Problem Statement: Count how many numbers exist between a given range in each row
    Solution without parallelization
    5. How to parallelize any function?
    6. Asynchronous Parallel Processing
    7. How to Parallelize a Pandas DataFrame?
    8. Exercises
    9. Conclusion
    [/columnize]

    1. Introduction

    Parallel processing is a mode of operation where the task is executed simultaneously in multiple processors in the same computer. It is meant to reduce the overall processing time.

    However, there is usually a bit of overhead when communicating between processes which can actually increase the overall time taken for small tasks instead of decreasing it.

    In python, the multiprocessing module is used to run independent parallel processes by using subprocesses (instead of threads). It allows you to leverage multiple processors on a machine (both Windows and Unix), which means, the processes can be run in completely separate memory locations.

    By the end of this tutorial you would know:

    • How to structure the code and understand the syntax to enable parallel processing using multiprocessing?
    • How to implement synchronous and asynchronous parallel processing?
    • How to parallelize a Pandas DataFrame?
    • Solve 3 different usecases with the multiprocessing.Pool() interface.

    2. How many maximum parallel processes can you run?

    The maximum number of processes you can run at a time is limited by the number of processors in your computer. If you don’t know how many processors are present in the machine, the cpu_count() function in multiprocessing will show it.

    import multiprocessing as mp
    print("Number of processors: ", mp.cpu_count())
    

    3. What is Synchronous and Asynchronous execution?

    In parallel processing, there are two types of execution: Synchronous and Asynchronous.

    A synchronous execution is one the processes are completed in the same order in which it was started. This is achieved by locking the main program until the respective processes are finished.

    Asynchronous, on the other hand, doesn’t involve locking. As a result, the order of results can get mixed up but usually gets done quicker.

    There are 2 main objects in multiprocessing to implement parallel execution of a function: The Pool Class and the Process Class.

    1. Pool Class
      1. Synchronous execution
        • Pool.map() and Pool.starmap()
        • Pool.apply()
      2. Asynchronous execution
        • Pool.map_async() and Pool.starmap_async()
        • Pool.apply_async())
    2. Process Class

    Let’s take up a typical problem and implement parallelization using the above techniques. In this tutorial, we stick to the Pool class, because it is most convenient to use and serves most common practical applications.

    4. Problem Statement: Count how many numbers exist between a given range in each row

    The first problem is: Given a 2D matrix (or list of lists), count how many numbers are present between a given range in each row. We will work on the list prepared below.

    import numpy as np
    from time import time
    
    # Prepare data
    np.random.RandomState(100)
    arr = np.random.randint(0, 10, size=[200000, 5])
    data = arr.tolist()
    data[:5]
    

    Solution without parallelization

    Let’s see how long it takes to compute it without parallelization. For this, we iterate the function howmany_within_range() (written below) to check how many numbers lie within range and returns the count.

    # Solution Without Paralleization
    
    def howmany_within_range(row, minimum, maximum):
        """Returns how many numbers lie within `maximum` and `minimum` in a given `row`"""
        count = 0
        for n in row:
            if minimum <= n <= maximum:
                count = count + 1
        return count
    
    results = []
    for row in data:
        results.append(howmany_within_range(row, minimum=4, maximum=8))
    
    print(results[:10])
    #> [3, 1, 4, 4, 4, 2, 1, 1, 3, 3]
    

    5. How to parallelize any function?

    The general way to parallelize any operation is to take a particular function that should be run multiple times and make it run parallelly in different processors.

    To do this, you initialize a Pool with n number of processors and pass the function you want to parallelize to one of Pools parallization methods.

    multiprocessing.Pool() provides the apply(), map() and starmap() methods to make any function run in parallel.

    Nice! So what’s the difference between apply() and map()?

    Both apply and map take the function to be parallelized as the main argument. But the difference is, apply() takes an args argument that accepts the parameters passed to the ‘function-to-be-parallelized’ as an argument, whereas, map can take only one iterable as an argument.

    So, map() is really more suitable for simpler iterable operations but does the job faster.

    We will get to starmap() once we see how to parallelize howmany_within_range() function with apply() and map().

    5.1. Parallelizing using Pool.apply()

    Let’s parallelize the howmany_within_range() function using multiprocessing.Pool().

    # Parallelizing using Pool.apply()
    
    import multiprocessing as mp
    
    # Step 1: Init multiprocessing.Pool()
    pool = mp.Pool(mp.cpu_count())
    
    # Step 2: `pool.apply` the `howmany_within_range()`
    results = [pool.apply(howmany_within_range, args=(row, 4, 8)) for row in data]
    
    # Step 3: Don't forget to close
    pool.close()    
    
    print(results[:10])
    #> [3, 1, 4, 4, 4, 2, 1, 1, 3, 3]
    

    5.2. Parallelizing using Pool.map()

    Pool.map() accepts only one iterable as argument. So as a workaround, I modify the howmany_within_range function by setting a default to the minimum and maximum parameters to create a new howmany_within_range_rowonly() function so it accetps only an iterable list of rows as input. I know this is not a nice usecase of map(), but it clearly shows how it differs from apply().

    # Parallelizing using Pool.map()
    import multiprocessing as mp
    
    # Redefine, with only 1 mandatory argument.
    def howmany_within_range_rowonly(row, minimum=4, maximum=8):
        count = 0
        for n in row:
            if minimum <= n <= maximum:
                count = count + 1
        return count
    
    pool = mp.Pool(mp.cpu_count())
    
    results = pool.map(howmany_within_range_rowonly, [row for row in data])
    
    pool.close()
    
    print(results[:10])
    #> [3, 1, 4, 4, 4, 2, 1, 1, 3, 3]
    

    5.3. Parallelizing using Pool.starmap()

    In previous example, we have to redefine howmany_within_range function to make couple of parameters to take default values. Using starmap(), you can avoid doing this. How you ask?

    Like Pool.map(), Pool.starmap() also accepts only one iterable as argument, but in starmap(), each element in that iterable is also a iterable. You can to provide the arguments to the ‘function-to-be-parallelized’ in the same order in this inner iterable element, will in turn be unpacked during execution.

    So effectively, Pool.starmap() is like a version of Pool.map() that accepts arguments.

    # Parallelizing with Pool.starmap()
    import multiprocessing as mp
    
    pool = mp.Pool(mp.cpu_count())
    
    results = pool.starmap(howmany_within_range, [(row, 4, 8) for row in data])
    
    pool.close()
    
    print(results[:10])
    #> [3, 1, 4, 4, 4, 2, 1, 1, 3, 3]
    

    6. Asynchronous Parallel Processing

    The asynchronous equivalents apply_async(), map_async() and starmap_async() lets you do execute the processes in parallel asynchronously, that is the next process can start as soon as previous one gets over without regard for the starting order. As a result, there is no guarantee that the result will be in the same order as the input.

    6.1 Parallelizing with Pool.apply_async()

    apply_async() is very similar to apply() except that you need to provide a callback function that tells how the computed results should be stored.

    However, a caveat with apply_async() is, the order of numbers in the result gets jumbled up indicating the processes did not complete in the order it was started.

    A workaround for this is, we redefine a new howmany_within_range2() to accept and return the iteration number (i) as well and then sort the final results.

    # Parallel processing with Pool.apply_async()
    
    import multiprocessing as mp
    pool = mp.Pool(mp.cpu_count())
    
    results = []
    
    # Step 1: Redefine, to accept `i`, the iteration number
    def howmany_within_range2(i, row, minimum, maximum):
        """Returns how many numbers lie within `maximum` and `minimum` in a given `row`"""
        count = 0
        for n in row:
            if minimum <= n <= maximum:
                count = count + 1
        return (i, count)
    
    
    # Step 2: Define callback function to collect the output in `results`
    def collect_result(result):
        global results
        results.append(result)
    
    
    # Step 3: Use loop to parallelize
    for i, row in enumerate(data):
        pool.apply_async(howmany_within_range2, args=(i, row, 4, 8), callback=collect_result)
    
    # Step 4: Close Pool and let all the processes complete    
    pool.close()
    pool.join()  # postpones the execution of next line of code until all processes in the queue are done.
    
    # Step 5: Sort results [OPTIONAL]
    results.sort(key=lambda x: x[0])
    results_final = [r for i, r in results]
    
    print(results_final[:10])
    #> [3, 1, 4, 4, 4, 2, 1, 1, 3, 3]
    

    It is possible to use apply_async() without providing a callback function. Only that, if you don’t provide a callback, then you get a list of pool.ApplyResult objects which contains the computed output values from each process. From this, you need to use the pool.ApplyResult.get() method to retrieve the desired final result.

    # Parallel processing with Pool.apply_async() without callback function
    
    import multiprocessing as mp
    pool = mp.Pool(mp.cpu_count())
    
    results = []
    
    # call apply_async() without callback
    result_objects = [pool.apply_async(howmany_within_range2, args=(i, row, 4, 8)) for i, row in enumerate(data)]
    
    # result_objects is a list of pool.ApplyResult objects
    results = [r.get()[1] for r in result_objects]
    
    pool.close()
    pool.join()
    print(results[:10])
    #> [3, 1, 4, 4, 4, 2, 1, 1, 3, 3]
    

    6.2 Parallelizing with Pool.starmap_async()

    You saw how apply_async() works. Can you imagine and write up an equivalent version for starmap_async and map_async? The implementation is below anyways.

    # Parallelizing with Pool.starmap_async()
    
    import multiprocessing as mp
    pool = mp.Pool(mp.cpu_count())
    
    results = []
    
    results = pool.starmap_async(howmany_within_range2, [(i, row, 4, 8) for i, row in enumerate(data)]).get()
    
    # With map, use `howmany_within_range_rowonly` instead
    # results = pool.map_async(howmany_within_range_rowonly, [row for row in data]).get()
    
    pool.close()
    print(results[:10])
    #> [3, 1, 4, 4, 4, 2, 1, 1, 3, 3]
    

    7. How to Parallelize a Pandas DataFrame?

    So far you’ve seen how to parallelize a function by making it work on lists.

    But when working in data analysis or machine learning projects, you might want to parallelize Pandas Dataframes, which are the most commonly used objects (besides numpy arrays) to store tabular data.

    When it comes to parallelizing a DataFrame, you can make the function-to-be-parallelized to take as an input parameter:

    • one row of the dataframe
    • one column of the dataframe
    • the entire dataframe itself

    The first 2 can be done using multiprocessing module itself. But for the last one, that is parallelizing on an entire dataframe, we will use the pathos package that uses dill for serialization internally.

    First, lets create a sample dataframe and see how to do row-wise and column-wise paralleization. Something like using pd.apply() on a user defined function but in parallel.

    import numpy as np
    import pandas as pd
    import multiprocessing as mp
    
    df = pd.DataFrame(np.random.randint(3, 10, size=[5, 2]))
    print(df.head())
    #>    0  1
    #> 0  8  5
    #> 1  5  3
    #> 2  3  4
    #> 3  4  4
    #> 4  7  9
    

    We have a dataframe. Let’s apply the hypotenuse function on each row, but running 4 processes at a time.

    To do this, we exploit the df.itertuples(name=False). By setting name=False, you are passing each row of the dataframe as a simple tuple to the hypotenuse function.

    # Row wise Operation
    def hypotenuse(row):
        return round(row[1]**2 + row[2]**2, 2)**0.5
    
    with mp.Pool(4) as pool:
        result = pool.imap(hypotenuse, df.itertuples(name=False), chunksize=10)
        output = [round(x, 2) for x in result]
    
    print(output)
    #> [9.43, 5.83, 5.0, 5.66, 11.4]
    

    That was an example of row-wise parallelization. Let’s also do a column-wise parallelization. For this, I use df.iteritems() to pass an entire column as a series to the sum_of_squares function.

    # Column wise Operation
    def sum_of_squares(column):
        return sum([i**2 for i in column[1]])
    
    with mp.Pool(2) as pool:
        result = pool.imap(sum_of_squares, df.iteritems(), chunksize=10)
        output = [x for x in result]
    
    print(output) 
    #> [163, 147]
    

    Now comes the third part – Parallelizing a function that accepts a Pandas Dataframe, NumPy Array, etc. Pathos follows the multiprocessing style of: Pool > Map > Close > Join > Clear. Check out the pathos docs for more info.

    import numpy as np
    import pandas as pd
    import multiprocessing as mp
    from pathos.multiprocessing import ProcessingPool as Pool
    
    df = pd.DataFrame(np.random.randint(3, 10, size=[500, 2]))
    
    def func(df):
        return df.shape
    
    cores=mp.cpu_count()
    
    df_split = np.array_split(df, cores, axis=0)
    
    # create the multiprocessing pool
    pool = Pool(cores)
    
    # process the DataFrame by mapping function to each df across the pool
    df_out = np.vstack(pool.map(func, df_split))
    
    # close down the pool and join
    pool.close()
    pool.join()
    pool.clear()
    

    Thanks to notsoprocoder for this contribution based on pathos. If you are familiar with pandas dataframes but want to get hands-on and master it, check out these pandas exercises.

    8. Exercises

    Problem 1: Use Pool.apply() to get the row wise common items in list_a and list_b.

    list_a = [[1, 2, 3], [5, 6, 7, 8], [10, 11, 12], [20, 21]]
    list_b = [[2, 3, 4, 5], [6, 9, 10], [11, 12, 13, 14], [21, 24, 25]]
    
    Show Solution
    import multiprocessing as mp
    
    list_a = [[1, 2, 3], [5, 6, 7, 8], [10, 11, 12], [20, 21]]
    list_b = [[2, 3, 4, 5], [6, 9, 10], [11, 12, 13, 14], [21, 24, 25]]
    
    def get_commons(list_1, list_2):
        return list(set(list_1).intersection(list_2))
    
    pool = mp.Pool(mp.cpu_count())
    results = [pool.apply(get_commons, args=(l1, l2)) for l1, l2 in zip(list_a, list_b)]
    pool.close()    
    print(results[:10])
    

    Problem 2: Use Pool.map() to run the following python scripts in parallel.
    Script names: ‘script1.py’, ‘script2.py’, ‘script3.py’

    Show Solution
    import os                                                                       
    import multiprocessing as mp
    
    processes = ('script1.py', 'script2.py', 'script3.py')                      
    
    def run_python(process):                                                             
        os.system('python {}'.format(process))                                      
    
    pool = mp.Pool(processes=3)                                                        
    pool.map(run_python, processes)  
    

    Problem 3: Normalize each row of 2d array (list) to vary between 0 and 1.

    list_a = [[2, 3, 4, 5], [6, 9, 10, 12], [11, 12, 13, 14], [21, 24, 25, 26]]
    
    Show Solution
    import multiprocessing as mp
    
    list_a = [[2, 3, 4, 5], [6, 9, 10, 12], [11, 12, 13, 14], [21, 24, 25, 26]]
    
    def normalize(mylist):
        mini = min(mylist)
        maxi = max(mylist)
        return [(i - mini)/(maxi-mini) for i in mylist]
    
    pool = mp.Pool(mp.cpu_count())
    results = [pool.apply(normalize, args=(l1, )) for l1 in list_a]
    pool.close()    
    print(results[:10])
    

    9. Conclusion

    Hope you were able to solve the above exercises, congratulations if you did!

    In this post, we saw the overall procedure and various ways to implement parallel processing using the multiprocessing module. The procedure described above is pretty much the same even if you work on larger machines with many more number of processors, where you may reap the real speed benefits of parallel processing.

    Happy coding and I’ll see you in the next one!

    Cosine Similarity – Understanding the math and how it works (with python codes)

    Cosine similarity is a metric used to measure how similar the documents are irrespective of their size. Mathematically, it measures the cosine of the angle between two vectors projected in a multi-dimensional space. The cosine similarity is advantageous because even if the two similar documents are far apart by the Euclidean distance (due to the size of the document), chances are they may still be oriented closer together. The smaller the angle, higher the cosine similarity.

    By the end of this tutorial you will know:

    1. What is cosine similarity is and how it works?
    2. How to compute cosine similarity of documents in python?
    3. What is soft cosine similarity and how its different from cosine similarity?
    4. When to use soft cosine similarity and how to compute it in python?
    Comparing Lemmatization Approaches in Python. Photo by Matt Lamers
    [container]

    Contents

    [columnize] 1. Introduction
    2. What is Cosine Similarity and why is it advantageous?
    3. Cosine Similarity Example
    4. How to Compute Cosine Similarity in Python?
    5. Soft Cosine Similarity
    6. Conclusion
    [/columnize] [/container]

    1. Introduction

    A commonly used approach to match similar documents is based on counting the maximum number of common words between the documents.

    But this approach has an inherent flaw. That is, as the size of the document increases, the number of common words tend to increase even if the documents talk about different topics.

    The cosine similarity helps overcome this fundamental flaw in the ‘count-the-common-words’ or Euclidean distance approach.

    2. What is Cosine Similarity and why is it advantageous?

    Cosine similarity is a metric used to determine how similar the documents are irrespective of their size.

    Mathematically, it measures the cosine of the angle between two vectors projected in a multi-dimensional space. In this context, the two vectors I am talking about are arrays containing the word counts of two documents.

    As a similarity metric, how does cosine similarity differ from the number of common words?

    When plotted on a multi-dimensional space, where each dimension corresponds to a word in the document, the cosine similarity captures the orientation (the angle) of the documents and not the magnitude. If you want the magnitude, compute the Euclidean distance instead.

    The cosine similarity is advantageous because even if the two similar documents are far apart by the Euclidean distance because of the size (like, the word ‘cricket’ appeared 50 times in one document and 10 times in another) they could still have a smaller angle between them. Smaller the angle, higher the similarity.

    3. Cosine Similarity Example

    Let’s suppose you have 3 documents based on a couple of star cricket players – Sachin Tendulkar and Dhoni. Two of the documents (A) and (B) are from the wikipedia pages on the respective players and the third document (C) is a smaller snippet from Dhoni’s wikipedia page.

    The Three Documents
    The Three Documents

    As you can see, all three documents are connected by a common theme – the game of Cricket.

    Our objective is to quantitatively estimate the similarity between the documents.

    For ease of understanding, let’s consider only the top 3 common words between the documents: ‘Dhoni’, ‘Sachin’ and ‘Cricket’.

    You would expect Doc A and Doc C, that is the two documents on Dhoni would have a higher similarity over Doc A and Doc B, because, Doc C is essentially a snippet from Doc A itself.

    However, if we go by the number of common words, the two larger documents will have the most common words and therefore will be judged as most similar, which is exactly what we want to avoid.

    The results would be more congruent when we use the cosine similarity score to assess the similarity.

    Let me explain.

    Let’s project the documents in a 3-dimensional space, where each dimension is a frequency count of either: ‘Sachin’, ‘Dhoni’ or ‘Cricket’. When plotted on this space, the 3 documents would appear something like this.

    3d Projection
    3d Projection

    As you can see, Doc Dhoni_Small and the main Doc Dhoni are oriented closer together in 3-D space, even though they are far apart by magnitiude.

    It turns out, the closer the documents are by angle, the higher is the Cosine Similarity (Cos theta).

    Cosine Similarity Formula
    Cosine Similarity Formula

    As you include more words from the document, it’s harder to visualize a higher dimensional space. But you can directly compute the cosine similarity using this math formula.

    Enough with the theory. Let’s compute the cosine similarity with Python’s scikit learn.

    4. How to Compute Cosine Similarity in Python?

    We have the following 3 texts:

    Doc Trump (A) : Mr. Trump became president after winning the political election. Though he lost the support of some republican friends, Trump is friends with President Putin.

    Doc Trump Election (B) : President Trump says Putin had no political interference is the election outcome. He says it was a witchhunt by political parties. He claimed President Putin is a friend who had nothing to do with the election.

    Doc Putin (C) : Post elections, Vladimir Putin became President of Russia. President Putin had served as the Prime Minister earlier in his political career.

    Since, Doc B has more in common with Doc A than with Doc C, I would expect the Cosine between A and B to be larger than (C and B).

    # Define the documents
    doc_trump = "Mr. Trump became president after winning the political election. Though he lost the support of some republican friends, Trump is friends with President Putin"
    
    doc_election = "President Trump says Putin had no political interference is the election outcome. He says it was a witchhunt by political parties. He claimed President Putin is a friend who had nothing to do with the election"
    
    doc_putin = "Post elections, Vladimir Putin became President of Russia. President Putin had served as the Prime Minister earlier in his political career"
    
    documents = [doc_trump, doc_election, doc_putin]
    

    To compute the cosine similarity, you need the word count of the words in each document. The CountVectorizer or the TfidfVectorizer from scikit learn lets us compute this. The output of this comes as a sparse_matrix.

    On this, am optionally converting it to a pandas dataframe to see the word frequencies in a tabular format.

    # Scikit Learn
    from sklearn.feature_extraction.text import CountVectorizer
    import pandas as pd
    
    # Create the Document Term Matrix
    count_vectorizer = CountVectorizer(stop_words='english')
    count_vectorizer = CountVectorizer()
    sparse_matrix = count_vectorizer.fit_transform(documents)
    
    # OPTIONAL: Convert Sparse Matrix to Pandas Dataframe if you want to see the word frequencies.
    doc_term_matrix = sparse_matrix.todense()
    df = pd.DataFrame(doc_term_matrix, 
                      columns=count_vectorizer.get_feature_names(), 
                      index=['doc_trump', 'doc_election', 'doc_putin'])
    df
    
    Doc-Term Matrix
    Doc-Term Matrix

    Even better, I could have used the TfidfVectorizer() instead of CountVectorizer(), because it would have downweighted words that occur frequently across docuemnts.

    Then, use cosine_similarity() to get the final output. It can take the document term matri as a pandas dataframe as well as a sparse matrix as inputs.

    # Compute Cosine Similarity
    from sklearn.metrics.pairwise import cosine_similarity
    print(cosine_similarity(df, df))
    #> [[ 1.          0.48927489  0.37139068]
    #>  [ 0.48927489  1.          0.38829014]
    #>  [ 0.37139068  0.38829014  1.        ]]
    

    5. Soft Cosine Similarity

    Suppose if you have another set of documents on a completely different topic, say ‘food’, you want a similarity metric that gives higher scores for documents belonging to the same topic and lower scores when comparing docs from different topics.

    In such case, we need to consider the semantic meaning should be considered. That is, words similar in meaning should be treated as similar. For Example, ‘President’ vs ‘Prime minister’, ‘Food’ vs ‘Dish’, ‘Hi’ vs ‘Hello’ should be considered similar. For this, converting the words into respective word vectors, and then, computing the similarities can address this problem.

    Soft Cosines
    Soft Cosines

    Let’s define 3 additional documents on food items.

    # Define the documents
    doc_soup = "Soup is a primarily liquid food, generally served warm or hot (but may be cool or cold), that is made by combining ingredients of meat or vegetables with stock, juice, water, or another liquid. "
    
    doc_noodles = "Noodles are a staple food in many cultures. They are made from unleavened dough which is stretched, extruded, or rolled flat and cut into one of a variety of shapes."
    
    doc_dosa = "Dosa is a type of pancake from the Indian subcontinent, made from a fermented batter. It is somewhat similar to a crepe in appearance. Its main ingredients are rice and black gram."
    
    documents = [doc_trump, doc_election, doc_putin, doc_soup, doc_noodles, doc_dosa]
    

    To get the word vectors, you need a word embedding model. Let’s download the FastText model using gensim’s downloader api.

    import gensim
    # upgrade gensim if you can't import softcossim
    from gensim.matutils import softcossim 
    from gensim import corpora
    import gensim.downloader as api
    from gensim.utils import simple_preprocess
    print(gensim.__version__)
    #> '3.6.0'
    
    # Download the FastText model
    fasttext_model300 = api.load('fasttext-wiki-news-subwords-300')
    

    To compute soft cosines, you need the dictionary (a map of word to unique id), the corpus (word counts) for each sentence and the similarity matrix.

    # Prepare a dictionary and a corpus.
    dictionary = corpora.Dictionary([simple_preprocess(doc) for doc in documents])
    
    # Prepare the similarity matrix
    similarity_matrix = fasttext_model300.similarity_matrix(dictionary, tfidf=None, threshold=0.0, exponent=2.0, nonzero_limit=100)
    
    # Convert the sentences into bag-of-words vectors.
    sent_1 = dictionary.doc2bow(simple_preprocess(doc_trump))
    sent_2 = dictionary.doc2bow(simple_preprocess(doc_election))
    sent_3 = dictionary.doc2bow(simple_preprocess(doc_putin))
    sent_4 = dictionary.doc2bow(simple_preprocess(doc_soup))
    sent_5 = dictionary.doc2bow(simple_preprocess(doc_noodles))
    sent_6 = dictionary.doc2bow(simple_preprocess(doc_dosa))
    
    sentences = [sent_1, sent_2, sent_3, sent_4, sent_5, sent_6]
    

    If you want the soft cosine similarity of 2 documents, you can just call the softcossim() function

    # Compute soft cosine similarity
    print(softcossim(sent_1, sent_2, similarity_matrix))
    #> 0.567228632589
    

    But, I want to compare the soft cosines for all documents against each other. So, create the soft cosine similarity matrix.

    import numpy as np
    import pandas as pd
    
    def create_soft_cossim_matrix(sentences):
        len_array = np.arange(len(sentences))
        xx, yy = np.meshgrid(len_array, len_array)
        cossim_mat = pd.DataFrame([[round(softcossim(sentences[i],sentences[j], similarity_matrix) ,2) for i, j in zip(x,y)] for y, x in zip(xx, yy)])
        return cossim_mat
    
    soft_cosine_similarity_matrix(sentences)
    
    soft cosine similarity matrix
    Soft cosine similarity matrix

    As one might expect, the similarity scores amongst similar documents are higher (see the red boxes).

    6. Conclusion

    Now you should clearly understand the math behind the computation of cosine similarity and how it is advantageous over magnitude based metrics like Euclidean distance.

    Soft cosines can be a great feature if you want to use a similarity metric that can help in clustering or classification of documents.

    If you want to dig in further into natural language processing, the gensim tutorial is highly recommended.

    Gensim Tutorial – A Complete Beginners Guide

    Gensim is billed as a Natural Language Processing package that does ‘Topic Modeling for Humans’. But it is practically much more than that. It is a leading and a state-of-the-art package for processing texts, working with word vector models (such as Word2Vec, FastText etc) and for building topic models.

    Gensim Tutorial – A Complete Beginners Guide. Photo by Jasmin Schreiber
    [container]

    Contents

    [columnize] 1. Introduction
    2. What is a Dictionary and a Corpus?
    3. How to create a Dictionary from a list of sentences?
    4. How to create a Dictionary from one or more text files?
    5. How to create a bag of words corpus in gensim?
    6. How to create a bag of words corpus from external text file?
    7. How to save a gensim dictionary and corpus to disk and load them back?
    8. How to create the TFIDF matrix (corpus) in gensim?
    9. How to use gensim downloader API to load datasets?
    10. How to create bigrams and trigrams using Phraser models?
    11. How to create topic models with LDA?
    12. How to interpret the LDA Topic Model’s output?
    13. How to create a LSI topic model using gensim?
    14. How to train Word2Vec model using gensim?
    15. How to update an existing Word2Vec model with new data?
    16. How to extract word vectors using pre-trained Word2Vec and FastText models?
    17. How to create document vectors using Doc2Vec?
    18. How to compute similarity metrics like cosine similarity and soft cosine similarity?
    19. How to summarize text documents?
    20. Conclusion
    [/columnize] [/container]

    1. Introduction

    What is gensim?

    Gensim is billed as a Natural Language Processing package that does ‘Topic Modeling for Humans’. But its practically much more than that.

    If you are unfamiliar with topic modeling, it is a technique to extract the underlying topics from large volumes of text. Gensim provides algorithms like LDA and LSI (which we will see later in this post) and the necessary sophistication to build high-quality topic models.

    You may argue that topic models and word embedding are available in other packages like scikit, R etc. But the width and scope of facilities to build and evaluate topic models are unparalleled in gensim, plus many more convenient facilities for text processing.

    It is a great package for processing texts, working with word vector models (such as Word2Vec, FastText etc) and for building topic models.

    Also, another significant advantage with gensim is: it lets you handle large text files without having to load the entire file in memory.

    This post intends to give a practical overview of the nearly all major features, explained in a simple and easy to understand way.

    By the end of this tutorial, you would know:

    • What are the core concepts in gensim?
    • What is dictionary and corpus, why they matter and where to use them?
    • How to create and work with dictionary and corpus?
    • How to load and work with text data from multiple text files in memory efficient way
    • Create topic models with LDA and interpret the outputs
    • Create TFIDF model, bigrams, trigrams, Word2Vec model, Doc2Vec model
    • Compute similarity metrics
    • And much more..

    Let’s begin.

    2. What is a Dictionary and Corpus?

    In order to work on text documents, Gensim requires the words (aka tokens) be converted to unique ids. In order to achieve that, Gensim lets you create a Dictionary object that maps each word to a unique id.

    So, how to create a `Dictionary`? By converting your text/sentences to a [list of words] and pass it to the corpora.Dictionary() object.

    We will see how to actually do this in the next section.

    But why is the dictionary object needed and where can it be used?

    The dictionary object is typically used to create a ‘bag of words’ Corpus. It is this Dictionary and the bag-of-words (Corpus) that are used as inputs to topic modeling and other models that Gensim specializes in.

    Alright, what sort of text inputs can gensim handle? The input text typically comes in 3 different forms:

    1. As sentences stored in python’s native list object
    2. As one single text file, small or large.
    3. In multiple text files.

    Now, when your text input is large, you need to be able to create the dictionary object without having to load the entire text file.

    The good news is Gensim lets you read the text and update the dictionary, one line at a time, without loading the entire text file into system memory. Let’s see how to do that in the next 2 sections.

    But, before we get in, let’s understand some NLP jargon.

    A ‘token’ typically means a ‘word’. A ‘document’ can typically refer to a ‘sentence’ or ‘paragraph’ and a ‘corpus’ is typically a ‘collection of documents as a bag of words’. That is, for each document, a corpus contains each word’s id and its frequency count in that document. As a result, information of the order of words is lost.

    If everything is clear so far, let’s get our hands wet and see how to create the dictionary from a list of sentences.

    3. How to create a Dictionary from a list of sentences?

    In gensim, the dictionary contains a map of all words (tokens) to its unique id.

    You can create a dictionary from a paragraph of sentences, from a text file that contains multiple lines of text and from multiple such text files contained in a directory. For the second and third cases, we will do it without loading the entire file into memory so that the dictionary gets updated as you read the text line by line.

    Let’s start with the ‘List of sentences’ input.

    When you have multiple sentences, you need to convert each sentence to a list of words. List comprehensions is a common way to do this.

    import gensim
    from gensim import corpora
    from pprint import pprint
    
    # How to create a dictionary from a list of sentences?
    documents = ["The Saudis are preparing a report that will acknowledge that", 
                 "Saudi journalist Jamal Khashoggi's death was the result of an", 
                 "interrogation that went wrong, one that was intended to lead", 
                 "to his abduction from Turkey, according to two sources."]
    
    documents_2 = ["One source says the report will likely conclude that", 
                    "the operation was carried out without clearance and", 
                    "transparency and that those involved will be held", 
                    "responsible. One of the sources acknowledged that the", 
                    "report is still being prepared and cautioned that", 
                    "things could change."]
    
    # Tokenize(split) the sentences into words
    texts = [[text for text in doc.split()] for doc in documents]
    
    # Create dictionary
    dictionary = corpora.Dictionary(texts)
    
    # Get information about the dictionary
    print(dictionary)
    #> Dictionary(33 unique tokens: ['Saudis', 'The', 'a', 'acknowledge', 'are']...)
    

    As it says the dictionary has 34 unique tokens (or words). Let’s see the unique ids for each of these tokens.

    # Show the word to id map
    print(dictionary.token2id)
    #> {'Saudis': 0, 'The': 1, 'a': 2, 'acknowledge': 3, 'are': 4, 
    #> 'preparing': 5, 'report': 6, 'that': 7, 'will': 8, 'Jamal': 9, 
    #> "Khashoggi's": 10, 'Saudi': 11, 'an': 12, 'death': 13, 
    #> 'journalist': 14, 'of': 15, 'result': 16, 'the': 17, 'was': 18, 
    #> 'intended': 19, 'interrogation': 20, 'lead': 21, 'one': 22, 
    #> 'to': 23, 'went': 24, 'wrong,': 25, 'Turkey,': 26, 'abduction': 27, 
    #> 'according': 28, 'from': 29, 'his': 30, 'sources.': 31, 'two': 32}
    

    We have successfully created a Dictionary object. Gensim will use this dictionary to create a bag-of-words corpus where the words in the documents are replaced with its respective id provided by this dictionary.

    If you get new documents in the future, it is also possible to update an existing dictionary to include the new words.

    documents_2 = ["The intersection graph of paths in trees",
                   "Graph minors IV Widths of trees and well quasi ordering",
                   "Graph minors A survey"]
    
    texts_2 = [[text for text in doc.split()] for doc in documents_2]
    
    dictionary.add_documents(texts_2)
    
    
    # If you check now, the dictionary should have been updated with the new words (tokens).
    print(dictionary)
    #> Dictionary(45 unique tokens: ['Human', 'abc', 'applications', 'computer', 'for']...)
    
    print(dictionary.token2id)
    #> {'Human': 0, 'abc': 1, 'applications': 2, 'computer': 3, 'for': 4, 'interface': 5, 
    #>  'lab': 6, 'machine': 7, 'A': 8, 'of': 9, 'opinion': 10, 'response': 11, 'survey': 12, 
    #>  'system': 13, 'time': 14, 'user': 15, 'EPS': 16, 'The': 17, 'management': 18, 
    #>  'System': 19, 'and': 20, 'engineering': 21, 'human': 22, 'testing': 23, 'Relation': 24, 
    #>  'error': 25, 'measurement': 26, 'perceived': 27, 'to': 28, 'binary': 29, 'generation': 30, 
    #>  'random': 31, 'trees': 32, 'unordered': 33, 'graph': 34, 'in': 35, 'intersection': 36, 
    #>  'paths': 37, 'Graph': 38, 'IV': 39, 'Widths': 40, 'minors': 41, 'ordering': 42, 
    #>  'quasi': 43, 'well': 44}
    

    4. How to create a Dictionary from one or more text files?

    You can also create a dictionary from a text file or from a directory of text files.

    The below example reads a file line-by-line and uses gensim’s simple_preprocess to process one line of the file at a time.

    The advantage here is it let’s you read an entire text file without loading the file in memory all at once.

    Let’s use a sample.txt file to demonstrate this.

    from gensim.utils import simple_preprocess
    from smart_open import smart_open
    import os
    
    # Create gensim dictionary form a single tet file
    dictionary = corpora.Dictionary(simple_preprocess(line, deacc=True) for line in open('sample.txt', encoding='utf-8'))
    
    # Token to Id map
    dictionary.token2id
    
    #> {'according': 35,
    #>  'and': 22,
    #>  'appointment': 23,
    #>  'army': 0,
    #>  'as': 43,
    #>  'at': 24,
    #>   ...
    #> }
    

    We have created a dictionary from a single text file. Nice!

    Now, how to read one-line-at-a-time from multiple files?

    Assuming you have all the text files in the same directory, you need to define a class with an __iter__ method. The __iter__() method should iterate through all the files in a given directory and yield the processed list of word tokens.

    Let’s define one such class by the name ReadTxtFiles, which takes in the path to directory containing the text files. I am using this directory of sports food docs as input.

    class ReadTxtFiles(object):
        def __init__(self, dirname):
            self.dirname = dirname
    
        def __iter__(self):
            for fname in os.listdir(self.dirname):
                for line in open(os.path.join(self.dirname, fname), encoding='latin'):
                    yield simple_preprocess(line)
    
    path_to_text_directory = "lsa_sports_food_docs"
    
    dictionary = corpora.Dictionary(ReadTxtFiles(path_to_text_directory))
    
    # Token to Id map
    dictionary.token2id
    # {'across': 0,
    #  'activity': 1,
    #  'although': 2,
    #  'and': 3,
    #  'are': 4,
    #  ...
    # }
    

    This blog post gives a nice overview to understand the concept of iterators and generators.

    5. How to create a bag of words corpus in gensim?

    Now you know how to create a dictionary from a list and from text file.

    The next important object you need to familiarize with in order to work in gensim is the Corpus (a Bag of Words). That is, it is a corpus object that contains the word id and its frequency in each document. You can think of it as gensim’s equivalent of a Document-Term matrix.

    Once you have the updated dictionary, all you need to do to create a bag of words corpus is to pass the tokenized list of words to the Dictionary.doc2bow()

    Let’s create s Corpus for a simple list (my_docs) containing 2 sentences.

    # List with 2 sentences
    my_docs = ["Who let the dogs out?",
               "Who? Who? Who? Who?"]
    
    # Tokenize the docs
    tokenized_list = [simple_preprocess(doc) for doc in my_docs]
    
    # Create the Corpus
    mydict = corpora.Dictionary()
    mycorpus = [mydict.doc2bow(doc, allow_update=True) for doc in tokenized_list]
    pprint(mycorpus)
    #> [[(0, 1), (1, 1), (2, 1), (3, 1), (4, 1)], [(4, 4)]]
    

    How to interpret the above corpus?

    The (0, 1) in line 1 means, the word with id=0 appears once in the 1st document.
    Likewise, the (4, 4) in the second list item means the word with id 4 appears 4 times in the second document. And so on.

    Well, this is not human readable. To convert the id’s to words, you will need the dictionary to do the conversion.

    Let’s see how to get the original texts back.

    word_counts = [[(mydict[id], count) for id, count in line] for line in mycorpus]
    pprint(word_counts)
    #> [[('dogs', 1), ('let', 1), ('out', 1), ('the', 1), ('who', 1)], [('who', 4)]]
    

    Notice, the order of the words gets lost. Just the word and it’s frequency information is retained.

    6. How to create a bag of words corpus from a text file?

    Reading words from a python list is quite straightforward because the entire text was in-memory already.
    However, you may have a large file that you don’t want to load the entire file in memory.

    You can import such files one line at a time by defining a class and the __iter__ function that iteratively reads the file one line at a time and yields a corpus object. But how to create the corpus object?

    The __iter__() from BoWCorpus reads a line from the file, process it to a list of words using simple_preprocess() and pass that to the dictionary.doc2bow(). Can you related how this is similar and different from the ReadTxtFiles class we created earlier?

    Also, notice that I am using the smart_open() from smart_open package because, it lets you open and read large files line-by-line from a variety of sources such as S3, HDFS, WebHDFS, HTTP, or local and compressed files. That’s pretty awesome by the way!

    However, if you had used open() for a file in your system, it will work perfectly file as well.

    from gensim.utils import simple_preprocess
    from smart_open import smart_open
    import nltk
    nltk.download('stopwords')  # run once
    from nltk.corpus import stopwords
    stop_words = stopwords.words('english')
    
    
    class BoWCorpus(object):
        def __init__(self, path, dictionary):
            self.filepath = path
            self.dictionary = dictionary
    
        def __iter__(self):
            global mydict  # OPTIONAL, only if updating the source dictionary.
            for line in smart_open(self.filepath, encoding='latin'):
                # tokenize
                tokenized_list = simple_preprocess(line, deacc=True)
    
                # create bag of words
                bow = self.dictionary.doc2bow(tokenized_list, allow_update=True)
    
                # update the source dictionary (OPTIONAL)
                mydict.merge_with(self.dictionary)
    
                # lazy return the BoW
                yield bow
    
    
    # Create the Dictionary
    mydict = corpora.Dictionary()
    
    # Create the Corpus
    bow_corpus = BoWCorpus('sample.txt', dictionary=mydict)  # memory friendly
    
    # Print the token_id and count for each line.
    for line in bow_corpus:
        print(line)
    
    #> [(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1), (6, 1), (7, 1), (8, 1), (9, 1), (10, 1), (11, 1)]
    #> [(12, 1), (13, 1), (14, 1), (15, 1), (16, 1), (17, 1)]
    #> ... truncated ...
    

    7. How to save a gensim dictionary and corpus to disk and load them back?

    This is quite straightforward. See the examples below.

    # Save the Dict and Corpus
    mydict.save('mydict.dict')  # save dict to disk
    corpora.MmCorpus.serialize('bow_corpus.mm', bow_corpus)  # save corpus to disk
    

    We have saved the dictionary and corpus objects. Let’s load them back.

    # Load them back
    loaded_dict = corpora.Dictionary.load('mydict.dict')
    
    corpus = corpora.MmCorpus('bow_corpus.mm')
    for line in corpus:
        print(line)
    

    8. How to create the TFIDF matrix (corpus) in gensim?

    The Term Frequency – Inverse Document Frequency(TF-IDF) is also a bag-of-words model but unlike the regular corpus, TFIDF down weights tokens (words) that appears frequently across documents.

    How is TFIDF computed?

    Tf-Idf is computed by multiplying a local component like term frequency (TF) with a global component, that is, inverse document frequency (IDF) and optionally normalizing the result to unit length.

    As a result of this, the words that occur frequently across documents will get downweighted.

    There are multiple variations of formulas for TF and IDF existing. Gensim uses the SMART Information retrieval system that can be used to implement these variations. You can specify what formula to use specifying the smartirs parameter in the TfidfModel. See help(models.TfidfModel) for more details.

    So, how to get the TFIDF weights?

    By training the corpus with models.TfidfModel(). Then, apply the corpus within the square brackets of the trained tfidf model. See example below.

    from gensim import models
    import numpy as np
    
    documents = ["This is the first line",
                 "This is the second sentence",
                 "This third document"]
    
    # Create the Dictionary and Corpus
    mydict = corpora.Dictionary([simple_preprocess(line) for line in documents])
    corpus = [mydict.doc2bow(simple_preprocess(line)) for line in documents]
    
    # Show the Word Weights in Corpus
    for doc in corpus:
        print([[mydict[id], freq] for id, freq in doc])
    
    # [['first', 1], ['is', 1], ['line', 1], ['the', 1], ['this', 1]]
    # [['is', 1], ['the', 1], ['this', 1], ['second', 1], ['sentence', 1]]
    # [['this', 1], ['document', 1], ['third', 1]]
    
    # Create the TF-IDF model
    tfidf = models.TfidfModel(corpus, smartirs='ntc')
    
    # Show the TF-IDF weights
    for doc in tfidf[corpus]:
        print([[mydict[id], np.around(freq, decimals=2)] for id, freq in doc])
    # [['first', 0.66], ['is', 0.24], ['line', 0.66], ['the', 0.24]]
    # [['is', 0.24], ['the', 0.24], ['second', 0.66], ['sentence', 0.66]]
    # [['document', 0.71], ['third', 0.71]]
    

    Notice the difference in weights of the words between the original corpus and the tfidf weighted corpus.

    The words ‘is’ and ‘the’ occur in two documents and were weighted down. The word ‘this’ appearing in all three documents was removed altogether. In simple terms, words that occur more frequently across the documents get smaller weights.

    9. How to use gensim downloader API to load datasets?

    Gensim provides an inbuilt API to download popular text datasets and word embedding models.

    A comprehensive list of available datasets and models is maintained here.

    Using the API to download the dataset is as simple as calling the api.load() method with the right data or model name.

    The below example shows how to download the ‘glove-wiki-gigaword-50’ model.

    import gensim.downloader as api
    
    # Get information about the model or dataset
    api.info('glove-wiki-gigaword-50')
    # {'base_dataset': 'Wikipedia 2014 + Gigaword 5 (6B tokens, uncased)',
    #  'checksum': 'c289bc5d7f2f02c6dc9f2f9b67641813',
    #  'description': 'Pre-trained vectors based on Wikipedia 2014 + Gigaword, 5.6B tokens, 400K vocab, uncased (https://nlp.stanford.edu/projects/glove/).',
    #  'file_name': 'glove-wiki-gigaword-50.gz',
    #  'file_size': 69182535,
    #  'license': 'http://opendatacommons.org/licenses/pddl/',
    #  (... truncated...)
    
    # Download
    w2v_model = api.load("glove-wiki-gigaword-50")
    w2v_model.most_similar('blue')
    # [('red', 0.8901656866073608),
    #  ('black', 0.8648407459259033),
    #  ('pink', 0.8452916741371155),
    #  ('green', 0.8346816301345825),
    #  ... ]
    

    10. How to create bigrams and trigrams using Phraser models?

    Now you know how to download datasets and pre-trained models with gensim.

    Let’s download the text8 dataset, which is nothing but the “First 100,000,000 bytes of plain text from Wikipedia”. Then, from this, we will generate bigrams and trigrams.

    But what are bigrams and trigrams? and why do they matter?

    In paragraphs, certain words always tend to occur in pairs (bigram) or in groups of threes (trigram). Because the two words combined together form the actual entity. For example: The word ‘French’ refers the language or region and the word ‘revolution’ can refer to the planetary revolution. But combining them, ‘French Revolution’, refers to something completely different.

    It’s quite important to form bigrams and trigrams from sentences, especially when working with bag-of-words models.

    So how to create the bigrams?

    It’s quite easy and efficient with gensim’s Phrases model. The created Phrases model allows indexing, so, just pass the original text (list) to the built Phrases model to form the bigrams. An example is shown below:

    dataset = api.load("text8")
    dataset = [wd for wd in dataset]
    
    dct = corpora.Dictionary(dataset)
    corpus = [dct.doc2bow(line) for line in dataset]
    
    # Build the bigram models
    bigram = gensim.models.phrases.Phrases(dataset, min_count=3, threshold=10)
    
    # Construct bigram
    print(bigram[dataset[0]])
    # ['anarchism', 'originated', 'as', 'a', 'term', 'of', 'abuse', 'first', 'used',
    #  'against', 'early', 'working_class', 'radicals', 'including', 'the', 'diggers',
    #  'of', 'the', 'english', 'revolution', 'and', 'the', 'sans_culottes', 'of', 'the',
    #  'french_revolution', 'whilst',...]
    

    The bigrams are ready. Can you guess how to create a trigram?

    Well, Simply rinse and repeat the same procedure to the output of the bigram model. Once you’ve generated the bigrams, you can pass the output to train a new Phrases model. Then, apply the bigrammed corpus on the trained trigram model. Confused? See the example below.

    # Build the trigram models
    trigram = gensim.models.phrases.Phrases(bigram[dataset], threshold=10)
    
    # Construct trigram
    print(trigram[bigram[dataset[0]]])
    

    11. How to create Topic Models with LDA?

    The objective of topic models is to extract the underlying topics from a given collection of text documents. Each document in the text is considered as a combination of topics and each topic is considered as a combination of related words.

    Topic modeling can be done by algorithms like Latent Dirichlet Allocation (LDA) and Latent Semantic Indexing (LSI).

    In both cases you need to provide the number of topics as input. The topic model, in turn, will provide the topic keywords for each topic and the percentage contribution of topics in each document.

    The quality of topics is highly dependent on the quality of text processing and the number of topics you provide to the algorithm. The earlier post on how to build best topic models explains the procedure in more detail. However, I recommend understanding the basic steps involved and the interpretation in the example below.

    Step 0: Load the necessary packages and import the stopwords.

    # Step 0: Import packages and stopwords
    from gensim.models import LdaModel, LdaMulticore
    import gensim.downloader as api
    from gensim.utils import simple_preprocess, lemmatize
    from nltk.corpus import stopwords
    import re
    import logging
    logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s')
    logging.root.setLevel(level=logging.INFO)
    stop_words = stopwords.words('english')
    stop_words = stop_words + ['com', 'edu', 'subject', 'lines', 'organization', 'would', 'article', 'could']
    

    Step 1: Import the dataset. I am going to use the text8 dataset that can be downloaded using gensim’s downloader API.

    # Step 1: Import the dataset and get the text and real topic of each news article
    dataset = api.load("text8")
    data = [d for d in dataset]
    

    Step 2: Prepare the downloaded data by removing stopwords and lemmatize it. For Lemmatization, gensim requires the pattern package. So, be sure to do pip install pattern in your terminal or prompt before running this. I have setup lemmatization such that only Nouns (NN), Adjectives (JJ) and Pronouns (RB) are retained. Because I prefer only such words to go as topic keywords. This is a personal choice.

    # Step 2: Prepare Data (Remove stopwords and lemmatize)
    data_processed = []
    
    for i, doc in enumerate(data[:100]):
        doc_out = []
        for wd in doc:
            if wd not in stop_words:  # remove stopwords
                lemmatized_word = lemmatize(wd, allowed_tags=re.compile('(NN|JJ|RB)'))  # lemmatize
                if lemmatized_word:
                    doc_out = doc_out + [lemmatized_word[0].split(b'/')[0].decode('utf-8')]
            else:
                continue
        data_processed.append(doc_out)
    
    # Print a small sample    
    print(data_processed[0][:5]) 
    #> ['anarchism', 'originated', 'term', 'abuse', 'first']
    

    The data_processed is now processed as a list of list of words. You can now use this to create the Dictionary and Corpus, which will then be used as inputs to the LDA model.

    # Step 3: Create the Inputs of LDA model: Dictionary and Corpus
    dct = corpora.Dictionary(data_processed)
    corpus = [dct.doc2bow(line) for line in data_processed]
    

    We have the Dictionary and Corpus created. Let’s build a LDA topic model with 7 topics, using LdaMulticore(). 7 topics is an arbitrary choice for now.

    # Step 4: Train the LDA model
    lda_model = LdaMulticore(corpus=corpus,
                             id2word=dct,
                             random_state=100,
                             num_topics=7,
                             passes=10,
                             chunksize=1000,
                             batch=False,
                             alpha='asymmetric',
                             decay=0.5,
                             offset=64,
                             eta=None,
                             eval_every=0,
                             iterations=100,
                             gamma_threshold=0.001,
                             per_word_topics=True)
    
    # save the model
    lda_model.save('lda_model.model')
    
    # See the topics
    lda_model.print_topics(-1)
    # [(0, '0.001*"also" + 0.000*"first" + 0.000*"state" + 0.000*"american" + 0.000*"time" + 0.000*"book" + 0.000*"year" + 0.000*"many" + 0.000*"person" + 0.000*"new"'),
    #  (1, '0.001*"also" + 0.001*"state" + 0.001*"ammonia" + 0.000*"first" + 0.000*"many" + 0.000*"american" + 0.000*"war" + 0.000*"time" + 0.000*"year" + 0.000*"name"'),
    #  (2, '0.005*"also" + 0.004*"american" + 0.004*"state" + 0.004*"first" + 0.003*"year" + 0.003*"many" + 0.003*"time" + 0.003*"new" + 0.003*"war" + 0.003*"person"'),
    #  (3, '0.001*"atheism" + 0.001*"also" + 0.001*"first" + 0.001*"atheist" + 0.001*"american" + 0.000*"god" + 0.000*"state" + 0.000*"many" + 0.000*"new" + 0.000*"year"'),
    #  (4, '0.001*"state" + 0.001*"also" + 0.001*"many" + 0.000*"world" + 0.000*"agave" + 0.000*"time" + 0.000*"new" + 0.000*"war" + 0.000*"god" + 0.000*"person"'),
    #  (5, '0.001*"also" + 0.001*"abortion" + 0.001*"first" + 0.001*"american" + 0.000*"state" + 0.000*"many" + 0.000*"year" + 0.000*"time" + 0.000*"war" + 0.000*"person"'),
    #  (6, '0.005*"also" + 0.004*"first" + 0.003*"time" + 0.003*"many" + 0.003*"state" + 0.003*"world" + 0.003*"american" + 0.003*"person" + 0.003*"apollo" + 0.003*"language"')]
    

    The lda_model.print_topics shows what words contributed to which of the 7 topics, along with the weightage of the word’s contribution to that topic.

    You can see the words like ‘also’, ‘many’ coming across different topics. So I would add such words to the stop_words list to remove them and further tune to topic model for optimal number of topics.

    LdaMulticore() supports parallel processing. Alternately you could also try and see what topics the LdaModel() gives.

    12. How to interpret the LDA Topic Model’s output?

    The lda_model object supports indexing. That is, if you pass a document (list of words) to the lda_model, it provides 3 things:

    1. The topic(s) that document belongs to along with percentage.
    2. The topic(s) each word in that document belongs to.
    3. The topic(s) each word in that document belongs to AND the phi values.

    So, what is phi value?

    Phi value is the probability of the word belonging to that particular topic. And the sum of phi values for a given word adds up to the number of times that word occurred in that document.

    For example, in below output for the 0th document, the word with id=0 belongs to topic number 6 and the phi value is 3.999. That means, the word with id=0 appeared 4 times in the 0th document.

    # Reference: https://github.com/RaRe-Technologies/gensim/blob/develop/docs/notebooks/topic_methods.ipynb
    for c in lda_model[corpus[5:8]]:
        print("Document Topics      : ", c[0])      # [(Topics, Perc Contrib)]
        print("Word id, Topics      : ", c[1][:3])  # [(Word id, [Topics])]
        print("Phi Values (word id) : ", c[2][:2])  # [(Word id, [(Topic, Phi Value)])]
        print("Word, Topics         : ", [(dct[wd], topic) for wd, topic in c[1][:2]])   # [(Word, [Topics])]
        print("Phi Values (word)    : ", [(dct[wd], topic) for wd, topic in c[2][:2]])  # [(Word, [(Topic, Phi Value)])]
        print("------------------------------------------------------\n")
    
    #> Document Topics      :  [(2, 0.96124125), (6, 0.038569752)]
    #> Word id, Topics      :  [(0, [2, 6]), (7, [2, 6]), (10, [2, 6])]
    #> Phi Values (word id) :  [(0, [(2, 2.887749), (6, 0.112249866)]), (7, [(2, 0.90105206), (6, 0.09893738)])]
    #> Word, Topics         :  [('ability', [2, 6]), ('absurdity', [2, 6])]
    #> Phi Values (word)    :  [('ability', [(2, 2.887749), (6, 0.112249866)]), ('absurdity', [(2, 0.90105206), (6, 0.09893738)])]
    #> ------------------------------------------------------
    
    #> Document Topics      :  [(6, 0.9997751)]
    #> Word id, Topics      :  [(0, [6]), (10, [6]), (16, [6])]
    #> Phi Values (word id) :  [(0, [(6, 5.9999967)]), (10, [(6, 2.9999983)])]
    #> Word, Topics         :  [('ability', [6]), ('academic', [6])]
    #> Phi Values (word)    :  [('ability', [(6, 5.9999967)]), ('academic', [(6, 2.9999983)])]
    #> ------------------------------------------------------
    
    #> Document Topics      :  [(6, 0.9998023)]
    #> Word id, Topics      :  [(1, [6]), (10, [6]), (15, [6])]
    #> Phi Values (word id) :  [(1, [(6, 0.99999917)]), (10, [(6, 5.999997)])]
    #> Word, Topics         :  [('able', [6]), ('academic', [6])]
    #> Phi Values (word)    :  [('able', [(6, 0.99999917)]), ('academic', [(6, 5.999997)])]
    #> ------------------------------------------------------
    

    13. How to create a LSI topic model using gensim?

    The syntax for using an LSI model is similar to how we built the LDA model, except that we will use the LsiModel().

    from gensim.models import LsiModel
    
    # Build the LSI Model
    lsi_model = LsiModel(corpus=corpus, id2word=dct, num_topics=7, decay=0.5)
    
    # View Topics
    pprint(lsi_model.print_topics(-1))
    #> [(0, '0.262*"also" + 0.197*"state" + 0.197*"american" + 0.178*"first" + '
    #>   '0.151*"many" + 0.149*"time" + 0.147*"year" + 0.130*"person" + 0.130*"world" '
    #>   '+ 0.124*"war"'),
    #>  (1, '0.937*"agave" + 0.164*"asia" + 0.100*"aruba" + 0.063*"plant" + 0.053*"var" '
    #>   '+ 0.052*"state" + 0.045*"east" + 0.044*"congress" + -0.042*"first" + '
    #>   '0.041*"maguey"'),
    #>  (2, '0.507*"american" + 0.180*"football" + 0.179*"player" + 0.168*"war" + '
    #>   '0.150*"british" + -0.140*"also" + 0.114*"ball" + 0.110*"day" + '
    #>   '-0.107*"atheism" + -0.106*"god"'),
    #>  (3, '-0.362*"apollo" + 0.248*"lincoln" + 0.211*"state" + -0.172*"player" + '
    #>   '-0.151*"football" + 0.127*"union" + -0.125*"ball" + 0.124*"government" + '
    #>   '-0.116*"moon" + 0.116*"jews"'),
    #>  (4, '-0.363*"atheism" + -0.334*"god" + -0.329*"lincoln" + -0.230*"apollo" + '
    #>   '-0.215*"atheist" + -0.143*"abraham" + 0.136*"island" + -0.132*"aristotle" + '
    #>   '0.124*"aluminium" + -0.119*"belief"'),
    #>  (5, '-0.360*"apollo" + 0.344*"atheism" + -0.326*"lincoln" + 0.226*"god" + '
    #>   '0.205*"atheist" + 0.139*"american" + -0.130*"lunar" + 0.128*"football" + '
    #>   '-0.125*"moon" + 0.114*"belief"'),
    #>  (6, '-0.313*"lincoln" + 0.226*"apollo" + -0.166*"football" + -0.163*"war" + '
    #>   '0.162*"god" + 0.153*"australia" + -0.148*"play" + -0.146*"ball" + '
    #>   '0.122*"atheism" + -0.122*"line"')]
    

    14. How to train Word2Vec model using gensim?

    A word embedding model is a model that can provide numerical vectors for a given word. Using the Gensim’s downloader API, you can download pre-built word embedding models like word2vec, fasttext, GloVe and ConceptNet. These are built on large corpuses of commonly occurring text data such as wikipedia, google news etc.

    However, if you are working in a specialized niche such as technical documents, you may not able to get word embeddings for all the words. So, in such cases its desirable to train your own model.

    Gensim’s Word2Vec implementation let’s you train your own word embedding model for a given corpus.

    from gensim.models.word2vec import Word2Vec
    from multiprocessing import cpu_count
    import gensim.downloader as api
    
    # Download dataset
    dataset = api.load("text8")
    data = [d for d in dataset]
    
    # Split the data into 2 parts. Part 2 will be used later to update the model
    data_part1 = data[:1000]
    data_part2 = data[1000:]
    
    # Train Word2Vec model. Defaults result vector size = 100
    model = Word2Vec(data_part1, min_count = 0, workers=cpu_count())
    
    # Get the word vector for given word
    model['topic']
    #> array([ 0.0512,  0.2555,  0.9393, ... ,-0.5669,  0.6737], dtype=float32)
    
    model.most_similar('topic')
    #> [('discussion', 0.7590423822402954),
    #>  ('consensus', 0.7253159284591675),
    #>  ('discussions', 0.7252693176269531),
    #>  ('interpretation', 0.7196053266525269),
    #>  ('viewpoint', 0.7053568959236145),
    #>  ('speculation', 0.7021505832672119),
    #>  ('discourse', 0.7001898884773254),
    #>  ('opinions', 0.6993060111999512),
    #>  ('focus', 0.6959210634231567),
    #>  ('scholarly', 0.6884037256240845)]
    
    # Save and Load Model
    model.save('newmodel')
    model = Word2Vec.load('newmodel')
    

    We have trained and saved a Word2Vec model for our document. However, when a new dataset comes, you want to update the model so as to account for new words.

    15. How to update an existing Word2Vec model with new data?

    On an existing Word2Vec model, call the build_vocab() on the new datset and then call the train() method. build_vocab() is called first because the model has to be apprised of what new words to expect in the incoming corpus.

    # Update the model with new data.
    model.build_vocab(data_part2, update=True)
    model.train(data_part2, total_examples=model.corpus_count, epochs=model.iter)
    model['topic']
    # array([-0.6482, -0.5468,  1.0688,  0.82  , ... , -0.8411,  0.3974], dtype=float32)
    

    16. How to extract word vectors using pre-trained Word2Vec and FastText models?

    We just saw how to get the word vectors for Word2Vec model we just trained. However, gensim lets you download state of the art pretrained models through the downloader API. Let’s see how to extract the word vectors from a couple of these models.

    import gensim.downloader as api
    
    # Download the models
    fasttext_model300 = api.load('fasttext-wiki-news-subwords-300')
    word2vec_model300 = api.load('word2vec-google-news-300')
    glove_model300 = api.load('glove-wiki-gigaword-300')
    
    # Get word embeddings
    word2vec_model300.most_similar('support')
    # [('supporting', 0.6251285076141357),
    #  ...
    #  ('backing', 0.6007589101791382),
    #  ('supports', 0.5269277691841125),
    #  ('assistance', 0.520713746547699),
    #  ('supportive', 0.5110025405883789)]
    

    We have 3 different embedding models. You can evaluate which one performs better using the respective model’s evaluate_word_analogies() on a standard analogies dataset.

    # Word2ec_accuracy
    word2vec_model300.evaluate_word_analogies(analogies="questions-words.txt")[0]
    #> 0.7401448525607863
    
    # fasttext_accuracy
    fasttext_model300.evaluate_word_analogies(analogies="questions-words.txt")[0]
    #> 0.8827876424099353
    
    # GloVe accuracy
    glove_model300.evaluate_word_analogies(analogies="questions-words.txt")[0]
    #> 0.7195422354510931
    

    17. How to create document vectors using Doc2Vec?

    Unlike Word2Vec, a Doc2Vec model provides a vectorised representation of a group of words taken collectively as a single unit. It is not a simple average of the word vectors of the words in the sentence.

    Let’s use the text8 dataset to train the Doc2Vec.

    import gensim
    import gensim.downloader as api
    
    # Download dataset
    dataset = api.load("text8")
    data = [d for d in dataset]
    

    The training data for Doc2Vec should be a list of TaggedDocuments. To create one, we pass a list of words and a unique integer as input to the models.doc2vec.TaggedDocument().

    # Create the tagged document needed for Doc2Vec
    def create_tagged_document(list_of_list_of_words):
        for i, list_of_words in enumerate(list_of_list_of_words):
            yield gensim.models.doc2vec.TaggedDocument(list_of_words, [i])
    
    train_data = list(create_tagged_document(data))
    
    print(train_data[:1])
    #> [TaggedDocument(words=['anarchism', 'originated', ... 'social', 'or', 'emotional'], tags=[0])]
    

    The input is prepared. To train the model, you need to initialize the Doc2Vec model, build the vocabulary and then finally train the model.

    # Init the Doc2Vec model
    model = gensim.models.doc2vec.Doc2Vec(vector_size=50, min_count=2, epochs=40)
    
    # Build the Volabulary
    model.build_vocab(train_data)
    
    # Train the Doc2Vec model
    model.train(train_data, total_examples=model.corpus_count, epochs=model.epochs)
    

    To get the document vector of a sentence, pass it as a list of words to the infer_vector() method.

    print(model.infer_vector(['australian', 'captain', 'elected', 'to', 'bowl']))
    #> array([-0.11043505,  0.21719663, -0.21167697, -0.10790558,  0.5607173 ,
    #>        ...
    #>        0.16428669, -0.31307793, -0.28575218, -0.0113026 ,  0.08981086],
    #>       dtype=float32)
    

    18. How to compute similarity metrics like cosine similarity and soft cosine similarity?

    Soft cosine similarity is similar to cosine similarity but in addition considers the semantic relationship between the words through its vector representation.

    To compute soft cosines, you will need a word embedding model like Word2Vec or FastText. First, compute the similarity_matrix. Then convert the input sentences to bag-of-words corpus and pass them to the softcossim() along with the similarity matrix.

    from gensim.matutils import softcossim
    from gensim import corpora
    
    sent_1 = 'Sachin is a cricket player and a opening batsman'.split()
    sent_2 = 'Dhoni is a cricket player too He is a batsman and keeper'.split()
    sent_3 = 'Anand is a chess player'.split()
    
    # Prepare the similarity matrix
    similarity_matrix = fasttext_model300.similarity_matrix(dictionary, tfidf=None, threshold=0.0, exponent=2.0, nonzero_limit=100)
    
    # Prepare a dictionary and a corpus.
    documents = [sent_1, sent_2, sent_3]
    dictionary = corpora.Dictionary(documents)
    
    # Convert the sentences into bag-of-words vectors.
    sent_1 = dictionary.doc2bow(sent_1)
    sent_2 = dictionary.doc2bow(sent_2)
    sent_3 = dictionary.doc2bow(sent_3)
    
    # Compute soft cosine similarity
    print(softcossim(sent_1, sent_2, similarity_matrix))
    #> 0.7868705819999783
    
    print(softcossim(sent_1, sent_3, similarity_matrix))
    #> 0.6036445529268666
    
    print(softcossim(sent_2, sent_3, similarity_matrix))
    #> 0.60965453519611
    

    Below are some useful similarity and distance metrics based on the word embedding models like fasttext and GloVe. We have already downloaded these models using the downloader API.

    # Which word from the given list doesn't go with the others?
    print(fasttext_model300.doesnt_match(['india', 'australia', 'pakistan', 'china', 'beetroot']))  
    #> beetroot
    
    # Compute cosine distance between two words.
    print(fasttext_model300.distance('king', 'queen'))
    #> 0.22957539558410645
    
    
    # Compute cosine distances from given word or vector to all words in `other_words`.
    print(fasttext_model300.distances('king', ['queen', 'man', 'woman']))
    #> [0.22957546 0.465837   0.547001  ]
    
    
    # Compute cosine similarities
    print(fasttext_model300.cosine_similarities(fasttext_model300['king'], 
                                                vectors_all=(fasttext_model300['queen'], 
                                                            fasttext_model300['man'], 
                                                            fasttext_model300['woman'],
                                                            fasttext_model300['queen'] + fasttext_model300['man'])))  
    #> array([0.77042454, 0.534163  , 0.45299897, 0.76572555], dtype=float32)
    # Note: Queen + Man is very similar to King.
    
    # Get the words closer to w1 than w2
    print(glove_model300.words_closer_than(w1='king', w2='kingdom'))
    #> ['prince', 'queen', 'monarch']
    
    
    # Find the top-N most similar words.
    print(fasttext_model300.most_similar(positive='king', negative=None, topn=5, restrict_vocab=None, indexer=None))
    #> [('queen', 0.63), ('prince', 0.62), ('monarch', 0.59), ('kingdom', 0.58), ('throne', 0.56)]
    
    
    # Find the top-N most similar words, using the multiplicative combination objective,
    print(glove_model300.most_similar_cosmul(positive='king', negative=None, topn=5))
    #> [('queen', 0.82), ('prince', 0.81), ('monarch', 0.79), ('kingdom', 0.79), ('throne', 0.78)]
    

    19. How to summarize text documents?

    Gensim implements the textrank summarization using the summarize() function in the summarization module. All you need to do is to pass in the tet string along with either the output summarization ratio or the maximum count of words in the summarized output.

    There is no need to split the sentence into a tokenized list because gensim does the splitting using the built-in split_sentences() method in the gensim.summarization.texcleaner module.

    Let’s summarize the clipping from a new article in sample.txt.

    from gensim.summarization import summarize, keywords
    from pprint import pprint
    
    text = " ".join((line for line in smart_open('sample.txt', encoding='utf-8')))
    
    # Summarize the paragraph
    pprint(summarize(text, word_count=20))
    #> ('the PLA Rocket Force national defense science and technology experts panel, '
    #>  'according to a report published by the')
    
    # Important keywords from the paragraph
    print(keywords(text))
    #> force zhang technology experts pla rocket
    

    For more information on summarization with gensim, refer to this tutorial.

    20. Conclusion

    We have covered a lot of ground about the various features of gensim and get a good grasp on how to work with and manipulate texts. The above examples should serve as nice templates to get you started and build upon for various NLP tasks. Hope you will find it helpful and feel comfortable to use gensim more often in your NLP projects.

    Comprar Seguidores No Instagram

    Lemmatization Approaches with Examples in Python

    Lemmatization is the process of converting a word to its base form. The difference between stemming and lemmatization is, lemmatization considers the context and converts the word to its meaningful base form, whereas stemming just removes the last few characters, often leading to incorrect meanings and spelling errors.

    Comparing Lemmatization Approaches in Python. Photo by Jasmin Schreiber
    [container]

    Contents

    [columnize] 1. Introduction
    2. Wordnet Lemmatizer
    3. Wordnet Lemmatizer with appropriate POS tag
    4. spaCy Lemmatization
    5. TextBlob Lemmatizer
    6. TextBlob Lemmatizer with appropriate POS tag
    7. Pattern Lemmatizer
    8. Stanford CoreNLP Lemmatization
    9. Gensim Lemmatize
    10. TreeTagger
    11. Comparing NLTK, TextBlob, spaCy, Pattern and Stanford CoreNLP
    12. Conclusion
    [/columnize] [/container]

    1. Introduction

    Lemmatization is the process of converting a word to its base form. The difference between stemming and lemmatization is, lemmatization considers the context and converts the word to its meaningful base form, whereas stemming just removes the last few characters, often leading to incorrect meanings and spelling errors.

    For example, lemmatization would correctly identify the base form of ‘caring’ to ‘care’, whereas, stemming would cutoff the ‘ing’ part and convert it to car.

    ‘Caring’ -> Lemmatization -> ‘Care’
    ‘Caring’ -> Stemming -> ‘Car’

    Also, sometimes, the same word can have multiple different ‘lemma’s. So, based on the context it’s used, you should identify the ‘part-of-speech’ (POS) tag for the word in that specific context and extract the appropriate lemma. Examples of implementing this comes in the following sections.

    Today, we will see how to implement lemmatization using the following python packages.

    1. Wordnet Lemmatizer

    2. Spacy Lemmatizer

    3. TextBlob

    4. CLiPS Pattern

    5. Stanford CoreNLP

    6. Gensim Lemmatizer

    7. TreeTagger

    2. Wordnet Lemmatizer with NLTK

    Wordnet is an large, freely and publicly available lexical database for the English language aiming to establish structured semantic relationships between words. It offers lemmatization capabilities as well and is one of the earliest and most commonly used lemmatizers.

    NLTK offers an interface to it, but you have to download it first in order to use it. Follow the below instructions to install nltk and download wordnet.

    # How to install and import NLTK
    # In terminal or prompt:
    # pip install nltk
    
    # # Download Wordnet through NLTK in python console:
    import nltk
    nltk.download('wordnet')
    

    In order to lemmatize, you need to create an instance of the WordNetLemmatizer() and call the lemmatize() function on a single word.

    import nltk
    from nltk.stem import WordNetLemmatizer 
    
    # Init the Wordnet Lemmatizer
    lemmatizer = WordNetLemmatizer()
    
    # Lemmatize Single Word
    print(lemmatizer.lemmatize("bats"))
    #> bat
    
    print(lemmatizer.lemmatize("are"))
    #> are
    
    print(lemmatizer.lemmatize("feet"))
    #> foot
    

    Let’s lemmatize a simple sentence. We first tokenize the sentence into words using nltk.word_tokenize and then we will call lemmatizer.lemmatize() on each word. This can be done in a list comprehension (the for-loop inside square brackets to make a list).

    # Define the sentence to be lemmatized
    sentence = "The striped bats are hanging on their feet for best"
    
    # Tokenize: Split the sentence into words
    word_list = nltk.word_tokenize(sentence)
    print(word_list)
    #> ['The', 'striped', 'bats', 'are', 'hanging', 'on', 'their', 'feet', 'for', 'best']
    
    # Lemmatize list of words and join
    lemmatized_output = ' '.join([lemmatizer.lemmatize(w) for w in word_list])
    print(lemmatized_output)
    #> The striped bat are hanging on their foot for best
    

    The above code is a simple example of how to use the wordnet lemmatizer on words and sentences.

    Notice it didn’t do a good job. Because, ‘are’ is not converted to ‘be’ and ‘hanging’ is not converted to ‘hang’ as expected. This can be corrected if we provide the correct ‘part-of-speech’ tag (POS tag) as the second argument to lemmatize().

    Sometimes, the same word can have a multiple lemmas based on the meaning / context.

    print(lemmatizer.lemmatize("stripes", 'v'))  
    #> strip
    
    print(lemmatizer.lemmatize("stripes", 'n'))  
    #> stripe
    

    3. Wordnet Lemmatizer with appropriate POS tag

    It may not be possible manually provide the corrent POS tag for every word for large texts. So, instead, we will find out the correct POS tag for each word, map it to the right input character that the WordnetLemmatizer accepts and pass it as the second argument to lemmatize().

    So how to get the POS tag for a given word?

    In nltk, it is available through the nltk.pos_tag() method. It accepts only a list (list of words), even if its a single word.

    print(nltk.pos_tag(['feet']))
    #> [('feet', 'NNS')]
    
    print(nltk.pos_tag(nltk.word_tokenize(sentence)))
    #> [('The', 'DT'), ('striped', 'JJ'), ('bats', 'NNS'), ('are', 'VBP'), ('hanging', 'VBG'), ('on', 'IN'), ('their', 'PRP$'), ('feet', 'NNS'), ('for', 'IN'), ('best', 'JJS')]
    

    nltk.pos_tag() returns a tuple with the POS tag. The key here is to map NLTK’s POS tags to the format wordnet lemmatizer would accept. The get_wordnet_pos() function defined below does this mapping job.

    # Lemmatize with POS Tag
    from nltk.corpus import wordnet
    
    def get_wordnet_pos(word):
        """Map POS tag to first character lemmatize() accepts"""
        tag = nltk.pos_tag([word])[0][1][0].upper()
        tag_dict = {"J": wordnet.ADJ,
                    "N": wordnet.NOUN,
                    "V": wordnet.VERB,
                    "R": wordnet.ADV}
    
        return tag_dict.get(tag, wordnet.NOUN)
    
    
    # 1. Init Lemmatizer
    lemmatizer = WordNetLemmatizer()
    
    # 2. Lemmatize Single Word with the appropriate POS tag
    word = 'feet'
    print(lemmatizer.lemmatize(word, get_wordnet_pos(word)))
    
    # 3. Lemmatize a Sentence with the appropriate POS tag
    sentence = "The striped bats are hanging on their feet for best"
    print([lemmatizer.lemmatize(w, get_wordnet_pos(w)) for w in nltk.word_tokenize(sentence)])
    #> ['The', 'strip', 'bat', 'be', 'hang', 'on', 'their', 'foot', 'for', 'best']
    

    4. spaCy Lemmatization

    spaCy is a relatively new in the space and is billed as an industrial strength NLP engine. It comes with pre-built models that can parse text and compute various NLP related features through one single function call. Ofcourse, it provides the lemma of the word too.

    Before we begin, let’s install spaCy and download the ‘en’ model.

    # Install spaCy (run in terminal/prompt)
    import sys
    !{sys.executable} -m pip install spacy
    
    # Download spaCy's  'en' Model
    !{sys.executable} -m spacy download en
    

    spaCy determines the part-of-speech tag by default and assigns the corresponding lemma. It comes with a bunch of prebuilt models where the ‘en’ we just downloaded above is one of the standard ones for english.

    import spacy
    
    # Initialize spacy 'en' model, keeping only tagger component needed for lemmatization
    nlp = spacy.load('en', disable=['parser', 'ner'])
    
    sentence = "The striped bats are hanging on their feet for best"
    
    # Parse the sentence using the loaded 'en' model object `nlp`
    doc = nlp(sentence)
    
    # Extract the lemma for each token and join
    " ".join([token.lemma_ for token in doc])
    #> 'the strip bat be hang on -PRON- foot for good'
    

    It did all the lemmatizations the Wordnet Lemmatizer supplied with the correct POS tag did. Plus it also lemmatized ‘best’ to ‘good’. Nice!

    You’d see the -PRON- character coming up whenever spacy detects a pronoun.

    5. TextBlob Lemmatizer

    TexxtBlob is a powerful, fast and convenient NLP package as well. Using the Word and TextBlob objects, its quite straighforward to parse and lemmatize words and sentences respectively.

    # pip install textblob
    from textblob import TextBlob, Word
    
    # Lemmatize a word
    word = 'stripes'
    w = Word(word)
    w.lemmatize()
    #> stripe
    

    However to lemmatize a sentence or paragraph, we parse it using TextBlob and call the lemmatize() function on the parsed words.

    # Lemmatize a sentence
    sentence = "The striped bats are hanging on their feet for best"
    sent = TextBlob(sentence)
    " ". join([w.lemmatize() for w in sent.words])
    #> 'The striped bat are hanging on their foot for best'
    

    It did not do a great job at the outset, because, like NLTK, TextBlob also uses wordnet internally. So, let’s pass the appropriate POS tag to the lemmatize() method.

    6. TextBlob Lemmatizer with appropriate POS tag

    # Define function to lemmatize each word with its POS tag
    def lemmatize_with_postag(sentence):
        sent = TextBlob(sentence)
        tag_dict = {"J": 'a', 
                    "N": 'n', 
                    "V": 'v', 
                    "R": 'r'}
        words_and_tags = [(w, tag_dict.get(pos[0], 'n')) for w, pos in sent.tags]    
        lemmatized_list = [wd.lemmatize(tag) for wd, tag in words_and_tags]
        return " ".join(lemmatized_list)
    
    # Lemmatize
    sentence = "The striped bats are hanging on their feet for best"
    lemmatize_with_postag(sentence)
    #> 'The striped bat be hang on their foot for best'
    

    7. Pattern Lemmatizer

    Pattern by CLiPs is a versatile module with many useful NLP capabilities.

    !pip install pattern
    

    If you run into issues while installing pattern, check out the known issues on github. I myself faced this issue when installing on a mac.

    import pattern
    from pattern.en import lemma, lexeme
    
    sentence = "The striped bats were hanging on their feet and ate best fishes"
    " ".join([lemma(wd) for wd in sentence.split()])
    #> 'the stripe bat be hang on their feet and eat best fishes'
    

    You can also view the possible lexeme’s for each word.

    # Lexeme's for each word 
    [lexeme(wd) for wd in sentence.split()]
    
    #> [['the', 'thes', 'thing', 'thed'],
    #>  ['stripe', 'stripes', 'striping', 'striped'],
    #>  ['bat', 'bats', 'batting', 'batted'],
    #>  ['be', 'am', 'are', 'is', 'being', 'was', 'were', 'been', 
    #> . 'am not', "aren't", "isn't", "wasn't", "weren't"],
    #>  ['hang', 'hangs', 'hanging', 'hung'],
    #>  ['on', 'ons', 'oning', 'oned'],
    #>  ['their', 'theirs', 'theiring', 'theired'],
    #>  ['feet', 'feets', 'feeting', 'feeted'],
    #>  ['and', 'ands', 'anding', 'anded'],
    #>  ['eat', 'eats', 'eating', 'ate', 'eaten'],
    #>  ['best', 'bests', 'besting', 'bested'],
    #>  ['fishes', 'fishing', 'fishesed']]
    

    You could also obtain the lemma by parsing the text.

    from pattern.en import parse
    print(parse('The striped bats were hanging on their feet and ate best fishes', 
                lemmata=True, tags=False, chunks=False))
    #> The/DT/the striped/JJ/striped bats/NNS/bat were/VBD/be hanging/VBG/hang on/IN/on their/PRP$/their 
    #>  feet/NNS/foot and/CC/and ate/VBD/eat best/JJ/best fishes/NNS/fish
    

    8. Stanford CoreNLP Lemmatization

    Standford CoreNLP is a popular NLP tool that is originally implemented in Java. There are many python wrappers written around it. The one I use below is one that is quite convenient to use.

    But before that, you need to download Java and the Standford CoreNLP software. Make sure you have the following requirements before getting to the lemmatization code:

    Step 1: Java 8 Installed

    You can download and install from Java download page.

    Mac users can check the java version by typing java -version in terminal. If its 1.8+, then its Ok. Else follow below steps.

    brew update
    brew install jenv
    brew cask install java
    

    Step 2: Download Standford CoreNLP software and unzip it.

    Step 3: Start the Stanford CoreNLP server from terminal. How? cd to the folder you just unzipped and run below command in terminal:

    cd stanford-corenlp-full-2018-02-27
    java -mx4g -cp "*" edu.stanford.nlp.pipeline.StanfordCoreNLPServer -annotators "tokenize,ssplit,pos,lemma,parse,sentiment" -port 9000 -timeout 30000
    

    This will start a StanfordCoreNLPServer listening at port 9000. Now, we are ready to extract the lemmas in python.

    In the stanfordcorenlp package, the lemma is embedded in the output of the annotate() method of the StanfordCoreNLP connection object (see code below).

    # Run `pip install stanfordcorenlp` to install stanfordcorenlp package
    from stanfordcorenlp import StanfordCoreNLP
    import json
    
    # Connect to the CoreNLP server we just started
    nlp = StanfordCoreNLP('http://localhost', port=9000, timeout=30000)
    
    # Define proporties needed to get lemma
    props = {'annotators': 'pos,lemma',
             'pipelineLanguage': 'en',
             'outputFormat': 'json'}
    
    
    sentence = "The striped bats were hanging on their feet and ate best fishes"
    parsed_str = nlp.annotate(sentence, properties=props)
    parsed_dict = json.loads(parsed_str)
    parsed_dict
    #> {'sentences': [{'index': 0,
    #>    'tokens': [{'after': ' ',
    #>      'before': '',
    #>      'characterOffsetBegin': 0,
    #>      'characterOffsetEnd': 3,
    #>      'index': 1,
    #>      'lemma': 'the',      << ----------- LEMMA
    #>      'originalText': 'The',
    #>      'pos': 'DT',
    #>      'word': 'The'},
    #>     {'after': ' ',
    #>      'before': ' ',
    #>      'characterOffsetBegin': 4,
    #>      'characterOffsetEnd': 11,
    #>      'index': 2,
    #>      'lemma': 'striped',  << ----------- LEMMA
    #>      'originalText': 'striped',
    #>      'pos': 'JJ',
    #>      'word': 'striped'},
    #>     {'after': ' ',
    #>      'before': ' ',
    #>      'characterOffsetBegin': 12,
    #>      'characterOffsetEnd': 16,
    #>      'index': 3,
    #>      'lemma': 'bat',      << ----------- LEMMA
    #>      'originalText': 'bats',
    #>      'pos': 'NNS',
    #>      'word': 'bats'}
    #> ...
    #> ...              
    

    The output of nlp.annotate() was converted to a dict using json.loads. Now the lemma we need is embedded a couple of layers inside the parsed_dict. So here, we need to just the lemma value from each dict. I use list comprehensions below to do the trick.

    lemma_list = [v for d in parsed_dict['sentences'][0]['tokens'] for k,v in d.items() if k == 'lemma']
    " ".join(lemma_list)
    #> 'the striped bat be hang on they foot and eat best fish'
    

    Let’s generalize this a nice function so as to handle larger paragraphs.

    from stanfordcorenlp import StanfordCoreNLP
    import json, string
    
    def lemmatize_corenlp(conn_nlp, sentence):
        props = {
            'annotators': 'pos,lemma',
            'pipelineLanguage': 'en',
            'outputFormat': 'json'
        }
    
        # tokenize into words
        sents = conn_nlp.word_tokenize(sentence)
    
        # remove punctuations from tokenised list
        sents_no_punct = [s for s in sents if s not in string.punctuation]
    
        # form sentence
        sentence2 = " ".join(sents_no_punct)
    
        # annotate to get lemma
        parsed_str = conn_nlp.annotate(sentence2, properties=props)
        parsed_dict = json.loads(parsed_str)
    
        # extract the lemma for each word
        lemma_list = [v for d in parsed_dict['sentences'][0]['tokens'] for k,v in d.items() if k == 'lemma']
    
        # form sentence and return it
        return " ".join(lemma_list)
    
    
    # make the connection and call `lemmatize_corenlp`
    nlp = StanfordCoreNLP('http://localhost', port=9000, timeout=30000)
    lemmatize_corenlp(conn_nlp=nlp, sentence=sentence)
    #> 'the striped bat be hang on they foot and eat best fish'
    

    9. Gensim Lemmatize

    Gensim provide lemmatization facilities based on the pattern package. It can be implemented using the lemmatize() method in the utils module. By default lemmatize() allows only the ‘JJ’, ‘VB’, ‘NN’ and ‘RB’ tags.

    
    from gensim.utils import lemmatize
    sentence = "The striped bats were hanging on their feet and ate best fishes"
    lemmatized_out = [wd.decode('utf-8').split('/')[0] for wd in lemmatize(sentence)]
    #> ['striped', 'bat', 'be', 'hang', 'foot', 'eat', 'best', 'fish']
    

    10. TreeTagger

    Treetagger is a Part-of-Speech tagger for many languages. And it provides the lemma of the word as well.

    You will need to download and install the TreeTagger software itself in order to use it by following steps mentioned.

    # pip install treetaggerwrapper
    
    import treetaggerwrapper as ttpw
    tagger = ttpw.TreeTagger(TAGLANG='en', TAGDIR='/Users/ecom-selva.p/Documents/MLPlus/11_Lemmatization/treetagger')
    tags = tagger.tag_text("The striped bats were hanging on their feet and ate best fishes")
    lemmas = [t.split('\t')[-1] for t in tags]
    #> ['the', 'striped', 'bat', 'be', 'hang', 'on', 'their', 'foot', 'and', 'eat', 'good', 'fish']

    Treetagger indeed does a good job in converting ‘best’ to ‘good’ and for other words as well. For further reading, refer to TreeTaggerWrapper’s documentation.

    11. Comparing NLTK, TextBlob, spaCy, Pattern and Stanford CoreNLP

    Let’s run lemmatization using the 5 implementations on the following sentence and compare output.

    sentence = """Following mice attacks, caring farmers were marching to Delhi for better living conditions. 
    Delhi police on Tuesday fired water cannons and teargas shells at protesting farmers as they tried to 
    break barricades with their cars, automobiles and tractors."""
    
    # NLTK
    from nltk.stem import WordNetLemmatizer
    lemmatizer = WordNetLemmatizer()
    pprint(" ".join([lemmatizer.lemmatize(w, get_wordnet_pos(w)) for w in nltk.word_tokenize(sentence) if w not in string.punctuation]))
    # ('Following mouse attack care farmer be march to Delhi for well living '
    #  'condition Delhi police on Tuesday fire water cannon and teargas shell at '
    #  'protest farmer a they try to break barricade with their car automobile and '
    #  'tractor')
    
    # Spacy
    import spacy
    nlp = spacy.load('en', disable=['parser', 'ner'])
    doc = nlp(sentence)
    pprint(" ".join([token.lemma_ for token in doc]))
    # ('follow mice attack , care farmer be march to delhi for good living condition '
    #  '. delhi police on tuesday fire water cannon and teargas shell at protest '
    #  'farmer as -PRON- try to break barricade with -PRON- car , automobile and '
    #  'tractor .')
    
    # TextBlob
    pprint(lemmatize_with_postag(sentence))
    # ('Following mouse attack care farmer be march to Delhi for good living '
    #  'condition Delhi police on Tuesday fire water cannon and teargas shell at '
    #  'protest farmer a they try to break barricade with their car automobile and '
    #  'tractor')
    
    # Pattern
    from pattern.en import lemma
    pprint(" ".join([lemma(wd) for wd in sentence.split()]))
    # ('follow mice attacks, care farmer be march to delhi for better live '
    #  'conditions. delhi police on tuesday fire water cannon and tearga shell at '
    #  'protest farmer a they try to break barricade with their cars, automobile and '
    #  'tractors.')
    
    # Stanford
    pprint(lemmatize_corenlp(conn_nlp=conn_nlp, sentence=sentence))
    # ('follow mouse attack care farmer be march to Delhi for better living '
    #  'condition Delhi police on Tuesday fire water cannon and tearga shell at '
    #  'protest farmer as they try to break barricade with they car automobile and '
    #  'tractor')
    

    12. Conclusion

    So those are the methods you can use the text time you take up an NLP project. I would be happy to know if you have any new approaches or suggestions through your comments. Happy learning!

    Feature Selection – Ten Effective Techniques with Examples

    In machine learning, Feature selection is the process of choosing variables that are useful in predicting the response (Y). It is considered a good practice to identify which features are important when building predictive models. In this post, you will see how to implement 10 powerful feature selection approaches in R.

    Feature Selection Methods
    [container]

    Contents

    [columnize] Introduction
    1. Boruta
    2. Variable Importance from Machine Learning Algorithms
    3. Lasso Regression
    4. Step wise Forward and Backward Selection
    5. Relative Importance from Linear Regression
    6. Recursive Feature Elimination (RFE)
    7. Genetic Algorithm
    8. Simulated Annealing
    9. Information Value and Weights of Evidence
    10. DALEX Package
    Conclusion
    [/columnize] [/container]

    Introduction

    In real-world datasets, it is fairly common to have columns that are nothing but noise.

    You are better off getting rid of such variables because of the memory space they occupy, the time and the computational resources it is going to cost, especially in large datasets.

    Sometimes, you have a variable that makes business sense, but you are not sure if it actually helps in predicting the Y. You also need to consider the fact that, a feature that could be useful in one ML algorithm (say a decision tree) may go underrepresented or unused by another (like a regression model).

    Having said that, it is still possible that a variable that shows poor signs of helping to explain the response variable (Y), can turn out to be significantly useful in the presence of (or combination with) other predictors. What I mean by that is, a variable might have a low correlation value of (~0.2) with Y. But in the presence of other variables, it can help to explain certain patterns/phenomenon that other variables can’t explain.

    In such cases, it can be hard to make a call whether to include or exclude such variables.

    The strategies we are about to discuss can help fix such problems. Not only that, it will also help understand if a particular variable is important or not and how much it is contributing to the model

    An important caveat. It is always best to have variables that have sound business logic backing the inclusion of a variable and rely solely on variable importance metrics.

    Alright. Let’s load up the 'Glaucoma' dataset where the goal is to predict if a patient has Glaucoma or not based on 63 different physiological measurements. You can directly run the codes or download the dataset here.

    A lot of interesting examples ahead. Let’s get started.

    # Load Packages and prepare dataset
    library(TH.data)
    library(caret)
    data("GlaucomaM", package = "TH.data")
    trainData <- GlaucomaM
    head(trainData)
    
    Glaucoma Dataset
    Glaucoma Dataset

    1. Boruta

    Boruta is a feature ranking and selection algorithm based on random forests algorithm.

    The advantage with Boruta is that it clearly decides if a variable is important or not and helps to select variables that are statistically significant. Besides, you can adjust the strictness of the algorithm by adjusting the p values that defaults to 0.01 and the maxRuns.

    maxRuns is the number of times the algorithm is run. The higher the maxRuns the more selective you get in picking the variables. The default value is 100.

    In the process of deciding if a feature is important or not, some features may be marked by Boruta as 'Tentative'. Sometimes increasing the maxRuns can help resolve the 'Tentativeness' of the feature.

    Lets see an example based on the Glaucoma dataset from TH.data package that I created earlier.

    # install.packages('Boruta')
    library(Boruta)
    

    The boruta function uses a formula interface just like most predictive modeling functions. So the first argument to boruta() is the formula with the response variable on the left and all the predictors on the right.

    By placing a dot, all the variables in trainData other than Class will be included in the model.

    The doTrace argument controls the amount of output printed to the console. Higher the value, more the log details you get. So save space I have set it to 0, but try setting it to 1 and 2 if you are running the code.

    Finally the output is stored in boruta_output.

    # Perform Boruta search
    boruta_output <- Boruta(Class ~ ., data=na.omit(trainData), doTrace=0)  
    

    Let’s see what the boruta_output contains.

    names(boruta_output)
    
    1. ‘finalDecision’
    2. ‘ImpHistory’
    3. ‘pValue’
    4. ‘maxRuns’
    5. ‘light’
    6. ‘mcAdj’
    7. ‘timeTaken’
    8. ‘roughfixed’
    9. ‘call’
    10. ‘impSource’
    # Get significant variables including tentatives
    boruta_signif <- getSelectedAttributes(boruta_output, withTentative = TRUE)
    print(boruta_signif)  
    
     [1] "as"   "ean"  "abrg" "abrs" "abrn" "abri" "hic"  "mhcg" "mhcn" "mhci"
    [11] "phcg" "phcn" "phci" "hvc"  "vbss" "vbsn" "vbsi" "vasg" "vass" "vasi"
    [21] "vbrg" "vbrs" "vbrn" "vbri" "varg" "vart" "vars" "varn" "vari" "mdn" 
    [31] "tmg"  "tmt"  "tms"  "tmn"  "tmi"  "rnf"  "mdic" "emd" 
    

    If you are not sure about the tentative variables being selected for granted, you can choose a TentativeRoughFix on boruta_output.

    # Do a tentative rough fix
    roughFixMod <- TentativeRoughFix(boruta_output)
    boruta_signif <- getSelectedAttributes(roughFixMod)
    print(boruta_signif)
    
     [1] "abrg" "abrs" "abrn" "abri" "hic"  "mhcg" "mhcn" "mhci" "phcg" "phcn"
    [11] "phci" "hvc"  "vbsn" "vbsi" "vasg" "vbrg" "vbrs" "vbrn" "vbri" "varg"
    [21] "vart" "vars" "varn" "vari" "tmg"  "tms"  "tmi"  "rnf"  "mdic" "emd" 
    

    There you go. Boruta has decided on the ‘Tentative’ variables on our behalf. Let’s find out the importance scores of these variables.

    # Variable Importance Scores
    imps <- attStats(roughFixMod)
    imps2 = imps[imps$decision != 'Rejected', c('meanImp', 'decision')]
    head(imps2[order(-imps2$meanImp), ])  # descending sort
    
    meanImpdecision
    varg10.279747Confirmed
    vari10.245936Confirmed
    tmi 9.067300Confirmed
    vars 8.690654Confirmed
    hic 8.324252Confirmed
    varn 7.327045Confirmed

    Let’s plot it to see the importances of these variables.

    # Plot variable importance
    plot(boruta_output, cex.axis=.7, las=2, xlab="", main="Variable Importance")  
    
    Variable Importance Boruta
    Variable Importance Boruta

    This plot reveals the importance of each of the features.

    The columns in green are ‘confirmed’ and the ones in red are not. There are couple of blue bars representing ShadowMax and ShadowMin. They are not actual features, but are used by the boruta algorithm to decide if a variable is important or not.

    2. Variable Importance from Machine Learning Algorithms

    Another way to look at feature selection is to consider variables most used by various ML algorithms the most to be important.

    Depending on how the machine learning algorithm learns the relationship between X’s and Y, different machine learning algorithms may possibly end up using different variables (but mostly common vars) to various degrees.

    What I mean by that is, the variables that proved useful in a tree-based algorithm like rpart, can turn out to be less useful in a regression-based model. So all variables need not be equally useful to all algorithms.

    So how do we find the variable importance for a given ML algo?

    1. train() the desired model using the caret package.

    2. Then, use varImp() to determine the feature importances.

    You may want to try out multiple algorithms, to get a feel of the usefulness of the features across algos.

    # Train an rpart model and compute variable importance.
    library(caret)
    set.seed(100)
    rPartMod <- train(Class ~ ., data=trainData, method="rpart")
    rpartImp <- varImp(rPartMod)
    print(rpartImp)
    
    rpart variable importance
    
      only 20 most important variables shown (out of 62)
    
         Overall
    varg  100.00
    vari   93.19
    vars   85.20
    varn   76.86
    tmi    72.31
    vbss    0.00
    eai     0.00
    tmg     0.00
    tmt     0.00
    vbst    0.00
    vasg    0.00
    at      0.00
    abrg    0.00
    vbsg    0.00
    eag     0.00
    phcs    0.00
    abrs    0.00
    mdic    0.00
    abrt    0.00
    ean     0.00
    

    Only 5 of the 63 features was used by rpart and if you look closely, the 5 variables used here are in the top 6 that boruta selected.

    Let’s do one more: the variable importances from Regularized Random Forest (RRF) algorithm.

    # Train an RRF model and compute variable importance.
    set.seed(100)
    rrfMod <- train(Class ~ ., data=trainData, method="RRF")
    rrfImp <- varImp(rrfMod, scale=F)
    rrfImp
    
    RRF variable importance
    
      only 20 most important variables shown (out of 62)
    
         Overall
    varg 24.0013
    vari 18.5349
    vars  6.0483
    tmi   3.8699
    hic   3.3926
    mhci  3.1856
    mhcg  3.0383
    mv    2.1570
    hvc   2.1357
    phci  1.8830
    vasg  1.8570
    tms   1.5705
    phcn  1.4475
    phct  1.4473
    vass  1.3097
    tmt   1.2485
    phcg  1.1992
    mdn   1.1737
    tmg   1.0988
    abrs  0.9537
    
    plot(rrfImp, top = 20, main='Variable Importance')
    
    Regularized Random Forest - Variable Importance
    Regularized Random Forest – Variable Importance

    The topmost important variables are pretty much from the top tier of Boruta‘s selections.

    Some of the other algorithms available in train() that you can use to compute varImp are the following:

    ada, AdaBag, AdaBoost.M1, adaboost, bagEarth, bagEarthGCV, bagFDA, bagFDAGCV, bartMachine, blasso, BstLm, bstSm, C5.0, C5.0Cost, C5.0Rules, C5.0Tree, cforest, chaid, ctree, ctree2, cubist, deepboost, earth, enet, evtree, extraTrees, fda, gamboost, gbm_h2o, gbm, gcvEarth, glmnet_h2o, glmnet, glmStepAIC, J48, JRip, lars, lars2, lasso, LMT, LogitBoost, M5, M5Rules, msaenet, nodeHarvest, OneR, ordinalNet, ORFlog, ORFpls, ORFridge, ORFsvm, pam, parRF, PART, penalized, PenalizedLDA, qrf, ranger, Rborist, relaxo, rf, rFerns, rfRules, rotationForest, rotationForestCp, rpart, rpart1SE, rpart2, rpartCost, rpartScore, rqlasso, rqnc, RRF, RRFglobal, sdwd, smda, sparseLDA, spikeslab, wsrf, xgbLinear, xgbTree.

    3. Lasso Regression

    Least Absolute Shrinkage and Selection Operator (LASSO) regression is a type of regularization method that penalizes with L1-norm.

    It basically imposes a cost to having large weights (value of coefficients). And its called L1 regularization, because the cost added, is proportional to the absolute value of weight coefficients.

    As a result, in the process of shrinking the coefficients, it eventually reduces the coefficients of certain unwanted features all the to zero. That is, it removes the unneeded variables altogether.

    So effectively, LASSO regression can be considered as a variable selection technique as well.

    library(glmnet)
    trainData <- read.csv('https://raw.githubusercontent.com/selva86/datasets/master/GlaucomaM.csv')
    
    x <- as.matrix(trainData[,-63]) # all X vars
    y <- as.double(as.matrix(ifelse(trainData[, 63]=='normal', 0, 1))) # Only Class
    
    # Fit the LASSO model (Lasso: Alpha = 1)
    set.seed(100)
    cv.lasso <- cv.glmnet(x, y, family='binomial', alpha=1, parallel=TRUE, standardize=TRUE, type.measure='auc')
    
    # Results
    plot(cv.lasso)
    
    Variable Importance LASSO
    Variable Importance LASSO

    Let’s see how to interpret this plot.

    The X axis of the plot is the log of lambda. That means when it is 2 here, the lambda value is actually 100.

    The numbers at the top of the plot show how many predictors were included in the model. The position of red dots along the Y-axis tells what AUC we got when you include as many variables shown on the top x-axis.

    You can also see two dashed vertical lines.

    The first one on the left points to the lambda with the lowest mean squared error. The one on the right point to the number of variables with the highest deviance within 1 standard deviation.

    The best lambda value is stored inside 'cv.lasso$lambda.min'.

    # plot(cv.lasso$glmnet.fit, xvar="lambda", label=TRUE)
    cat('Min Lambda: ', cv.lasso$lambda.min, '\n 1Sd Lambda: ', cv.lasso$lambda.1se)
    df_coef <- round(as.matrix(coef(cv.lasso, s=cv.lasso$lambda.min)), 2)
    
    # See all contributing variables
    df_coef[df_coef[, 1] != 0, ]
    
    Min Lambda:  0.01166507 
     1Sd Lambda:  0.2513163
    
    Min Lambda:  0.01166507 
    1Sd Lambda:  0.2513163
    
    (Intercept) 3.65
    at         -0.17
    as         -2.05
    eat        -0.53
    mhci        6.22
    phcs       -0.83
    phci        6.03
    hvc        -4.15
    vass       -23.72
    vbrn       -0.26
    vars       -25.86
    mdt        -2.34
    mds         0.5
    mdn         0.83
    mdi         0.3
    tmg         0.01
    tms         3.02
    tmi         2.65
    mv          4.94
    

    The above output shows what variables LASSO considered important. A high positive or low negative implies more important is that variable.

    4. Step wise Forward and Backward Selection

    Stepwise regression can be used to select features if the Y variable is a numeric variable. It is particularly used in selecting best linear regression models.

    It searches for the best possible regression model by iteratively selecting and dropping variables to arrive at a model with the lowest possible AIC.

    It can be implemented using the step() function and you need to provide it with a lower model, which is the base model from which it won’t remove any features and an upper model, which is a full model that has all possible features you want to have.

    Our case is not so complicated (< 20 vars), so lets just do a simple stepwise in 'both' directions.

    I will use the ozone dataset for this where the objective is to predict the 'ozone_reading' based on other weather related observations.

    # Load data
    trainData <- read.csv("http://rstatistics.net/wp-content/uploads/2015/09/ozone1.csv", stringsAsFactors=F)
    print(head(trainData))
    
      Month Day_of_month Day_of_week ozone_reading pressure_height Wind_speed
    1     1            1           4             3            5480          8
    2     1            2           5             3            5660          6
    3     1            3           6             3            5710          4
    4     1            4           7             5            5700          3
    5     1            5           1             5            5760          3
    6     1            6           2             6            5720          4
    
      Humidity Temperature_Sandburg Temperature_ElMonte Inversion_base_height
    1 20.00000             40.53473            39.77461              5000.000
    2 40.96306             38.00000            46.74935              4108.904
    3 28.00000             40.00000            49.49278              2693.000
    4 37.00000             45.00000            52.29403               590.000
    5 51.00000             54.00000            45.32000              1450.000
    6 69.00000             35.00000            49.64000              1568.000
    
      Pressure_gradient Inversion_temperature Visibility
    1               -15              30.56000        200
    2               -14              48.02557        300
    3               -25              47.66000        250
    4               -24              55.04000        100
    5                25              57.02000         60
    6                15              53.78000         60
    

    The data is ready. Let’s perform the stepwise.

    # Step 1: Define base intercept only model
    base.mod <- lm(ozone_reading ~ 1 , data=trainData)  
    
    # Step 2: Full model with all predictors
    all.mod <- lm(ozone_reading ~ . , data= trainData) 
    
    # Step 3: Perform step-wise algorithm. direction='both' implies both forward and backward stepwise
    stepMod <- step(base.mod, scope = list(lower = base.mod, upper = all.mod), direction = "both", trace = 0, steps = 1000)  
    
    # Step 4: Get the shortlisted variable.
    shortlistedVars <- names(unlist(stepMod[[1]])) 
    shortlistedVars <- shortlistedVars[!shortlistedVars %in% "(Intercept)"] # remove intercept
    
    # Show
    print(shortlistedVars)
    
    [1] "Temperature_Sandburg"  "Humidity"              "Temperature_ElMonte"  
    [4] "Month"                 "pressure_height"       "Inversion_base_height"
    

    The selected model has the above 6 features in it.

    But if you have too many features (> 100) in training data, then it might be a good idea to split the dataset into chunks of 10 variables each with Y as mandatory in each dataset. Loop through all the chunks and collect the best features.

    We are doing it this way because some variables that came as important in a training data with fewer features may not show up in a linear reg model built on lots of features.

    Finally, from a pool of shortlisted features (from small chunk models), run a full stepwise model to get the final set of selected features.

    You can take this as a learning assignment to be solved within 20 minutes.

    5. Relative Importance from Linear Regression

    This technique is specific to linear regression models.

    Relative importance can be used to assess which variables contributed how much in explaining the linear model’s R-squared value. So, if you sum up the produced importances, it will add up to the model’s R-sq value.

    In essence, it is not directly a feature selection method, because you have already provided the features that go in the model. But after building the model, the relaimpo can provide a sense of how important each feature is in contributing to the R-sq, or in other words, in ‘explaining the Y variable’.

    So, how to calculate relative importance?

    It is implemented in the relaimpo package. Basically, you build a linear regression model and pass that as the main argument to calc.relimp(). relaimpo has multiple options to compute the relative importance, but the recommended method is to use type='lmg', as I have done below.

    # install.packages('relaimpo')
    library(relaimpo)
    
    # Build linear regression model
    model_formula = ozone_reading ~ Temperature_Sandburg + Humidity + Temperature_ElMonte + Month + pressure_height + Inversion_base_height
    lmMod <- lm(model_formula, data=trainData)
    
    # calculate relative importance
    relImportance <- calc.relimp(lmMod, type = "lmg", rela = F)  
    
    # Sort
    cat('Relative Importances: \n')
    sort(round(relImportance$lmg, 3), decreasing=TRUE)
    
    Relative Importances: 
    Temperature_ElMonte    0.214
    Temperature_Sandburg   0.203
    pressure_height        0.104
    Inversion_base_height  0.096
    Humidity               0.086
    Month                  0.012
    

    Additionally, you can use bootstrapping (using boot.relimp) to compute the confidence intervals of the produced relative importances.

    bootsub <- boot.relimp(ozone_reading ~ Temperature_Sandburg + Humidity + Temperature_ElMonte + Month + pressure_height + Inversion_base_height, data=trainData,
                           b = 1000, type = 'lmg', rank = TRUE, diff = TRUE)
    
    plot(booteval.relimp(bootsub, level=.95))
    
    Relative Importance of Features
    Relative Importance of Features

    6. Recursive Feature Elimination (RFE)

    Recursive feature elimnation (rfe) offers a rigorous way to determine the important variables before you even feed them into a ML algo.

    It can be implemented using the rfe() from caret package.

    The rfe() also takes two important parameters.

    • sizes
    • rfeControl

    So, what does sizes and rfeControl represent?

    The sizes determines the number of most important features the rfe should iterate. Below, I have set the size as 1 to 5, 10, 15 and 18.

    Secondly, the rfeControl parameter receives the output of the rfeControl(). You can set what type of variable evaluation algorithm must be used. Here, I have used random forests based rfFuncs. The method='repeatedCV' means it will do a repeated k-Fold cross validation with repeats=5.

    Once complete, you get the accuracy and kappa for each model size you provided. The final selected model subset size is marked with a * in the rightmost selected column.

    str(trainData)
    
    'data.frame':    366 obs. of  13 variables:
     $ Month                : int  1 1 1 1 1 1 1 1 1 1 ...
     $ Day_of_month         : int  1 2 3 4 5 6 7 8 9 10 ...
     $ Day_of_week          : int  4 5 6 7 1 2 3 4 5 6 ...
     $ ozone_reading        : num  3 3 3 5 5 6 4 4 6 7 ...
     $ pressure_height      : num  5480 5660 5710 5700 5760 5720 5790 5790 5700 5700 ...
     $ Wind_speed           : int  8 6 4 3 3 4 6 3 3 3 ...
     $ Humidity             : num  20 41 28 37 51 ...
     $ Temperature_Sandburg : num  40.5 38 40 45 54 ...
     $ Temperature_ElMonte  : num  39.8 46.7 49.5 52.3 45.3 ...
     $ Inversion_base_height: num  5000 4109 2693 590 1450 ...
     $ Pressure_gradient    : num  -15 -14 -25 -24 25 15 -33 -28 23 -2 ...
     $ Inversion_temperature: num  30.6 48 47.7 55 57 ...
     $ Visibility           : int  200 300 250 100 60 60 100 250 120 120 ...
    
    set.seed(100)
    options(warn=-1)
    
    subsets <- c(1:5, 10, 15, 18)
    
    ctrl <- rfeControl(functions = rfFuncs,
                       method = "repeatedcv",
                       repeats = 5,
                       verbose = FALSE)
    
    lmProfile <- rfe(x=trainData[, c(1:3, 5:13)], y=trainData$ozone_reading,
                     sizes = subsets,
                     rfeControl = ctrl)
    
    lmProfile
    
    Recursive feature selection
    
    Outer resampling method: Cross-Validated (10 fold, repeated 5 times) 
    
    Resampling performance over subset size:
    
     Variables  RMSE Rsquared   MAE RMSESD RsquaredSD  MAESD Selected
             1 5.222   0.5794 4.008 0.9757    0.15034 0.7879         
             2 3.971   0.7518 3.067 0.4614    0.07149 0.3276         
             3 3.944   0.7553 3.054 0.4675    0.06523 0.3708         
             4 3.924   0.7583 3.026 0.5132    0.06640 0.4163         
             5 3.880   0.7633 2.950 0.5525    0.07021 0.4334         
            10 3.751   0.7796 2.853 0.5550    0.06791 0.4457        *
            12 3.767   0.7779 2.869 0.5511    0.06664 0.4424         
    
    The top 5 variables (out of 10):
       Temperature_ElMonte, Pressure_gradient, Temperature_Sandburg, Inversion_temperature, Humidity
    

    So, it says, Temperature_ElMonte, Pressure_gradient, Temperature_Sandburg, Inversion_temperature, Humidity are the top 5 variables in that order.

    And the best model size out of the provided models sizes (in subsets) is 10.

    You can see all of the top 10 variables from 'lmProfile$optVariables' that was created using `rfe` function above.

    7. Genetic Algorithm

    You can perform a supervised feature selection with genetic algorithms using the gafs(). This is quite resource expensive so consider that before choosing the number of iterations (iters) and the number of repeats in gafsControl().

    # Define control function
    ga_ctrl <- gafsControl(functions = rfGA,  # another option is `caretGA`.
                            method = "cv",
                            repeats = 3)
    
    # Genetic Algorithm feature selection
    set.seed(100)
    ga_obj <- gafs(x=trainData[, c(1:3, 5:13)], 
                   y=trainData[, 4], 
                   iters = 3,   # normally much higher (100+)
                   gafsControl = ga_ctrl)
    
    ga_obj
    
    Genetic Algorithm Feature Selection
    
    366 samples
    12 predictors
    
    Maximum generations: 3 
    Population per generation: 50 
    Crossover probability: 0.8 
    Mutation probability: 0.1 
    Elitism: 0 
    
    Internal performance values: RMSE, Rsquared
    Subset selection driven to minimize internal RMSE 
    
    External performance values: RMSE, Rsquared, MAE
    Best iteration chose by minimizing external RMSE 
    External resampling method: Cross-Validated (10 fold) 
    
    During resampling:
      * the top 5 selected variables (out of a possible 12):
        Month (100%), Pressure_gradient (100%), Temperature_ElMonte (100%), Visibility (90%), Inversion_temperature (80%)
      * on average, 7.5 variables were selected (min = 5, max = 10)
    
    In the final search using the entire training set:
       * 6 features selected at iteration 3 including:
         Month, Day_of_month, Wind_speed, Temperature_ElMonte, Pressure_gradient ... 
       * external performance at this iteration is
    
           RMSE    Rsquared         MAE 
         3.6605      0.7901      2.8010 
    
    # Optimal variables
    ga_obj$optVariables
    
    1. ‘Month’
    2. ‘Day_of_month’
    3. ‘Wind_speed’
    4. ‘Temperature_ElMonte’
    5. ‘Pressure_gradient’
    6. ‘Visibility’

    So the optimal variables according to the genetic algorithms are listed above. But, I wouldn’t use it just yet because, the above variant was tuned for only 3 iterations, which is quite low. I had to set it so low to save computing time.

    8. Simulated Annealing

    Simulated annealing is a global search algorithm that allows a suboptimal solution to be accepted in hope that a better solution will show up eventually.

    It works by making small random changes to an initial solution and sees if the performance improved. The change is accepted if it improves, else it can still be accepted if the difference of performances meet an acceptance criteria.

    In caret it has been implemented in the safs() which accepts a control parameter that can be set using the safsControl() function.

    safsControl is similar to other control functions in caret (like you saw in rfe and ga), and additionally it accepts an improve parameter which is the number of iterations it should wait without improvement until the values are reset to previous iteration.

    # Define control function
    sa_ctrl <- safsControl(functions = rfSA,
                            method = "repeatedcv",
                            repeats = 3,
                            improve = 5) # n iterations without improvement before a reset
    
    # Genetic Algorithm feature selection
    set.seed(100)
    sa_obj <- safs(x=trainData[, c(1:3, 5:13)], 
                   y=trainData[, 4],
                   safsControl = sa_ctrl)
    
    sa_obj
    
    Simulated Annealing Feature Selection
    
    366 samples
    12 predictors
    
    Maximum search iterations: 10 
    Restart after 5 iterations without improvement (0.2 restarts on average)
    
    Internal performance values: RMSE, Rsquared
    Subset selection driven to minimize internal RMSE 
    
    External performance values: RMSE, Rsquared, MAE
    Best iteration chose by minimizing external RMSE 
    External resampling method: Cross-Validated (10 fold, repeated 3 times) 
    
    During resampling:
      * the top 5 selected variables (out of a possible 12):
        Temperature_ElMonte (73.3%), Inversion_temperature (63.3%), Month (60%), Day_of_week (50%), Inversion_base_height (50%)
      * on average, 6 variables were selected (min = 3, max = 8)
    
    In the final search using the entire training set:
       * 6 features selected at iteration 10 including:
         Month, Day_of_month, Day_of_week, Wind_speed, Temperature_ElMonte ... 
       * external performance at this iteration is
    
           RMSE    Rsquared         MAE 
         4.0574      0.7382      3.0727 
    
    # Optimal variables
    print(sa_obj$optVariables)
    
    [1] "Month"               "Day_of_month"        "Day_of_week"        
    [4] "Wind_speed"          "Temperature_ElMonte" "Visibility"         
    

    9. Information Value and Weights of Evidence

    The Information Value can be used to judge how important a given categorical variable is in explaining the binary Y variable. It goes well with logistic regression and other classification models that can model binary variables.

    Let’s try to find out how important the categorical variables are in predicting if an individual will earn >50k from the ‘adult.csv’ dataset. Just run the code below to import the dataset.

    library(InformationValue)
    inputData <- read.csv("http://rstatistics.net/wp-content/uploads/2015/09/adult.csv")
    print(head(inputData))
    
      AGE         WORKCLASS FNLWGT  EDUCATION EDUCATIONNUM       MARITALSTATUS
    1  39         State-gov  77516  Bachelors           13       Never-married
    2  50  Self-emp-not-inc  83311  Bachelors           13  Married-civ-spouse
    3  38           Private 215646    HS-grad            9            Divorced
    4  53           Private 234721       11th            7  Married-civ-spouse
    5  28           Private 338409  Bachelors           13  Married-civ-spouse
    6  37           Private 284582    Masters           14  Married-civ-spouse
    
              OCCUPATION   RELATIONSHIP   RACE     SEX CAPITALGAIN CAPITALLOSS
    1       Adm-clerical  Not-in-family  White    Male        2174           0
    2    Exec-managerial        Husband  White    Male           0           0
    3  Handlers-cleaners  Not-in-family  White    Male           0           0
    4  Handlers-cleaners        Husband  Black    Male           0           0
    5     Prof-specialty           Wife  Black  Female           0           0
    6    Exec-managerial           Wife  White  Female           0           0
    
      HOURSPERWEEK  NATIVECOUNTRY ABOVE50K
    1           40  United-States        0
    2           13  United-States        0
    3           40  United-States        0
    4           40  United-States        0
    5           40           Cuba        0
    6           40  United-States        0
    

    Alright, let’s now find the information value for the categorical variables in the inputData.

    # Choose Categorical Variables to compute Info Value.
    cat_vars <- c ("WORKCLASS", "EDUCATION", "MARITALSTATUS", "OCCUPATION", "RELATIONSHIP", "RACE", "SEX", "NATIVECOUNTRY")  # get all categorical variables
    
    # Init Output
    df_iv <- data.frame(VARS=cat_vars, IV=numeric(length(cat_vars)), STRENGTH=character(length(cat_vars)), stringsAsFactors = F)  # init output dataframe
    
    # Get Information Value for each variable
    for (factor_var in factor_vars){
      df_iv[df_iv$VARS == factor_var, "IV"] <- InformationValue::IV(X=inputData[, factor_var], Y=inputData$ABOVE50K)
      df_iv[df_iv$VARS == factor_var, "STRENGTH"] <- attr(InformationValue::IV(X=inputData[, factor_var], Y=inputData$ABOVE50K), "howgood")
    }
    
    # Sort
    df_iv <- df_iv[order(-df_iv$IV), ]
    
    df_iv
    
    VARSIVSTRENGTH
    5RELATIONSHIP1.53560810Highly Predictive
    3MARITALSTATUS1.33882907Highly Predictive
    4OCCUPATION0.77622839Highly Predictive
    2EDUCATION0.74105372Highly Predictive
    7SEX0.30328938Highly Predictive
    1WORKCLASS0.16338802Highly Predictive
    8NATIVECOUNTRY0.07939344Somewhat Predictive
    6RACE0.06929987Somewhat Predictive

    Here is what the quantum of Information Value means:

    • Less than 0.02, then the predictor is not useful for modeling (separating the Goods from the Bads)
    • 0.02 to 0.1, then the predictor has only a weak relationship.
    • 0.1 to 0.3, then the predictor has a medium strength relationship.
    • 0.3 or higher, then the predictor has a strong relationship.

    That was about IV. Then what is Weight of Evidence?

    Weights of evidence can be useful to find out how important a given categorical variable is in explaining the ‘events’ (called ‘Goods’ in below table.)

    Weights of Evidence
    Weights of Evidence

    The ‘Information Value’ of the categorical variable can then be derived from the respective WOE values.

    IV = (perc good of all goods−perc bad of all bads) * WOE

    The ‘WOETable’ below given the computation in more detail.

    WOETable(X=inputData[, 'WORKCLASS'], Y=inputData$ABOVE50K)
    
    CATGOODSBADSTOTALPCT_GPCT_BWOEIV
    ? 191 1645 18360.02429407280.0665453074-1.00765060.0425744832
    Federal-gov 371 589 9600.04718901040.0238268608 0.68334750.0159644662
    Local-gov 617 1476 20930.07847875860.0597087379 0.27334960.0051307781
    Never-worked 7 7 70.00089035870.0002831715 1.14557160.0006955764
    Private496317733226960.63126430930.7173543689-0.12784530.0110062102
    Self-emp-inc 622 494 11160.07911472910.0199838188 1.37597620.0813627242
    Self-emp-not-inc 724 1817 25410.09208852710.0735032362 0.22542090.0041895135
    State-gov 353 945 12980.04489951670.0382281553 0.16085470.0010731201
    Without-pay 14 14 140.00178071740.0005663430 1.14557160.0013911528

    The total IV of a variable is the sum of IV’s of its categories.

    10. DALEX Package

    The DALEX is a powerful package that explains various things about the variables used in an ML model.

    For example, using the variable_dropout() function you can find out how important a variable is based on a dropout loss, that is how much loss is incurred by removing a variable from the model.

    Apart from this, it also has the single_variable() function that gives you an idea of how the model’s output will change by changing the values of one of the X’s in the model.

    It also has the single_prediction() that can decompose a single model prediction so as to understand which variable caused what effect in predicting the value of Y.

    library(randomForest)
    library(DALEX)
    
    # Load data
    inputData <- read.csv("http://rstatistics.net/wp-content/uploads/2015/09/adult.csv")
    
    # Train random forest model
    rf_mod <- randomForest(factor(ABOVE50K) ~ ., data=inputData, ntree=100)
    rf_mod
    
    # Variable importance with DALEX
    explained_rf <- explain(rf_mod, data=inputData, y=inputData$ABOVE50K)
    
    # Get the variable importances
    varimps = variable_dropout(explained_rf, type='raw')
    
    print(varimps)
    
    Call:
     randomForest(formula = factor(ABOVE50K) ~ ., data = inputData,      ntree = 100) 
                   Type of random forest: classification
                         Number of trees: 100
    No. of variables tried at each split: 3
    
            OOB estimate of  error rate: 17.4%
    Confusion matrix:
          0    1 class.error
    0 24600  120 0.004854369
    1  5547 2294 0.707435276
    
    
            variable dropout_loss        label
    1   _full_model_          852 randomForest
    2   EDUCATIONNUM          842 randomForest
    3      EDUCATION          843 randomForest
    4  MARITALSTATUS          844 randomForest
    5         FNLWGT          845 randomForest
    6     OCCUPATION          847 randomForest
    7            SEX          847 randomForest
    8    CAPITALLOSS          847 randomForest
    9   HOURSPERWEEK          847 randomForest
    10           AGE          848 randomForest
    11          RACE          848 randomForest
    12     WORKCLASS          849 randomForest
    13  RELATIONSHIP          850 randomForest
    14 NATIVECOUNTRY          853 randomForest
    15      ABOVE50K          853 randomForest
    16   CAPITALGAIN          893 randomForest
    17    _baseline_          975 randomForest
    
    plot(varimps)
    
    Dalex Variable Importance
    Dalex Variable Importance

    Conclusion

    Hope you find these methods useful. As it turns out different methods showed different variables as important, or at least the degree of importance changed. This need not be a conflict, because each method gives a different perspective of how the variable can be useful depending on how the algorithms learn Y ~ x. So its cool.

    If you find any code breaks or bugs, report the issue here or just write it below.

    101 Pandas Exercises for Data Analysis

    101 python pandas exercises are designed to challenge your logical muscle and to help internalize data manipulation with python’s favorite package for data analysis. The questions are of 3 levels of difficulties with L1 being the easiest to L3 being the hardest.

    101 Pandas Exercises. Photo by Chester Ho.

    You might also like to practice the 101 NumPy exercises, they are often used together.

    1. How to import pandas and check the version?

    Show Solution
    import numpy as np  # optional
    import pandas as pd
    print(pd.__version__)
    print(pd.show_versions(as_json=True))
    
    0.20.3
    {'system': {'commit': None}, 'dependencies': {'pandas': '0.20.3', 'pytest': '3.2.1', 'pip': '9.0.1', 'setuptools': '36.5.0.post20170921', 'Cython': '0.26.1', 'numpy': '1.13.3', 'scipy': '0.19.1', 'xarray': None, 'IPython': '6.1.0', 'sphinx': '1.6.3', 'patsy': '0.4.1', 'dateutil': '2.6.1', 'pytz': '2017.2', 'blosc': None, 'bottleneck': '1.2.1', 'tables': '3.4.2', 'numexpr': '2.6.2', 'feather': None, 'matplotlib': '2.1.0', 'openpyxl': '2.4.8', 'xlrd': '1.1.0', 'xlwt': '1.2.0', 'xlsxwriter': '1.0.2', 'lxml': '4.1.0', 'bs4': '4.6.0', 'html5lib': '0.999999999', 'sqlalchemy': '1.1.13', 'pymysql': None, 'psycopg2': None, 'jinja2': '2.9.6', 's3fs': None, 'pandas_gbq': None, 'pandas_datareader': None}}
    None
    

    2. How to create a series from a list, numpy array and dict?

    Create a pandas series from each of the items below: a list, numpy and a dictionary

    Input

    import numpy as np
    mylist = list('abcedfghijklmnopqrstuvwxyz')
    myarr = np.arange(26)
    mydict = dict(zip(mylist, myarr))
    
    Show Solution
    # Inputs
    import numpy as np
    mylist = list('abcedfghijklmnopqrstuvwxyz')
    myarr = np.arange(26)
    mydict = dict(zip(mylist, myarr))
    
    # Solution
    ser1 = pd.Series(mylist)
    ser2 = pd.Series(myarr)
    ser3 = pd.Series(mydict)
    print(ser3.head())
    
    a    0
    b    1
    c    2
    d    4
    e    3
    dtype: int64
    

    3. How to convert the index of a series into a column of a dataframe?

    Difficulty Level: L1

    Convert the series ser into a dataframe with its index as another column on the dataframe.

    Input

    mylist = list('abcedfghijklmnopqrstuvwxyz')
    myarr = np.arange(26)
    mydict = dict(zip(mylist, myarr))
    ser = pd.Series(mydict)
    
    Show Solution
    # Input
    mylist = list('abcedfghijklmnopqrstuvwxyz')
    myarr = np.arange(26)
    mydict = dict(zip(mylist, myarr))
    ser = pd.Series(mydict)
    
    # Solution
    df = ser.to_frame().reset_index()
    print(df.head())
    
      index  0
    0     a  0
    1     b  1
    2     c  2
    3     d  4
    4     e  3
    

    4. How to combine many series to form a dataframe?

    Difficulty Level: L1

    Combine ser1 and ser2 to form a dataframe.

    Input

    import numpy as np
    ser1 = pd.Series(list('abcedfghijklmnopqrstuvwxyz'))
    ser2 = pd.Series(np.arange(26))
    
    Show Solution
    # Input
    import numpy as np
    ser1 = pd.Series(list('abcedfghijklmnopqrstuvwxyz'))
    ser2 = pd.Series(np.arange(26))
    
    # Solution 1
    df = pd.concat([ser1, ser2], axis=1)
    
    # Solution 2
    df = pd.DataFrame({'col1': ser1, 'col2': ser2})
    print(df.head())
    
      col1  col2
    0    a     0
    1    b     1
    2    c     2
    3    e     3
    4    d     4
    

    5. How to assign name to the series’ index?

    Difficulty Level: L1

    Give a name to the series ser calling it ‘alphabets’.

    Input

    ser = pd.Series(list('abcedfghijklmnopqrstuvwxyz'))
    
    Show Solution
    # Input
    ser = pd.Series(list('abcedfghijklmnopqrstuvwxyz'))
    
    # Solution
    ser.name = 'alphabets'
    ser.head()
    
    0    a
    1    b
    2    c
    3    e
    4    d
    Name: alphabets, dtype: object
    

    6. How to get the items of series A not present in series B?

    Difficulty Level: L2

    From ser1 remove items present in ser2.

    ser1 = pd.Series([1, 2, 3, 4, 5])
    ser2 = pd.Series([4, 5, 6, 7, 8])
    
    Show Solution
    # Input
    ser1 = pd.Series([1, 2, 3, 4, 5])
    ser2 = pd.Series([4, 5, 6, 7, 8])
    
    # Solution
    ser1[~ser1.isin(ser2)]
    
    0    1
    1    2
    2    3
    dtype: int64
    

    7. How to get the items not common to both series A and series B?

    Difficulty Level: L2

    Get all items of ser1 and ser2 not common to both.

    Input

    ser1 = pd.Series([1, 2, 3, 4, 5])
    ser2 = pd.Series([4, 5, 6, 7, 8])
    
    Show Solution
    # Input
    ser1 = pd.Series([1, 2, 3, 4, 5])
    ser2 = pd.Series([4, 5, 6, 7, 8])
    
    # Solution
    ser_u = pd.Series(np.union1d(ser1, ser2))  # union
    ser_i = pd.Series(np.intersect1d(ser1, ser2))  # intersect
    ser_u[~ser_u.isin(ser_i)]
    
    0    1
    1    2
    2    3
    5    6
    6    7
    7    8
    dtype: int64
    

    8. How to get the minimum, 25th percentile, median, 75th, and max of a numeric series?

    Difficuty Level: L2

    Compute the minimum, 25th percentile, median, 75th, and maximum of ser.

    Input

    ser = pd.Series(np.random.normal(10, 5, 25))
    
    Show Solution
    # Input
    state = np.random.RandomState(100)
    ser = pd.Series(state.normal(10, 5, 25))
    
    # Solution
    np.percentile(ser, q=[0, 25, 50, 75, 100])
    
    array([  1.39267584,   6.49135133,  10.2578186 ,  13.06985067,  25.80920994])
    

    9. How to get frequency counts of unique items of a series?

    Difficulty Level: L1

    Calculte the frequency counts of each unique value ser.

    Input

    ser = pd.Series(np.take(list('abcdefgh'), np.random.randint(8, size=30)))
    
    Show Solution
    # Input
    ser = pd.Series(np.take(list('abcdefgh'), np.random.randint(8, size=30)))
    
    # Solution
    ser.value_counts()
    
    f    8
    g    7
    b    6
    c    4
    a    2
    e    2
    h    1
    dtype: int64
    

    10. How to keep only top 2 most frequent values as it is and replace everything else as ‘Other’?

    Difficulty Level: L2

    From ser, keep the top 2 most frequent items as it is and replace everything else as ‘Other’.

    Input

    np.random.RandomState(100)
    ser = pd.Series(np.random.randint(1, 5, [12]))
    
    Show Solution
    # Input
    np.random.RandomState(100)
    ser = pd.Series(np.random.randint(1, 5, [12]))
    
    # Solution
    print("Top 2 Freq:", ser.value_counts())
    ser[~ser.isin(ser.value_counts().index[:2])] = 'Other'
    ser
    
    Top 2 Freq: 4    5
    3    3
    2    2
    1    2
    dtype: int64
    
    0     Other
    1     Other
    2         3
    3         4
    4     Other
    5         4
    6         4
    7         3
    8         3
    9         4
    10        4
    11    Other
    dtype: object
    

    11. How to bin a numeric series to 10 groups of equal size?

    Difficulty Level: L2

    Bin the series ser into 10 equal deciles and replace the values with the bin name.

    Input

    ser = pd.Series(np.random.random(20))
    

    Desired Output

    # First 5 items
    0    7th
    1    9th
    2    7th
    3    3rd
    4    8th
    dtype: category
    Categories (10, object): [1st < 2nd < 3rd < 4th ... 7th < 8th < 9th < 10th]
    
    Show Solution
    # Input
    ser = pd.Series(np.random.random(20))
    print(ser.head())
    
    # Solution
    pd.qcut(ser, q=[0, .10, .20, .3, .4, .5, .6, .7, .8, .9, 1], 
            labels=['1st', '2nd', '3rd', '4th', '5th', '6th', '7th', '8th', '9th', '10th']).head()
    
    0    0.556912
    1    0.892955
    2    0.566632
    3    0.146656
    4    0.881579
    dtype: float64
    
    0    7th
    1    9th
    2    7th
    3    3rd
    4    8th
    dtype: category
    Categories (10, object): [1st < 2nd < 3rd < 4th ... 7th < 8th < 9th < 10th]
    

    12. How to convert a numpy array to a dataframe of given shape? (L1)

    Difficulty Level: L1

    Reshape the series ser into a dataframe with 7 rows and 5 columns

    Input

    ser = pd.Series(np.random.randint(1, 10, 35))
    
    Show Solution
    # Input
    ser = pd.Series(np.random.randint(1, 10, 35))
    
    # Solution
    df = pd.DataFrame(ser.values.reshape(7,5))
    print(df)
    
       0  1  2  3  4
    0  1  2  1  2  5
    1  1  2  4  5  2
    2  1  3  3  2  8
    3  8  6  4  9  6
    4  2  1  1  8  5
    5  3  2  8  5  6
    6  1  5  5  4  6
    

    13. How to find the positions of numbers that are multiples of 3 from a series?

    Difficulty Level: L2

    Find the positions of numbers that are multiples of 3 from ser.

    Input

    ser = pd.Series(np.random.randint(1, 10, 7))
    
    Show Solution
    # Input
    ser = pd.Series(np.random.randint(1, 10, 7))
    ser
    
    # Solution
    print(ser)
    np.argwhere(ser % 3==0)
    
    0    6
    1    8
    2    6
    3    7
    4    6
    5    2
    6    4
    dtype: int64
    
    array([[0],
           [2],
           [4]])
    

    14. How to extract items at given positions from a series

    Difficulty Level: L1

    From ser, extract the items at positions in list pos.

    Input

    ser = pd.Series(list('abcdefghijklmnopqrstuvwxyz'))
    pos = [0, 4, 8, 14, 20]
    
    Show Solution
    # Input
    ser = pd.Series(list('abcdefghijklmnopqrstuvwxyz'))
    pos = [0, 4, 8, 14, 20]
    
    # Solution
    ser.take(pos)
    
    0     a
    4     e
    8     i
    14    o
    20    u
    dtype: object
    

    15. How to stack two series vertically and horizontally ?

    Difficulty Level: L1

    Stack ser1 and ser2 vertically and horizontally (to form a dataframe).

    Input

    ser1 = pd.Series(range(5))
    ser2 = pd.Series(list('abcde'))
    
    Show Solution
    # Input
    ser1 = pd.Series(range(5))
    ser2 = pd.Series(list('abcde'))
    
    # Output
    # Vertical
    ser1.append(ser2)
    
    # Horizontal
    df = pd.concat([ser1, ser2], axis=1)
    print(df)
    
       0  1
    0  0  a
    1  1  b
    2  2  c
    3  3  d
    4  4  e
    

    16. How to get the positions of items of series A in another series B?

    Difficulty Level: L2

    Get the positions of items of ser2 in ser1 as a list.

    Input

    ser1 = pd.Series([10, 9, 6, 5, 3, 1, 12, 8, 13])
    ser2 = pd.Series([1, 3, 10, 13])
    
    Show Solution
    # Input
    ser1 = pd.Series([10, 9, 6, 5, 3, 1, 12, 8, 13])
    ser2 = pd.Series([1, 3, 10, 13])
    
    # Solution 1
    [np.where(i == ser1)[0].tolist()[0] for i in ser2]
    
    # Solution 2
    [pd.Index(ser1).get_loc(i) for i in ser2]
    
    [5, 4, 0, 8]
    

    17. How to compute the mean squared error on a truth and predicted series?

    Difficulty Level: L2

    Compute the mean squared error of truth and pred series.

    Input

    truth = pd.Series(range(10))
    pred = pd.Series(range(10)) + np.random.random(10)
    
    Show Solution
    # Input
    truth = pd.Series(range(10))
    pred = pd.Series(range(10)) + np.random.random(10)
    
    # Solution
    np.mean((truth-pred)**2)
    
    0.28448128110629545
    

    18. How to convert the first character of each element in a series to uppercase?

    Difficulty Level: L2

    Change the first character of each word to upper case in each word of ser.

    ser = pd.Series(['how', 'to', 'kick', 'ass?'])
    
    Show Solution
    # Input
    ser = pd.Series(['how', 'to', 'kick', 'ass?'])
    
    # Solution 1
    ser.map(lambda x: x.title())
    
    # Solution 2
    ser.map(lambda x: x[0].upper() + x[1:])
    
    # Solution 3
    pd.Series([i.title() for i in ser])
    
    0     How
    1      To
    2    Kick
    3    Ass?
    dtype: object
    

    19. How to calculate the number of characters in each word in a series?

    Difficulty Level: L2

    Input

    ser = pd.Series(['how', 'to', 'kick', 'ass?'])
    
    Show Solution
    # Input
    ser = pd.Series(['how', 'to', 'kick', 'ass?'])
    
    # Solution
    ser.map(lambda x: len(x))
    
    0    3
    1    2
    2    4
    3    4
    dtype: int64
    

    20. How to compute difference of differences between consequtive numbers of a series?

    Difficulty Level: L1

    Difference of differences between the consequtive numbers of ser.

    Input

    ser = pd.Series([1, 3, 6, 10, 15, 21, 27, 35])
    

    Desired Output

    [nan, 2.0, 3.0, 4.0, 5.0, 6.0, 6.0, 8.0]
    [nan, nan, 1.0, 1.0, 1.0, 1.0, 0.0, 2.0]
    
    Show Solution
    # Input
    ser = pd.Series([1, 3, 6, 10, 15, 21, 27, 35])
    
    # Solution
    print(ser.diff().tolist())
    print(ser.diff().diff().tolist())
    
    [nan, 2.0, 3.0, 4.0, 5.0, 6.0, 6.0, 8.0]
    [nan, nan, 1.0, 1.0, 1.0, 1.0, 0.0, 2.0]
    

    21. How to convert a series of date-strings to a timeseries?

    Difficiulty Level: L2

    Input

    ser = pd.Series(['01 Jan 2010', '02-02-2011', '20120303', '2013/04/04', '2014-05-05', '2015-06-06T12:20'])
    

    Desired Output

    0   2010-01-01 00:00:00
    1   2011-02-02 00:00:00
    2   2012-03-03 00:00:00
    3   2013-04-04 00:00:00
    4   2014-05-05 00:00:00
    5   2015-06-06 12:20:00
    dtype: datetime64[ns]
    
    Show Solution
    # Input
    ser = pd.Series(['01 Jan 2010', '02-02-2011', '20120303', '2013/04/04', '2014-05-05', '2015-06-06T12:20'])
    
    # Solution 1
    from dateutil.parser import parse
    ser.map(lambda x: parse(x))
    
    # Solution 2
    pd.to_datetime(ser)
    
    0   2010-01-01 00:00:00
    1   2011-02-02 00:00:00
    2   2012-03-03 00:00:00
    3   2013-04-04 00:00:00
    4   2014-05-05 00:00:00
    5   2015-06-06 12:20:00
    dtype: datetime64[ns]
    

    22. How to get the day of month, week number, day of year and day of week from a series of date strings?

    Difficiulty Level: L2

    Get the day of month, week number, day of year and day of week from ser.

    Input

    ser = pd.Series(['01 Jan 2010', '02-02-2011', '20120303', '2013/04/04', '2014-05-05', '2015-06-06T12:20'])
    

    Desired output

    Date:  [1, 2, 3, 4, 5, 6]
    Week number:  [53, 5, 9, 14, 19, 23]
    Day num of year:  [1, 33, 63, 94, 125, 157]
    Day of week:  ['Friday', 'Wednesday', 'Saturday', 'Thursday', 'Monday', 'Saturday']
    
    Show Solution
    # Input
    ser = pd.Series(['01 Jan 2010', '02-02-2011', '20120303', '2013/04/04', '2014-05-05', '2015-06-06T12:20'])
    
    # Solution
    from dateutil.parser import parse
    ser_ts = ser.map(lambda x: parse(x))
    
    # day of month
    print("Date: ", ser_ts.dt.day.tolist())
    
    # week number
    print("Week number: ", ser_ts.dt.weekofyear.tolist())
    
    # day of year
    print("Day number of year: ", ser_ts.dt.dayofyear.tolist())
    
    # day of week
    print("Day of week: ", ser_ts.dt.weekday_name.tolist())
    
    Date:  [1, 2, 3, 4, 5, 6]
    Week number:  [53, 5, 9, 14, 19, 23]
    Day num of year:  [1, 33, 63, 94, 125, 157]
    Day of week:  ['Friday', 'Wednesday', 'Saturday', 'Thursday', 'Monday', 'Saturday']
    

    23. How to convert year-month string to dates corresponding to the 4th day of the month?

    Difficiulty Level: L2

    Change ser to dates that start with 4th of the respective months.

    Input

    ser = pd.Series(['Jan 2010', 'Feb 2011', 'Mar 2012'])
    

    Desired Output

    0   2010-01-04
    1   2011-02-04
    2   2012-03-04
    dtype: datetime64[ns]
    
    Show Solution
    import pandas as pd
    # Input
    ser = pd.Series(['Jan 2010', 'Feb 2011', 'Mar 2012'])
    
    # Solution 1
    from dateutil.parser import parse
    # Parse the date
    ser_ts = ser.map(lambda x: parse(x))
    
    # Construct date string with date as 4
    ser_datestr = ser_ts.dt.year.astype('str') + '-' + ser_ts.dt.month.astype('str') + '-' + '04'
    
    # Format it.
    [parse(i).strftime('%Y-%m-%d') for i in ser_datestr]
    
    # Solution 2
    ser.map(lambda x: parse('04 ' + x))
    
    0   2010-01-04
    1   2011-02-04
    2   2012-03-04
    dtype: datetime64[ns]
    

    24. How to filter words that contain atleast 2 vowels from a series?

    Difficiulty Level: L3

    From ser, extract words that contain atleast 2 vowels.

    Input

    ser = pd.Series(['Apple', 'Orange', 'Plan', 'Python', 'Money'])
    

    Desired Output

    0     Apple
    1    Orange
    4     Money
    dtype: object
    
    Show Solution
    # Input
    ser = pd.Series(['Apple', 'Orange', 'Plan', 'Python', 'Money'])
    
    # Solution
    from collections import Counter
    mask = ser.map(lambda x: sum([Counter(x.lower()).get(i, 0) for i in list('aeiou')]) >= 2)
    ser[mask]
    
    0     Apple
    1    Orange
    4     Money
    dtype: object
    

    25. How to filter valid emails from a series?

    Difficiulty Level: L3

    Extract the valid emails from the series emails. The regex pattern for valid emails is provided as reference.

    Input

    emails = pd.Series(['buying books at amazom.com', '[email protected]', '[email protected]', '[email protected]'])
    pattern ='[A-Za-z0-9._%+-][email protected][A-Za-z0-9.-]+\\.[A-Za-z]{2,4}'
    

    Desired Output

    1    [email protected]
    2            [email protected]
    3    [email protected]
    dtype: object
    
    Show Solution
    # Input
    emails = pd.Series(['buying books at amazom.com', '[email protected]', '[email protected]', '[email protected]'])
    
    # Solution 1 (as series of strings)
    import re
    pattern ='[A-Za-z0-9._%+-][email protected][A-Za-z0-9.-]+\\.[A-Za-z]{2,4}'
    mask = emails.map(lambda x: bool(re.match(pattern, x)))
    emails[mask]
    
    # Solution 2 (as series of list)
    emails.str.findall(pattern, flags=re.IGNORECASE)
    
    # Solution 3 (as list)
    [x[0] for x in [re.findall(pattern, email) for email in emails] if len(x) > 0]
    
    ['[email protected]', '[email protected]', '[email protected]']
    

    26. How to get the mean of a series grouped by another series?

    Difficiulty Level: L2

    Compute the mean of weights of each fruit.

    Input

    fruit = pd.Series(np.random.choice(['apple', 'banana', 'carrot'], 10))
    weights = pd.Series(np.linspace(1, 10, 10))
    print(weight.tolist())
    print(fruit.tolist())
    #> [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
    #> ['banana', 'carrot', 'apple', 'carrot', 'carrot', 'apple', 'banana', 'carrot', 'apple', 'carrot']
    

    Desired output

    # values can change due to randomness
    apple     6.0
    banana    4.0
    carrot    5.8
    dtype: float64
    
    Show Solution
    # Input
    fruit = pd.Series(np.random.choice(['apple', 'banana', 'carrot'], 10))
    weights = pd.Series(np.linspace(1, 10, 10))
    
    # Solution
    weights.groupby(fruit).mean()
    
    apple     7.4
    banana    2.0
    carrot    6.0
    dtype: float64
    

    27. How to compute the euclidean distance between two series?

    Difficiulty Level: L2

    Compute the euclidean distance between series (points) p and q, without using a packaged formula.

    Input

    p = pd.Series([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
    q = pd.Series([10, 9, 8, 7, 6, 5, 4, 3, 2, 1])
    

    Desired Output

    18.165
    
    Show Solution
    # Input
    p = pd.Series([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
    q = pd.Series([10, 9, 8, 7, 6, 5, 4, 3, 2, 1])
    
    # Solution 
    sum((p - q)**2)**.5
    
    # Solution (using func)
    np.linalg.norm(p-q)
    
    18.165902124584949
    

    28. How to find all the local maxima (or peaks) in a numeric series?

    Difficiulty Level: L3

    Get the positions of peaks (values surrounded by smaller values on both sides) in ser.

    Input

    ser = pd.Series([2, 10, 3, 4, 9, 10, 2, 7, 3])
    

    Desired output

    array([1, 5, 7])
    
    Show Solution
    # Input
    ser = pd.Series([2, 10, 3, 4, 9, 10, 2, 7, 3])
    
    # Solution
    dd = np.diff(np.sign(np.diff(ser)))
    peak_locs = np.where(dd == -2)[0] + 1
    peak_locs
    
    array([1, 5, 7])
    

    29. How to replace missing spaces in a string with the least frequent character?

    Replace the spaces in my_str with the least frequent character.

    Difficiulty Level: L2

    Input

    my_str = 'dbc deb abed gade'
    

    Desired Output

    'dbccdebcabedcgade'  # least frequent is 'c'
    
    Show Solution
    # Input
    my_str = 'dbc deb abed gade'
    
    # Solution
    ser = pd.Series(list('dbc deb abed gade'))
    freq = ser.value_counts()
    print(freq)
    least_freq = freq.dropna().index[-1]
    "".join(ser.replace(' ', least_freq))
    
    d    4
    b    3
    e    3
         3
    a    2
    g    1
    c    1
    dtype: int64
    
    'dbccdebcabedcgade'
    

    30. How to create a TimeSeries starting ‘2000-01-01’ and 10 weekends (saturdays) after that having random numbers as values?

    Difficiulty Level: L2

    Desired output

    # values can be random
    2000-01-01    4
    2000-01-08    1
    2000-01-15    8
    2000-01-22    4
    2000-01-29    4
    2000-02-05    2
    2000-02-12    4
    2000-02-19    9
    2000-02-26    6
    2000-03-04    6
    
    Show Solution
    # Solution
    ser = pd.Series(np.random.randint(1,10,10), pd.date_range('2000-01-01', periods=10, freq='W-SAT'))
    ser
    
    2000-01-01    6
    2000-01-08    7
    2000-01-15    4
    2000-01-22    6
    2000-01-29    8
    2000-02-05    6
    2000-02-12    5
    2000-02-19    8
    2000-02-26    1
    2000-03-04    7
    Freq: W-SAT, dtype: int64
    

    31. How to fill an intermittent time series so all missing dates show up with values of previous non-missing date?

    Difficiulty Level: L2

    ser has missing dates and values. Make all missing dates appear and fill up with value from previous date.

    Input

    ser = pd.Series([1,10,3,np.nan], index=pd.to_datetime(['2000-01-01', '2000-01-03', '2000-01-06', '2000-01-08']))
    print(ser)
    #> 2000-01-01     1.0
    #> 2000-01-03    10.0
    #> 2000-01-06     3.0
    #> 2000-01-08     NaN
    #> dtype: float64
    

    Desired Output

    2000-01-01     1.0
    2000-01-02     1.0
    2000-01-03    10.0
    2000-01-04    10.0
    2000-01-05    10.0
    2000-01-06     3.0
    2000-01-07     3.0
    2000-01-08     NaN
    
    Show Solution
    # Input
    ser = pd.Series([1,10,3, np.nan], index=pd.to_datetime(['2000-01-01', '2000-01-03', '2000-01-06', '2000-01-08']))
    
    # Solution
    ser.resample('D').ffill()  # fill with previous value
    
    # Alternatives
    ser.resample('D').bfill()  # fill with next value
    ser.resample('D').bfill().ffill()  # fill next else prev value
    
    2000-01-01     1.0
    2000-01-02    10.0
    2000-01-03    10.0
    2000-01-04     3.0
    2000-01-05     3.0
    2000-01-06     3.0
    2000-01-07     3.0
    2000-01-08     3.0
    Freq: D, dtype: float64
    

    32. How to compute the autocorrelations of a numeric series?

    Difficiulty Level: L3

    Compute autocorrelations for the first 10 lags of ser. Find out which lag has the largest correlation.

    Input

    ser = pd.Series(np.arange(20) + np.random.normal(1, 10, 20))
    

    Desired output

    # values will change due to randomness
    [0.29999999999999999, -0.11, -0.17000000000000001, 0.46000000000000002, 0.28000000000000003, -0.040000000000000001, -0.37, 0.41999999999999998, 0.47999999999999998, 0.17999999999999999]
    Lag having highest correlation:  9
    
    Show Solution
    # Input
    ser = pd.Series(np.arange(20) + np.random.normal(1, 10, 20))
    
    # Solution
    autocorrelations = [ser.autocorr(i).round(2) for i in range(11)]
    print(autocorrelations[1:])
    print('Lag having highest correlation: ', np.argmax(np.abs(autocorrelations[1:]))+1)
    
    [0.29999999999999999, -0.11, -0.17000000000000001, 0.46000000000000002, 0.28000000000000003, -0.040000000000000001, -0.37, 0.41999999999999998, 0.47999999999999998, 0.17999999999999999]
    Lag having highest correlation:  9
    

    33. How to import only every nth row from a csv file to create a dataframe?

    Difficiulty Level: L2

    Import every 50th row of BostonHousing dataset as a dataframe.

    Show Solution
    # Solution 1: Use chunks and for-loop
    df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/BostonHousing.csv', chunksize=50)
    df2 = pd.DataFrame()
    for chunk in df:
        df2 = df2.append(chunk.iloc[0,:])
    
    
    # Solution 2: Use chunks and list comprehension
    df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/BostonHousing.csv', chunksize=50)
    df2 = pd.concat([chunk.iloc[0] for chunk in df], axis=1)
    df2 = df2.transpose()
    
    # Solution 3: Use csv reader
    import csv          
    with open('BostonHousing.csv', 'r') as f:
        reader = csv.reader(f)
        out = []
        for i, row in enumerate(reader):
            if i%50 == 0:
                out.append(row)
    
    df2 = pd.DataFrame(out[1:], columns=out[0])
    print(df2.head())
    
                      crim    zn  indus chas                  nox     rm   age  \
    0              0.21977   0.0   6.91    0  0.44799999999999995  5.602  62.0   
    1               0.0686   0.0   2.89    0                0.445  7.416  62.5   
    2   2.7339700000000002   0.0  19.58    0                0.871  5.597  94.9   
    3               0.0315  95.0   1.47    0  0.40299999999999997  6.975  15.3   
    4  0.19072999999999998  22.0   5.86    0                0.431  6.718  17.5   
    
          dis rad  tax ptratio       b  lstat  medv  
    0  6.0877   3  233    17.9   396.9   16.2  19.4  
    1  3.4952   2  276    18.0   396.9   6.19  33.2  
    2  1.5257   5  403    14.7  351.85  21.45  15.4  
    3  7.6534   3  402    17.0   396.9   4.56  34.9  
    4  7.8265   7  330    19.1  393.74   6.56  26.2  
    

    34. How to change column values when importing csv to a dataframe?

    Difficulty Level: L2

    Import the boston housing dataset, but while importing change the 'medv' (median house value) column so that values < 25 becomes ‘Low’ and > 25 becomes ‘High’.

    Show Solution
    # Solution 1: Using converter parameter
    df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/BostonHousing.csv', 
                     converters={'medv': lambda x: 'High' if float(x) > 25 else 'Low'})
    
    
    # Solution 2: Using csv reader
    import csv
    with open('BostonHousing.csv', 'r') as f:
        reader = csv.reader(f)
        out = []
        for i, row in enumerate(reader):
            if i > 0:
                row[13] = 'High' if float(row[13]) > 25 else 'Low'
            out.append(row)
    
    df = pd.DataFrame(out[1:], columns=out[0])
    print(df.head())
    
                       crim    zn indus chas                  nox  \
    0               0.00632  18.0  2.31    0   0.5379999999999999   
    1               0.02731   0.0  7.07    0                0.469   
    2               0.02729   0.0  7.07    0                0.469   
    3  0.032369999999999996   0.0  2.18    0  0.45799999999999996   
    4               0.06905   0.0  2.18    0  0.45799999999999996   
    
                      rm   age     dis rad  tax ptratio       b lstat  medv  
    0              6.575  65.2    4.09   1  296    15.3   396.9  4.98   Low  
    1              6.421  78.9  4.9671   2  242    17.8   396.9  9.14   Low  
    2              7.185  61.1  4.9671   2  242    17.8  392.83  4.03  High  
    3  6.997999999999999  45.8  6.0622   3  222    18.7  394.63  2.94  High  
    4              7.147  54.2  6.0622   3  222    18.7   396.9  5.33  High  
    

    35. How to create a dataframe with rows as strides from a given series?

    Difficiulty Level: L3

    Input

    L = pd.Series(range(15))
    

    Desired Output

    array([[ 0,  1,  2,  3],
           [ 2,  3,  4,  5],
           [ 4,  5,  6,  7],
           [ 6,  7,  8,  9],
           [ 8,  9, 10, 11],
           [10, 11, 12, 13]])
    
    Show Solution
    L = pd.Series(range(15))
    
    def gen_strides(a, stride_len=5, window_len=5):
        n_strides = ((a.size-window_len)//stride_len) + 1
        return np.array([a[s:(s+window_len)] for s in np.arange(0, a.size, stride_len)[:n_strides]])
    
    gen_strides(L, stride_len=2, window_len=4)
    
    array([[ 0,  1,  2,  3],
           [ 2,  3,  4,  5],
           [ 4,  5,  6,  7],
           [ 6,  7,  8,  9],
           [ 8,  9, 10, 11],
           [10, 11, 12, 13]])
    

    36. How to import only specified columns from a csv file?

    Difficulty Level: L1

    Import ‘crim’ and ‘medv’ columns of the BostonHousing dataset as a dataframe.

    Show Solution
    df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/BostonHousing.csv', usecols=['crim', 'medv'])
    print(df.head())
    
          crim  medv
    0  0.00632  24.0
    1  0.02731  21.6
    2  0.02729  34.7
    3  0.03237  33.4
    4  0.06905  36.2
    

    37. How to get the nrows, ncolumns, datatype, summary stats of each column of a dataframe? Also get the array and list equivalent.

    Difficulty Level: L2

    Get the number of rows, columns, datatype and summary statistics of each column of the Cars93 dataset. Also get the numpy array and list equivalent of the dataframe.

    Show Solution
    df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/Cars93_miss.csv')
    
    #  number of rows and columns
    print(df.shape)
    
    # datatypes
    print(df.dtypes)
    
    # how many columns under each dtype
    print(df.get_dtype_counts())
    print(df.dtypes.value_counts())
    
    # summary statistics
    df_stats = df.describe()
    
    # numpy array 
    df_arr = df.values
    
    # list
    df_list = df.values.tolist()
    
    (93, 27)
    Manufacturer           object
    Model                  object
    Type                   object
    Min.Price             float64
    Price                 float64
    Max.Price             float64
    MPG.city              float64
    MPG.highway           float64
    AirBags                object
    DriveTrain             object
    Cylinders              object
    EngineSize            float64
    Horsepower            float64
    RPM                   float64
    Rev.per.mile          float64
    Man.trans.avail        object
    Fuel.tank.capacity    float64
    Passengers            float64
    Length                float64
    Wheelbase             float64
    Width                 float64
    Turn.circle           float64
    Rear.seat.room        float64
    Luggage.room          float64
    Weight                float64
    Origin                 object
    Make                   object
    dtype: object
    float64    18
    object      9
    dtype: int64
    float64    18
    object      9
    dtype: int64
    

    38. How to extract the row and column number of a particular cell with given criterion?

    Difficulty Level: L1

    Input

    df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/Cars93_miss.csv')
    

    Which manufacturer, model and type has the highest Price? What is the row and column number of the cell with the highest Price value?

    Show Solution
    # Input
    df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/Cars93_miss.csv')
    
    # Solution
    # Get Manufacturer with highest price
    df.loc[df.Price == np.max(df.Price), ['Manufacturer', 'Model', 'Type']]
    
    # Get Row and Column number
    row, col = np.where(df.values == np.max(df.Price))
    
    # Get the value
    df.iat[row[0], col[0]]
    df.iloc[row[0], col[0]]
    
    # Alternates
    df.at[row[0], 'Price']
    df.get_value(row[0], 'Price')
    
    # The difference between `iat` - `iloc` vs `at` - `loc` is:
    # `iat` snd `iloc` accepts row and column numbers. 
    # Whereas `at` and `loc` accepts index and column names.
    
    61.899999999999999
    

    39. How to rename a specific columns in a dataframe?

    Difficulty Level: L2

    Rename the column Type as CarType in df and replace the ‘.’ in column names with ‘_’.

    Input

    df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/Cars93_miss.csv')
    print(df.columns)
    #> Index(['Manufacturer', 'Model', 'Type', 'Min.Price', 'Price', 'Max.Price',
    #>        'MPG.city', 'MPG.highway', 'AirBags', 'DriveTrain', 'Cylinders',
    #>        'EngineSize', 'Horsepower', 'RPM', 'Rev.per.mile', 'Man.trans.avail',
    #>        'Fuel.tank.capacity', 'Passengers', 'Length', 'Wheelbase', 'Width',
    #>        'Turn.circle', 'Rear.seat.room', 'Luggage.room', 'Weight', 'Origin',
    #>        'Make'],
    #>       dtype='object')
    

    Desired Solution

    print(df.columns)
    #> Index(['Manufacturer', 'Model', 'CarType', 'Min_Price', 'Price', 'Max_Price',
    #>        'MPG_city', 'MPG_highway', 'AirBags', 'DriveTrain', 'Cylinders',
    #>        'EngineSize', 'Horsepower', 'RPM', 'Rev_per_mile', 'Man_trans_avail',
    #>        'Fuel_tank_capacity', 'Passengers', 'Length', 'Wheelbase', 'Width',
    #>        'Turn_circle', 'Rear_seat_room', 'Luggage_room', 'Weight', 'Origin',
    #>        'Make'],
    #>       dtype='object')
    
    Show Solution
    # Input
    df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/Cars93_miss.csv')
    
    # Solution
    # Step 1:
    df=df.rename(columns = {'Type':'CarType'})
    # or
    df.columns.values[2] = "CarType"
    
    # Step 2:
    df.columns = df.columns.map(lambda x: x.replace('.', '_'))
    print(df.columns)
    
    Index(['Manufacturer', 'Model', 'CarType', 'Min_Price', 'Price', 'Max_Price',
           'MPG_city', 'MPG_highway', 'AirBags', 'DriveTrain', 'Cylinders',
           'EngineSize', 'Horsepower', 'RPM', 'Rev_per_mile', 'Man_trans_avail',
           'Fuel_tank_capacity', 'Passengers', 'Length', 'Wheelbase', 'Width',
           'Turn_circle', 'Rear_seat_room', 'Luggage_room', 'Weight', 'Origin',
           'Make'],
          dtype='object')
    

    40. How to check if a dataframe has any missing values?

    Difficulty Level: L1

    Check if df has any missing values.

    Input

    df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/Cars93_miss.csv')
    
    Show Solution
    # Input
    df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/Cars93_miss.csv')
    
    # Solution
    df.isnull().values.any()
    

    41. How to count the number of missing values in each column?

    Difficulty Level: L2

    Count the number of missing values in each column of df. Which column has the maximum number of missing values?

    Input

    df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/Cars93_miss.csv')
    
    Show Solution
    # Input
    df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/Cars93_miss.csv')
    
    # Solution
    n_missings_each_col = df.apply(lambda x: x.isnull().sum())
    n_missings_each_col.argmax()
    
    'Luggage.room'
    

    42. How to replace missing values of multiple numeric columns with the mean?

    Difficulty Level: L2

    Replace missing values in Min.Price and Max.Price columns with their respective mean.

    Input

    df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/Cars93_miss.csv')
    
    Show Solution
    # Input
    df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/Cars93_miss.csv')
    
    # Solution
    df_out = df[['Min.Price', 'Max.Price']] = df[['Min.Price', 'Max.Price']].apply(lambda x: x.fillna(x.mean()))
    print(df_out.head())
    
       Min.Price  Max.Price
    0  12.900000  18.800000
    1  29.200000  38.700000
    2  25.900000  32.300000
    3  17.118605  44.600000
    4  17.118605  21.459091
    

    43. How to use apply function on existing columns with global variables as additional arguments?

    Difficulty Level: L3

    In df, use apply method to replace the missing values in Min.Price with the column’s mean and those in Max.Price with the column’s median.

    Input

    df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/Cars93_miss.csv')
    

    Use Hint from StackOverflow

    Show Solution
    # Input
    df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/Cars93_miss.csv')
    
    # Solution
    d = {'Min.Price': np.nanmean, 'Max.Price': np.nanmedian}
    df[['Min.Price', 'Max.Price']] = df[['Min.Price', 'Max.Price']].apply(lambda x, d: x.fillna(d[x.name](x)), args=(d, ))
    

    44. How to select a specific column from a dataframe as a dataframe instead of a series?

    Difficulty Level: L2

    Get the first column (a) in df as a dataframe (rather than as a Series).

    Input

    df = pd.DataFrame(np.arange(20).reshape(-1, 5), columns=list('abcde'))
    
    Show Solution
    # Input
    df = pd.DataFrame(np.arange(20).reshape(-1, 5), columns=list('abcde'))
    
    # Solution
    type(df[['a']])
    type(df.loc[:, ['a']])
    type(df.iloc[:, [0]])
    
    # Alternately the following returns a Series
    type(df.a)
    type(df['a'])
    type(df.loc[:, 'a'])
    type(df.iloc[:, 1])
    
    pandas.core.series.Series
    

    45. How to change the order of columns of a dataframe?

    Difficulty Level: L3

    Actually 3 questions.

    1. In df, interchange columns 'a' and 'c'.

    2. Create a generic function to interchange two columns, without hardcoding column names.

    3. Sort the columns in reverse alphabetical order, that is colume 'e' first through column 'a' last.

    Input

    df = pd.DataFrame(np.arange(20).reshape(-1, 5), columns=list('abcde'))
    
    Show Solution
    # Input
    df = pd.DataFrame(np.arange(20).reshape(-1, 5), columns=list('abcde'))
    
    # Solution Q1
    df[list('cbade')]
    
    # Solution Q2 - No hard coding
    def switch_columns(df, col1=None, col2=None):
        colnames = df.columns.tolist()
        i1, i2 = colnames.index(col1), colnames.index(col2)
        colnames[i2], colnames[i1] = colnames[i1], colnames[i2]
        return df[colnames]
    
    df1 = switch_columns(df, 'a', 'c')
    
    # Solution Q3
    df[sorted(df.columns)]
    # or
    df.sort_index(axis=1, ascending=False, inplace=True)
    

    46. How to set the number of rows and columns displayed in the output?

    Difficulty Level: L2

    Change the pamdas display settings on printing the dataframe df it shows a maximum of 10 rows and 10 columns.

    Input

    df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/Cars93_miss.csv')
    
    Show Solution
    # Input
    df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/Cars93_miss.csv')
    
    # Solution
    pd.set_option('display.max_columns', 10)
    pd.set_option('display.max_rows', 10)
    # df
    
    # Show all available options
    # pd.describe_option()
    

    47. How to format or suppress scientific notations in a pandas dataframe?

    Difficulty Level: L2

    Suppress scientific notations like ‘e-03’ in df and print upto 4 numbers after decimal.

    Input

    df = pd.DataFrame(np.random.random(4)**10, columns=['random'])
    df
    #>          random
    #> 0  3.474280e-03
    #> 1  3.951517e-05
    #> 2  7.469702e-02
    #> 3  5.541282e-28
    

    Desired Output

    #>    random
    #> 0  0.0035
    #> 1  0.0000
    #> 2  0.0747
    #> 3  0.0000
    
    Show Solution
    # Input
    df = pd.DataFrame(np.random.random(4)**10, columns=['random'])
    
    # Solution 1: Rounding
    df.round(4)
    
    # Solution 2: Use apply to change format
    df.apply(lambda x: '%.4f' % x, axis=1)
    # or
    df.applymap(lambda x: '%.4f' % x)
    
    # Solution 3: Use set_option
    pd.set_option('display.float_format', lambda x: '%.4f' % x)
    
    # Solution 4: Assign display.float_format
    pd.options.display.float_format = '{:.4f}'.format
    print(df)
    
    # Reset/undo float formatting
    pd.options.display.float_format = None
    
       random
    0  0.0002
    1  0.5942
    2  0.0000
    3  0.0030
    

    48. How to format all the values in a dataframe as percentages?

    Difficulty Level: L2

    Format the values in column 'random' of df as percentages.

    Input

    df = pd.DataFrame(np.random.random(4), columns=['random'])
    df
    #>      random
    #> 0    .689723
    #> 1    .957224
    #> 2    .159157
    #> 3    .21082
    

    Desired Output

    #>      random
    #> 0    68.97%
    #> 1    95.72%
    #> 2    15.91%
    #> 3    2.10%
    
    Show Solution
    # Input
    df = pd.DataFrame(np.random.random(4), columns=['random'])
    
    # Solution
    out = df.style.format({
        'random': '{0:.2%}'.format,
    })
    
    out
    

    random
    021.66%
    144.90%
    285.69%
    392.12%

    49. How to filter every nth row in a dataframe?

    Difficulty Level: L1

    From df, filter the 'Manufacturer', 'Model' and 'Type' for every 20th row starting from 1st (row 0).

    Input

    df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/Cars93_miss.csv')
    
    Show Solution
    # Input
    df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/Cars93_miss.csv')
    
    # Solution
    print(df.iloc[::20, :][['Manufacturer', 'Model', 'Type']])
    
       Manufacturer    Model     Type
    0         Acura  Integra    Small
    20     Chrysler  LeBaron  Compact
    40        Honda  Prelude   Sporty
    60      Mercury   Cougar  Midsize
    80       Subaru   Loyale    Small
    

    50. How to create a primary key index by combining relevant columns?

    Difficulty Level: L2

    In df, Replace NaNs with ‘missing’ in columns 'Manufacturer', 'Model' and 'Type' and create a index as a combination of these three columns and check if the index is a primary key.

    Input

    df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/Cars93_miss.csv', usecols=[0,1,2,3,5])
    

    Desired Output

                           Manufacturer    Model     Type  Min.Price  Max.Price
    Acura_Integra_Small           Acura  Integra    Small       12.9       18.8
    missing_Legend_Midsize      missing   Legend  Midsize       29.2       38.7
    Audi_90_Compact                Audi       90  Compact       25.9       32.3
    Audi_100_Midsize               Audi      100  Midsize        NaN       44.6
    BMW_535i_Midsize                BMW     535i  Midsize        NaN        NaN
    
    Show Solution
    # Input
    df = pd.read_csv('https://raw.githubusercontent.com/selva86/datasets/master/Cars93_miss.csv', usecols=[0,1,2,3,5])
    
    # Solution
    df[['Manufacturer', 'Model', 'Type']] = df[['Manufacturer', 'Model', 'Type']].fillna('missing')
    df.index = df.Manufacturer + '_' + df.Model + '_' + df.Type
    print(df.index.is_unique)
    
    True
    

    51. How to get the row number of the nth largest value in a column?

    Difficulty Level: L2

    Find the row position of the 5th largest value of column 'a' in df.

    Input

    df = pd.DataFrame(np.random.randint(1, 30, 30).reshape(10,-1), columns=list('abc'))
    
    Show Solution
    # Input
    df = pd.DataFrame(np.random.randint(1, 30, 30).reshape(10,-1), columns=list('abc'))
    
    # Solution
    n = 5
    df['a'].argsort()[::-1][n]
    
        a   b   c
    0  27   7  25
    1   8   4  20
    2   1   7  17
    3  24   9  17
    4  21  15   9
    5  21  16  20
    6  19  27  25
    7  12   8  20
    8  11  16  28
    9  24  13   4
    
    4
    

    52. How to find the position of the nth largest value greater than a given value?

    Difficulty Level: L2

    In ser, find the position of the 2nd largest value greater than the mean.

    Input

    ser = pd.Series(np.random.randint(1, 100, 15))
    
    Show Solution
    # Input
    ser = pd.Series(np.random.randint(1, 100, 15))
    
    # Solution
    print('ser: ', ser.tolist(), 'mean: ', round(ser.mean()))
    np.argwhere(ser > ser.mean())[1]
    
    ser:  [7, 77, 16, 86, 60, 38, 34, 36, 83, 27, 16, 52, 50, 52, 54] mean:  46
    
    array([3])
    

    53. How to get the last n rows of a dataframe with row sum > 100?

    Difficulty Level: L2

    Get the last two rows of df whose row sum is greater than 100.

    df = pd.DataFrame(np.random.randint(10, 40, 60).reshape(-1, 4))
    
    Show Solution
    # Input
    df = pd.DataFrame(np.random.randint(10, 40, 60).reshape(-1, 4))
    
    # Solution
    # print row sums
    rowsums = df.apply(np.sum, axis=1)
    
    # last two rows with row sum greater than 100
    last_two_rows = df.iloc[np.where(rowsums > 100)[0][-2:], :]
    

    54. How to find and cap outliers from a series or dataframe column?

    Difficulty Level: L2

    Replace all values of ser in the lower 5%ile and greater than 95%ile with respective 5th and 95th %ile value.

    Input

    ser = pd.Series(np.logspace(-2, 2, 30))
    
    Show Solution
    # Input
    ser = pd.Series(np.logspace(-2, 2, 30))
    
    # Solution
    def cap_outliers(ser, low_perc, high_perc):
        low, high = ser.quantile([low_perc, high_perc])
        print(low_perc, '%ile: ', low, '|', high_perc, '%ile: ', high)
        ser[ser < low] = low
        ser[ser > high] = high
        return(ser)
    
    capped_ser = cap_outliers(ser, .05, .95)
    
    0.05 %ile:  0.016049294077 | 0.95 %ile:  63.8766722202
    

    55. How to reshape a dataframe to the largest possible square after removing the negative values?

    Difficulty Level: L3

    Reshape df to the largest possible square with negative values removed. Drop the smallest values if need be. The order of the positive numbers in the result should remain the same as the original.

    Input

    df = pd.DataFrame(np.random.randint(-20, 50, 100).reshape(10,-1))
    
    Show Solution
    # Input
    df = pd.DataFrame(np.random.randint(-20, 50, 100).reshape(10,-1))
    print(df)
    
    # Solution
    # Step 1: remove negative values from arr
    arr = df[df > 0].values.flatten()
    arr_qualified = arr[~np.isnan(arr)]
    
    # Step 2: find side-length of largest possible square
    n = int(np.floor(arr_qualified.shape[0]**.5))
    
    # Step 3: Take top n^2 items without changing positions
    top_indexes = np.argsort(arr_qualified)[::-1]
    output = np.take(arr_qualified, sorted(top_indexes[:n**2])).reshape(n, -1)
    print(output)
    
        0   1   2   3   4   5   6   7   8   9
    0  25 -13  17  16   0   6  22  44  10 -19
    1  47   4  -1  29 -13  12  41 -13  49  42
    2  20 -20   9  16 -17  -1  37  39  41  37
    3  27  44  -5   5   3 -12   0 -13  23  45
    4   8  27  -8  -3  48 -16  -5  40  16  10
    5  12  12  41 -12   3 -17  -3  27 -15  -1
    6  -9  -3  41 -13   1   0  28  33  -2  18
    7  18 -14  35   5   4  14   4  44  14  34
    8   1  24  26  28 -10  17 -14  14  38  17
    9  13  12   5   9 -16  -7  12 -18   1  24
    [[ 25.  17.  16.   6.  22.  44.  10.  47.]
     [  4.  29.  12.  41.  49.  42.  20.   9.]
     [ 16.  37.  39.  41.  37.  27.  44.   5.]
     [  3.  23.  45.   8.  27.  48.  40.  16.]
     [ 10.  12.  12.  41.   3.  27.  41.  28.]
     [ 33.  18.  18.  35.   5.   4.  14.   4.]
     [ 44.  14.  34.  24.  26.  28.  17.  14.]
     [ 38.  17.  13.  12.   5.   9.  12.  24.]]
    

    56. How to swap two rows of a dataframe?

    Difficulty Level: L2

    Swap rows 1 and 2 in df.

    Input

    df = pd.DataFrame(np.arange(25).reshape(5, -1))
    
    Show Solution
    # Input
    df = pd.DataFrame(np.arange(25).reshape(5, -1))
    
    # Solution
    def swap_rows(df, i1, i2):
        a, b = df.iloc[i1, :].copy(), df.iloc[i2, :].copy()
        df.iloc[i1, :], df.iloc[i2, :] = b, a
        return df
    
    print(swap_rows(df, 1, 2))
    
        0   1   2   3   4
    0   0   1   2   3   4
    1  10  11  12  13  14
    2   5   6   7   8   9
    3  15  16  17  18  19
    4  20  21  22  23  24
    

    57. How to reverse the rows of a dataframe?

    Difficulty Level: L2

    Reverse all the rows of dataframe df.

    Input

    df = pd.DataFrame(np.arange(25).reshape(5, -1))
    
    Show Solution
    # Input
    df = pd.DataFrame(np.arange(25).reshape(5, -1))
    
    # Solution 1
    df.iloc[::-1, :]
    
    # Solution 2
    print(df.loc[df.index[::-1], :])
    
        0   1   2   3   4
    4  20  21  22  23  24
    3  15  16  17  18  19
    2  10  11  12  13  14
    1   5   6   7   8   9
    0   0   1   2   3   4
    

    58. How to create one-hot encodings of a categorical variable (dummy variables)?

    Difficulty Level: L2

    Get one-hot encodings for column 'a' in the dataframe df and append it as columns.

    Input

    df = pd.DataFrame(np.arange(25).reshape(5,-1), columns=list('abcde'))
        a   b   c   d   e
    0   0   1   2   3   4
    1   5   6   7   8   9
    2  10  11  12  13  14
    3  15  16  17  18  19
    4  20  21  22  23  24
    

    Output

       0  5  10  15  20   b   c   d   e
    0  1  0   0   0   0   1   2   3   4
    1  0  1   0   0   0   6   7   8   9
    2  0  0   1   0   0  11  12  13  14
    3  0  0   0   1   0  16  17  18  19
    4  0  0   0   0   1  21  22  23  24
    
    Show Solution
    # Input
    df = pd.DataFrame(np.arange(25).reshape(5,-1), columns=list('abcde'))
    
    # Solution
    df_onehot = pd.concat([pd.get_dummies(df['a']), df[list('bcde')]], axis=1)
    print(df_onehot)
    
        a   b   c   d   e
    0   0   1   2   3   4
    1   5   6   7   8   9
    2  10  11  12  13  14
    3  15  16  17  18  19
    4  20  21  22  23  24
       0  5  10  15  20   b   c   d   e
    0  1  0   0   0   0   1   2   3   4
    1  0  1   0   0   0   6   7   8   9
    2  0  0   1   0   0  11  12  13  14
    3  0  0   0   1   0  16  17  18  19
    4  0  0   0   0   1  21  22  23  24
    

    59. Which column contains the highest number of row-wise maximum values?

    Difficulty Level: L2

    Obtain the column name with the highest number of row-wise maximum’s in df.

    df = pd.DataFrame(np.random.randint(1,100, 40).reshape(10, -1))
    
    Show Solution
    # Input
    df = pd.DataFrame(np.random.randint(1,100, 40).reshape(10, -1))
    
    # Solution
    print('Column with highest row maxes: ', df.apply(np.argmax, axis=1).value_counts().index[0])
    
    Column with highest row maxes:  2
    

    60. How to create a new column that contains the row number of nearest column by euclidean distance?

    Create a new column such that, each row contains the row number of nearest row-record by euclidean distance.

    Difficulty Level: L3

    Input

    df = pd.DataFrame(np.random.randint(1,100, 40).reshape(10, -1), columns=list('pqrs'), index=list('abcdefghij'))
    df
    #     p   q   r   s
    # a  57  77  13  62
    # b  68   5  92  24
    # c  74  40  18  37
    # d  80  17  39  60
    # e  93  48  85  33
    # f  69  55   8  11
    # g  39  23  88  53
    # h  63  28  25  61
    # i  18   4  73   7
    # j  79  12  45  34
    

    Desired Output

    df
    #    p   q   r   s nearest_row   dist
    # a  57  77  13  62           i  116.0
    # b  68   5  92  24           a  114.0
    # c  74  40  18  37           i   91.0
    # d  80  17  39  60           i   89.0
    # e  93  48  85  33           i   92.0
    # f  69  55   8  11           g  100.0
    # g  39  23  88  53           f  100.0
    # h  63  28  25  61           i   88.0
    # i  18   4  73   7           a  116.0
    # j  79  12  45  34           a   81.0
    
    Show Solution
    df = pd.DataFrame(np.random.randint(1,100, 40).reshape(10, -1), columns=list('pqrs'), index=list('abcdefghij'))
    
    # Solution
    import numpy as np
    
    # init outputs
    nearest_rows = []
    nearest_distance = []
    
    # iterate rows.
    for i, row in df.iterrows():
        curr = row
        rest = df.drop(i)
        e_dists = {}  # init dict to store euclidean dists for current row.
        # iterate rest of rows for current row
        for j, contestant in rest.iterrows():
            # compute euclidean dist and update e_dists
            e_dists.update({j: round(np.linalg.norm(curr.values - contestant.values))})
        # update nearest row to current row and the distance value
        nearest_rows.append(max(e_dists, key=e_dists.get))
        nearest_distance.append(max(e_dists.values()))
    
    df['nearest_row'] = nearest_rows
    df['dist'] = nearest_distance
    

    61. How to know the maximum possible correlation value of each column against other columns?

    Difficulty Level: L2

    Compute maximum possible absolute correlation value of each column against other columns in df.

    Input

    df = pd.DataFrame(np.random.randint(1,100, 80).reshape(8, -1), columns=list('pqrstuvwxy'), index=list('abcdefgh'))
    
    Show Solution
    # Input
    df = pd.DataFrame(np.random.randint(1,100, 80).reshape(8, -1), columns=list('pqrstuvwxy'), index=list('abcdefgh'))
    df
    
    # Solution
    abs_corrmat = np.abs(df.corr())
    max_corr = abs_corrmat.apply(lambda x: sorted(x)[-2])
    print('Maximum Correlation possible for each column: ', np.round(max_corr.tolist(), 2))
    
    Maximum Correlation possible for each column:  [ 0.91  0.57  0.55  0.71  0.53  0.26  0.91  0.71  0.69  0.71]
    

    62. How to create a column containing the minimum by maximum of each row?

    Difficulty Level: L2

    Compute the minimum-by-maximum for every row of df.

    df = pd.DataFrame(np.random.randint(1,100, 80).reshape(8, -1))
    
    Show Solution
    # Input
    df = pd.DataFrame(np.random.randint(1,100, 80).reshape(8, -1))
    
    # Solution 1
    min_by_max = df.apply(lambda x: np.min(x)/np.max(x), axis=1)
    
    # Solution 2
    min_by_max = np.min(df, axis=1)/np.max(df, axis=1)
    

    63. How to create a column that contains the penultimate value in each row?

    Difficulty Level: L2

    Create a new column 'penultimate' which has the second largest value of each row of df.

    Input

    df = pd.DataFrame(np.random.randint(1,100, 80).reshape(8, -1))
    
    Show Solution
    # Input
    df = pd.DataFrame(np.random.randint(1,100, 80).reshape(8, -1))
    
    # Solution
    out = df.apply(lambda x: x.sort_values().unique()[-2], axis=1)
    df['penultimate'] = out
    print(df)
    
        0   1   2   3   4   5   6   7   8   9  penultimate
    0  52  69  62   7  20  69  38  10  57  17           62
    1  52  94  49  63   1  90  14  76  20  84           90
    2  78  37  58   7  27  41  27  26  48  51           58
    3   6  39  99  36  62  90  47  25  60  84           90
    4  37  36  91  93  76  69  86  95  69   6           93
    5   5  54  73  61  22  29  99  27  46  24           73
    6  71  65  45   9  63  46   4  93  36  18           71
    7  85   7  76  46  65  97  64  52  28  80           85
    

    64. How to normalize all columns in a dataframe?

    Difficulty Level: L2

    1. Normalize all columns of df by subtracting the column mean and divide by standard deviation.
    2. Range all columns of df such that the minimum value in each column is 0 and max is 1.

    Don’t use external packages like sklearn.

    Input

    df = pd.DataFrame(np.random.randint(1,100, 80).reshape(8, -1))
    
    Show Solution
    # Input
    df = pd.DataFrame(np.random.randint(1,100, 80).reshape(8, -1))
    
    # Solution Q1
    out1 = df.apply(lambda x: ((x - x.mean())/x.std()).round(2))
    print('Solution Q1\n',out1)
    
    # Solution Q2
    out2 = df.apply(lambda x: ((x.max() - x)/(x.max() - x.min())).round(2))
    print('Solution Q2\n', out2)  
    
    Solution Q1
           0     1     2     3     4     5     6     7     8     9
    0  1.09  0.64 -0.33 -0.96 -1.30  0.06  0.38  1.18 -1.60  1.66
    1 -0.93 -2.36  0.87  1.47 -1.15  1.27  0.07 -0.87 -0.18  0.23
    2  1.53  0.48 -0.90  0.18 -0.33  0.81 -1.29  0.34  0.06 -0.55
    3  0.59 -0.24 -1.06  0.61  1.18 -1.23 -0.53 -0.45  0.34 -1.25
    4  0.18  0.33  1.07  1.17  0.50 -0.26 -0.25 -1.45  1.11  1.11
    5 -1.16  0.64 -0.93 -0.59 -0.15  0.63  1.02  1.13  1.20 -0.19
    6 -0.58  0.07 -0.20 -0.87 -0.22 -1.62 -1.04  0.81 -1.23 -1.04
    7 -0.73  0.45  1.47 -1.02  1.47  0.34  1.65 -0.71  0.31  0.02
    Solution Q2
           0     1     2     3     4     5     6     7     8     9
    0  0.16  0.00  0.71  0.98  1.00  0.42  0.43  0.00  1.00  0.00
    1  0.91  1.00  0.24  0.00  0.95  0.00  0.54  0.78  0.49  0.49
    2  0.00  0.05  0.93  0.52  0.65  0.16  1.00  0.32  0.41  0.76
    3  0.35  0.29  1.00  0.35  0.10  0.86  0.74  0.62  0.31  1.00
    4  0.50  0.10  0.16  0.12  0.35  0.53  0.65  1.00  0.03  0.19
    5  1.00  0.00  0.95  0.83  0.58  0.22  0.22  0.02  0.00  0.64
    6  0.78  0.19  0.66  0.94  0.61  1.00  0.91  0.14  0.87  0.93
    7  0.84  0.06  0.00  1.00  0.00  0.32  0.00  0.72  0.32  0.56
    

    65. How to compute the correlation of each row with the suceeding row?

    Difficulty Level: L2

    Compute the correlation of each row of df with its succeeding row.

    Input

    df = pd.DataFrame(np.random.randint(1,100, 80).reshape(8, -1))
    
    Show Solution
    # Input
    df = pd.DataFrame(np.random.randint(1,100, 80).reshape(8, -1))
    
    # Solution
    [df.iloc[i].corr(df.iloc[i+1]).round(2) for i in range(df.shape[0])[:-1]]
    
        0   1   2   3   4   5   6   7   8   9
    0  93  49  26   2  96  56  11  73  90  65
    1  54  17  47  52  65   9  21  87  94   4
    2  51  11  44  77  37  57  17  25  95  26
    3  84   8  61  43  63  63  59  65  69  29
    4   8  27  53  95  10  35  16  61  39  83
    5  30  70  91  26  12  44  37  71  21  48
    6  66  44  47  44  29  99  86  78  31   1
    7  17  40  28  12  89  95  79  54  81  47
    
    [0.40999999999999998,
     0.47999999999999998,
     0.42999999999999999,
     -0.37,
     0.23000000000000001,
     0.14000000000000001,
     0.22]
    

    66. How to replace both the diagonals of dataframe with 0?

    Difficulty Level: L2

    Replace both values in both diagonals of df with 0.

    Input

    df = pd.DataFrame(np.random.randint(1,100, 100).reshape(10, -1))
    df
    #     0   1   2   3   4   5   6   7   8   9
    # 0  11  46  26  44  11  62  18  70  68  26
    # 1  87  71  52  50  81  43  83  39   3  59
    # 2  47  76  93  77  73   2   2  16  14  26
    # 3  64  18  74  22  16  37  60   8  66  39
    # 4  10  18  39  98  25   8  32   6   3  29
    # 5  29  91  27  86  23  84  28  31  97  10
    # 6  37  71  70  65   4  72  82  89  12  97
    # 7  65  22  97  75  17  10  43  78  12  77
    # 8  47  57  96  55  17  83  61  85  26  86
    # 9  76  80  28  45  77  12  67  80   7  63
    

    Desired output

    #     0   1   2   3   4   5   6   7   8   9
    # 0   0  46  26  44  11  62  18  70  68   0
    # 1  87   0  52  50  81  43  83  39   0  59
    # 2  47  76   0  77  73   2   2   0  14  26
    # 3  64  18  74   0  16  37   0   8  66  39
    # 4  10  18  39  98   0   0  32   6   3  29
    # 5  29  91  27  86   0   0  28  31  97  10
    # 6  37  71  70   0   4  72   0  89  12  97
    # 7  65  22   0  75  17  10  43   0  12  77
    # 8  47   0  96  55  17  83  61  85   0  86
    # 9   0  80  28  45  77  12  67  80   7   0
    
    Show Solution
    # Input
    df = pd.DataFrame(np.random.randint(1,100, 100).reshape(10, -1))
    
    # Solution
    for i in range(df.shape[0]):
        df.iat[i, i] = 0
        df.iat[df.shape[0]-i-1, i] = 0
    

    67. How to get the particular group of a groupby dataframe by key?

    Difficulty Level: L2

    This is a question related to understanding of grouped dataframe. From df_grouped, get the group belonging to 'apple' as a dataframe.

    Input

    df = pd.DataFrame({'col1': ['apple', 'banana', 'orange'] * 3,
                       'col2': np.random.rand(9),
                       'col3': np.random.randint(0, 15, 9)})
    
    df_grouped = df.groupby(['col1'])
    
    # Input
    df = pd.DataFrame({'col1': ['apple', 'banana', 'orange'] * 3,
                       'col2': np.random.rand(9),
                       'col3': np.random.randint(0, 15, 9)})
    
    df_grouped = df.groupby(['col1'])
    
    # Solution 1
    df_grouped.get_group('apple')
    
    # Solution 2
    for i, dff in df_grouped:
        if i == 'apple':
            print(dff)
    
        col1      col2  col3
    0  apple  0.673434     7
    3  apple  0.182348    14
    6  apple  0.050457     3
    
    [/expand]

    68. How to get the n’th largest value of a column when grouped by another column?

    Difficulty Level: L2

    In df, find the second largest value of 'taste' for 'banana'

    Input

    df = pd.DataFrame({'fruit': ['apple', 'banana', 'orange'] * 3,
                       'rating': np.random.rand(9),
                       'price': np.random.randint(0, 15, 9)})
    
                   
    
    Show Solution
    # Input
    df = pd.DataFrame({'fruit': ['apple', 'banana', 'orange'] * 3,
                       'taste': np.random.rand(9),
                       'price': np.random.randint(0, 15, 9)})
    
    print(df)
    
    # Solution
    df_grpd = df['taste'].groupby(df.fruit)
    df_grpd.get_group('banana').sort_values().iloc[-2]
    
        fruit  price     taste
    0   apple      7  0.190229
    1  banana      2  0.438063
    2  orange      1  0.860182
    3   apple      6  0.042149
    4  banana      2  0.896021
    5  orange      5  0.255107
    6   apple      6  0.874533
    7  banana      4  0.696274
    8  orange      9  0.140713
    
    0.69627423645996078
    

    69. How to compute grouped mean on pandas dataframe and keep the grouped column as another column (not index)?

    Difficulty Level: L1

    In df, Compute the mean price of every fruit, while keeping the fruit as another column instead of an index.

    Input

    df = pd.DataFrame({'fruit': ['apple', 'banana', 'orange'] * 3,
                       'rating': np.random.rand(9),
                       'price': np.random.randint(0, 15, 9)})
    
                   
    
    Show Solution
    # Input
    df = pd.DataFrame({'fruit': ['apple', 'banana', 'orange'] * 3,
                       'rating': np.random.rand(9),
                       'price': np.random.randint(0, 15, 9)})
    
    # Solution
    out = df.groupby('fruit', as_index=False)['price'].mean()
    print(out)
    
        fruit      price
    0   apple  11.000000
    1  banana   6.333333
    2  orange   6.333333
    

    70. How to join two dataframes by 2 columns so they have only the common rows?

    Difficulty Level: L2

    Join dataframes df1 and df2 by &#