AC自动机(模板)

感谢&资料:

主要参考紫书的写法,其他的写法都用到了指针,不太喜欢指针…

因为紫书没有完整的代码,然后参考了一下 dalao 的博客

简介:

AC自动机主要解决的是字符串匹配问题,不同于 KMP 的是,AC 自动机可以进行多个模式串的匹配。KMP 的失配指针是一个线性的数组,但是 AC 自动机是多个模式串匹配,所以适配指针是一个树型的结构。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#include <bits/stdc++.h>

using namespace std;

//字典树的最大节点个数
const int MAX = 250001;
const int N = 1000010;
//有多少个不同的字符
const int SIGMA_SIZE = 26;

//字典树的节点
int ch[MAX][SIGMA_SIZE];
//当前节点是否为一个模式串的结尾, 当前节点的上一个模式串结尾, fail指针
int val[MAX], last[MAX], f[MAX], sz;
int ANS;

void init() {
sz = 1;
memset(ch, 0, sizeof(ch));
memset(val, 0, sizeof(val));
memset(f, 0, sizeof(f));
memset(last, 0, sizeof(last));
}

int idx(char c) {
return c - 'a';
}

//更新答案
void add(int u) {
while (u) {
ANS += val[u];
val[u] = 0;
u = last[u];
}
}

//添加模式串
void Creat(char *s) {
int u = 0, len = strlen(s);
for (int i = 0; i < len; i++) {
int c = idx(s[i]);
if (!ch[u][c]) ch[u][c] = sz++;
u = ch[u][c];
}
val[u]++;
}


//获得fail指针
void getFail() {
queue<int> q;
for (int i = 0; i < SIGMA_SIZE; i++)
if (ch[0][i]) q.push(ch[0][i]);

while (!q.empty()) {
int r = q.front(); q.pop();
for (int c = 0; c < SIGMA_SIZE; c++) {
int u = ch[r][c];
if (!u) continue;
q.push(u);
int v = f[r];
//和kmp相似,和根据父亲的fail指针获得当前的
while (v && ch[v][c] == 0) v = f[v];
f[u] = ch[v][c];
//更新last
last[u] = val[f[u]] ? f[u] : last[f[u]];
}
}
}

//进行匹配
void find(char * T) {
int len = strlen(T), j = 0;
for (int i = 0; i < len; i++) {
int c = idx(T[i]);
while (j && ch[j][c] == 0) j = f[j];
j = ch[j][c];
//如果当前是模式串的结尾,那么更新答案
//else if 里是处理一个模式串包含另一个模式串的情况
if (val[j]) add(j);
else if (last[j]) add(last[j]);
}
}

char str[N];

int main() {
int T;
scanf("%d", &T);
while (T--) {
init();
int n;
scanf("%d", &n);
while (n--) {
scanf("%s", str);
Creat(str);
}
getFail();
scanf("%s", str);
ANS = 0;
find(str);
printf("%d\n", ANS);
}
}

代码二

get_fail 改了下,预处理了需要跳到的位置,匹配是 O(文章长度的)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#include <bits/stdc++.h>

using namespace std;

const int MAX = 250001;
const int N = 1000010;

const int SIGMA_SIZE = 26;

int ch[MAX][SIGMA_SIZE];
int val[MAX], last[MAX], f[MAX], sz;
int ANS;

void init() {
sz = 1;
memset(ch, 0, sizeof(ch));
memset(val, 0, sizeof(val));
memset(f, 0, sizeof(f));
memset(last, 0, sizeof(last));
}

int idx(char c) {
return c - 'a';
}

void add(int u) {
while (u) {
ANS += val[u];
val[u] = 0;
u = last[u];
}
}

void creat(char *s) {
int u = 0, len = strlen(s);
for (int i = 0; i < len; i++) {
int c = idx(s[i]);
if (!ch[u][c]) ch[u][c] = sz++;
u = ch[u][c];
}
val[u]++;
}

void get_fail() {
queue<int> q;
for (int i = 0; i < SIGMA_SIZE; i++)
if (ch[0][i]) q.push(ch[0][i]);
while (!q.empty()) {
int r = q.front();
q.pop();
for (int c = 0; c < SIGMA_SIZE; c++) {
int u = ch[r][c];
if (!u) {
ch[r][c] = ch[f[r]][c];
continue;
}
q.push(u);
int v = f[r];
while (v && ch[v][c] == 0) v = f[v];
f[u] = ch[v][c];
last[u] = val[f[u]] ? f[u] : last[f[u]];
}
}
}

void find(char *T) {
int len = strlen(T), j = 0;
for (int i = 0; i < len; i++) {
int c = idx(T[i]);
j = ch[j][c];
if (val[j]) add(j);
// else if (last[j]) add(last[j]);
}
}

char str[N];

int main() {
ios::sync_with_stdio(false);
int T;
cin >> T;
while (T--) {
init();
int n;
cin >> n;
while (n--) {
cin >> str;
creat(str);
}
get_fail();
cin >> str;
ANS = 0;
find(str);
cout << ANS << endl;
}
}