观察函数,发现 \(f(1) = a_n, f(n) = a_1, ......\) 其实描述了一种对应关系,如果一个对应矛盾则该序列不合法。
考虑 \(O(n^2)\) 的暴力怎么写,枚举区间的中点,向左右拓展维护是否合法,和已知的对应关系。
发现这个过程和求回文串很像,考虑 manacher,manacher 的一个重要性质就是 box 左右对称的两个区间的回文半径在没有超出 box 范围的情况下是一样的,不难验证新“回文”也满足这个性质。
所以在没有超出 box 范围的情况下直接继承回文半径,否则需要将回文半径缩到 box 内,直接搞要上可持久化数据结构,不过我们可以证明,每次暴力一个一个缩复杂度均摊是 \(O(n)\) 的,每缩一个新的贡献可以找到后继来快速计算。
证明:每次缩一个都会使 box 的左端点向右移动一格,因为当前点的最长回文半径包含的区间一定会成为新的 box(两个 box 如果右端点相同取左端点最大的),而只有拓展回文半径的操作会减少左端点,不过右端点始终递增,所以这个操作最多减少 \(n\) 个,均摊下来最多会有 \(2n\) 次缩的操作。
时间复杂度线性。
#include <iostream>
#include <cstring>
#include <algorithm>
#include <queue>
#include <ctime>
#define int long long
#define x first
#define y second
using namespace std;
typedef pair<int, int> PII;
typedef long double ld;
const int N = 6e5 + 10, mod = 998244353;int n, m, a[N], b[N], idx, f[N], pw[N], ne[N], pre[N], lst[N], d[N], ans[N], cnt[N];signed main() {ios::sync_with_stdio(0), cin.tie(0);clock_t stt = clock();cin >> n >> m; pw[0] = 1;for(int i = 1; i < N; i ++) pw[i] = pw[i - 1] * m % mod;for(int i = 1; i <= n; i ++) cin >> a[i];for(int i = 1; i <= n; i ++)b[++ idx] = a[i], b[++ idx] = 0;for(int i = 0; i <= m; i ++) lst[i] = idx + 5;for(int i = idx; i; i --) {ne[i] = lst[b[i]];pre[lst[b[i]]] = i;lst[b[i]] = i;}for(int i = 1, l = 1, r = 0; i <= idx; i ++) {if(i <= r) {d[i] = d[l + r - i], ans[i] = ans[l + r - i], cnt[i] = cnt[l + r - i];while(i + d[i] - 1 > r) {int p = l + r - i;if(b[p - d[i] + 1]) {ans[i] = (ans[i] - pw[m - cnt[i]]) % mod;if(ne[p - d[i] + 1] > idx || ne[p - d[i] + 1] >= p + d[i] - 1) {cnt[i] -= 1 + (b[p - d[i] + 1] != b[p + d[i] - 1]);} }d[i] --;}}while(i - d[i] >= 1 && i + d[i] <= idx) {int x = (ne[i - d[i]] > i + d[i] - 1 ? 0 : b[i * 2 - (ne[i - d[i]])]), y = (pre[i + d[i]] < i - d[i] + 1 ? 0 : b[i * 2 - (pre[i + d[i]])]);if(!b[i - d[i]]) {d[i] ++;continue;}if(!x && !y) {if(b[i - d[i]]) {cnt[i] += (b[i - d[i]] != b[i + d[i]]) + 1, ans[i] = (ans[i] + pw[m - cnt[i]]) % mod;}d[i] ++;}else {if(x == b[i + d[i]] && y == b[i - d[i]]) ans[i] = (ans[i] + pw[m - cnt[i]]) % mod, d[i] ++;else break;}}if(i + d[i] - 1 >= r) r = i + d[i] - 1, l = i - d[i] + 1;}int res = 0;for(int i = 1; i <= idx; i ++)(res += ans[i]) %= mod;cout << (res + mod) % mod << '\n';cerr << (1.0 * clock() - stt) / CLOCKS_PER_SEC << "s\n";return 0;
}
神秘题