【洛谷】P8306 【模板】字典树

  • Post author:
  • Post category:其他




题目地址:


https://www.luogu.com.cn/problem/P8306

题目描述:

给定



n

n






n





个模式串



s

1

,

s

2

,

,

s

n

s_1, s_2, \dots, s_n







s










1


















,





s










2


















,









,





s










n

























q

q






q





次询问,每次询问给定一个文本串



t

i

t_i







t










i





















,请回答



s

1

s

n

s_1 \sim s_n







s










1






























s










n





















中有多少个字符串



s

j

s_j







s










j





















满足



t

i

t_i







t










i

























s

j

s_j







s










j























前缀

。一个字符串



t

t






t









s

s






s





的前缀当且仅当从



s

s






s





的末尾删去若干个(可以为



0

0






0





个)连续的字符后与



t

t






t





相同。输入的字符串大小敏感。例如,字符串

Fusu

和字符串

fusu

不同。

输入格式:

输入的第一行是一个整数,表示数据组数



T

T






T







对于每组数据,格式如下:

第一行是两个整数,分别表示模式串的个数



n

n






n





和询问的个数



q

q






q







接下来



n

n






n





行,每行一个字符串,表示一个模式串。

接下来



q

q






q





行,每行一个字符串,表示一次询问。

输出格式:

按照输入的顺序依次输出各测试数据的答案。

对于每次询问,输出一行一个整数表示答案。

数据范围:

对于全部的测试点,保证



1

T

,

n

,

q

1

0

5

1 \leq T, n, q\leq 10^5






1













T


,




n


,




q













1



0










5












,且输入字符串的总长度不超过



3

×

1

0

6

3 \times 10^6






3




×








1



0










6












。输入的字符串只含大小写字母和数字,且不含空串。

说明:

std的IO使用的是关闭同步后的

cin/cout

,本题不卡常。

可以用Trie,每个节点还需要另外存一下经过该节点的字符串的总个数



c

c






c





,这样查询的时候,可以顺着查询字符串向下走,如果走不动了则返回



0

0






0





,否则返回最后停留的节点的



c

c






c





值。代码如下:

#include <iostream>
using namespace std;

const int N = 3e6 + 10;
int n, q;
char s[N];
int tr[N][65], idx;
int cnt[N];
int mp['z' + 1];

void add() {
  cnt[0]++;
  int c = 0;
  for (int i = 1; s[i]; i++) {
    int pos = mp[s[i]];
    if (!tr[c][pos]) tr[c][pos] = ++idx;
    c = tr[c][pos];
    cnt[c]++;
  }
}

int query() {
  int c = 0;
  for (int i = 1; s[i]; i++) {
    int pos = mp[s[i]];
    if (!tr[c][pos]) return 0;
    c = tr[c][pos];
  }

  return cnt[c];
}

int main() {
  int cc = 0;
  for (char ch = 'A'; ch <= 'Z'; ch++) mp[ch] = cc++;
  for (char ch = 'a'; ch <= 'z'; ch++) mp[ch] = cc++;
  for (char ch = '0'; ch <= '9'; ch++) mp[ch] = cc++;

  int T;
  scanf("%d", &T);
  while (T--) {
    for (int i = 0; i <= idx; i++)
      for (int j = 0; j <= 'z'; j++)
        tr[i][j] = 0;
    for (int i = 0; i <= idx; i++)
      cnt[i] = 0;
    idx = 0;

    scanf("%d%d", &n, &q);
    for (int i = 1; i <= n; i++) {
      scanf("%s", s + 1);
      add();
    }

    while (q--) {
      scanf("%s", s + 1);
      printf("%d\n", query());
    }
  }
}

每组数据时间复杂度



O

(

i

s

i

+

i

q

i

)

O(\sum_i s_i+\sum_iq_i)






O


(














i





















s










i




















+




















i





















q










i


















)









s

i

s_i







s










i





















是每次插入的字符串长度,



q

i

q_i







q










i





















是每次查询的字符串长度,空间



O

(

i

s

i

)

O(\sum_i s_i)






O


(














i





















s










i


















)







版权声明:本文为qq_46105170原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。