首页 > 解决方案 > 我可以在这个线性函数问题中使用什么算法?

问题描述

首先,我想为我糟糕的英语道歉。当我将下面的代码提交给 DomJudge 时,我得到了TimeLimit ERROR. 我想不出解决这个问题的方法,尽管我在整个互联网上进行了搜索,但仍然找不到解决方案。有人可以给我一个提示吗?

问题:

Here are N linear function fi(x) = aix + bi, where 1 ≤ i ≤ N。Define F(x) = maxifi(x). Please compute
the following equation for the input c[i], where 1 ≤ i ≤ m.   
**Σ(i=1 to m) F(c[i])**  
For example, given 4 linear function as follows. f1(x) = –x, f2 = x, f3 = –2x – 3, f4 = 2x – 3. And the
input is c[1] = 4, c[2] = –5, c[3] = –1, c[4] = 0, c[5] = 2. We have F(c[1]) = 5, F(c[2]) = 7, F(c[3])
= 1, F(c[4]) = 0, F(c[5]) = 2. Then,  
**Σ(i=1 to 5)([])
= ([1]) + ([2]) + ([3]) + ([4]) + ([5]) = 5 + 7 + 1 + 0 + 2 = 15** 

输入格式:

The first line contains two positive integers N and m. The next N lines will contain two integers ai
and bi, and the last line will contain m integers c[1], c[2], c[3],…, c[m]. Each element separated by
a space.  

输出格式:
请输出上述函数的值。

问题图片:https ://i.stack.imgur.com/6HeaA.png

样本输入:

4 5  
 -1 0  
 1 0  
 -2 -3  
 2 -3  
 4 -5 -1 0 2  

样本输出:

 15

我的程序

 #include <iostream>
    #include <vector>

struct L
{
    int a;
    int b;
};

int main()
{
    int n{ 0 };
    int m{ 0 };
    while (std::cin >> n >> m)
    {
        //input
        std::vector<L> arrL(n);
        for (int iii{ 0 }; iii < n; ++iii)
        {
            std::cin >> arrL[iii].a >> arrL[iii].b;
        }
        //find max in every linear polymore
        int answer{ 0 };
        for (int iii{ 0 }; iii < m; ++iii)
        {
            int input{ 0 };
            int max{ 0 };
            std::cin >> input;
            max = arrL[0].a * input + arrL[0].b;
            for (int jjj{ 1 }; jjj < n; ++jjj)
            {
                int tmp{arrL[jjj].a * input + arrL[jjj].b };
                if (tmp > max)max = tmp;
            }
            answer += max;
        }
        //output
        std::cout << answer << '\n';
    }
    return 0;
}

标签: c++

解决方案


您的解决方案是 O(n*m)。

通过迭代确定“主导段”和相应的交叉点(以下称为“锚点”)来获得更快的解决方案。

每个锚点都链接到其左侧和右侧的两个段。

第一步包括根据a值对行进行排序,然后迭代地添加每个新行。
在添加 linei时,我们知道这条线对于大输入值是显性的,并且必须添加(即使它将在以下步骤中删除)。我们计算这条线与之前添加的线的交点:

  • 如果交点值高于最右边的锚点,那么我们添加一个新的锚点对应于这条新线
  • 如果交点值低于最右边的锚点,那么我们知道我们必须抑制最后一个锚点值。在这种情况下,我们迭代该过程,现在计算与前一个锚点的右段的交集。

复杂性由排序决定:O(nlogn + mlogm)。锚点确定过程是 O(n)。

当我们有了锚点时,确定每个输入x值的右段是 O(n+m)。如果需要,可以通过二分搜索(未实现)进一步减少最后一个值。

与代码的第一个版本相比,已更正了一些错误。这些错误与一些极端情况有关,在最左边有一些相同的线(即 的最低值a)。此外,还生成了随机序列(超过 10^7),用于将结果与 OP 代码获得的结果进行比较。没有发现差异。如果某些错误仍然存​​在,它们很可能对应于其他未知的极端情况。该算法本身看起来非常有效。

#include <iostream>
#include <vector>
#include <algorithm>
#include <cassert>

//  lines of equation `y = ax + b`
struct line {
    int a;
    int b;
    friend std::ostream& operator << (std::ostream& os, const line& coef) {
        os << "(" << coef.a << ", " << coef.b << ")";
        return os;
    }
};

