首页 > 解决方案 > Dijkstra 的二叉堆算法实现

问题描述

我一直在尝试用二进制堆为我的算法类实现 Dijkstra 算法。目标是找到通过 Dijkstra 找到的最小生成树的长度。我以前使用过,但使用的是优先队列。我发现我被堆管理绊倒了,如果有人能发现错误。主要方法如下:

import java.io.File;
import java.util.regex.*;
import java.util.ArrayList;
public class main {

    public static void main(String[] args) {
        // TODO Auto-generated method stub
        try {
            File input = new File("filepath\1000.txt");
            Scanner scanner = new Scanner(input);

            //Regex patterns used to find input size, nodes, and edges
            String size = "n=([0-9]*)";
            String node = "^[0-9]+";
            String edge = "^[\\s]+([0-9]+)[\\s]+([0-9]+)";

            Pattern sizePattern = Pattern.compile(size);
            Pattern edgePattern = Pattern.compile(edge);
            Pattern nodePattern = Pattern.compile(node);


            //Loop to fill nodeList
            ArrayList <Integer> finalDistances = new ArrayList<>();
            MinHeap list = null;
            Node currentNode = null;
            Edge currentEdge = null;
            while(scanner.hasNextLine()) {

                String line = scanner.nextLine();
                Matcher match1 = sizePattern.matcher(line);
                Matcher match2 = nodePattern.matcher(line);
                Matcher match3 = edgePattern.matcher(line);
                //catches size
                if(match1.find()) {
                    int numberOfNodes = Integer.parseInt(match1.group(1));
                    //Holds all the nodes to process with Dijkstra's Algo
                    list = new MinHeap(numberOfNodes);
                }
                //catches node
                if (match2.find()) {
                    if (!list.contains(Integer.parseInt(line))) {
                        currentNode = new Node(Integer.parseInt(line));
                        currentNode.seen = false;
                        currentNode.distance = Integer.MAX_VALUE;
                        list.insert(currentNode);
                    }
                    else {
                        currentNode = list.getNode(Integer.parseInt(line));
                    }
                }
                //catches edge
                if(match3.find()) {
                    if (!list.contains(Integer.parseInt(match3.group(1)))) {
                        Node temp = new Node(Integer.parseInt(match3.group(1)));
                        temp.seen = false;
                        temp.distance = Integer.MAX_VALUE;
                        list.insert(temp);
                        currentEdge = new Edge(Integer.parseInt(match3.group(1)), Integer.parseInt(match3.group(2)));
                        currentNode.add(currentEdge);
                    } else {
                        currentEdge = new Edge(Integer.parseInt(match3.group(1)), Integer.parseInt(match3.group(2)));
                        currentNode.add(currentEdge);
                    }
                }
            }
            Node source = list.getNode(0);
            source.distance=0;
            list.updateNode(source);
            int treeLength = 0;
            while(!list.isEmpty()) {
                currentNode = list.extractMin();
                currentNode.seen = true;
                ArrayList<Edge> edgeList = currentNode.edges;
                for (int i = 0; i < edgeList.size(); i++) {
                    currentEdge = edgeList.get(i);
                    if (list.contains(currentEdge.end)) {
                        int calcDist = currentNode.distance + currentEdge.weight;
                        if (calcDist < list.getNode(currentEdge.end).distance) {
                            list.decreaseKey(list.getNode(currentEdge.end), calcDist);
                        }
                    }
                }
                System.out.println(currentNode.toString() + "with distance" +currentNode.distance);
                finalDistances.add(currentNode.distance);
            }
            for (int j = 0; j < finalDistances.size(); j++) treeLength+= finalDistances.get(j);
            System.out.println("Tree length is: "+treeLength);
        }

        //fail safe
        catch (Exception ex){
            System.out.println("Shit broke!");
        }
    }

}

输入只是一个txt文件,格式如下:

0
    25    244
   108    275
   140    273
   159    313
   219    199
   254    392
   369    171
   518    271
   538    250
   568    253
   603    307
   613    196
   638    314

1
    24    187
    43    182
    65    331
   155    369
   182    222
   186    426
   224    395
   233     72
   240    128
   250    101
   269    251
   371     73
   409    301
   444     40
   451    262
   464    337
   517    393
   569    171
   586    384
   599    221
   601    145
   611    209
   616    330
   629    324
   644    254
   646    316
   675    237
   684    327
   695    439
   696    288

第一行是节点和边的数量。一行上的数字本身就是一个节点,后面的数字是带权重的边。我的 Node 类只有一个 int 值、边缘数组列表和与源的距离。我的 Edge 类有一个 int 值和 int 权重。

import java.util.ArrayList;

