首页 > 解决方案 > 如何使用 HDBSCAN 正确聚类一维数据集?

问题描述

我下面的数据集显示了每个价格的产品销售额(下载数据集 csv 的链接):

     price      quantity
0    5098.0        20
1    5098.5        40
2    5099.0        10
3    5100.0        90
4    5100.5        20
..      ...       ...
290  5247.0       150
291  5247.5        30
292  5248.0       150
293  5248.5        20
294  5249.0        55

[295 rows x 2 columns]

我想要实现的是使用 HDBSCAN 和 sklearn 对密集区域(下面的矩形)进行聚类。我们有四个区域,但是区域 3 和 4 也可以组合成一个大区域,通过更改函数调用中 的参数min_cluster_sizemin_samples ,这将导致整个数据集上只有 3 个区域。在此处输入图像描述

这是我的代码:

import hdbscan
import plotly.express as px
import pandas as pd
import numpy as np

data = pd.read_csv('data_set.csv')

price = group['price'].values.flatten()
price = price[:,np.newaxis]
weight = group['quantity'].values.flatten()

kde = KernelDensity(kernel='gaussian', bandwidth=1.5).fit(price,sample_weight=weight)
#the multiplication factor is only for visualization purposes
data['prob'] = np.exp(kde.score_samples(price))*85000

fig = px.line(data,x='price',y='prob')
fig.add_bar(x=data['price'],y=data['quantity'])
fig.show()

在此处输入图像描述

data = data[['price','quantity']]

clusterer = hdbscan.HDBSCAN(min_cluster_size=4,min_samples=8)
clusterer.fit(data)
data['cluster'] = clusterer.labels_

fig = px.bar(data,x='price',y='quantity',color='cluster',orientation='v')
fig.show()

在此处输入图像描述

问题是结果,聚类没有按预期工作(下图 x 上图)它对幅度进行聚类,而不是算法中提到的密集区域。我在代码中遗漏了什么? 在此处输入图像描述

我已经尝试了以下事情:规范化数据(两个轴)并在调用 HDBSCAN 类之前交换轴。任何帮助,将不胜感激。我有点迷失在这段代码中,但我认为通过阅读文档可以直接解决这个特定问题,因为 HDBSCAN 可以很好地处理密度和噪声。

标签: pythonmachine-learningscikit-learnhierarchical-clusteringhdbscan

解决方案


您实现这一点的方式,实际上是在尝试对二维数据进行聚类。当您将聚类结果可视化为散点图时,这更有意义:

scatter_clustering

为了按照我相信您的意图对一维数据进行聚类,您可以重塑数据。本质上,您需要一个价格列表,其中每个price值在列表quantity时间中重复。这对于 numpy 来说非常简单:

data_1d = np.array(np.repeat(data.price, data.quantity)).reshape(-1, 1)

这使

array([[5098.],
       [5098.],
       [5098.],
       ...,
       [5249.],
       [5249.],
       [5249.]])

然后你可以直接在这个 numpy 数组上进行集群,但是你需要显着增加min_cluster_sizemin_samples因为你现在有更多的值可以集群:

clusterer = hdbscan.HDBSCAN(min_cluster_size=5000, min_samples=200)
clusterer.fit(data_1d)

最后,我们可以组合集群标签,为每个 选择出现频率最高的标签*** price,然后按 分组price

clustered_data_1d = pd.DataFrame(np.concatenate((data_1d, clusterer.labels_.reshape(-1, 1)), axis=1), columns=['price', 'cluster'])
clustered_data_1d['quantity'] = 1
grouped_data_1d = clustered_data_1d.groupby('price').agg({'cluster': lambda x: x.value_counts().index[0], 'quantity': np.sum}).reset_index()

为了验证我们得到了预期的结果,让我们绘制:

fig = px.bar(grouped_data_1d, x='price', y='quantity', color='cluster', orientation='v')
fig.update_traces(dict(marker_line_width=0))
fig.show()

在此处输入图像描述

看起来 HDBSCAN 使用其他默认参数生成的集群与您的预期非常相似,但我相信如果您的最终应用程序需要更少的集群,您可以稍微调整一下。