struct anchor {
    double x;
    int segment_left;
    int segment_right;
    friend std::ostream& operator << (std::ostream& os, const anchor& anc) {
        os << "(" << anc.x << ", " << anc.segment_left << ", " << anc.segment_right << ")";
        return os;
    }
};

//  intersection of two lines
double intersect (line& seg1, line& seg2) {
    double x;
    x = double (seg1.b - seg2.b) / (seg2.a - seg1.a);
    return x;
}

long long int max_funct (std::vector<line>& lines, std::vector<int> absc) {
    long long int sum = 0;
    auto comp = [&] (line& x, line& y) {
        if (x.a == y.a) return x.b < y.b;
        return x.a < y.a;
    };
    std::sort (lines.begin(), lines.end(), comp);
    std::sort (absc.begin(), absc.end());

    // anchors and dominating segments determination
    
    int n = lines.size();
    std::vector<anchor> anchors (n+1);
    int n_anchor = 1;
    
    int l0 = 0;
    while ((l0 < n-1) && (lines[l0].a == lines[l0+1].a)) l0++;
    int l1 = l0 + 1;

    if (l0 == n-1) {
        anchors[0] = {0.0, l0, l0};
    } else {
        while ((l1 < n-1) && (lines[l1].a == lines[l1+1].a)) l1++;
        double x = intersect(lines[l0], lines[l1]);
        anchors[0] = {x, l0, l1};
        
        for (int i = l1 + 1; i < n; ++i) {
            if ((i != (n-1)) && lines[i].a == lines[i+1].a) continue;
            double x = intersect(lines[anchors[n_anchor-1].segment_right], lines[i]);
            if (x > anchors[n_anchor-1].x) {
                anchors[n_anchor].x = x;
                anchors[n_anchor].segment_left = anchors[n_anchor - 1].segment_right;
                anchors[n_anchor].segment_right = i;
                n_anchor++;
            } else {
                n_anchor--;
                if (n_anchor == 0) {
                    x = intersect(lines[anchors[0].segment_left], lines[i]);
                    anchors[0] = {x, anchors[0].segment_left, i};
                    n_anchor = 1;
                } else {
                    i--;
                }
            }
        }
    }
    
    // sum calculation
    
    int j = 0;      // segment index (always increasing)
    for (int x: absc) {
        while (j < n_anchor && anchors[j].x < x) j++;
        line seg;
        if (j == 0) {
            seg = lines[anchors[0].segment_left];
        } else {
            if (j == n_anchor) {
                if (anchors[n_anchor-1].x < x) {
                    seg = lines[anchors[n_anchor-1].segment_right];
                } else {
                    seg = lines[anchors[n_anchor-1].segment_left];
                }
            } else {                
                seg = lines[anchors[j-1].segment_right];
            }
        }
        sum += seg.a * x + seg.b;
    }
    
    return sum;
}

int main() {
    std::vector<line> lines = {{-1, 0}, {1, 0}, {-2, -3}, {2, -3}};
    std::vector<int> x = {4, -5, -1, 0, 2};
    long long int sum = max_funct (lines, x);
    std::cout << "sum = " << sum << "\n";
    
    lines = {{1,0}, {2, -12}, {3, 1}};
    x = {-3, -1, 1, 5};
    sum = max_funct (lines, x);
    std::cout << "sum = " << sum << "\n";
}   

一个可能的问题是在计算double x对应的线交叉点时会丢失信息,因此会丢失锚点。这是Rational用于避免此类损失的版本。

#include <iostream>
#include <vector>
#include <algorithm>
#include <cassert>
    
struct Rational {
    int p, q;
    Rational () {p = 0; q = 1;}
    Rational (int x, int y) {
        p = x;
        q = y;
        if (q < 0) {
            q -= q;
            p -= p;
        }
    }
    Rational (int x) {
        p = x;
        q = 1;
    }
    friend std::ostream& operator << (std::ostream& os, const Rational& x) {
        os << x.p << "/" << x.q;
        return os;
    }
    friend bool operator< (const Rational& x1, const Rational& x2) {return x1.p*x2.q < x1.q*x2.p;}
    friend bool operator> (const Rational& x1, const Rational& x2) {return x2 < x1;}
    friend bool operator<= (const Rational& x1, const Rational& x2) {return !(x1 > x2);}
    friend bool operator>= (const Rational& x1, const Rational& x2) {return !(x1 < x2);}
    friend bool operator== (const Rational& x1, const Rational& x2) {return x1.p*x2.q == x1.q*x2.p;}
    friend bool operator!= (const Rational& x1, const Rational& x2) {return !(x1 == x2);}
};

//  lines of equation `y = ax + b`
struct line {
    int a;
    int b;
    friend std::ostream& operator << (std::ostream& os, const line& coef) {
        os << "(" << coef.a << ", " << coef.b << ")";
        return os;
    }
};

struct anchor {
    Rational x;
    int segment_left;
    int segment_right;
    friend std::ostream& operator << (std::ostream& os, const anchor& anc) {
        os << "(" << anc.x << ", " << anc.segment_left << ", " << anc.segment_right << ")";
        return os;
    }
};

//  intersection of two lines
Rational intersect (line& seg1, line& seg2) {
    assert (seg2.a != seg1.a);
    Rational x = {seg1.b - seg2.b, seg2.a - seg1.a};
    return x;
}

long long int max_funct (std::vector<line>& lines, std::vector<int> absc) {
    long long int sum = 0;
    auto comp = [&] (line& x, line& y) {
        if (x.a == y.a) return x.b < y.b;
        return x.a < y.a;
    };
    std::sort (lines.begin(), lines.end(), comp);
    std::sort (absc.begin(), absc.end());
    
    // anchors and dominating segments determination
    
    int n = lines.size();
    std::vector<anchor> anchors (n+1);
    int n_anchor = 1;
    
    int l0 = 0;
    while ((l0 < n-1) && (lines[l0].a == lines[l0+1].a)) l0++;
    int l1 = l0 + 1;

    if (l0 == n-1) {
        anchors[0] = {0.0, l0, l0};
    } else {
        while ((l1 < n-1) && (lines[l1].a == lines[l1+1].a)) l1++;
        Rational x = intersect(lines[l0], lines[l1]);
        anchors[0] = {x, l0, l1};
        
        for (int i = l1 + 1; i < n; ++i) {
            if ((i != (n-1)) && lines[i].a == lines[i+1].a) continue;
            Rational x = intersect(lines[anchors[n_anchor-1].segment_right], lines[i]);
            if (x > anchors[n_anchor-1].x) {
                anchors[n_anchor].x = x;
                anchors[n_anchor].segment_left = anchors[n_anchor - 1].segment_right;
                anchors[n_anchor].segment_right = i;
                n_anchor++;
            } else {
                n_anchor--;
                if (n_anchor == 0) {
                    x = intersect(lines[anchors[0].segment_left], lines[i]);
                    anchors[0] = {x, anchors[0].segment_left, i};
                    n_anchor = 1;
                } else {
                    i--;
                }
            }
        }
    }
    
    
    // sum calculation
    
    int j = 0;      // segment index (always increasing)
    for (int x: absc) {
        while (j < n_anchor && anchors[j].x < x) j++;
        line seg;
        if (j == 0) {
            seg = lines[anchors[0].segment_left];
        } else {
            if (j == n_anchor) {
                if (anchors[n_anchor-1].x < x) {
                    seg = lines[anchors[n_anchor-1].segment_right];
                } else {
                    seg = lines[anchors[n_anchor-1].segment_left];
                }
            } else {                
                seg = lines[anchors[j-1].segment_right];
            }
        }
        sum += seg.a * x + seg.b;
    }
    
    return sum;
}

long long int max_funct_op (const std::vector<line> &arrL, const std::vector<int> &x) {
    long long int answer = 0;
    int n = arrL.size();
    int m = x.size();
    for (int i = 0; i < m; ++i) {
        int input = x[i];
        int vmax = arrL[0].a * input + arrL[0].b;
        for (int jjj = 1; jjj < n; ++jjj) {
            int tmp = arrL[jjj].a * input + arrL[jjj].b;
            if (tmp > vmax) vmax = tmp;
        }
        answer += vmax;
    }   
    return answer;
}

int main() {
    long long int sum, sum_op;
    std::vector<line> lines = {{-1, 0}, {1, 0}, {-2, -3}, {2, -3}};
    std::vector<int> x = {4, -5, -1, 0, 2};
    sum_op = max_funct_op (lines, x);
    sum = max_funct (lines, x);
    std::cout << "sum = " << sum << "  sum_op = " << sum_op << "\n";
}

推荐阅读