首页 > 技术文章 > k-means 简单实现

zhangdebin 2016-04-20 19:16 原文

同学很久以前做的,那时候我刚实习,他刚参加工作(他是两年制),那时候开始对数据挖掘感兴趣,他发给我的他自己做的demo。记得他要毕业时,还一起帮着想kmeans创新点,如今他已经从事数据挖掘工作两年了。
他的博客地址:http://www.cnblogs.com/niuxiaoha/p/4645989.html

package neugle.kmeans;

import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;

public class Kmeans {
    private static int k = 3;// 划分簇数目
    private static int dataCount = 150;// 文本数量
    private static int n = 0;// 迭代次数

    public static void main(String[] args) {
        ArrayList<IrisModel> irisList = ReadFile();// 取得文本中数据
        ArrayList<IrisModel> beforeRandomPot = new ArrayList<IrisModel>();// 记录上一次质心位置
        ArrayList<IrisModel> randomPot = RandomPot(irisList);// 获得随机数据
        ArrayList<ArrayList<IrisModel>> kMeansList = null;
        while (!CompareRandomPot(beforeRandomPot, randomPot)) {
            kMeansList = KMeans(irisList, randomPot);// 进行n次聚类
            n++;
        }
        Print(kMeansList);
        System.out.println("迭代了" + n + "次");
    }

    // 读取文件中数据
    private static ArrayList<IrisModel> ReadFile() {
        FileReader read = null;
        BufferedReader br = null;
        ArrayList<IrisModel> irisList = new ArrayList<IrisModel>();
        try {
            read = new FileReader("D:\\iris.data");
            br = new BufferedReader(read);
            String readLine = null;
            while ((readLine = br.readLine()) != null) {
                IrisModel iris = new IrisModel();
                String[] agrs = readLine.split(",");
                iris.Sep_len = Double.parseDouble(agrs[0]);
                iris.Sep_wid = Double.parseDouble(agrs[1]);
                iris.Pet_len = Double.parseDouble(agrs[2]);
                iris.Pet_wid = Double.parseDouble(agrs[3]);
                iris.Iris_type = agrs[4];
                irisList.add(iris);
            }
        } catch (FileNotFoundException e) {
            System.out.println("读取文件异常");
            irisList = null;
        } catch (IOException e) {
            System.out.println("读取文件异常");
            irisList = null;
        } finally {
            try {
                br.close();
            } catch (IOException e) {
                System.out.println("关闭文件异常");
            }
        }
        return irisList;
    }

    // 随机生成初始k个点
    private static ArrayList<IrisModel> RandomPot(ArrayList<IrisModel> irisList) {
        ArrayList<Integer> initCenter = new ArrayList<Integer>();
        ArrayList<IrisModel> randomPot = new ArrayList<IrisModel>();
        for (int i = 0; i < k; i++) {
            int num = (int) (Math.random() * dataCount);
            if (!initCenter.contains(num))
                initCenter.add(num);
            else
                i--;
        }
        Iterator<Integer> i = initCenter.iterator();
        while (i.hasNext()) {
            randomPot.add(irisList.get(i.next()));
        }
        return randomPot;
    }

    // KMeans主程序
    private static ArrayList<ArrayList<IrisModel>> KMeans(
            ArrayList<IrisModel> irisList, ArrayList<IrisModel> randomPot) {
        ArrayList<ArrayList<IrisModel>> groupNum = new ArrayList<ArrayList<IrisModel>>();
        for (int i = 0; i < randomPot.size(); i++) {
            ArrayList<IrisModel> list = new ArrayList<IrisModel>();
            list.add(randomPot.get(i));
            groupNum.add(list);
        }
        for (int i = 0; i < irisList.size(); i++) {
            double temp = Double.MAX_VALUE;
            int flag = -1;
            for (int j = 0; j < randomPot.size(); j++) {
                double distance = DistanceOfTwoPoint(irisList.get(i),
                        randomPot.get(j));
                if (distance < temp) {
                    temp = distance;
                    flag = j;
                }
            }
            groupNum.get(flag).add(irisList.get(i));
        }
        // 重新计算质心
        ArrayList<IrisModel> tempList = CalcCenter(groupNum);
        randomPot.clear();
        for (int i = 0; i < tempList.size(); i++) {
            randomPot.add(tempList.get(i));
        }
        return groupNum;
    }

    // 计算两点欧氏距离
    private static double DistanceOfTwoPoint(IrisModel d1, IrisModel d2) {
        double sum = Math.sqrt(Math.pow((d1.Sep_len - d2.Sep_len), 2)
                + Math.pow((d1.Sep_wid - d2.Sep_wid), 2)
                + Math.pow((d1.Pet_len - d2.Pet_len), 2)
                + Math.pow((d1.Pet_wid - d2.Pet_wid), 2));
        return sum;
    }

    // 重新计算k个簇的质心
    private static ArrayList<IrisModel> CalcCenter(
            ArrayList<ArrayList<IrisModel>> c) {
        ArrayList<IrisModel> cIris = new ArrayList<IrisModel>();
        Iterator<ArrayList<IrisModel>> i = c.iterator();
        while (i.hasNext()) {
            ArrayList<IrisModel> irisList = i.next();
            IrisModel eIris = new IrisModel();
            for (int k = 0; k < irisList.size(); k++) {
                eIris.Sep_len += irisList.get(k).Sep_len;
                eIris.Sep_wid += irisList.get(k).Sep_wid;
                eIris.Pet_len += irisList.get(k).Pet_len;
                eIris.Pet_wid += irisList.get(k).Pet_wid;
            }
            eIris.Sep_len = eIris.Sep_len / irisList.size();
            eIris.Sep_wid = eIris.Sep_wid / irisList.size();
            eIris.Pet_len = eIris.Pet_len / irisList.size();
            eIris.Pet_wid = eIris.Pet_wid / irisList.size();
            cIris.add(eIris);
        }

        return cIris;
    }

    // 比较前后两次的质心,以确定是否结束
    private static Boolean CompareRandomPot(
            ArrayList<IrisModel> beforeRandomPot, ArrayList<IrisModel> randomPot) {
        boolean flag = true;
        for (int i = 0; i < randomPot.size(); i++) {
            if (beforeRandomPot.size() <= 0
                    || !beforeRandomPot.contains(randomPot.get(i))) {
                flag = false;
                break;
            }
        }
        if (flag == false) {
            if (beforeRandomPot.size() > 0) {
                beforeRandomPot.clear();
            }
            for (int i = 0; i < randomPot.size(); i++) {
                beforeRandomPot.add(randomPot.get(i));
            }
        }
        return flag;
    }

    // 打印
    private static void Print(ArrayList<ArrayList<IrisModel>> kmeansList) {
        System.out.println("------------------------------------");
        Iterator<ArrayList<IrisModel>> i = kmeansList.iterator();
        while (i.hasNext()) {
            Iterator<IrisModel> ii = i.next().iterator();
            int n = 0;
            while (ii.hasNext()) {
                n++;
                IrisModel irisModel = ii.next();
                if (n == 1)
                    continue;
                System.out.println(irisModel.Sep_len + " " + irisModel.Sep_wid
                        + " " + irisModel.Pet_len + " " + irisModel.Pet_wid
                        + " " + irisModel.Iris_type);
            }
            System.out.println(n - 1);
            System.out.println("------------------------------------");
        }
    }
}

推荐阅读