public class Node {
    int value;
    ArrayList<Edge> edges;
    boolean seen;
    int distance;
    public Node(int value) {
        edges = new ArrayList<>();
        this.value = value;
        seen = false;
    }
    public void add (Edge edge) {
        edges.add(edge);
    }
    @Override
    //debugging use
    public String toString() {
        String listOfEdges= "";
        //for (int i = 0; i < edges.size(); i++) {
        //  listOfEdges = listOfEdges + " ," + Integer.toString(edges.get(i).end);
        //}
        return "Node: "+Integer.toString(value); //+" Edges: "+listOfEdges;
    }
}
public class Edge {
    int end;
    int weight;
    public Edge(int end, int weight) {
        this.end = end;
        this.weight = weight;
    }
    @Override
    //debugging use
    public String toString() {
        return "Edge: "+end+" Weight: "+weight;
    }
}

似乎在运行算法的某个点之后,一些顶点保持无穷大,而另一些顶点保持负无穷大。我使用了两个数组:第一个是存储节点的堆,第二个是将节点的索引存储在堆中。完整课程如下:

    private int capacity;
    private int currentSize;
    private Node[] heap;
    private int[] locations;
    //private HashMap<Integer, Integer> locations = new HashMap<>();

    public MinHeap(int capacity) {
        this.capacity = capacity;
        heap = new Node[capacity];
        locations = new int[capacity];
        //heap[0] = new Node(-1);
        //heap[0].distance = Integer.MIN_VALUE;
        currentSize = 0;
    }
    private int getParent(int index) {
        return (index-1)/2;
    }
    private int getLeftChild(int index) {
        return index*2+1;
    }
    private int getRightChild(int index) {
        return index*2+2;
    }
    private void swap(int a, int b) {
        Node temp = heap[a];
        heap[a] = heap[b];
        heap[b] = temp;
        //maybe bug


    }
    Node extractMin() {
        Node min = heap[0];
        Node last = heap[currentSize-1];
        swap (0,currentSize-1);
        locations[last.value] = 0;
        locations[min.value] = -1;
        currentSize--;
        heapifyDown(0);
        System.out.println("heap size: "+currentSize);
        return min;
    }
    void heapifyDown(int index) {
        int currentIndex = index;
        int leftIndex = getLeftChild(currentIndex);
        int rightIndex = getRightChild(currentIndex);
        if (leftIndex < currentSize && heap[currentIndex].distance > heap[leftIndex].distance) {
            currentIndex = leftIndex;
        }
        if (rightIndex < currentSize && heap[currentIndex].distance > heap[rightIndex].distance) {
            currentIndex = rightIndex;
        }
        if (currentIndex != index) {
            Node newTop = heap[currentIndex];
            Node oldTop = heap[index];

            locations[newTop.value] = index;
            locations[oldTop.value] = currentIndex;
            swap(index,currentIndex);
            heapifyDown(currentIndex);
        }
    }
    void heapifyUp(int index) {
        int parentIndex = getParent(index);
        int currentIndex = index;
        Node currentNode = heap[currentIndex];
        Node parentNode = heap[parentIndex];
        while (currentIndex > 0 && heap[parentIndex].distance > heap[currentIndex].distance) {
            System.out.println("Swapped: "+heap[getParent(currentIndex)].toString()+" That has a distance of: "+heap[getParent(currentIndex)].distance+ " With: "+heap[currentIndex]+" Which has a distance of: "+heap[currentIndex].distance);
            swap(parentIndex,currentIndex);
            locations[currentNode.value] = parentIndex;
            locations[parentNode.value] = currentIndex;
            currentIndex = parentIndex;
            parentIndex = getParent(parentIndex);

            System.out.println("min: "+heap[0].toString());
        }
    }
    public void decreaseKey(Node node, int distance) {
        int location = locations[node.value];
        heap[location].distance = distance;
        heapifyUp(location);
    }
    public void insert(Node node) {
        //currentSize++;
        int index = currentSize;
        heap[currentSize] = node;
        locations[node.value] = currentSize;
        currentSize++;
        heapifyUp(index);
    }
    public boolean contains(int node) {
        return locations[node] != 0 && locations[node] != -1;
    }
    public boolean isEmpty() {
        return currentSize==0;
    }
    public Node getNode(Node node) {
        return heap[locations[node.value]];
    }
    public Node getNode(int nodeValue) {
        return heap[locations[nodeValue]];
    }
    public void updateNode(Node node) {
        heap[locations[node.value]] = node;
    }
    public void print() {
        for (int i = 0; i < currentSize; i++) {
            System.out.println(heap[i].toString());
        }
    }
    public Node peek() {
        return heap[0];
    }
    public int size() {
        return currentSize;
    }
}

任何反馈表示赞赏。

标签: javaheapdijkstraminimum-spanning-treemin-heap

解决方案


推荐阅读