题意:
问不包含n个子串的长度为m的字符串的构造个数。
解法:
构造trie图,然后DP求长度为m的合法串个数
以前高精度都靠java,这次手写,各种错误。。。唉。。
代码:
1 # include <iostream>
2 # include <cstring>
3 using namespace std;
4 struct BigInteger
5 {
6 int bit[100];
7 bool init;
8 BigInteger()
9 {
10 memset(bit,0,sizeof(bit));
11 init=true;
12 }
13 BigInteger operator+(const BigInteger &pos)
14 {
15 BigInteger res;
16 res.init=init;
17 for(int i=0;i<99;i++)
18 {
19 res.bit[i]+=bit[i]+pos.bit[i];
20 res.bit[i+1]+=res.bit[i]/10;
21 res.bit[i]%=10;
22 }
23 return res;
24 }
25 void print()
26 {
27 int i;
28 for(i=99;i>0&&!bit[i];i--);
29 for(int j=i;j>=0;j--)
30 cout<<bit[j];
31 cout<<endl;
32 }
33 };
34 struct node
35 {
36 node *nxt[51],*pre;
37 bool end;
38 void clear()
39 {
40 memset(nxt,NULL,sizeof(nxt));
41 end=false;
42 pre=NULL;
43 }
44 node()
45 {
46 clear();
47 }
48 }buffer[200];
49 int map[255];
50 int c=1,n,m,num;
51 void insert(char *str)
52 {
53 node *p=buffer;
54 for(int i=0;str[i]!='\0';i++)
55 {
56 if(!(p->nxt[map[str[i]]]))
57 p->nxt[map[str[i]]]=&buffer[c++];
58 p=p->nxt[map[str[i]]];
59 }
60 p->end=true;
61 }
62 void make_per()
63 {
64 int s=-1,e=-1;
65 node *q[200];
66 node *p=buffer;
67 for(int i=0;i<n;i++)
68 if(p->nxt[i])
69 {
70 p->nxt[i]->pre=p;
71 q[++e]=p->nxt[i];
72 }
73 else
74 p->nxt[i]=p;
75 while(s!=e)
76 {
77 p=q[++s];
78 for(int i=0;i<n;i++)
79 {
80 node *pre=p->pre;
81 while(pre->nxt[i]==NULL) pre=pre->pre;
82 if(p->nxt[i])
83 {
84 p->nxt[i]->pre=pre->nxt[i];
85 p->nxt[i]->end=(p->nxt[i]->pre->end||p->nxt[i]->end);
86 q[++e]=p->nxt[i];
87 }
88 else
89 p->nxt[i]=pre->nxt[i];
90 }
91 }
92 }
93 BigInteger dp[200][55],zero,one;
94 BigInteger solve(int s,node *p)
95 {
96 if(p->end) return zero;
97 else if(s==m) return one;
98 else if(!(dp[p-buffer][s].init)) return dp[p-buffer][s];
99 else
100 {
101 dp[p-buffer][s].init=false;
102 for(int i=0;i<n;i++)
103 {
104 dp[p-buffer][s]=dp[p-buffer][s]+solve(s+1,p->nxt[i]);
105 }
106 return dp[p-buffer][s];
107 }
108
109 }
110 int main()
111 {
112 cin>>n>>m>>num;
113 char str[128];
114 cin>>str;
115 int tc=0;
116 for(int i=0;str[i]!='\0';i++)
117 map[str[i]]=tc++;
118 while(num--)
119 {
120 cin>>str;
121 insert(str);
122 }
123 make_per();
124 one.bit[0]=1;;
125 solve(0,buffer);
126 dp[0][0].print();
127 return 0;
128
129 }