poj 1625(ac自动机+dp+高精度)

  • Post author:
  • Post category:其他


题意:有p个模式串,长度为m的目标串中不出现模式串的种类是多少,且给出了字符串会出现的n个字母。

题解:由于串的长度最多到50,可以用dp,f[i][j]表示串长度为i结尾节点是j的路径数,状态转移方程f[i][j] = sum(f[i – 1][k]),串长度为i-1,结尾节点是k,且添加字符c后能安全跳到节点j的所有情况和。因为总种类最多有n^m,没有让取模,需要用高精度。

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <queue>
#include <map>
using namespace std;
const int N = 105;
int Next[N][55], fail[N], val[N], sz, n, m, p;
int f[55][N][N], res[N];
char str[N];
map<char, int> mp;

void init() {
    memset(Next[0], 0, sizeof(Next[0]));
    val[0] = 0;
    sz = 1;
}

void insert(char *s) {
    int u = 0, len = strlen(s);
    for (int i = 0; i < len; i++) {
        int k = mp[s[i]];
        if (!Next[u][k]) {
            memset(Next[sz], 0, sizeof(Next[sz]));
            val[sz] = 0;
            Next[u][k] = sz++;
        }
        u = Next[u][k];
    }
    val[u] = 1;
}

void getFail() {
    queue<int> Q;
    fail[0] = 0;
    for (int i = 0; i < n; i++)
        if (Next[0][i]) {
            fail[Next[0][i]] = 0;
            Q.push(Next[0][i]);
        }
    while (!Q.empty()) {
        int u = Q.front();
        Q.pop();
        if (val[fail[u]])
            val[u] = 1;
        for (int i = 0; i < n; i++) {
            if (!Next[u][i])
                Next[u][i] = Next[fail[u]][i];
            else {
                fail[Next[u][i]] = Next[fail[u]][i];
                Q.push(Next[u][i]);
            }
        }
    }
}

void Sum(int* a, int* b) {
    int temp = 0;
    for (int i = 0; i < N; i++) {
        a[i] += b[i] + temp;
        temp = a[i] / 10;
        a[i] %= 10;
    }
}

int main() {
    while (scanf("%d%d%d", &n, &m, &p) == 3) {
        init();
        mp.clear();
        scanf("%s", str);
        for (int i = 0; i < n; i++)
            mp[str[i]] = i;
        for (int i = 0; i < p; i++) {
            scanf("%s", str);
            insert(str);
        }
        getFail();
        memset(f, 0, sizeof(f));        
        f[0][0][0] = 1;
        for (int i = 1; i <= m; i++)
            for (int j = 0; j < sz; j++)
                for (int k = 0; k < n; k++)
                    if (!val[Next[j][k]])
                        Sum(f[i][Next[j][k]], f[i - 1][j]);
        memset(res, 0, sizeof(res));
        for (int i = 0; i < sz; i++)
            Sum(res, f[m][i]);
        int pos;
        for (pos = N - 1; pos >= 0; pos--)
            if (res[pos] != 0)
                break;
        if (pos == -1)
            printf("0\n");
        else {
            for (; pos >= 0; pos--)
                printf("%d", res[pos]);
            printf("\n");
        }
    }
    return 0;
}



版权声明:本文为u013392752原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。