首页 > 解决方案 > 我不知道我用 Kruskal 算法实现 MST 有什么问题

问题描述

我正在尝试解决算法问题(使用 Kruskal 算法的 MST)。它有时有效,有时无效(出现运行时错误)。我用另一种方法解决了这个问题,但我想弄清楚这段代码有什么问题,以应对未来的问题。估计是内存问题,不过两天查不出来。

typedef struct edges
{
    struct edges *nextNode;
    int weight;
    int src;
    int dest;
}EDGES;

typedef struct vertex
{
    EDGES *edgePtr;
    int verNum;
    int weight;
    COLOR color;
}VERTEX;

typedef struct setTree
{
    struct setTree *parent;
    int verNum;
    int rank;
}SETNODE;

SETNODE *findSet(SETNODE *tree)
{
    if (tree != tree->parent)
        tree->parent = findSet(tree->parent);
    return tree->parent;
}

void makeSet(SETNODE *tree, int vernum)
{
    tree->parent = tree;
    tree->verNum = vernum;
    tree->rank = 0;
}

int linkSet(SETNODE *x, SETNODE *y)
{
    SETNODE *x_root = findSet(x);
    SETNODE *y_root = findSet(y);

    if (x_root->rank > y_root->rank)
        y_root->parent = x_root;
    else if (x_root->rank < y_root->rank)
    {
        x_root->parent = y_root;
    }
    else
    {
        y_root->parent = x_root;
        ++x_root->rank;
    }

    return 1;
}

int unionSet(SETNODE *x, SETNODE *y)
{
    return linkSet(x, y);
}

void MST_Kruskal(VERTEX **graph, int N, EDGES **edgeList, int edgeNum)
{
    SETNODE *set[MAXN];
    EDGES *result[MAXN*(MAXN - 1)];
    EDGES *tmp[MAXN];
    int edgeCnt = 0;
    int tmpVerNum1, tmpVerNum2;


    for (int i = 0; i < N; ++i)
    {
        set[i] = new SETNODE;
    }

    for (int i = 0; i < N; ++i)
    {
        makeSet(set[i], (i + 1));
    }

    mergeSort(edgeList, tmp, 0, edgeNum - 1);

    /*for (int i = 0; i < edgeNum; ++i)
    printf("%d->%d : %d\n", edgeList[i]->src, edgeList[i]->dest, edgeList[i]->weight);
    */
    for (int i = 0; i < edgeNum; ++i)
    {
        tmpVerNum1 = edgeList[i]->src; tmpVerNum2 = edgeList[i]->dest;

        if (findSet(set[tmpVerNum1 - 1]) != findSet(set[tmpVerNum2 - 1]))
        {
            int success = unionSet(set[tmpVerNum1 - 1], set[tmpVerNum2 - 1]);
            if (success)
                result[edgeCnt++] = edgeList[i];
        }
    }

    printf("%d\n", edgeCnt);
    for (int i = 0; i < edgeCnt; ++i)
    {
        printf("%d %d %d\n", result[i]->src, result[i]->dest, result[i]->weight);
    }

    for (int i = 0; i < N; ++i)
    {
        delete set[i];
    }
}

对于mergeSort,这个问题有某种限制,如果边之间的权重相同,首先应该有更少的顶点数。

void merge(EDGES **arr, EDGES **tmp, int left, int middle, int right)
{
    int i = left, j = middle + 1, k = left, l;

    while (i <= middle && j <= right)
    {
        if (arr[i]->weight < arr[j]->weight)
        {
            tmp[k++] = arr[i++];
        }
        else if (arr[i]->weight == arr[j]->weight)
        {
            if (arr[i]->src < arr[j]->src)
                tmp[k++] = arr[i++];
            else if (arr[i]->src == arr[j]->src)
            {
                if (arr[i]->dest <= arr[j]->dest)
                    tmp[k++] = arr[i++];
                else
                    tmp[k++] = arr[j++];
            }
            else
                tmp[k++] = arr[j++];
        }
        else
        {
            tmp[k++] = arr[j++];
        }
    }
    if (i > middle)
    {
        for (l = j; l <= right; l++)
        {
            tmp[k++] = arr[l];
        }
    }
    else
    {
        for (l = i; l <= middle; l++)
        {
            tmp[k++] = arr[l];
        }
    }


    for (l = left; l <= right; l++)
    {
        arr[l] = tmp[l];
    }

}


