In [1]:
import pandas as pd
from mlxtend.frequent_patterns import apriori
from mlxtend.frequent_patterns import association_rules
import numpy as np
import matplotlib.pyplot as plt

pd.set_option('display.max_columns',100)
#pd.set_option('display.width',2000)
pd.set_option('display.max_colwidth', -1)
In [2]:
#lets have a look at the data file
import csv
with open('./groceries.csv', newline='') as csvfile:
    reader = csv.reader(csvfile, delimiter=',', quotechar='|')
    for row_number, row in enumerate(reader):
        print("[{}] {}".format(row_number,', '.join(row)))        
        if (row_number>10):
            break
[0] citrus fruit, semi-finished bread, margarine, ready soups
[1] tropical fruit, yogurt, coffee
[2] whole milk
[3] pip fruit, yogurt, cream cheese, meat spreads
[4] other vegetables, whole milk, condensed milk, long life bakery product
[5] whole milk, butter, yogurt, rice, abrasive cleaner
[6] rolls/buns
[7] other vegetables, UHT-milk, rolls/buns, bottled beer, liquor (appetizer)
[8] potted plants
[9] whole milk, cereals
[10] tropical fruit, other vegetables, white bread, bottled water, chocolate
[11] citrus fruit, tropical fruit, whole milk, butter, curd, yogurt, flour, bottled water, dishes
In [3]:
#load dataset
df = pd.read_table('./groceries.csv',header=None)
#create binary matrix to be used as input to apriori
df1= df.iloc[:,0].str.get_dummies(sep=',')
df1.head()
Out[3]:
Instant food products UHT-milk abrasive cleaner artif. sweetener baby cosmetics baby food bags baking powder bathroom cleaner beef berries beverages bottled beer bottled water brandy brown bread butter butter milk cake bar candles candy canned beer canned fish canned fruit canned vegetables cat food cereals chewing gum chicken chocolate chocolate marshmallow citrus fruit cleaner cling film/bags cocoa drinks coffee condensed milk cooking chocolate cookware cream cream cheese curd curd cheese decalcifier dental care dessert detergent dish cleaner dishes dog food ... ready soups red/blush wine rice roll products rolls/buns root vegetables rubbing alcohol rum salad dressing salt salty snack sauces sausage seasonal products semi-finished bread shopping bags skin care sliced cheese snack products soap soda soft cheese softener sound storage medium soups sparkling wine specialty bar specialty cheese specialty chocolate specialty fat specialty vegetables spices spread cheese sugar sweet spreads syrup tea tidbits toilet cleaner tropical fruit turkey vinegar waffles whipped/sour cream whisky white bread white wine whole milk yogurt zwieback
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ... 1 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0
2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0
3 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0
4 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0

5 rows × 169 columns