***使用“模式”或最常出现的集群标签对我来说可能有点懒惰。您还可以考虑取平均值和四舍五入,或者找到price每个标签的最低值和最高值并将它们用作集群端点,或者完全使用其他东西!


为希望复制的人复制粘贴的完整代码:

import pandas as pd
import numpy as np
import hdbscan
import plotly.express as px

data_dct = {'price': {0: 5098.0, 1: 5098.5, 2: 5099.0, 3: 5100.0, 4: 5100.5, 5: 5101.0, 6: 5101.5, 7: 5102.0, 8: 5102.5, 9: 5103.0, 10: 5103.5, 11: 5104.0, 12: 5104.5, 13: 5105.0, 14: 5105.5, 15: 5106.0, 16: 5106.5, 17: 5107.0, 18: 5107.5, 19: 5108.0, 20: 5108.5, 21: 5109.0, 22: 5109.5, 23: 5110.0, 24: 5110.5, 25: 5111.0, 26: 5111.5, 27: 5112.0, 28: 5112.5, 29: 5113.0, 30: 5113.5, 31: 5114.0, 32: 5114.5, 33: 5115.0, 34: 5115.5, 35: 5116.0, 36: 5116.5, 37: 5117.0, 38: 5117.5, 39: 5118.0, 40: 5118.5, 41: 5119.0, 42: 5119.5, 43: 5120.0, 44: 5120.5, 45: 5121.0, 46: 5121.5, 47: 5122.0, 48: 5122.5, 49: 5123.0, 50: 5123.5, 51: 5124.0, 52: 5124.5, 53: 5125.0, 54: 5125.5, 55: 5126.0, 56: 5126.5, 57: 5127.0, 58: 5127.5, 59: 5128.0, 60: 5128.5, 61: 5129.0, 62: 5129.5, 63: 5130.0, 64: 5130.5, 65: 5131.0, 66: 5131.5, 67: 5132.0, 68: 5132.5, 69: 5133.0, 70: 5133.5, 71: 5134.0, 72: 5134.5, 73: 5135.0, 74: 5135.5, 75: 5136.0, 76: 5136.5, 77: 5137.0, 78: 5137.5, 79: 5138.0, 80: 5138.5, 81: 5139.0, 82: 5139.5, 83: 5140.0, 84: 5140.5, 85: 5141.0, 86: 5141.5, 87: 5142.0, 88: 5142.5, 89: 5143.0, 90: 5143.5, 91: 5144.0, 92: 5144.5, 93: 5145.0, 94: 5145.5, 95: 5146.0, 96: 5147.0, 97: 5147.5, 98: 5148.0, 99: 5148.5, 100: 5149.0, 101: 5149.5, 102: 5150.0, 103: 5150.5, 104: 5151.0, 105: 5151.5, 106: 5152.0, 107: 5152.5, 108: 5153.0, 109: 5153.5, 110: 5154.0, 111: 5154.5, 112: 5155.0, 113: 5155.5, 114: 5156.0, 115: 5156.5, 116: 5157.0, 117: 5157.5, 118: 5158.0, 119: 5158.5, 120: 5159.0, 121: 5159.5, 122: 5160.0, 123: 5160.5, 124: 5161.0, 125: 5161.5, 126: 5162.0, 127: 5162.5, 128: 5163.0, 129: 5163.5, 130: 5164.0, 131: 5164.5, 132: 5165.0, 133: 5165.5, 134: 5166.0, 135: 5166.5, 136: 5167.0, 137: 5167.5, 138: 5168.0, 139: 5168.5, 140: 5169.0, 141: 5169.5, 142: 5170.0, 143: 5170.5, 144: 5171.0, 145: 5171.5, 146: 5172.0, 147: 5172.5, 148: 5173.0, 149: 5173.5, 150: 5174.0, 151: 5174.5, 152: 5175.0, 153: 5175.5, 154: 5176.0, 155: 5176.5, 156: 5177.0, 157: 5177.5, 158: 5178.0, 159: 5178.5, 160: 5179.0, 161: 5179.5, 162: 5180.0, 163: 5180.5, 164: 5181.0, 165: 5181.5, 166: 5182.0, 167: 5182.5, 168: 5183.0, 169: 5183.5, 170: 5184.0, 171: 5185.0, 172: 5185.5, 173: 5186.0, 174: 5186.5, 175: 5187.0, 176: 5188.0, 177: 5188.5, 178: 5189.0, 179: 5189.5, 180: 5190.0, 181: 5190.5, 182: 5191.0, 183: 5191.5, 184: 5192.0, 185: 5192.5, 186: 5193.0, 187: 5193.5, 188: 5194.0, 189: 5194.5, 190: 5195.0, 191: 5195.5, 192: 5196.0, 193: 5196.5, 194: 5197.0, 195: 5197.5, 196: 5198.0, 197: 5198.5, 198: 5199.0, 199: 5199.5, 200: 5200.0, 201: 5200.5, 202: 5201.0, 203: 5201.5, 204: 5202.0, 205: 5202.5, 206: 5203.0, 207: 5203.5, 208: 5204.0, 209: 5204.5, 210: 5205.0, 211: 5205.5, 212: 5206.0, 213: 5206.5, 214: 5207.0, 215: 5207.5, 216: 5208.0, 217: 5208.5, 218: 5209.0, 219: 5209.5, 220: 5210.0, 221: 5210.5, 222: 5211.0, 223: 5211.5, 224: 5212.0, 225: 5212.5, 226: 5213.0, 227: 5213.5, 228: 5214.0, 229: 5214.5, 230: 5215.0, 231: 5215.5, 232: 5216.0, 233: 5216.5, 234: 5217.0, 235: 5217.5, 236: 5218.0, 237: 5218.5, 238: 5219.0, 239: 5219.5, 240: 5220.0, 241: 5220.5, 242: 5221.0, 243: 5221.5, 244: 5222.0, 245: 5222.5, 246: 5223.0, 247: 5224.5, 248: 5225.0, 249: 5225.5, 250: 5226.0, 251: 5226.5, 252: 5227.0, 253: 5227.5, 254: 5228.0, 255: 5228.5, 256: 5229.0, 257: 5229.5, 258: 5230.0, 259: 5230.5, 260: 5231.0, 261: 5231.5, 262: 5232.0, 263: 5232.5, 264: 5233.0, 265: 5233.5, 266: 5234.0, 267: 5234.5, 268: 5235.0, 269: 5235.5, 270: 5236.5, 271: 5237.0, 272: 5237.5, 273: 5238.0, 274: 5238.5, 275: 5239.0, 276: 5239.5, 277: 5240.0, 278: 5240.5, 279: 5241.0, 280: 5241.5, 281: 5242.0, 282: 5242.5, 283: 5243.0, 284: 5243.5, 285: 5244.0, 286: 5244.5, 287: 5245.0, 288: 5246.0, 289: 5246.5, 290: 5247.0, 291: 5247.5, 292: 5248.0, 293: 5248.5, 294: 5249.0}, 'quantity': {0: 20, 1: 40, 2: 10, 3: 90, 4: 20, 5: 25, 6: 85, 7: 305, 8: 75, 9: 10, 10: 150, 11: 150, 12: 215, 13: 155, 14: 80, 15: 55, 16: 255, 17: 180, 18: 205, 19: 250, 20: 140, 21: 210, 22: 130, 23: 235, 24: 400, 25: 180, 26: 275, 27: 675, 28: 240, 29: 250, 30: 145, 31: 255, 32: 350, 33: 205, 34: 180, 35: 265, 36: 100, 37: 390, 38: 150, 39: 145, 40: 425, 41: 450, 42: 305, 43: 250, 44: 155, 45: 685, 46: 585, 47: 665, 48: 500, 49: 425, 50: 320, 51: 340, 52: 320, 53: 795, 54: 550, 55: 850, 56: 895, 57: 685, 58: 320, 59: 420, 60: 280, 61: 535, 62: 375, 63: 425, 64: 25, 65: 705, 66: 640, 67: 515, 68: 260, 69: 650, 70: 305, 71: 315, 72: 160, 73: 525, 74: 160, 75: 355, 76: 65, 77: 230, 78: 45, 79: 180, 80: 95, 81: 350, 82: 20, 83: 295, 84: 15, 85: 125, 86: 60, 87: 225, 88: 40, 89: 110, 90: 100, 91: 40, 92: 40, 93: 110, 94: 110, 95: 110, 96: 50, 97: 10, 98: 155, 99: 15, 100: 135, 101: 20, 102: 105, 103: 215, 104: 290, 105: 260, 106: 195, 107: 105, 108: 45, 109: 45, 110: 40, 111: 95, 112: 185, 113: 70, 114: 265, 115: 105, 116: 300, 117: 100, 118: 375, 119: 100, 120: 265, 121: 265, 122: 520, 123: 285, 124: 530, 125: 270, 126: 805, 127: 430, 128: 400, 129: 340, 130: 485, 131: 160, 132: 720, 133: 370, 134: 465, 135: 1250, 136: 890, 137: 310, 138: 810, 139: 455, 140: 815, 141: 525, 142: 600, 143: 300, 144: 375, 145: 265, 146: 690, 147: 115, 148: 60, 149: 125, 150: 455, 151: 290, 152: 20, 153: 115, 154: 25, 155: 20, 156: 80, 157: 60, 158: 110, 159: 60, 160: 65, 161: 100, 162: 100, 163: 20, 164: 15, 165: 30, 166: 150, 167: 15, 168: 50, 169: 85, 170: 265, 171: 180, 172: 15, 173: 15, 174: 20, 175: 95, 176: 70, 177: 55, 178: 360, 179: 295, 180: 665, 181: 330, 182: 390, 183: 225, 184: 680, 185: 215, 186: 135, 187: 120, 188: 215, 189: 75, 190: 420, 191: 210, 192: 250, 193: 110, 194: 155, 195: 125, 196: 145, 197: 25, 198: 375, 199: 10, 200: 30, 201: 10, 202: 120, 203: 75, 204: 60, 205: 55, 206: 55, 207: 140, 208: 265, 209: 175, 210: 190, 211: 80, 212: 145, 213: 225, 214: 45, 215: 85, 216: 185, 217: 70, 218: 215, 219: 130, 220: 345, 221: 125, 222: 55, 223: 165, 224: 200, 225: 80, 226: 125, 227: 235, 228: 385, 229: 280, 230: 605, 231: 695, 232: 860, 233: 175, 234: 450, 235: 200, 236: 625, 237: 160, 238: 260, 239: 60, 240: 175, 241: 130, 242: 45, 243: 480, 244: 220, 245: 90, 246: 315, 247: 20, 248: 585, 249: 105, 250: 40, 251: 85, 252: 120, 253: 205, 254: 105, 255: 225, 256: 745, 257: 255, 258: 775, 259: 105, 260: 615, 261: 155, 262: 370, 263: 315, 264: 100, 265: 35, 266: 190, 267: 70, 268: 585, 269: 85, 270: 75, 271: 80, 272: 295, 273: 35, 274: 165, 275: 175, 276: 190, 277: 575, 278: 200, 279: 140, 280: 65, 281: 80, 282: 75, 283: 55, 284: 265, 285: 155, 286: 10, 287: 150, 288: 60, 289: 115, 290: 150, 291: 30, 292: 150, 293: 20, 294: 55}}
data = pd.DataFrame(data_dct)

# Make data 1-dimensional
data_1d = np.array(np.repeat(data.price, data.quantity)).reshape(-1, 1)

# Cluster
clusterer = hdbscan.HDBSCAN(min_cluster_size=5000, min_samples=200)
clusterer.fit(data_1d)

# Merge the cluster labels to data and re-groupby `price`
clustered_data_1d = pd.DataFrame(np.concatenate((data_1d, clusterer.labels_.reshape(-1, 1)), axis=1), columns=['price', 'cluster'])
clustered_data_1d['quantity'] = 1
grouped_data_1d = clustered_data_1d.groupby('price').agg({'cluster': lambda x: x.value_counts().index[0], 'quantity': np.sum}).reset_index()

# Plot
fig = px.bar(grouped_data_1d, x='price', y='quantity', color='cluster', orientation='v')
fig.update_traces(dict(marker_line_width=0))
fig.show()

推荐阅读