由于这周初学SAM,变找些题在做。有了SAM,很多问题都变简单了,比如这题。。。
题目描述:
询问两个长度为100,000的字符串,不小于k的公共子串有多少个。
算法分析:
对一个串建立SAM,并预处理出某个节点可以匹配的串的终点位置的个数,另一个串在上面跑。
跑的同时记录最大可匹配范围len。
如果匹配到某个点,len的值大于k了。那么就说明该统计能匹配的串然后加入到answer中了。
这个统计应该是不断回溯这个节点的父亲。而且对于某个k,除了第一个点,其他的点答案都是一样,于是可以记忆化。
1 #include<iostream>
2 #include<algorithm>
3 #include<cstdio>
4 #include<cstring>
5 using namespace std;
6 const int N = (int)1e5+10;
7 char ch[N],pat[N];
8 int par[N<<1], G[N<<1][60], val[N<<1], sz, last;
9 inline int convert(char x){if(x >= 'a' && x <= 'z') return x - 'a'; else return x - 'A' + 26;}
10 void ins(int x){
11 int p = last , np = sz ++;
12 memset(G[np],0,sizeof(G[np]));
13 val[np] = val[p] + 1;
14 while(p!=-1&&G[p][x]==0)G[p][x]=np,p=par[p];
15 if(p==-1){
16 par[np]=0;
17 } else {
18 int q=G[p][x];
19 if(val[q]==val[p]+1) par[np]=q;
20 else {
21 int nq = sz ++;
22 val[nq] = val[p]+1;
23 memcpy(G[nq],G[q],sizeof(G[q]));
24 par[nq] = par[q];
25 par[q] = par[np] = nq;
26 while(p!=-1 && G[p][x]==q) G[p][x]=nq, p=par[p];
27 }
28 }
29 last = np;
30 }
31 int cnt[N<<1];
32 struct node{
33 int id,v;
34 node(){};
35 node(int _id,int _v):id(_id),v(_v){}
36 bool operator < (const node& a)const{
37 return v > a.v;
38 }
39 } num[N<<1];
40 long long dp[N<<1];
41 void build(){
42 sz = 1; last = 0; val[0] = 0; par[0] = -1;
43 memset(G[0],0,sizeof(G[0]));
44 int n = strlen(ch);
45 for(int i = 0; i < n; i++) ins(convert(ch[i]));
46 //for(int i = 0; i < sz; i++) cout<<G[i]['x'-'a']<< endl;
47 for(int i = 0; i < sz; i++) dp[i] = -1, cnt[i] = 0;
48 int p = last;
49 while(p) cnt[p]=1, p = par[p];
50 for(int i = 0; i < sz; i++) num[i] = node(i,val[i]);
51 sort(num,num+sz);
52 for(int i = 0; i < sz; i++){
53 int u = num[i].id,v;
54 for(int i = 0; i < 60; i++) if(v = G[u][i]) {
55 cnt[u] += cnt[v];
56 }
57 }
58 //for(int i = 0; i < sz; i++) cout<< cnt[i] <<" ";cout<<endl;
59 }
60 long long dfs(int p,int k){
61 long long &ans = dp[p];
62 if(ans != -1) return ans ;
63 if(val[par[p]] >= k) {
64 ans = dfs(par[p],k) + 1LL* (val[p] - val[par[p]]) * cnt[p];
65 } else {
66 ans = 1LL * (val[p] - k + 1) * cnt[p];
67 }
68 return ans ;
69 }
70 long long solve(int k){
71 int n = strlen(pat), len = 0, p = 0;
72 long long ans = 0;
73 for(int i = 0; i < n; i ++) {
74 int x= convert(pat[i]);
75 while(p && G[p][x] == 0) p = par[p], len = val[p];
76 if(p = G[p][x]){
77 len ++;
78 if(len >= k) {
79 if(val[par[p]] >= k) ans += 1LL * (len - val[par[p]]) * cnt[p] + dfs(par[p], k);
80 else ans += 1LL * (len - k + 1) * cnt[p];
81 }
82 }
83 // cout<< pat[i] <<" "<< ans <<" "<<len<<" "<<p<<endl;
84 }
85 return ans ;
86 }
87 int main(){
88 int k ;
89 while(~scanf("%d",&k)&&k){
90 scanf("%s%s",ch,pat);
91 build();
92 cout<< solve(k) << endl;
93 }
94 }
95
posted on 2012-10-25 19:23
西月弦 阅读(556)
评论(0) 编辑 收藏 引用 所属分类:
解题报告