In [4]:
#count number of transactions in dataset 
len(df1.index)
Out[4]:
9835
In [5]:
#how many transactions contain beef
df1['beef'].sum()
Out[5]:
516
In [6]:
#count number of transactions per product
df1.sum().nlargest(10)
Out[6]:
whole milk          2513
other vegetables    1903
rolls/buns          1809
soda                1715
yogurt              1372
bottled water       1087
root vegetables     1072
tropical fruit      1032
shopping bags       969 
sausage             924 
dtype: int64
In [7]:
#plot products with most transactions
df1.sum().nlargest(10).plot.bar(title='Top-10 products with their cumulative sales')
Out[7]:
<matplotlib.axes._subplots.AxesSubplot at 0x118b0bc18>
In [19]:
#plot products with most transactions
df1.sum().nsmallest(10).plot.bar(title='Bottom-10 products with their cumulative sales')
Out[19]:
<matplotlib.axes._subplots.AxesSubplot at 0x11a9253c8>
In [9]:
#row sum computes number of products in a basket
dfcounts=df1.sum(axis=1)
plt.xlabel('Number of items in a transaction')
plt.ylabel('Number of Transactions')
plt.hist(dfcounts,bins=range(1,32),rwidth=0.8)
plt.show()
In [10]:
# Build up the frequent items
frequent_itemsets = apriori(df1, min_support=0.002, use_colnames=True)
frequent_itemsets
Out[10]:
support itemsets
0 0.008033 (Instant food products)
1 0.033452 (UHT-milk)
2 0.003559 (abrasive cleaner)
3 0.003254 (artif. sweetener)
4 0.017692 (baking powder)
5 0.002745 (bathroom cleaner)
6 0.052466 (beef)
7 0.033249 (berries)
8 0.026029 (beverages)
9 0.080529 (bottled beer)
10 0.110524 (bottled water)
11 0.004169 (brandy)
12 0.064870 (brown bread)
13 0.055414 (butter)
14 0.027961 (butter milk)
15 0.013218 (cake bar)
16 0.008948 (candles)
17 0.029893 (candy)
18 0.077682 (canned beer)
19 0.015048 (canned fish)
20 0.003254 (canned fruit)
21 0.010778 (canned vegetables)
22 0.023284 (cat food)
23 0.005694 (cereals)
24 0.021047 (chewing gum)
25 0.042908 (chicken)
26 0.049619 (chocolate)
27 0.009049 (chocolate marshmallow)
28 0.082766 (citrus fruit)
29 0.005084 (cleaner)
... ... ...
4193 0.004881 (tropical fruit, rolls/buns, yogurt, whole milk)
4194 0.003050 (whipped/sour cream, whole milk, rolls/buns, yogurt)
4195 0.002745 (tropical fruit, whole milk, sausage, root vegetables)
4196 0.003254 (yogurt, whole milk, sausage, root vegetables)
4197 0.002440 (yogurt, whole milk, soda, root vegetables)
4198 0.002745 (tropical fruit, whipped/sour cream, whole milk, root vegetables)
4199 0.002339 (tropical fruit, whipped/sour cream, yogurt, root vegetables)
4200 0.005694 (tropical fruit, yogurt, whole milk, root vegetables)
4201 0.003660 (whipped/sour cream, yogurt, whole milk, root vegetables)
4202 0.002237 (yogurt, whole milk, sausage, soda)
4203 0.003152 (tropical fruit, yogurt, whole milk, sausage)
4204 0.003152 (tropical fruit, yogurt, whole milk, soda)
4205 0.004372 (tropical fruit, whipped/sour cream, yogurt, whole milk)
4206 0.002034 (tropical fruit, yogurt, whole milk, white bread)
4207 0.002034 (whole milk, other vegetables, tropical fruit, yogurt, bottled water)
4208 0.002339 (butter, whole milk, other vegetables, tropical fruit, yogurt)
4209 0.003152 (whole milk, other vegetables, tropical fruit, citrus fruit, root vegetables)
4210 0.002339 (whole milk, other vegetables, yogurt, citrus fruit, root vegetables)
4211 0.002440 (whole milk, other vegetables, tropical fruit, yogurt, citrus fruit)
4212 0.002034 (fruit/vegetable juice, whole milk, other vegetables, yogurt, root vegetables)
4213 0.002440 (whole milk, other vegetables, tropical fruit, pip fruit, root vegetables)
4214 0.002339 (whole milk, other vegetables, yogurt, pip fruit, root vegetables)
4215 0.002339 (whole milk, other vegetables, tropical fruit, yogurt, pip fruit)
4216 0.002034 (rolls/buns, whole milk, other vegetables, tropical fruit, root vegetables)
4217 0.002440 (rolls/buns, whole milk, other vegetables, yogurt, root vegetables)
4218 0.002542 (rolls/buns, whole milk, other vegetables, tropical fruit, yogurt)
4219 0.003559 (whole milk, other vegetables, tropical fruit, yogurt, root vegetables)
4220 0.002339 (whipped/sour cream, whole milk, other vegetables, yogurt, root vegetables)
4221 0.002440 (whipped/sour cream, whole milk, other vegetables, tropical fruit, yogurt)
4222 0.002237 (rolls/buns, whole milk, tropical fruit, yogurt, root vegetables)

4223 rows × 2 columns

