隐身术:挺巧妙的一道题。
首先考虑一个暴力:将“子串”的条件转化为对每一个后缀的前缀考虑,枚举每一个后缀。然后对每一个后缀做一个编辑距离的 DP,统计答案即可。
具体地,编辑距离的 DP 状态定义为:\(dp_{i, j}\) 表示 \(S_{1\sim i}, T_{1\sim j}\) 之间的最小步数。转移为:
- 当 \(S_{i+1} = T_{j+1}\),即后一个字符相等时,直接转移:\(dp_{i+1, j + 1}\overset{\min}\leftarrow dp_{i, j}\)。
- 对 \(T\) 执行删除操作 / 对 \(S\) 执行插入操作:此时编辑距离加 \(1\),\(T\) 的匹配长度增加了 \(1\),因此:\(dp_{i, j + 1}\overset{\min}\leftarrow dp_{i, j} + 1\)。
- 对 \(T\) 执行插入操作 / 对 \(S\) 执行删除操作:此时编辑距离加 \(1\),\(S\) 的匹配长度增加了 \(1\),因此:\(dp_{i +1, j }\overset{\min}\leftarrow dp_{i, j} + 1\)。
- 对 \(T\) 执行插入操作 / 对 \(S\) 执行插入操作:此时编辑距离加 \(1\),\(S, T\) 的匹配长度增加了 \(1\),因此:\(dp_{i +1, j + 1}\overset{\min}\leftarrow dp_{i, j} + 1\)。
考虑正解。注意到 \(\bm K\) 的值很小,容易想到 DP 优化中一个常见的套路:交换 DP 定义域与值域。并且要求该定义域关于原值域具有单调性。
原 DP 的任何一个维度显然关于编辑距离没有单调性。因此可以先找出编辑距离的一个单调性,然后再根据这个单调性找到对应的定义域,重新设计一个 DP。
我们发现,当两个串 \(A, B\) 确定时,如果把 \(A, B\) 的某个相同长度的后缀删除,那么编辑距离一定不会增加。因此,当两个串 \(\bm{A, B}\) 的长度之差确定时,\(\bm{A}\) 的长度和 \(\bm{B}\) 的长度关于编辑距离单调不降。
由此可以得到一个优化后的 DP:\(dp_{i, j}\) 表示编辑距离为 \(i\),\(\bm{|T| - |S| = j}\) 时最长 \(\bm S\) 能延伸的长度。转移如下:
- 当 \(S_{dp_{i, j}+1} = T_{dp_{i, j} + j +1}\),即后若干个字符相等时,直接让 DP 值自增,直到后一个字符不相等。即求 \(\bm{S, T}\) 某两个后缀的 LCP,可以通过后缀数组 / 哈希二分实现。
- 对 \(T\) 执行删除操作 / 对 \(S\) 执行插入操作:此时编辑距离加 \(1\),\(T\) 的延伸长度增加了 \(1\),\(|T| - |S|\) 增加了 \(1\),\(S\) 延伸的长度不增加,因此:\(dp_{i +1, j + 1}\overset{\max}\leftarrow dp_{i, j}\)。
- 对 \(T\) 执行插入操作 / 对 \(S\) 执行删除操作:此时编辑距离加 \(1\),\(S\) 的延伸长度增加了 \(1\),\(|T| - |S|\) 减少了 \(1\),因此:\(dp_{i +1, j - 1}\overset{\max}\leftarrow dp_{i, j} + 1\)。
- 对 \(T\) 执行替换操作 / 对 \(S\) 执行替换操作:此时编辑距离加 \(1\),\(S, T\) 的延伸长度增加了 \(1\),\(|T| - |S|\) 不变,因此:\(dp_{i +1, j}\overset{\max}\leftarrow dp_{i, j} + 1\)。
就此 DP,如果使用 SA 求 LCP 那么时间复杂度就是 \(O(nk^2)\)。但如果使用哈希二分求 LCP 那么时间复杂度就是 \(O(nk^2\log n)\)。
哈希二分 Version:
#include <bits/stdc++.h>
#define fi first
#define se second
#define eb(x) emplace_back(x)
#define pb(x) push_back(x)
#define lc(x) (tr[x].ls)
#define rc(x) (tr[x].rs)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double ldb;
using pi = pair<int, int>;
const int N = 100005, K = 10;
const ll mod = 998244853, base = 13163;
mt19937_64 rnd((unsigned) time(NULL));
ll rd(ll l, ll r)
{return uniform_int_distribution<ll> (l, r) (rnd);
}
int n, m, p;
char s[N], t[N];
ll hs1[N], hs2[N], pw[N], cw[N], ans;
void do_hash()
{pw[0] = 1;for(int i = 1; i < N; i++) pw[i] = (pw[i - 1] * base) % mod;for(int i = 0; i < 30; i++) cw[i] = rd(1, base - 1);for(int i = 1; i <= n; i++)hs1[i] = (hs1[i - 1] * base % mod + cw[s[i] - 'A']) % mod;for(int i = 1; i <= m; i++)hs2[i] = (hs2[i - 1] * base % mod + cw[t[i] - 'A']) % mod;
}
ll geths1(int l, int r)
{return ((hs1[r] - (hs1[l - 1] * pw[r - l + 1] % mod)) % mod + mod) % mod;
}
ll geths2(int l, int r)
{return ((hs2[r] - (hs2[l - 1] * pw[r - l + 1] % mod)) % mod + mod) % mod;
}
ll LCP(int x, int y)
{int l = 1, r = min(n - x + 1, m - y + 1);while(l < r){int mid = (l + r + 1) >> 1;if(geths1(x, x + mid - 1) == geths2(y, y + mid - 1)) l = mid;else r = mid - 1;}if(geths1(x, x + l - 1) != geths2(y, y + l - 1)) return 0;return l;
}
int dp[K][2 * K];
void solve(int sx)
{memset(dp, -0x3f, sizeof(dp));dp[0][p] = 0;int mxv = 0;set<int> s;for(int k = 0; k <= p; k++){for(int j = 0; j <= 2 * p; j++){if(dp[k][j] < 0) continue;dp[k][j] = min(dp[k][j], min(n, m - sx - j + p + 1));int lcp = LCP(1 + dp[k][j], sx + dp[k][j] + j - p);dp[k][j] = (dp[k][j] + lcp);if(dp[k][j] == n){if(sx + n - 1 + j - p >= sx && sx + n - 1 + j - p <= m)s.insert(sx + n - 1 + j - p);}mxv = max(mxv, dp[k][j]);if(k == p) continue;dp[k + 1][j] = max(dp[k + 1][j], dp[k][j] + 1);if(j + 1 <= 2 * p) dp[k + 1][j + 1] = max(dp[k + 1][j + 1], dp[k][j]);if(j - 1 >= 0) dp[k + 1][j - 1] = max(dp[k + 1][j - 1], dp[k][j] + 1);}}ans += s.size();
}
int main()
{//freopen("sample.in", "r", stdin);//freopen("sample.out", "w", stdout);ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);cin >> p;cin >> s + 1 >> t + 1;n = strlen(s + 1);m = strlen(t + 1);do_hash();for(int i = 1; i <= m; i++)solve(i);cout << ans;return 0;
}
后缀数组 Version:
#include <bits/stdc++.h>
#define fi first
#define se second
#define eb(x) emplace_back(x)
#define pb(x) push_back(x)
#define lc(x) (tr[x].ls)
#define rc(x) (tr[x].rs)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double ldb;
using pi = pair<int, int>;
const int N = 400005, K = 10;
int n, m, p;
int len, cm, x[N], y[N], c[N], sa[N];
ll ans;
char s[N], t[N];
void getsa()
{cm = 127;int i, k;for(i = 1; i <= len; i++) c[x[i] = s[i]]++;for(i = 1; i <= cm; i++) c[i] += c[i - 1];for(i = len; i >= 1; i--) sa[c[x[i]]--] = i;for(k = 1; k <= len; k <<= 1){memset(c, 0, sizeof(c));for(i = 1; i <= len; i++) y[i] = sa[i];for(i = 1; i <= len; i++) c[x[y[i] + k]]++;for(i = 1; i <= cm; i++) c[i] += c[i - 1];for(i = len; i >= 1; i--) sa[c[x[y[i] + k]]--] = y[i];memset(c, 0, sizeof(c));for(i = 1; i <= len; i++) y[i] = sa[i];for(i = 1; i <= len; i++) c[x[y[i]]]++;for(i = 1; i <= cm; i++) c[i] += c[i - 1];for(i = len; i >= 1; i--) sa[c[x[y[i]]]--] = y[i];for(i = 1; i <= len; i++) y[i] = x[i];for(cm = 0, i = 1; i <= len; i++){if(y[sa[i]] == y[sa[i - 1]] && y[sa[i] + k] == y[sa[i - 1] + k]) x[sa[i]] = cm;else x[sa[i]] = ++cm;}if(cm == len) return;}
}
int rk[N], ht[N];
void getht()
{for(int i = 1; i <= len; i++) rk[sa[i]] = i;for(int i = 1, k = 0; i <= len; i++){if(k) k--;int j = sa[rk[i] - 1];while(i + k <= len && j + k <= len && s[i + k] == s[j + k]) k++;ht[rk[i]] = k;}
}
int st[20][N], lg2[N];
void init()
{for(int i = 1; i <= len; i++) st[0][i] = ht[i];lg2[1] = 0;for(int i = 2; i < N; i++) lg2[i] = (lg2[i >> 1] + 1);for(int j = 1; j < 20; j++)for(int i = 1; i + (1 << j) - 1 <= len; i++)st[j][i] = min(st[j - 1][i], st[j - 1][i + (1 << (j - 1))]);
}
int LCP(int x, int y)
{y += n + 1;x = rk[x]; y = rk[y];if(x > y) swap(x, y);x++;int lenx = y - x + 1;int lg = lg2[lenx];return min(st[lg][x], st[lg][y - (1 << lg) + 1]);
}
int dp[K][2 * K], cnt, dmdf[N];
bitset<N> vis;
void solve(int sx)
{memset(dp, -0x3f, sizeof(dp));dp[0][p] = 0;int mxv = 0;cnt = 0;for(int k = 0; k <= p; k++){for(int j = 0; j <= 2 * p; j++){if(dp[k][j] < 0) continue;dp[k][j] = min(dp[k][j], min(n, m - sx - j + p + 1));int lcp = LCP(1 + dp[k][j], sx + dp[k][j] + j - p);dp[k][j] = (dp[k][j] + lcp);if(dp[k][j] == n){if(sx + n - 1 + j - p >= sx && sx + n - 1 + j - p <= m){if(vis[sx + n - 1 + j - p] == 0){vis[sx + n - 1 + j - p] = 1;dmdf[++cnt] = sx + n - 1 + j - p;ans++;}}}mxv = max(mxv, dp[k][j]);if(k == p) continue;dp[k + 1][j] = max(dp[k + 1][j], dp[k][j] + 1);if(j + 1 <= 2 * p) dp[k + 1][j + 1] = max(dp[k + 1][j + 1], dp[k][j]);if(j - 1 >= 0) dp[k + 1][j - 1] = max(dp[k + 1][j - 1], dp[k][j] + 1);}}for(int i = 1; i <= cnt; i++) vis[dmdf[i]] = 0;
}
int main()
{//freopen("sample.in", "r", stdin);//freopen("sample.out", "w", stdout);ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);cin >> p;cin >> s + 1 >> t + 1;n = strlen(s + 1);m = strlen(t + 1);len = n + m + 1;s[n + 1] = '#';for(int i = n + 2, j = 1; i <= len; i++, j++) s[i] = t[j];getsa();getht();init();for(int i = 1; i <= m; i++)solve(i);cout << ans;return 0;
}