首页 > 技术文章 > codeforces 1469E - A Bit Similar

blogofddw 2021-01-09 16:33 原文

这是19岁写的第一份代码 XD


题意:有一个长度为n的01字符串T,需要找到一个长度为k且字典序最小的01字符串s,使得s与T的任意一个长度为k的子串a bit similar,即至少有一位相同。


官方题解有非常简洁且简单的\(O(n\text{log}n)\)做法,然而我竟没有想到QAQ,最后用了非常暴力的后缀数据结构。。
首先把所有长度为k的子串建成一棵trie。显而易见,这棵树非常类似后缀树。由于不会直接建后缀树,我只能非常可耻地先求出后缀数组,再按照建虚树的方法建出后缀树...
然后就可以在树上dp了。设\(f(v)\)表示从结点v出发,是否存在一个长度为v的高度,且与以v为根的所有字符串a bit similar的字符串(以v为根的字符串的长度都是v的高度)。
求出f以后,就可以利用f构造出字典序最小的合法字符串了。构造方法虽然不难,但是条件分支太多了,写出来丑陋无比。。。


这个做法除了\(O(n \text{lglg}n)\)的较优的理论复杂度以外一无是处qaq
不知道正常的后缀树建树方法会不会更快...
附上丑陋的代码:

#include <cstdlib>
#include <iostream>
#include <cstdio>
#include <math.h>
#include <cstring>
#include <time.h>
#include <complex>
#include <algorithm>
#include <queue>
#include <unordered_map>
#include <set>
#include <bitset>

#pragma warning(disable:4996)
#define PII std::pair<long long, long long>
#define PTT std::pair<tree *, tree *>

template<typename T> T min(T x, T y)
{
	return x < y ? x : y;
}
template<typename T> T max(T x, T y)
{
	return x > y ? x : y;
};

const long long INF = 2000000005;//00000;// autojs.org
const long long mod = 1000000007;// 998244353;//
const long long MAXN = 3000005;
const long long A = 17;
long long pa[MAXN];

int N, K, KK;
char str[MAXN], *s;
long long h[MAXN], cnt[MAXN];
long long hash(int l, int r)
{
	return ((h[r] - h[l - 1] * pa[r - l + 1]) % mod + mod) % mod;
}
long long count(int l, int r)
{
	return cnt[r] - cnt[l - 1];
}

struct data { int id, x, y; };
void sort(int *a)
{
	static data tmp[MAXN], t[MAXN];
	static int cnt[MAXN];
	for (int i = 1; i <= N; i++)
		tmp[i] = { i,(s[i] == '1') + 1,0 };
	for (int i = 1; i < K; i <<= 1)
	{									//tmp按id排列,x为当前排名
		for (int j = N; j >= 1; j--)
			tmp[j].y = (i * 2 <= K ? tmp[max(j - i, 0)].x : tmp[max(j - K + i, 0)].x),
			std::swap(tmp[j].x, tmp[j].y), cnt[j] = 0;
		for (int j = 1; j <= N; j++) cnt[tmp[j].y]++;
		for (int j = 1; j <= N; j++) cnt[j] += cnt[j - 1];
		for (int j = N; j >= 1; j--) t[cnt[tmp[j].y]--] = tmp[j];
		for (int j = 1; j <= N; j++) cnt[j] = 0;
		for (int j = 1; j <= N; j++) cnt[t[j].x]++;
		for (int j = 1; j <= N; j++) cnt[j] += cnt[j - 1];
		for (int j = N; j >= 1; j--) tmp[cnt[t[j].x]--] = t[j];
		for (int j = 1; j <= N; j++)
			if (tmp[j].x == tmp[j - 1].x && tmp[j].y == tmp[j - 1].y)
				t[tmp[j].id] = tmp[j], t[tmp[j].id].x = t[tmp[j - 1].id].x;
			else
				t[tmp[j].id] = tmp[j], t[tmp[j].id].x = j;
		for (int j = 1; j <= N; j++)
			tmp[j] = t[j];
	}
	for (int j = 1; j <= N; j++) cnt[j] = 0;
	for (int j = 1; j <= N; j++) cnt[tmp[j].x]++;
	for (int j = 1; j <= N; j++) cnt[j] += cnt[j - 1];
	for (int j = 1; j <= N; j++) a[cnt[tmp[j].x]--] = tmp[j].id;
}
/*int find(int x, int y, int len)
{
	int lx = x - len + 1, ly = y - len + 1;
	int l = 0, r = len + 1;
	while (r - l > 1)
	{
		int mid = (l + r) / 2;
		if (hash(lx, lx + mid - 1) == hash(ly, ly + mid - 1))
			l = mid;
		else
			r = mid;
	}
	return l;
}
bool cmp(int x, int y)
{
	int lx = x - K + 1, ly = y - K + 1;
	int r = find(x, y, K) + 1;
	return r != K + 1 && s[lx + r - 1] < s[ly + r - 1];
}*/

namespace tree {
	int tot, beg[MAXN], end[MAXN], len[MAXN];
	std::vector<int> son[MAXN];
	int f[MAXN], ans[MAXN];

