矩阵快速幂小结

感谢&资料:

非常感谢一些 up 主和博主分析的资料,矩阵快速幂好多书上没详细讲,于是搜了一些资料。

简介:

矩阵快速幂其实就是矩阵乘法 + 快速幂,上面的视频还讲了扩展的幂运算的含义,挺不错的,代码上的体现就是,将原本的数字相乘换成了矩阵相乘,其他都是一样的,实质就是加速了矩阵的幂运算,对原本 $O(n)$的幂运算加速到$O(\log n)$。

1
2
3
4
5
6
7
8
9
10
Mat pow_mod(Mat a, int k) {
Mat rst(a.n, a.m);
rst.unit();
while (k) {
if (k & 1) rst = rst * a;
a = a * a;
k >>= 1;
}
return rst;
}

这就是一个矩阵快速幂的核心代码,我这里重载了乘号,然后看起来和普通的快速幂根本没什么区别,unit 函数是将矩阵初始化为单位矩阵。

应用:

通过上面的描述,矩阵快速幂其实就是加速矩阵的幂运算,但是这个有什么用呢?

在 ACM 竞赛中矩阵快速幂通常用来加速递推式的计算,将原本 $O (n)$ 的计算过程加速到 $O(k\log n)$ k 与矩阵的大小有关。

一个简单的例子就是用来计算斐波那契数列,假设要求斐波那契数列的第 n 项 $f_n$ 当 n 比较小直接用循环跑就可以了,这样是 $O(n)$ ,但是当 n 非常大甚至高达 $10^9$ 时用过这种朴素的方法进行递推肯定是超时的。

这时候就要用到矩阵快速幂,首先,也是最核心的一部,要找到递推关系和转移矩阵,对于斐波那契数列,非常任意就可以找到:

$\begin{bmatrix} 1 & 1 \\ 1 & 0 \end{bmatrix} \cdot \begin{bmatrix} f_{n-1} \\ f_{n-2} \end{bmatrix} = \begin{bmatrix} f_{n-1} + f_{n-2} \\ f_{n-1} \end{bmatrix} = \begin{bmatrix} f_n \\ f_{n-1} \end{bmatrix}$

我们称第一个矩阵为转移矩阵,第二个一般是一个$n \times 1$的向量,右边的向量每乘一次转移矩阵,向量中的每个元素根据递推式往后推了一次。

这时我设 $A_n = \begin{bmatrix} f_n \\ f_{n-1} \end{bmatrix}$ , $T = \begin{bmatrix} 1 & 1 \\ 1 & 0 \end{bmatrix}$那么显然有 $A_n = T^{n-1} \cdot A_{1}$.

转移我们对 T 这个转移矩阵就行快速幂,就可以再乘以 $A_1$ 就能算出来 $A_n$,复杂度是 $O(k\log n)$ 。

通过上面的分析,其实矩阵快速幂在应用的时候,难点是找出转移矩阵 T 和 $A_1$ ,下面给出一些简单的递推式练手:

$f_n = a\cdot f_{n-1} + b\cdot f_{n-2} + c$ (a, b, c是常数)

$\begin{bmatrix} a & b & 1 \\ 1 & 0 & 0 \\ 0 & 0 & 1 \end{bmatrix} \cdot \begin{bmatrix} f_{n-1} \\ f_{n-2} \\\ c \end{bmatrix} = \begin{bmatrix} f_n \\ f_{n-1} \\\ c\end{bmatrix}$

$f_n = c^n - f_{n-1}$

$\begin{bmatrix} -1 & c \\ 0 & c \end{bmatrix} \cdot \begin{bmatrix} f_{n-1} \\ c^{n-1} \end{bmatrix} = \begin{bmatrix} f_n \\ c^{n} \end{bmatrix}$

板子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#include <bits/stdc++.h>

using namespace std;

typedef long long ll;

const int MOD = 1e9 + 7;

struct Mat {
ll v[112][112];
int n, m;

Mat(int tn, int tm) {
n = tn, m = tm;
memset(v, 0, sizeof(v));
}

void unit() {
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
if (i == j) v[i][j] = 1;
else v[i][j] = 0;
}
}
}

void read() {
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
scanf("%lld", &v[i][j]);
}
}
}

void out() {
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
printf(j == m - 1 ? "%d\n" : "%d ", v[i][j]);
}
}
}

Mat operator*(const Mat &a) {
Mat rst(n, a.m);
for (int i = 0; i < n; i++) {
for (int j = 0; j < a.m; j++) {
ll t = 0;
for (int k = 0; k < m; k++) {
t += v[i][k] * a.v[k][j] % MOD;
t = t % MOD;
}
rst.v[i][j] = t;
}
}
return rst;
}
};

Mat pow_mod(Mat a, int k) {
Mat rst(a.n, a.m);
rst.unit();
while (k) {
if (k & 1) rst = rst * a;
a = a * a;
k >>= 1;
}
return rst;
}


int main() {
int n, m;
scanf("%d %d", &n, &m);
Mat a(n, n);
a.read();
pow_mod(a, m).out();
}

简单练习: