首页 > 技术文章 > 寻找符合子序列要求的区间个数 - 牛客

ccut-ry 2018-11-10 10:02 原文

链接:https://ac.nowcoder.com/acm/contest/217/B
来源:牛客网

题目描述

msc和mcc是一对好朋友,有一天他们得到了一个长度为n的字符串s.

这个字符串s十分妙,其中只有’m’,’s’和’c’三种字符。

定义s[i,j]表示s中从第i个到第j个字符按顺序拼接起来得到的字符串。

定义一个字符串t的子序列为从t中选出一些位置并且将这些位置上面的字符按顺序拼接起来得到的字符串。

两个子序列重合当且仅当存在一个位置x使得两个子序列同时选择了位置x。

由于msc和mcc是一对很好很好的好朋友,所以她们希望选择两个数字x和y满足1≤x≤y≤n使得s[x,y]中同时存在两个**不重合的子序列**使得其中一个是’msc’且另外一个是’mcc’

现在给出n和字符串s,问她们可以选出多少对不同的(x,y)。

输入描述:

第一行一个正整数n,表示字符串s的长度。

第二行一个长度为n的字符串s,其中s只包含字符’m’,’s’和,’c’。

输出描述:

一行一个正整数,表示答案。
示例1

输入

复制
6
mscmcc

输出

复制
1

备注:

1≤n≤100,000

题意 : 给你一个字符串, 寻找有多少个子区间存在 "msc" 以及 "mcc" , 并且要求找到的子序列中没有共用的字母
思路分析 :
  暴力的想法是直接枚举所有的区间, n ^ 3 的做法,显然是TLE
  要怎么优化一下呢 ?
  考虑以每一个字母为起始的情况,去寻找到哪个位置刚好存在一个这样的两个子序列,这时可以直接统计答案
  而且因为字母比较少,所有符合要求的最短的序列总共有 8 种,
  再预处理一下从每个位置到 m, s, c 的最近的位置是多少,这样每次寻找一个子序列就是 O(6)
  因此总体复杂度是 6*8*n
代码示例 :
#define ll long long
const ll maxn = 1e5+5;
const ll inf = 0x3f3f3f3f;

ll n;
char s[maxn];
ll f[maxn][4]; // m - 1 , s - 2, c - 3
string str[10];

void init() {
    ll p1 = 0, p2 = 0, p3 = 0;
    char ch1='\0', ch2='\0', ch3='\0';
    for(ll i = 1; i <= n; i++){
        while((p1 <= i || ch1 != 'm') && p1 <= n) {
            p1++; ch1 = s[p1];
        }
        while((p2 <= i || ch2 != 's') && p2 <= n) {
            p2++; ch2 = s[p2];
        }
        while((p3 <= i || ch3 != 'c') && p3 <= n) {
            p3++; ch3 = s[p3];
        }
        f[i][1] = p1, f[i][2] = p2, f[i][3] = p3;
    }
    str[1] = "mscmcc";  str[5] = "mmccsc";
    str[2] = "msmccc";  str[6] = "mccmsc";
    str[3] = "mmsccc";  str[7] = "mcmscc";
    str[4] = "mmcscc";  str[8] = "mcmcsc";
}

ll fun(ll pos, ll p){
    ll pt, start = 0;
    if (s[pos] == 'm') start = 1;
    for(ll i = start; i < 6; i++){
        if (str[p][i] == 'm') pt = 1;
        else if (str[p][i] == 's') pt = 2;
        else pt = 3; 
        pos = f[pos][pt]; 
        if (pos > n) return inf;
    }
    return pos;
}

void solve() {
    ll ans = 0;
    
    for(ll i = 1; i <= n; i++){
        ll len = inf;
        for(ll j = 1; j <= 8; j++){
            len = min(len, fun(i, j));
        }
        if (len > n) break;
        ans += n-len+1; 
    }   
    printf("%lld\n", ans);
}

int main() {
    cin >> n;
    scanf("%s", s+1);
    init();
    solve();
    return 0;
}

 



推荐阅读