	void DP(int v)
	{
		if (len[v] == K)
		{
			f[v] = 0;
			return;
		}
		for (int i = 0; i < son[v].size(); i++)
			DP(son[v][i]);
		if (son[v].size() == 1)
		{
			f[v] = 1;
			return;
		}
		int el0 = len[son[v][0]] - len[v];
		int el1 = len[son[v][1]] - len[v];
		f[v] = el0 > 1 || el1 > 1 || f[son[v][0]] || f[son[v][1]];
	}
	void construct(int v)
	{
		int cn = son[v].size(), el0 = len[son[v][0]] - len[v];
		int r0 = end[son[v][0]], l0 = r0 - el0 + 1;
		if (cn == 0)
			return;
		if (cn == 1)
		{
			if (count(l0, r0) < el0)
			{
				for (int i = len[v] + 1; i <= K; i++)
					ans[i] = 0;
				return;
			}
			if (f[son[v][0]])
			{
				for (int i = len[v] + 1; i <= len[son[v][0]]; i++)
					ans[i] = 0;
				construct(son[v][0]);
				return;
			}
			for (int i = len[v] + 1; i <= K; i++)
				ans[i] = 0;
			ans[len[son[v][0]]] = 1;
			return;
		}
		int el1 = len[son[v][1]] - len[v];
		int r1 = end[son[v][1]], l1 = r1 - el1 + 1;
		bool b0 = count(l0, r0) < el0, b1 = count(l1, r1) < el1;
		int f0 = f[son[v][0]], f1 = f[son[v][1]];
		if (b0 && b1)
		{
			for (int i = len[v] + 1; i <= K; i++)
				ans[i] = 0;
			return;
		}
		if (f1)
		{
			for (int i = len[v] + 1; i <= len[son[v][1]]; i++)
				ans[i] = 0;
			construct(son[v][1]);
			return;
		}
		if (el1 > 1)
		{
			for (int i = len[v] + 1; i <= K; i++)
				ans[i] = 0;
			ans[len[son[v][1]]] = 1;
		}
		else
		{
			ans[len[v] + 1] = 1;
			for (int i = len[v] + 2; i <= len[son[v][0]]; i++)
				ans[i] = 0;
			if (f0)
				construct(son[v][0]);
			else
			{
				ans[len[son[v][0]]] = (s[end[son[v][0]]] == '1');
				for (int i = len[son[v][0]] + 1; i <= K; i++)
					ans[i] = 0;
			}
		}
	}
	void solve()
	{
		DP(0);
		if (!f[0])
		{
			printf("NO\n");
			return;
		}
		construct(0);
		printf("YES\n");
		for (int i = K + 1; i <= KK; i++)
			printf("0");
		for (int i = 1; i <= K; i++)
			printf("%d", ans[i]);
		printf("\n");
	}
	void build()
	{
		static int tmp[MAXN], tmp2[MAXN], S[MAXN], top;
		for (int i = 0; i <= tot; i++)
			son[i].clear();
		tot = 0, beg[0] = end[0] = len[0] = 0;
		S[top = 1] = 0;

		for (int i = 1; i <= N; i++)
			tmp[i] = tmp2[i] = i;
		//	std::sort(tmp + K, tmp + N + 1, cmp);
		sort(tmp);

		tot++, beg[tot] = tmp[K] - K + 1, end[tot] = tmp[K], len[tot] = K;
		S[++top] = tot;
		for (int i = 1; i <= N; i++)
		{
			if (tmp[i] < KK || count(tmp[i] - KK + 1, tmp[i] - K) < KK - K) 
				continue;

			int id = ++tot;
			beg[tot] = tmp[i] - K + 1, end[tot] = tmp[i], len[tot] = K;
			int t = find(end[tot], end[S[top]], K);
			if (t == K) continue;
			while (t < len[S[top - 1]])
			{
				son[S[top - 1]].push_back(S[top]);
				top--;
			}
			int v = S[top--];
			if (t > len[S[top]])
			{
				tot++;
				beg[tot] = tmp[i] - K + 1, end[tot] = tmp[i] - K + t, len[tot] = t;
				S[++top] = tot;
			}
			son[S[top]].push_back(v);
			S[++top] = id;
		}
		while (top > 1)
			son[S[top - 1]].push_back(S[top]), top--;
	}
}

void init()
{
	s = str;
	scanf("%d %d", &N, &K); KK = K;
	scanf("%s", s + 1);

	K = min(K, 30);

	h[0] = 1;
	for (int i = 1; i <= N; i++)
		h[i] = (h[i - 1] * A + (s[i] == '1')) % mod,
		cnt[i] = cnt[i - 1] + (s[i] == '1');
}

int main()
{
	pa[0] = 1;
	for (int i = 1; i < MAXN; i++)
		pa[i] = pa[i - 1] * A % mod;

	int t;
	scanf("%d", &t);
	while (t--)
	{
		init();
		tree::build();
		tree::solve();
	}

	return 0;
}
/*
1
9 2
111001000

*/

推荐阅读