void mergeSort(EDGES **arr, EDGES **tmp, int left, int right)
{
    int middle = (left + right) / 2;

    if (left < right)
    {
        mergeSort(arr, tmp, left, middle);
        mergeSort(arr, tmp, middle + 1, right);
        merge(arr, tmp, left, middle, right);
    }
}

标签: c++algorithm

解决方案


通常,您使用的指针越少意味着您遇到内存问题的可能性越小。确保您不会意外访问 a 的成员nullptr,或者忘记将地址传递给函数。如果您能向我们提供有关该错误的更多信息,将会很有帮助:)

这是我不久前写的 Kruskal 的实现。它在模拟链表以EdgeList使用预分配数组时大量使用 STL 并且没有指针:

#include <cstdio>
#include <vector>
#include <queue>
#include <functional>
#include <algorithm>

const int MAXSZ = 1010;
struct Edge
{
    static int cnt;
    int from, to, weight, next;
} edges[MAXSZ * 2], full[MAXSZ * 2];
int Edge::cnt = 1;
int head[MAXSZ]; // for EdgeList impl
int bigb[MAXSZ]; // for DisjointSet impl

int find(const int c) // DisjointSet `find`
{
    if (bigb[c] == c)
        return c;
    bigb[c] = find(bigb[c]);
    return bigb[c];
}

void addEdge(const int a, const int b, const int w) // add edges to EdgeList from pure data
{
    edges[Edge::cnt].from = a;
    edges[Edge::cnt].to = b;
    edges[Edge::cnt].weight = w;
    edges[Edge::cnt].next = head[a];
    head[a] = Edge::cnt;
    ++Edge::cnt;

    bigb[find(a)] = bigb[find(b)] = std::min(find(a), find(b)); // needed to init DisjointSet
}

void addEdge(const Edge &base) // copy insert to EdgeList
{
    edges[Edge::cnt] = base;
    edges[Edge::cnt].next = head[base.from];
    head[base.from] = Edge::cnt;
    ++Edge::cnt;

    bigb[find(base.from)] = bigb[find(base.to)] = std::min(find(base.from), find(base.to));
}

int main()
{
    int m, n;
    scanf("%d%d", &m, &n);
    auto cmp = [](const int &l, const int &r) { return full[l].weight > full[r].weight; };
    std::priority_queue<int, std::vector<int>, std::function<bool(int, int)> > pq(cmp); // Used for kruskal
    /* input */
    for (int i = 1; i <= n; ++i)
    {
        int a, b, w;
        scanf("%d%d%d", &a, &b, &w);
        full[i].from = a;
        full[i].to = b;
        full[i].weight = w;
        pq.push(i);

        bigb[full[i].from] = full[i].from;
        bigb[full[i].to] = full[i].to;
    }

    int sum = 0;
    for (int vis = 0; !pq.empty(); pq.pop())
    {
        auto cur = full[pq.top()];
        if (find(cur.to) == vis && find(cur.from) == vis)
            continue; // if it leads back to something we already have
        addEdge(cur);
        addEdge(cur.to, cur.from, cur.weight); // other direction
        vis = std::min(find(cur.to), find(cur.from));
        printf("%d -> %d, %d\n", cur.from, cur.to, cur.weight);
        sum += cur.weight;
        //debug*/ for (int i=1; i<=m; ++i) printf("%3d", i); printf("\n"); for (int i=1; i<=m; ++i) printf("%3d", bigb[i]); printf("\n\n");
    }

    printf("total: %d\n", sum);

    return 0;
}

/* test data
first line: verticies, edges
next #edges lines: from, to, weight
3 3
1 2 3
3 2 1
3 1 2

5 6
1 2 2
1 4 1
2 3 2
3 4 -2
3 5 1
4 5 9
*/


推荐阅读