首页 > 技术文章 > 「THUSCH 2017」大魔法师

MYCui 2020-11-16 13:02 原文

这道题目的难点在于要发现正解是矩阵乘法优化线段树

难点主要是极其难写

这个题目的正解就是:矩阵乘法+\(luogu\)线段树模板2

规定矩阵 \(1\):

\[ \left[ \begin{matrix} A \\ B \\ C \end{matrix} \right] \tag{1} \]

  • 对于操作一:

\[ \left[ \begin{matrix} 1 & 1 & 0 \\ 0 & 1 & 0 \\ 0 & 0 & 1 \end{matrix} \right] \tag{2} \]

矩阵\(1\)*矩阵\(2\)得到:

\[ \left[ \begin{matrix} A + B \\ B \\ C \end{matrix} \right] \]

  • 对于操作2:

\[ \left[ \begin{matrix} 1 & 0 & 0 \\ 0 & 1 & 1 \\ 0 & 0 & 1 \end{matrix} \right] \tag{3} \]

矩阵\(1\) * 矩阵 \(3\)得到:

\[ \left[ \begin{matrix} A \\ B+C \\ C \end{matrix} \right] \]

  • 对于操作三:

\[ \left[ \begin{matrix} 1 & 0 & 0 \\ 0 & 1 & 0 \\ 1 & 0 & 1 \end{matrix} \right] \tag{4} \]

矩阵\(1\) * 矩阵\(4\)得到:

\[ \left[ \begin{matrix} A \\ B \\ C+A \end{matrix} \right] \]

  • 对于操作四:

\[ \left[ \begin{matrix} v \\ 0 \\ 0 \end{matrix} \right] \tag{5} \]

矩阵\(1\) + \(矩阵5\)得到:

\[ \left[ \begin{matrix} A + v \\ B \\ C \end{matrix} \right] \]

  • 对于操作5:

\[ \left[ \begin{matrix} 1 & 0 & 0 \\ 0 & v & 0 \\ 0 & 0 & 1 \end{matrix} \right] \tag{6} \]

矩阵\(1\) * 矩阵\(6\)得到:

\[ \left[ \begin{matrix} A \\ B*v \\ C \end{matrix} \right] \]

  • 对于操作六:

\[ \left[ \begin{matrix} 1 & 0 & 0 \\ 0 & 1 & 0 \\ 0 & 0 & 0 \end{matrix} \right] \tag{7} \]

\[ \left[ \begin{matrix} 0 \\ 0 \\ v \end{matrix} \right] \tag{8} \]

矩阵\(1\) * \(矩阵7 + 矩阵8\)得到:

\[ \left[ \begin{matrix} A \\ B \\ v \end{matrix} \right] \]

综上,就完成了六种修改操作。很简单的发现,这个东西用线段树比较好维护。

考虑如何下传标记。

这里发现一共要实现矩阵的乘法以及加法。

类似于线段树的模板二,也要支持序列的乘法以及加法。

将线段树模板二的序列乘法加法改为矩阵乘法加法即可。

下传:先乘后加

那么假设现在的答案矩阵为:\((x + lazadd) * lazmul\)(\(lazadd\) 表示加法懒标记矩阵,\(lazmul\) 表示乘法懒标记矩阵, \(x\) 还没有传懒标记前的答案矩阵)

现在要乘以一个\(t\),加上一个\(k\),先将 \(lazadd\) 以及 \(lazmul\) 还有 \(sum\) 都乘以矩阵 \(t\) ,然后再将 \(lazadd\) 还有 \(sum\) 加上 矩阵 \(k\) 即可。

#include <bits/stdc++.h>
using namespace std;
#define int long long
const int MAXN = 250000 + 50,Mod = 998244353;
int n,Q;
int data[MAXN][3];
int mul1[3][3],mul2[3][3],mul3[3][3],mul5[3][3],mul6[3][3],Sum[3],M[3][3];

