首页 > 技术文章 > 矩阵加速(数列) 学习笔记

jiangtaizhe001 2021-10-10 19:59 原文

前置知识——矩阵

定义

对于一个矩阵 \(A\) 主对角线指元素 \(A_{i,i}\) 上的元素。
单位矩阵式指一个矩阵中主对角线上的元素全是 \(1\),其他元素是 \(0\) 的一个矩阵,用 \(I\) 表示。单位矩阵 \(I\) 满足对于任意的矩阵 \(A\) 都有 \(A\times I=A\)
一个矩阵 \(A\) 的逆矩阵式指使得 \(A\times B=I\) 的矩阵 \(B\),可以用高斯消元的方法求,这里不多叙述。其实是我不会

运算

矩阵的加法和减法很简单,只需要把两个矩阵的元素逐个相加就可以了。矩阵的数乘亦是如此。
乘法则较为复杂。矩阵乘法在第一个矩阵的列数和第二个矩阵的行数相同的时候才有意义。
假设有两个矩阵 \(A\)\(B\),其中矩阵 \(A\)\(P\times M\) 的矩阵, 矩阵 \(B\)\(M\times Q\) 的矩阵,那么矩阵 \(C=A\times B\) 的第 \(i\) 行第 \(j\) 列的元素表达为:

\[C_{i,j}=\sum_{k=1}^{m} A_{i,k}\times B_{k,j} \]

通俗的讲,就是两个矩阵相乘的时候,结果的第 \(i\) 行第 \(j\) 列的结果是由第一个矩阵的第 \(i\) 行和第二个矩阵的第 \(j\) 列的数字乘起来的和。
矩阵乘法不满足交换律,但是满足结合律。
在写代码的时候,注意三重循环的顺序,让内存访问更连续,这样可以做到常数级别的优化,建议按顺序枚举 i-k-j 而不是 i-j-k 。

struct Mat{
	ll a[maxn][maxn]; int n,m;
	Mat operator + (const Mat x) const {
		Mat res; RI i,j; res.n=this->n; res.m=this->m;
		for(i=1;i<=res.n;i++) for(j=1;j<=res.m;j++) res.a[i][j]=this->a[i][j]+x.a[i][j];
		return res;
	}
	Mat operator - (const Mat x) const {
		Mat res; RI i,j; res.n=this->n; res.m=this->m;
		for(i=1;i<=res.n;i++) for(j=1;j<=res.m;j++) res.a[i][j]=this->a[i][j]-x.a[i][j];
		return res;
	}
	Mat operator * (const Mat x) const {
		Mat res; RI i,k,j,r; res.n=this->n; res.m=x.m;
		for(i=1;i<=res.n;i++) for(j=1;j<=res.m;j++) res.a[i][j]=0;
		for(i=1;i<=res.n;i++) for(k=1;k<=this->m;k++){
			r=this->a[i][k];
			for(j=1;j<=res.m;j++) res.a[i][j]+=x.a[k][j]*r,res.a[i][j]%=MOD;
		} return res;
	}
};

板子题