In [11]:
#plot itemsets with exactly 3 items with highest support values
frequent_itemsets[frequent_itemsets.itemsets.str.len()==3].sort_values(by=['support'],ascending=False).head()
Out[11]:
support itemsets
3486 0.023183 (whole milk, root vegetables, other vegetables)
3546 0.022267 (yogurt, whole milk, other vegetables)
3473 0.017895 (rolls/buns, whole milk, other vegetables)
3535 0.017082 (tropical fruit, whole milk, other vegetables)
3698 0.015557 (rolls/buns, yogurt, whole milk)
In [12]:
# Create the rules
rules = association_rules(frequent_itemsets)
rules[['antecedents','consequents','support','confidence','lift']]
Out[12]:
antecedents consequents support confidence lift
0 (curd, hamburger meat) (whole milk) 0.002542 0.806452 3.156169
1 (rolls/buns, herbs) (whole milk) 0.002440 0.800000 3.130919
2 (tropical fruit, herbs) (whole milk) 0.002339 0.821429 3.214783
3 (pork, butter, other vegetables) (whole milk) 0.002237 0.846154 3.311549
4 (curd, domestic eggs, other vegetables) (whole milk) 0.002847 0.823529 3.223005
5 (tropical fruit, whole milk, grapes) (other vegetables) 0.002034 0.800000 4.134524
6 (tropical fruit, root vegetables, whole milk, citrus fruit) (other vegetables) 0.003152 0.885714 4.577509
7 (root vegetables, yogurt, citrus fruit, other vegetables) (whole milk) 0.002339 0.821429 3.214783
8 (fruit/vegetable juice, yogurt, whole milk, root vegetables) (other vegetables) 0.002034 0.800000 4.134524
9 (fruit/vegetable juice, yogurt, root vegetables, other vegetables) (whole milk) 0.002034 0.833333 3.261374
10 (tropical fruit, rolls/buns, yogurt, root vegetables) (whole milk) 0.002237 0.814815 3.188899
In [13]:
#list rules with high lift and confidence values
rules[ (rules['lift'] >= 4) &
       (rules['confidence'] >= 0.8) ] [['antecedents','consequents','support','confidence','lift']]
Out[13]:
antecedents consequents support confidence lift
5 (tropical fruit, whole milk, grapes) (other vegetables) 0.002034 0.800000 4.134524
6 (tropical fruit, root vegetables, whole milk, citrus fruit) (other vegetables) 0.003152 0.885714 4.577509
8 (fruit/vegetable juice, yogurt, whole milk, root vegetables) (other vegetables) 0.002034 0.800000 4.134524
In [14]:
def draw(rules):
    import networkx as nx  
    G = nx.DiGraph()

    for i,row in rules.iterrows():      
        for c in row['consequents']:
            G.add_nodes_from([c])
            for a in row['antecedents']: 
                G.add_nodes_from([a])
                G.add_edge(a,c,color='black',weight=row['confidence'])
    color_map=[]
    for node in G:
        #color blue nodes that participate as a consequent to a rule
        if (rules['consequents'].str.contains(node, regex=False).any()):
            color_map.append('blue') 
        else:
            color_map.append('red')  

    edges = G.edges()
    colors = [G[u][v]['color'] for u,v in edges]
    weights = [G[u][v]['weight'] for u,v in edges]
    
    min_weight=min(weights)-0.01
    max_weight=max(weights)
    weights = [5*(w-min_weight)/(max_weight-min_weight) for w in weights]
    
    pos = nx.spring_layout(G, k=10, scale=1)
    nx.draw(G, pos,edges=edges,edge_color=colors,node_color=color_map,width=weights,font_size=8,with_labels=False)            
    for p in pos:  # raise text positions
      pos[p][1] += 0.18
    nx.draw_networkx_labels(G, pos)
 
    plt.show()
In [15]:
draw(rules)  
In [16]:
#list rules that contain 'citrus fruit' in the LHS
rules[rules['antecedents'].apply(str).str.contains('citrus fruit')][['antecedents','consequents','support','confidence','lift']]
Out[16]:
antecedents consequents support confidence lift
6 (tropical fruit, root vegetables, whole milk, citrus fruit) (other vegetables) 0.003152 0.885714 4.577509
7 (root vegetables, yogurt, citrus fruit, other vegetables) (whole milk) 0.002339 0.821429 3.214783
In [17]:
rules.plot.scatter(x='support',y='confidence',c='lift', colormap='viridis')
Out[17]:
<matplotlib.axes._subplots.AxesSubplot at 0x11a714438>
In [18]:
rules[ (rules['lift'] >= 4.4)][['antecedents','consequents','support','confidence','lift']]
Out[18]:
antecedents consequents support confidence lift
6 (tropical fruit, root vegetables, whole milk, citrus fruit) (other vegetables) 0.003152 0.885714 4.577509
In [ ]: