首页 > 解决方案 > 机器学习python代码优化中的新分类方法

问题描述

我有新的分类方法。我已经用python开发了它。但是代码很慢。

请帮我优化下面的代码。

from joblib import Parallel, delayed
from datetime import datetime
from sklearn.metrics.pairwise import euclidean_distances
import copy
import pandas as pd
import numpy as np    

这个函数用于在空间中寻找超球体

def find_sphere(ind, df, dist_sq, sidx):
    objectCount = df.shape[0]
    sphere = dict()
    sphere['index'] = ind
    sphere['relatives'] = []
    sphere['class'] = df.at[ind, "Class"]
    indexes = sidx[:, ind]
    sphere['distances'] = dist_sq[ind, :]
    for i in range(objectCount):
        if df.at[ind, "Class"] != df.at[indexes[i], "Class"]:
            sphere['relatives'] = indexes[:i]
            sphere['radius'] = sphere['distances'][indexes[i]]
            break
    sphere['enemies'] = np.where(
        sphere['distances'] == sphere['radius'])[0]
    sphere['coverages'] = set()
    for enemy in sphere['enemies']:
        min_dist = dist_sq[enemy, sphere['relatives']].min()
        sphere['coverages'].update(
            sphere['relatives'][np.where(dist_sq[enemy, sphere['relatives']] == min_dist)[0]])
    sphere['relatives'] = set(sphere['relatives'])
    return sphere

此功能是对空间中的对象进行聚类

def find_groups(spheres):
    coverages = {x for d in spheres for x in d['coverages']}
    all_coverages = {x for d in spheres for x in d['coverages']}
    notSeenObjects = copy.deepcopy(spheres)

    groups = list()
    while (len(coverages) > 0):
        # print(f'\t getting one of {len(coverages)} coverage(s)...')
        obj = coverages.pop()
        group = {d['index'] for d in notSeenObjects if obj in d['relatives']}
        if len(group) == 0:
            continue
        linkerCoverage = set()
        while True:
            linkerCoverage = linkerCoverage | {r for s in notSeenObjects if s['index'] in group
                                               for r in s['relatives'] if r in all_coverages}

            notSeenObjects = [
                s for s in notSeenObjects if not s['index'] in group]

            newObjects = {s['index'] for s in notSeenObjects if len(
                s['relatives'].intersection(linkerCoverage)) > 0}

            if len(newObjects) == 0:
                break

            group = group | newObjects
            coverages = coverages - linkerCoverage
        groups.append(group)
    return groups

此功能查找所有标准对象。意味着,可以用这个(标准)对象对新对象进行分类,而不是查看所有对象

def find_standart_objects(spheres, dists, groups):
    standartObjects = [(s['index'], s['radius'], s['class']) for s in spheres]
    for group in sorted(groups, key=len, reverse=True):
        candidates = [s for s in standartObjects if s[0] in group]
        for candidate in sorted(candidates, key=lambda x: x[1]):
            isRightRecognition = True
            for obj in spheres:
                result = sorted([(dists[obj['index'], s[0]] / s[1], obj['class'], s[2])
                                 for s in standartObjects if s[0] != candidate[0]], key=lambda x: x[0])
                mindist = result[0][0]
                for r in result:
                    if r[0] != mindist:
                        break
                    isRightRecognition = isRightRecognition and r[1] == r[2]
                if not isRightRecognition:
                    print(f'Standart found {candidate[0]}')
                    break
            if isRightRecognition:
                standartObjects = [
                    s for s in standartObjects if s[0] != candidate[0]]
    return standartObjects

这部分正在使用该功能

df = pd.read_csv("Dry_Bean.txt", sep='\t') #https://archive.ics.uci.edu/ml/machine-learning-databases/00602/DryBeanDataset.zip
objectCount = df.shape[0]
dist_sq = euclidean_distances(df.iloc[:, :-1]).round(15)
sidx = np.argsort(dist_sq, axis=0)
spheres = Parallel(n_jobs=6)(delayed(find_sphere)
                             (ind, df, dist_sq, sidx) for ind in range(0, objectCount))
groups = find_groups(spheres)
standartObjects = find_standart_objects(spheres, dist_sq, groups)

瓶颈部分是 find_standart_objects 和 find_sphere 函数。

标签: pythonpandasnumpyoptimizationclassification

解决方案


推荐阅读