题目传送门
给出序列 \(a\) ,满足:
\(a_i=\left\{ \begin{aligned} & 1 & & x\in \{ 1,2,3\}\\ & a_{i-1}+a_{i-3} & & x≥4\\ \end{aligned} \right.\)
\(a_n \bmod (10^9+7)\),其中 \(n \le 2\times 10^9\) ,多组数据,数据组数 \(\le 100\)

题目解析

显然我们可以 \(\Theta\left(n\right)\) 递推,但是效率太低了。因此我们考虑使用一些优化。
我会推通项公式!
好吧这里采用矩阵加速。
我们先定义一个矩阵 \(A=\left[\begin{matrix} a_n & a_{n+1} & a_{n+2} \\ \end{matrix}\right]\) 和矩阵 \(B=\left[\begin{matrix} a_{n+1} & a_{n+2} & a_{n+3} \\ \end{matrix}\right]\)
我们发现:
\(a_{n+1}= 0 \times a_{n} + 1 \times a_{n+1} + 0\times a_{n+2}\)
\(a_{n+2}= 0 \times a_{n} + 0 \times a_{n+1} + 1\times a_{n+2}\)
\(a_{n+3}= 1 \times a_{n} + 0 \times a_{n+1} + 1\times a_{n+2}\)
这样我们就发现了一个矩阵 \(base=\left[\begin{matrix} 0 & 0 & 1 \\ 1 & 0 & 0 \\ 0 & 1 & 1 \\ \end{matrix}\right]\) 使得 \(A \times base = B\)
这样只要我们定义一个初始矩阵 \(ans=\left[\begin{matrix} a_1 & a_2 & a_3 \end{matrix}\right]=\left[\begin{matrix} 1 & 1 & 1 \end{matrix}\right]\),答案就是矩阵 \(ans\times base^{n-3}\) 的第一行第三列的元素了,使用快速幂优化可以做到复杂度 \(\Theta\left(\log n \right)\),当然其实矩阵乘法可以直接循环展开。
当然注意 \(n\le 3\) 的情况。
代码:

#include<cstdio>
#define I inline
#define db double
#define U unsigned
#define Re register
#define ll long long
#define RI register int
#define ull unsigned long long
#define swap(x,y) x^=y^=x^=y;
#define abs(x) ((x)>0?(x):(-(x)))
#define max(a,b) ((a)>(b)?(a):(b))
#define min(a,b) ((a)<(b)?(a):(b))
#define Me(a,b) memset(a,b,sizeof(a))
#define EPS (1e-7)
#define INF (0x7fffffff)
#define LL_INF (0x7fffffffffffffff)
#define maxn 39
#define MOD 1000000007
//#define debug
using namespace std;
#define Type long long
I Type read(){
	Type sum=0; int flag=0; char c=getchar();
	while((c<'0'||c>'9')&&c!='-') c=getchar(); if(c=='-') c=getchar(),flag=1;
	while('0'<=c&&c<='9'){ sum=(sum<<1)+(sum<<3)+(c^48); c=getchar(); }
	if(flag) return -sum; return sum;
}
struct Mat{
	ll a[maxn][maxn]; int n,m;
	Mat operator + (const Mat x) const {
		Mat res; RI i,j; res.n=this->n; res.m=this->m;
		for(i=1;i<=res.n;i++) for(j=1;j<=res.m;j++) res.a[i][j]=this->a[i][j]+x.a[i][j];
		return res;
	}
	Mat operator - (const Mat x) const {
		Mat res; RI i,j; res.n=this->n; res.m=this->m;
		for(i=1;i<=res.n;i++) for(j=1;j<=res.m;j++) res.a[i][j]=this->a[i][j]-x.a[i][j];
		return res;
	}
	Mat operator * (const Mat x) const {
		Mat res; RI i,k,j,r; res.n=this->n; res.m=x.m;
		for(i=1;i<=res.n;i++) for(j=1;j<=res.m;j++) res.a[i][j]=0;
		for(i=1;i<=res.n;i++) for(k=1;k<=this->m;k++){
			r=this->a[i][k];
			for(j=1;j<=res.m;j++) res.a[i][j]+=x.a[k][j]*r,res.a[i][j]%=MOD;
		} return res;
	}
}st,base,MarI;
void getI(){ MarI.n=MarI.m=3; for(RI i=1;i<=3;i++) MarI.a[i][i]=1; return; }
Mat pow(Mat x,ll y){
	Mat tmp=x,res=MarI;
	while(y){ if(y&1) res=res*tmp; tmp=tmp*tmp; y>>=1; }
	return res;
}
int T; ll n;
int main(){
    //freopen(".in","r",stdin);
    //freopen(".out","w",stdout);
	T=read(); st.n=1; st.m=base.n=base.m=3; getI();
	st.a[1][1]=st.a[1][2]=st.a[1][3]=base.a[1][3]=base.a[2][1]=base.a[3][2]=base.a[3][3]=1;
	while(T--){
		n=read();
		if(n<=3) puts("1");
		else printf("%lld\n",(st*pow(base,n-3)).a[1][3]);
	}
	return 0;
}

推荐阅读