python - k-mean python的图像分离
问题描述
我是机器学习的新手,我正在学习图像分离的 k-mean,但我无法理解它的代码:
from matplotlib.image import imread
image = imread(os.path.join("images","unsupervised_learning","ladybug.png"))
image.shape
X = image.reshape(-1, 3)
kmeans = KMeans(n_clusters=8, random_state=42).fit(X)
segmented_img = kmeans.cluster_centers_[kmeans.labels_]
segmented_img = segmented_img.reshape(image.shape)
segmented_imgs = []
n_colors = (10, 8, 6, 4, 2)
for n_clusters in n_colors:
kmeans = KMeans(n_clusters=n_clusters, random_state=42).fit(X)
segmented_img = kmeans.cluster_centers_[kmeans.labels_]
segmented_imgs.append(segmented_img.reshape(image.shape))
plt.figure(figsize=(10,5))
plt.subplots_adjust(wspace=0.05, hspace=0.1)
plt.subplot(231)
plt.imshow(image)
plt.title("Original image")
plt.axis('off')
for idx, n_clusters in enumerate(n_colors):
plt.subplot(232 + idx)
plt.imshow(segmented_imgs[idx])
plt.title("{} colors".format(n_clusters))
plt.axis('off')
plt.show()
特别是,这段代码是什么意思
segmented_img = kmeans.cluster_centers_[kmeans.labels_]
解决方案
我建议您从阅读无监督学习开始,特别是使用 K-means 进行聚类及其一般的应用程序/概念,而不仅仅是图像。
我将注释此代码的每一行以解释正在发生的事情。
from matplotlib.image import imread #import module
image = imread(os.path.join("images","unsupervised_learning","ladybug.png")) #Read Image
image.shape #Get shape of image, which is height, width and channel(colours)
X = image.reshape(-1, 3) #Reshaping to get color channel first
kmeans = KMeans(n_clusters=8, random_state=42).fit(X) #Applying and fitting K-means clustering
segmented_img = kmeans.cluster_centers_[kmeans.labels_] #centres of the 8 clusters made
segmented_img = segmented_img.reshape(image.shape) #reshape them using the changed image shape
segmented_imgs = []
n_colors = (10, 8, 6, 4, 2)
for n_clusters in n_colors:
kmeans = KMeans(n_clusters=n_clusters, random_state=42).fit(X) #Applyting kmeans for each colour, using 10,8,6.... as number of clusters
segmented_img = kmeans.cluster_centers_[kmeans.labels_] #Repeating as mentioned above
segmented_imgs.append(segmented_img.reshape(image.shape))
plt.figure(figsize=(10,5)) #Plotting code
plt.subplots_adjust(wspace=0.05, hspace=0.1)
plt.subplot(231)
plt.imshow(image)
plt.title("Original image")
plt.axis('off')
for idx, n_clusters in enumerate(n_colors):
plt.subplot(232 + idx)
plt.imshow(segmented_imgs[idx])
plt.title("{} colors".format(n_clusters))
plt.axis('off')
plt.show()
在这行代码中,
segmented_img = kmeans.cluster_centers_[kmeans.labels_]
顾名思义,cluster_centers_ 是一个数组属性,它返回集群中心的坐标,而labels_ 是一个返回每个点的标签的属性。因此,segmented_img 包含每个点标签的聚类中心坐标。点。
推荐阅读
- angular - 如何将任何类型的数组传递给查找
- javascript - 如何在 Meteor.js 的 React 组件中导入 CSS
- c# - OOP - 使用改变实现的方法和半固态正确设计对象
- rest - 我可以使用 AAD 保护 REST API(在 springboot 中说)并以该 AAD 的用户身份访问它吗
- sql - 声明表时默认使用 select 语句
- vue.js - 无法在输入事件上访问 keyCode/key 事件
- vb.net - vb.net中的Concat字节数组
- python - AssertEqual 在比较两个相同的 int 对象时失败
- apache - 如何让 Apache 返回 404 而不是 502
- python-3.x - 如何避免 Jupyter Notebook 中的内存错误