考虑设 \(f_{i, j}\) 为前 \(i\) 个人死了 \(j\) 个,由于不知道哪些人选了所以无法转移。原因是前面的决策会影响后面的决策,所以考虑贡献延后计算。
会发现一个事情,对于当前 \(c_x \leq j\) 的东西之后不会再决策所以对后面是没有影响的,这启发我们在 \(c_x = j\) 的时候结算贡献。
不妨设 \(f_{i, j, k}\),其中 \(k\) 为前面钦定的 \(c_x > j\) 的位置的个数。
转移分三种:
- \(s_{i+ 1} = 1\) 且 \(c_x > j\),\(f_{i, j, k} \to f_{i + 1, j, k + 1}\)
- \(s_{i + 1} = 0\) 且 \(c_x > j\),\(f_{i, j, k} \times \binom{k + 1}{l} \times \binom{cnt_{j + 1}}{l} \times l! \to f_{i + 1, j + 1, k + 1 - l}\)
- \(c_x \leq j\),\(f_{i, j, k} \times \binom{k}{l} \times \binom{cnt_{j + 1}}{l} \times l! \times [pre_j - (i - k)] \to f_{i + 1, j + 1, k - l}\)
其中 \(pre\) 和 \(cnt\) 分别是 \(\leq j\) 和 \(= j\) 的数量。
答案考虑枚举死了多少人,即 \(\sum_{i = 0}^{n - m} f_{n, i, n - pre_i} (n - pre_i)!\)
注意到 \(l\) 的总和不超过 \(n\),所以是 \(O(n^3)\) 的。
我们做完了,主要思路就是找到难做的地方消除这个难点。
#include <bits/stdc++.h>
using namespace std;
#define rep(i, a, b) for (int i = (a); i <= (b); i ++)
#define fro(i, a, b) for (int i = (a); i >= b; i --)
#define INF 0x3f3f3f3f
#define eps 1e-6
#define lowbit(x) (x & (-x))
#define reg register
#define IL inline
typedef long long LL;
typedef std::pair<int, int> PII;
inline int read() {int x = 0, f = 1;char ch = getchar();while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }while (ch >= '0' && ch <= '9') { x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar(); }return x * f;
}
// mt19937_64 sj(chrono::steady_clock::now().time_since_epoch().count());
// uniform_int_distribution<LL> u0(0, 1ll << 60);const int N = 510, Mod = 998244353;
int n, m, c[N], pre[N], cnt[N];
int C[N][N], fac[N];
int f[2][N][N];
char s[N];IL int add(int &a, int b) {return a = (a + b) % Mod;
}void init() {for (int i = 0; i <= n; i ++) {C[i][0] = 1;for (int j = 1; j <= i; j ++) C[i][j] = (C[i - 1][j - 1] + C[i - 1][j]) % Mod;}fac[0] = 1;for (int i = 1; i <= n; i ++) fac[i] = 1ll * fac[i - 1] * i % Mod;
}int main() {n = read(), m = read();scanf("%s", s + 1);for (int i = 1; i <= n; i ++) c[i] = read(), cnt[c[i]] ++;pre[0] = cnt[0];for (int i = 1; i <= n; i ++) pre[i] = pre[i - 1] + cnt[i];init();f[0][0][0] = 1;for (int i = 0; i < n; i ++) {for (int j = 0; j <= i; j ++) {for (int k = 0; k <= i; k ++) {if (!f[i & 1][j][k]) continue;if (s[i + 1] == '1') add(f[i + 1 & 1][j][k + 1], f[i & 1][j][k]);if (s[i + 1] == '0') {for (int l = 0; l <= min(cnt[j + 1], k + 1); l ++) {int g = 1ll * C[k + 1][l] * C[cnt[j + 1]][l] % Mod * fac[l] % Mod;add(f[i + 1 & 1][j + 1][k + 1 - l], 1ll * f[i & 1][j][k] * g % Mod);}}if (pre[j] - (i - k) < 0) continue; for (int l = 0; l <= min(cnt[j + 1], k); l ++) {int g = 1ll * C[k][l] * C[cnt[j + 1]][l] % Mod * fac[l] % Mod * (pre[j] - (i - k)) % Mod;add(f[i + 1 & 1][j + 1][k - l], 1ll * f[i & 1][j][k] * g % Mod);} }}memset(f[i & 1], 0, sizeof f[i & 1]);}int ans = 0;for (int i = 0; i <= n - m; i ++) ans = add(ans, 1ll * f[n & 1][i][n - pre[i]] * fac[n - pre[i]] % Mod);printf("%d\n", ans);return 0;
}