#P10085. 聊天记录

聊天记录

#include <cstdio>
#include <iostream>
#include <map>
#include <vector>
#include <algorithm>

using namespace std;

const int maxn = 10010, maxlen = 10010, maxw = 4, maxl = maxw + maxw + 1;
const int seed = 14733;

struct State {
    int cnt;
    int s[maxl]; // a#b, last 3 bits - b
};

int n, w, cnt, len, mask;
map<int, int> ss;
vector<State> s;
pair<int, int> trans[1500][1 << maxl]; // <state id, shift>
string training[maxn], query;
vector<int> match[maxn][27];

inline int getId(const State &nxt) {
    int val = 0;
    for (int i = 0; i < nxt.cnt; ++i)
        val = val * seed + nxt.s[i];

    if (ss.find(val) == ss.end()) {
        ss[val] = s.size();
        s.push_back(nxt);
    }
    return ss[val];
}

bool stateEqual(int a, int b) {
    return (a >> 3) == (b >> 3);
}

void build() {
    State now;
    now.cnt = 1; now.s[0] = 0;
    s.push_back(now);
    ss[0] = 0;
    
    int cur = 0;
    while (cur < s.size()) {
        int id = cur;
        now = s[cur++];

        for (int t = 0; t < (1 << maxl); ++t) {
            int cnt = 0;
            static int tmp[maxl * 3];
            for (int i = 0; i < now.cnt; ++i) {
                int a = now.s[i] >> 3, e = now.s[i] & 7;
                int m = t >> a, fz = -1;

                // match first char
                if (m & 1) {
                    tmp[cnt++] = ((a + 1) << 3) + e;

                // allow mismatch
                } else if (e < w) {
                    // delete input char
                    tmp[cnt++] = now.s[i] + 1;
                    // modify input char to match first char
                    tmp[cnt++] = ((a + 1) << 3) + e + 1;
                    while (m) {
                        ++fz;
                        if (m & 1) break;
                        m >>= 1;
                    }
                    if (fz >= 0 && fz + e <= w) {
                        // delete chars until first match
                        tmp[cnt++] = ((a + fz + 1) << 3) + e + fz;
                    }
                }
            }

            if (cnt > 0) {
                sort(tmp, tmp + cnt);
                cnt = unique(tmp, tmp + cnt, stateEqual) - tmp;
                int x = cnt;
                cnt = 1;
                for (int i = 1; i < x; ++i) {
                    while (cnt && ((tmp[cnt - 1] >> 3) + (tmp[cnt - 1] & 7) >= (tmp[i] >> 3) + (tmp[i] & 7)))
                        --cnt;
                    if ((tmp[cnt - 1] >> 3) - (tmp[cnt - 1] & 7) < (tmp[i] >> 3) - (tmp[i] & 7))
                        tmp[cnt++] = tmp[i];
                }

                State nxt;
                nxt.cnt = cnt;
                int shift = tmp[0] >> 3;
                for (int i = 0; i < cnt; ++i) tmp[i] -= shift << 3;
                memcpy(nxt.s, tmp, sizeof(int) * cnt);
                int nid = getId(nxt);
                trans[id][t] = make_pair(nid, shift);
            } else {
                trans[id][t] = make_pair(-1, 0);
            }
        }
    }
}

int calc(int no, int st) {
    int n = training[no].length();
    int now = 0, pos = 0, val = (n <= w), mat;
    while (now >= 0 && st < len) {
        char c = query[st] == '_' ? 26 : query[st] - 'a';
        ++st;
        mat = match[no][c][pos];
        pos = min(n, pos + trans[now][mat].second);
        now = trans[now][mat].first;
        if (now >= 0) {
            int t = s[now].s[s[now].cnt - 1];
            if (n - (pos + (t >> 3)) + (t & 7) <= w)
                ++val;
        } else break;
    }
    return val;
}

int main() {
    //freopen("../chatlog.in", "r", stdin);
    scanf("%d%d", &n, &w);
    mask = (1 << (w + w + 1)) - 1;
    for (int i = 0; i < n; ++i) {
        cin >> training[i];
        
        const string &str = training[i];
        int l = str.length();
        for (int ic = 0; ic < 27; ++ic) {
            char c = ic < 26 ? 'a' + ic : '_';
            vector<int> &vec = match[i][ic];
            vec.resize(l + 1);
            int val = 0;
            for (int e = 0; e < l && e < w + w + 1; ++e)
                val += (str[e] == c) << e;
            vec[0] = val;
            for (int e = w + w + 1; e < l + w + w; ++e) {
                char sc = e < l ? str[e] : '*';
                val = (val >> 1) + ((sc == c) << (w + w));
                vec[e - w - w] = val;
            }
        }
    }
    cin >> query;
    len = query.length();

    build();
    int ans = 0;
    for (int i = 0; i < n; ++i)
        for (int s = 0; s < (int)query.length(); ++s)
            ans += calc(i, s);
    printf("%d\n", ans);
}