首页 > 技术文章 > apriori算法

Qi-Lin 2022-02-18 17:35 原文

Apriori算法简单实现

前言

以如下数据为例,来说明算法的运行过程,找出其频繁项。数据中每一行代表一条数据,每一列可以代表待关联的事物,比如每个客户购买的每个商品

[['a','c','e'],
['b','d'],
['b','c'],
['a','b','c','d'],
['a','b'],
['b','c'],
['a','b'],
['a','b','c','e'],
['a','b','c'],
['a','c','e']]

为了便于处理,将数据转化为如下格式,其中每个元素为1,表示用户购买了该商品,为0表示没有购买

   a  b  c  d  e
0  1  0  1  0  1
1  0  1  0  1  0
2  0  1  1  0  0
3  1  1  1  1  0
4  1  1  0  0  0
5  0  1  1  0  0
6  1  1  0  0  0
7  1  1  1  0  1
8  1  1  1  0  0
9  1  0  1  0  1

最后,设置阈值为3。

第一轮处理

首先得到个数为1的项集,然后统计数据中每一列值为1的个数(其实就是统计每一个项集的个数),最后将个数小于阈值的项集去除,并保存在stop表中。这样得到第一轮处理后的频繁项集。

开始循环处理

循环结束的条件是,频繁项集的个数为0。
首先,利用上一步得到的频繁项集进行连接(即求各个集合的并),然后去除重复的项集和项集中元素错误的(如第二次处理时,每个项集元素个数应该为2,如果出现个数为1,为3的,就要删除)。然后进行剪枝操作(即去除子集在stop表中的项集),最后统计项集在原数据中的个数,最后将个数小于阈值的项集去除,并保存在stop表中。这样得到第二轮处理后的频繁项集。
统计时采用的策略如下,以统计ac的次数为例,只要找到a,c对应的列,将每一行相加,然后统计其和为2的个数,就是ac出现的次数:

之后进行循环,直至满足循环结束条件。

以上述数据为例的算法过程可视化如下:

程序运行流程及结果如下:

  • 去掉['d']
  • 频繁项[(['a', 'b'], 5), (['a', 'c'], 5), (['a', 'e'], 3), (['c', 'b'], 5), (['e', 'c'], 3)]
  • 去掉[['e', 'b']]
  • 频繁项[(['a', 'c', 'b'], 3), (['a', 'c', 'e'], 3)]
  • 去掉[]
  • 频繁项[]
  • 去掉[['a', 'b', 'e', 'c']]
  • 结果:[(['a', 'b'], 5), (['c', 'a'], 5), (['a', 'e'], 3), (['c', 'b'], 5), (['c', 'e'], 3), (['c', 'a', 'b'], 3), (['c', 'a', 'e'], 3)]

运行可视化如下:

代码

from copy import copy
import pandas as pd

def Apriori(data,th):
    #保存最终频繁项
    result=[]

    #得到要统计的每一项
    col=data.columns.tolist()
    #统计每一项个数
    tmp=data[data==1].count().tolist()

    #根据每一项统计的数目,删除比阈值小的
    stop=[]
    f=list(zip(col,tmp))
    f_copy = copy(f)
    for i, j in f_copy:
        if j<th:
            stop.append(i)
            f.remove((i,j))
    # result.extend(f)

    # print('频繁项')
    # print(f)
    print('去掉')
    print(stop)


    turn=1
    while True:
        turn=turn+1

        # 得到要统计的每一项
        tmp = []
        col=[]
        d=[i[0] for i in f]
        col_set = []
        #连接操作
        for i in d:
            for j in d[d.index(i) + 1:]:
                #取并集
                item=set(i).union(set(j))
                #删除连接后不符合要求的元素
                #删除个数不对的,删除重复的
                if (len(item)==turn) and  (item not in col_set) :
                    #剪枝操作,删除子集就不是频繁项的
                    if (len(stop)!=0):
                        judge=[set(n).issubset(item) for n in stop]
                        if (True not in judge):
                            col_set.append(item)
                    else:
                        col_set.append(item)
                col=[list(i) for i in col_set]

        #直到没有频繁项就跳出循环
        if len(col)==0:
            break

        # 统计每一项个数
        for i in col:
            tmp.append(data[data[i].sum(axis=1) == turn].count().tolist()[0])

        # 根据每一项统计的数目,删除比阈值小的
        stop = []
        f = list(zip(col, tmp))
        f_copy=copy(f)
        for i, j in f_copy:
            if j < th:
                stop.append(i)
                f.remove((i, j))
        result.extend(f)

        print('频繁项')
        print(f)
        print('去掉')
        print(stop)
    print('结果')
    print(result)


data=pd.DataFrame({'a':[1,0,0,1,1,0,1,1,1,1],'b':[0,1,1,1,1,1,1,1,1,0],'c':[1,0,1,1,0,1,0,1,1,1],'d':[0,1,0,1,0,0,0,0,0,0],'e':[1,0,0,0,0,0,0,1,0,1]})
print(data)
Apriori(data,3)

推荐阅读