struct Node {
	int sum[3];
	int lazadd[3];
	int lazmul[3][3];
	int l , r ;
	void Clean()
	{
		for(int i = 0 ; i <= 2 ; i ++)
		{
			lazadd[i] = 0;
			for(int j = 0 ; j <= 2; j ++)
			lazmul[i][j] = 0;
			lazmul[i][i] = 1;
		}
		return ;
	}
	void ad(int Mul[3][3],int add[3])
	{
		int C[3][3] = {0};
		for(int i = 0 ; i <= 2 ; i ++)
		{
			for(int j = 0 ; j <= 2 ; j ++)
				for(int k = 0 ; k <= 2 ; k ++)
			C[i][j] += (lazmul[k][j] * Mul[i][k]) % Mod,C[i][j] %= Mod;
		}
		for(int i = 0 ; i <= 2 ; i ++)
			for(int j = 0 ; j <= 2 ; j ++)
			lazmul[i][j] = C[i][j],C[i][j] = 0;
		for(int i = 0 ; i <= 2 ; i ++)
		{
			for(int k = 0 ; k <= 2 ; k ++)
			C[i][0] += (Mul[i][k] * lazadd[k]) % Mod,C[i][0] %= Mod;
		}
		for(int i = 0 ; i <= 2 ; i ++)
			lazadd[i] = C[i][0],C[i][0] = 0,lazadd[i] += add[i],lazadd[i] %= Mod;
		for(int i = 0 ; i <= 2 ; i ++)
		{
			for(int k = 0 ; k <= 2 ; k ++)
			C[i][0] += (Mul[i][k] * sum[k]) % Mod,C[i][0] %= Mod;
		}
		for(int i = 0 ; i <= 2 ; i ++)
		sum[i] = C[i][0],sum[i] %= Mod,C[i][0] = 0,sum[i] += ((r - l + 1) * add[i]) % Mod,sum[i] %= Mod;
		return ;
	}
	
} T[MAXN * 4];

struct Output
{
	int A,B,C;
};

bool pd(int x)
{
	if(T[x].lazadd[0] || T[x].lazadd[1] || T[x].lazadd[2])return 1;
	for(int i = 0 ; i <= 2 ; i ++)
		for(int j = 0 ; j <= 2 ; j ++)
	if(M[i][j] != T[x].lazmul[i][j])return 1;
	return 0;
}

void pushdown(int x)
{
	if(pd(x) == 0) return;
	T[x << 1].ad(T[x].lazmul,T[x].lazadd);
	T[x << 1 | 1].ad(T[x].lazmul,T[x].lazadd);
	T[x].Clean();
	return ;
}
void build(int x,int l,int r)
{
	T[x].l = l , T[x].r = r;
	T[x].Clean();
	if(l == r)
	{
		T[x].sum[0] = data[l][0];T[x].sum[0] %= Mod;
		T[x].sum[1] = data[l][1];T[x].sum[1] %= Mod;
		T[x].sum[2] = data[l][2];T[x].sum[2] %= Mod;
		return ;
	}dianwo
	int mid = ( l + r ) >> 1;
	build(x << 1 , l , mid );
	build(x << 1 | 1 , mid + 1 , r);
	T[x].sum[0] = T[x << 1].sum[0] % Mod + T[x << 1 | 1].sum[0] % Mod;T[x].sum[0] %= Mod;
	T[x].sum[1] = T[x << 1].sum[1] % Mod + T[x << 1 | 1].sum[1] % Mod;T[x].sum[1] %= Mod;
	T[x].sum[2] = T[x << 1].sum[2] % Mod + T[x << 1 | 1].sum[2] % Mod;T[x].sum[2] %= Mod;
	return ;
}

void prepare()
{
	for(int i =  0 ; i <= 2 ; i ++)
	{
		for(int j = 0 ; j <= 2 ; j ++)
		mul1[i][j] = mul2[i][j] = mul3[i][j] = mul5[i][j] = mul6[i][j] = 0ll;
		mul1[i][i] = mul2[i][i] = mul3[i][i] = mul5[i][i] = mul6[i][i] = 1ll;
		M[i][i] = 1;
	}
	mul1[0][1] = 1ll;
	mul2[1][2] = 1ll;
	mul3[2][0] = 1ll;
	mul6[2][2] = 0ll;
	return ;
}

