题目大意
给定一个序列,定义其权值为划分序列的方案数,使得划分出来的每个区间 \([l, r]\) 有 \(\max_{i = l}^r {a_i} \leq r - l + 1\) 。对于每个 \(1 \leq i \leq n\) 求只将 \(a_i\) 修改为 \(1\) ,序列的权值。
做法详解
首先我们先想一想假如不修改 \(a_i\) 怎么做。对于这种序列划分的问题一般来说都是线性dp ,这道题也不例外,令 \(dp_i\) 表示划分 \([1, i]\) 的方案数,暴力转移显然有
当我们考虑转移时发现限制转移的状态不连续,没有单调性,十分难处理,所以考虑dp中的暴力优化——CDQ分治优化。
对于 \([l, r]\) ,我们从 \([l, mid]\) 向 \([mid + 1, r]\) 转移,那么我们首先要处理原本的区间最大值,由于转移一定是会跨过 \(mid\) 从左向右,那么我们可以维护 \([l, mid]\) 的后缀最大值数组 \(lmx\) 和 \([mid + 1, r]\) 的前缀最大值数组 \(rmx\),那么我们转移条件就变成了 \(max\{ lmx_j, rmx_i \} \leq i - j + 1\) ,拆分一下有
移项
我们发现两个式子左边只跟 \(i\) 有关,右边只跟 \(j\) 有关,所以将 \([l, mid]\) 按 \(lmx_j + j - 1\) 排序,然后双指针保证第一个式子,将 \(j - 1\) 拍到树状数组上,对于 \(i\) 查询树状数组 \([1, i - rmx_i]\) 的和,加到 \(dp_i\) 上。
现在不带修改做完了,我们想想带修改的怎么做。我们发现将 \(a_i\) 修改为 \(1\) 相当于是将 \(i\) 的限制给消削掉了,所以答案中每一种方案一定包含什么都不改的序列的划分方案,所以我们考虑增加了那些方案。
首先我们不可能将 \(a_i\) 改了后再跑一遍dp(这不废话),同时对于划分序列正着dp倒着dp没有影响,\(i\) 也只包含与一个区间,所以我们可以预处理两个dp数组 \(f_i\) 和 \(g_i\) 分别表示 \([1, i]\) 的划分方案数和 \([i, n]\) 的划分方案数。
那么增加的方案显然有 \(\sum_{l = 1}^i \sum_{r = i}^n f_{l - 1} g_{r + 1} \times [\max_{j = l}^r \{ a_j \} = a_i > r - l + 1 \wedge \text{secondmax}_{j = l}^r \{ a_j \} \leq r - l + 1]\) 。
同样,这个条件很难直接做,我们就把其扔到CDQ里面考虑。对于 \([l, r]\) ,我们依照预处理dp那样维护 \(lmx\) ,\(rmx\) ,再额外维护 \(selmx\) ,\(sermx\) ,以及 \(p_i\) 表示 \(i\) 的前缀/后缀最大值位置。考虑枚举 \(l/r\)(以 \(r\) 为例),首先,我们可以肯定的是当前所算出的方案数增加量是加在 \(p_r\) 的答案上的,然后,因为我们要同时保证 \(a_{p_i} > r - l + 1 \wedge \text{secondmax}_{j = l}^r \{ a_j \} \leq r - l + 1\) ,如果我们把 \(r - rmx_r \geq l - 1\) 拉出来排序那就需要跑两遍,还要加到同一个计算点上@#!%&#@*……总之,就是S上加S,所以我们把 \(r \geq lmx_l + l - 1\) 拉出来排序,然后将 \(l - 1\) 拍到树状数组上,权值为 \(f_{l - 1}\),查询时就可以简单地查询树状数组上 \((r - rmx_r, r - sermx_r]\) 的和,加到 \(ans_{p_r}\) 上即可。
就这样
for (int i = mid + 1, j = l;i <= r;++i) {while (j <= mid && lmx[id[j]] + id[j] <= i + 1) {BIT.change(id[j], f[id[j] - 1]);stk[++top] = id[j];++j;}trans(ans[p[i]], (BIT.query(i - rsemx[i] + 1) - BIT.query(i - rmx[i] + 1) + mod) * g[i + 1] % mod);
}
Solution
代码和讲的有些 \(+1\)、\(-1\) 的位置不一样,读者移下项就可以发现于解法中狮子是一样的。
还有代码输入输出跟本题有些不同,因为笔者是在另一个地方做的,见谅。
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 2e5 + 5;
const int mod = 998244353;
int n, q;
int a[N];
LL f[N], g[N];
#define lowbit(x) (x & -x)
inline void trans(LL& x, LL y) { x = (x + y) % mod; }
struct BinaryTree {LL c[N << 1];inline void change(int x, LL k) { for (;x <= n + 1;x += lowbit(x)) trans(c[x], k); }inline void clean(int x) { for (;x <= n + 1;x += lowbit(x)) c[x] = 0; }inline LL query(int x) {LL ret = 0;for (;x > 0;x -= lowbit(x)) trans(ret, c[x]);return ret;}
}BIT;
int id[N], lmx[N], rmx[N], lsemx[N], rsemx[N], p[N];
inline void reset(int l, int r) { for (int i = l;i <= r;++i) id[i] = i; }
inline bool cmp(int i, int j) { return lmx[i] + i < lmx[j] + j; }
inline bool cmp2(int i, int j) { return i - rmx[i] + 1 < j - rmx[j] + 1; }
int stk[N], top;
inline void solve(int l, int r, LL* dp) {if (l == r) {if (a[l] == 1) trans(dp[l], dp[l - 1]);return;}int mid = l + r >> 1;solve(l, mid, dp);lmx[mid] = a[mid];for (int i = mid - 1;i >= l;--i) lmx[i] = max(lmx[i + 1], a[i]);rmx[mid + 1] = a[mid + 1];for (int i = mid + 2;i <= r;++i) rmx[i] = max(rmx[i - 1], a[i]);sort(id + l, id + mid + 1, cmp);for (int i = mid + 1, j = l;i <= r;++i) {while (j <= mid && lmx[id[j]] + id[j] <= i + 1) {BIT.change(id[j], dp[id[j] - 1]);stk[++top] = id[j];++j;}trans(dp[i], BIT.query(i - rmx[i] + 1));}reset(l, mid);while (top > 0) BIT.clean(stk[top--]);solve(mid + 1, r, dp);
}
LL ans[N], d[N];
inline void solve2(int l, int r) {if (l == r) return trans(ans[l], (a[l] > 1) * f[l - 1] * g[l + 1]);int mid = l + r >> 1;solve2(l, mid);solve2(mid + 1, r);lmx[mid] = a[mid];lsemx[mid] = 0;p[mid] = mid;for (int i = mid - 1;i >= l;--i) {lmx[i] = lmx[i + 1], lsemx[i] = lsemx[i + 1];p[i] = p[i + 1];if (a[i] > lmx[i]) lsemx[i] = lmx[i], lmx[i] = a[i], p[i] = i;else if (a[i] > lsemx[i]) lsemx[i] = a[i];}rmx[mid + 1] = a[mid + 1];p[mid + 1] = mid + 1;rsemx[mid + 1] = 0;for (int i = mid + 2;i <= r;++i) {rmx[i] = rmx[i - 1], rsemx[i] = rsemx[i - 1];p[i] = p[i - 1];if (a[i] > rmx[i]) rsemx[i] = rmx[i], rmx[i] = a[i], p[i] = i;else if (a[i] > rsemx[i]) rsemx[i] = a[i];}sort(id + l, id + mid + 1, cmp);for (int i = mid + 1, j = l;i <= r;++i) {while (j <= mid && lmx[id[j]] + id[j] <= i + 1) {BIT.change(id[j], f[id[j] - 1]);stk[++top] = id[j];++j;}trans(ans[p[i]], (BIT.query(i - rsemx[i] + 1) - BIT.query(i - rmx[i] + 1) + mod) * g[i + 1] % mod);}reset(l, mid);while (top > 0) BIT.clean(stk[top--]);sort(id + mid + 1, id + r + 1, cmp2);for (int i = mid, j = r;i >= l;--i) {while (j > mid && id[j] - rmx[id[j]] + 1 >= i) {BIT.change(id[j] + 1, g[id[j] + 1]);stk[++top] = id[j] + 1;--j;}trans(ans[p[i]], ((BIT.query(n + 1) - BIT.query(i + lsemx[i] - 1) + mod) - (BIT.query(n + 1) - BIT.query(i + lmx[i] - 1)) + mod) * f[i - 1] % mod);}reset(mid + 1, r);while (top > 0) BIT.clean(stk[top--]);
}
int main() {freopen("divide.in", "r", stdin);freopen("divide.out", "w", stdout);scanf("%d%d", &n, &q);for (int i = 1;i <= n;++i) scanf("%d", &a[i]);for (int i = 1;i <= n;++i) id[i] = i;f[0] = g[0] = 1;solve(1, n, f);for (int i = 1;i <= n;++i) trans(ans[i], f[n]);reverse(a + 1, a + n + 1);solve(1, n, g);reverse(a + 1, a + n + 1);reverse(g + 1, g + n + 1);g[n + 1] = 1;solve2(1, n);while (q--) {int x;scanf("%d", &x);printf("%lld\n", ans[x]);}return 0;
}