\(\text{CF1608F MEX counting 题解}\)
求方案数显然是考虑 dp。考虑每次填一个数时 MEX 的变化:显然不降,但增幅不确定。我们并没有办法通过状压等技巧维护各个数的取值,考虑在 MEX 变化的时候维护每次大于当前 MEX 值的数的个数,这样做的原因是运用贡献延后计算的思想,大概的思路是通过枚举两个 MEX 值直接从大于 MEX 值的数中取出一些个确定它们的取值。
然后你就会发现一个问题:我们要求的是在两个 MEX 之间值域所有的数都要取到,但我们维护这样的数的个数是难以统计的,因而转而维护大于 MEX 的数的种类数,这样我们的状态变化是唯一的。那么设 \(dp_{i,j,k}\) 表示填到第 \(i\) 个位置,当前 MEX 为 \(k\),有 \(j\) 个种类的数大于 MEX 的方案数。考虑以下的转移:
-
填一个小于 MEX 的数:\(k\times dp_{i,j,k}\to dp_{i+1,j,k}\)
-
填大于 MEX 的数,但是出现过了,那么 MEX 值没有变化:\(j\times dp_{i,j,k}\to dp_{i+1,j,k}\)
-
填大于 MEX 的数,但是没有出现过,那么显然有 \(dp_{i,j,k}\to dp_{i,j+1,k}\)
-
填当前的 MEX,那么设 \(t\) 为更新后的 MEX,那么要求 \((k,t)\) 都要有数字填到,于是转移的贡献是 \(dp_{i,j,k}\times{j\choose{t-k-1}}\times (t-k-1)!\to dp_{i+1,j-(t-k-1),t}\)。这里转移的含义是我们先只钦定出 \(t-k-1\) 个数来算贡献,多出来没有处理的数我们统计答案的时候统一计算。
然后这个东西的复杂度是 \(O(n^2k^2)\) 的。考虑优化。瓶颈显然在最后一种转移,容易发现的是转移 \(j,k\) 下标加和是一个递增 1 的定值,那么让 \(dp_{i,j,k}\) 表示 \(dp_{i,j+k,k}\) 来转移即可,这样前缀和优化便是容易的了。如果提交的语言选择的不好可能需要卡常。
代码:
#include <bits/stdc++.h>
using namespace std;
const int N = 2005, mod = 998244353;
void add(unsigned long long &x, unsigned long long y) {x = x + y >= mod ? x + y - mod : x + y;
}
int qpow(int x, int y) {int ans = 1;while (y) {if (y & 1) ans = 1ll * ans * x % mod;x = 1ll * x * x % mod;y >>= 1;}return ans;
}
unsigned long long fac[N], inv[N];int n, k;
int tl[N], tr[N];
unsigned long long dp[2][N][N], sm[2][N][N];signed main() {fac[0] = 1;for (int i = 1; i < N; i++) fac[i] = 1ll * fac[i - 1] * i % mod;inv[N - 1] = qpow(fac[N - 1], mod - 2);for (int i = N - 2; ~i; --i) inv[i] = 1ll * inv[i + 1] * (i + 1) % mod;ios::sync_with_stdio(0);cin.tie(0);cin >> n >> k;for (int i = 1; i <= n; i++) {int x;cin >> x;tl[i] = max(0, x - k);tr[i] = min(i, x + k);}dp[0][0][0] = sm[0][0][0] = 1;int p = 0;for (int i = 1; i <= n; i++) {p ^= 1;for (int j = 0; j <= i; j++)for (int k = tl[i]; k <= min(tr[i], j); k++) {dp[p][j][k] += dp[p ^ 1][j][k] * j;if (j) dp[p][j][k] += dp[p ^ 1][j - 1][k];if (j && k) dp[p][j][k] += sm[p ^ 1][j - 1][min(k - 1, tr[i - 1])] * inv[j - k];dp[p][j][k] %= mod;if (k) sm[p][j][k] = sm[p][j][k - 1];sm[p][j][k] += dp[p][j][k] * fac[j - k];sm[p][j][k] %= mod;}for (int j = 0; j < i; j++)for (int k = tl[i - 1]; k <= tr[i - 1]; k++)dp[p ^ 1][j][k] = sm[p ^ 1][j][k] = 0;}unsigned long long ans = 0;for (int j = 0; j <= n; j++)for (int k = tl[n]; k <= min(tr[n], j); k++)add(ans, dp[p][j][k] * fac[n - k] % mod * inv[n - j] % mod);cout << ans << '\n';return 0;
}