Output Add(Output A, Output B)
{
	A.A += B.A;A.A %= Mod;
	A.B += B.B;A.B %= Mod;
	A.C += B.C;A.C %= Mod;
	return A;
}

Output GetSum(int x,int l,int r)
{
	Output Ans;
	Ans.A = Ans.B = Ans.C = 0ll;
	if(T[x].l >= l && T[x].r <= r)
	{
		Ans.A = T[x].sum[0] % Mod;
		Ans.B = T[x].sum[1] % Mod;
		Ans.C = T[x].sum[2] % Mod;
		return Ans;
	}
	pushdown(x);
	int mid = (T[x].l + T[x].r) >> 1;
	if(l <= mid)Ans = Add(Ans,GetSum(x << 1 , l , r));
	if(r  > mid)Ans = Add(Ans,GetSum(x << 1 | 1 , l, r));
	Ans.A %= Mod;
	Ans.B %= Mod;
	Ans.C %= Mod;
	return Ans;
}

void change(int x,int l,int r,int op)
{
	if(T[x].l >= l && T[x].r <= r)
	{
		if(op == 1)T[x].ad(mul1,Sum);
		if(op == 2)T[x].ad(mul2,Sum);
		if(op == 3)T[x].ad(mul3,Sum);
		if(op == 4)T[x].ad(M,Sum);
		if(op == 5)T[x].ad(mul5,Sum);
		if(op == 6)T[x].ad(mul6,Sum);
		return ;
	}
	pushdown(x);
	int mid = (T[x].l + T[x].r) >> 1;
	if(l <= mid)change(x << 1 , l , r , op);
	if(r  > mid)change(x << 1 | 1 , l, r , op);
	T[x].sum[0] = T[x << 1].sum[0] % Mod + T[x << 1 | 1].sum[0] % Mod;T[x].sum[0] %= Mod;
	T[x].sum[1] = T[x << 1].sum[1] % Mod + T[x << 1 | 1].sum[1] % Mod;T[x].sum[1] %= Mod;
	T[x].sum[2] = T[x << 1].sum[2] % Mod + T[x << 1 | 1].sum[2] % Mod;T[x].sum[2] %= Mod;
	return ;
}

signed main()
{	
	cin >> n;
	for(int i = 1 ; i <= n ; i ++)
	{
		int A,B,C;
		cin >> A >> B >> C;
		data[i][0] = A;
		data[i][1] = B;
		data[i][2] = C;
	}
	build(1 , 1 , n);
	prepare();
	cin >> Q;
	for(int i = 1 ; i <= Q ; i ++)
	{
		int op,l,r,v;
		cin >> op >> l >> r ;
		if( 4ll <= op && op <= 6ll )cin >> v;
		if(op == 4)
			Sum[0] = v,Sum[1] = Sum[2] = 0ll;
		if(op == 6)
			Sum[0] = Sum[1] = 0ll ,Sum[2] = v;
		if(op == 5)mul5[1][1] = v;
		if(op <= 6)change(1 , l , r , op);
		if(op == 7)
		{
			Output Ans = GetSum(1 , l , r);
			cout << Ans.A % Mod;putchar(' ');
			cout << Ans.B % Mod;putchar(' ');
			cout << Ans.C % Mod;putchar('\n');
		}
		Sum[0] = Sum[1] = Sum[2] = 0ll;
	}
	return 0;
}

后话:

下穿标记后,一定要记住清空 懒标记矩阵 并且记得要给局部变量赋值!

时间复杂度O(\(27 * n * log_2(n) * 线段树常数\))

另外还有一种开4*4的矩阵的做法,只用维护区间矩阵乘法。

算算时间复杂度我们发现:O(\(64 * n * log_2(n)\))(\(n == 2.5*10^5时\)) ≈ \(64 * n * 19\)\(O(304000000)\),线段树的常数有亿点点大.....亿点点危。

所以还是写这个 \(3*3\) 的矩阵的做法比较好。

推荐阅读