首页 > 技术文章 > CF1093F Vasya and Array DP

ww3113306 2019-02-14 16:30 原文

题面

题面
\(\Delta\)题面有点问题,应该是数列中没有长度大于等于\(len\)的连续数字才是合法的.

题解

\(f[i][j]\)表示DP到\(i\)位,以\(j\)为结尾的方案数, \(sum[i]\)表示\(\sum_{j = 1}^{k}f[i][j]\), \(g[i][j]\)表示第\(i\)位为结尾,当前段全都是数字\(j\)的最长长度(不考虑\(len\)的限制,能延长就尽量延长,你可以理解为把\(1\)\(i\)\(-1\)全都改成\(j\),然后再看第\(i\)位以\(j\)为结尾的连续数字有多长)。
那么有:

\[f[i][j] = \begin{cases} 0 \quad s_i \ne -1 \ and \ s_i \ne j \\ sum[i - 1] \quad g[i][j] < len \\ sum[i - 1] - (sum[i - len] - f[i - len][j]) \quad others \end{cases}\]

第一种转移比较简单,解释下最后两种。
第二种:
因为\(g[i][j] < len\),所以在第\(i\)位,以\(j\)结尾时,不管前面是什么情况,肯定都合法,所以直接加上上一位总的方案就可以了

第三种:(以下所说的连续数字含义均为连续相同数字)
先明确一点:对于一段长度大于等于\(len\)的连续数字而言,我们只会在它的第一个不合法位置减去它的方案数,例如一个长为\(len + k\)的连续数字,我们只会在它的第\(len\)位减去它的方案数。
显然这样可以保证不重不漏,也就相当于其实我们每次减去的都是以某个固定位置开头的不合法连续数字的方案数,所以肯定不会有重复和遗漏的。
那为什么不每次减去以某个固定位置结尾的不合法连续数字的方案数呢?
因为我们是从前向后DP的,所以对于前面一个固定位置的一些信息,我们已经处理出来了,对于以某个固定位置结尾的不合法连续数字而言,当我们DP到这个结尾位置的时候,这个地方的值正我们需要计算的,总不能自己调用自己吧。

再考虑计算:
首先\(sum[i - 1]\)是总的方案数,但是其中有一部分方案不合法,因为当前段长度大于等于\(len\).
因此我们再考虑如何计算不合法的方案。
根据我们的策略,对于一段长度大于等于\(len\)的连续数字,我们只会在它的第\(len\)位减去它的方案数。
也就是我们只能减去长度为\(len\)的连续数字
因此我们假定从第\(i\)位开始,向前\(len\)个都是\(j\).(注意此时\(g[i][j] >= len\),所以一定有方案可以使得从\(i\)向前\(len\)个都是\(j\))
那么因为已经不合法了,所以从第\(i - len\)位开始,往前走的就都可以任取,所以总方案数为\(sum[i - len]\).
但是我们只能统计长度为\(len\)的连续数字,所以第\(i - len\)位不能是\(j\),否则就会接在以前变成一段长度大于\(len\)的连续数字了。
因此我们还要减去以\(i - len\)为结尾的,结尾数字为\(j\)的方案数。

#include<bits/stdc++.h>
using namespace std;
#define R register int
#define p 998244353
#define AC 101000
#define ac 110

int n, k, len, ans;
int s[AC], f[AC][ac], g[AC][ac], sum[AC];

inline int read()
{
	int x = 0;char c = getchar();bool zz = false;
	while(c > '9' || c < '0') {if(c == '-') zz = true; c = getchar();}
	while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
	return zz ? -x : x;
} 

inline void up(int &a, int b) {a += b; if(a < 0) a += p; if(a >= p) a-= p;}
inline int ad(int a, int b) {a += b; if(a < 0) a += p; if(a >= p) a -= p; return a;}
inline int mul(int a, int b) {return 1LL * a * b % p;}

void pre()
{
	n = read(), k = read(), len = read();
	for(R i = 1; i <= n; i ++) s[i] = read();
}

void work()
{
	sum[0] = 1;
	for(R i = 1; i <= n; i ++)
	{
		for(R j = 1; j <= k; j ++)
		{
			if(s[i] != -1 && s[i] != j) continue;
			g[i][j] = ad(g[i - 1][j], (s[i] == -1 || s[i] == j));			
			if(g[i][j] < len) f[i][j] = sum[i - 1];
			else up(f[i][j], ad(ad(sum[i - 1], -sum[i - len]), f[i - len][j]));
			up(sum[i], f[i][j]);
		}
	}
	printf("%d\n", sum[n]);
}

int main()
{
	freopen("in.in", "r", stdin);
	pre();
	work();
	fclose(stdin);
	return 0;
}

推荐阅读