题意:给出一个大小为 \(n\) 的全集 \(A = \{1,2,\cdots n\}\),再给出 \(m\) 个集 \(S_1,S_2\cdots S_m\),要求从这些集里选出至多 \(k\) 个,满足 \(S\) 间没有交集且并集是全集,\(k\le n\le 21,m\le 262144\)。
做法:
以下多项式乘法若没有特殊说明均是多项式乘法而非子集卷积。
看到这个东西,没有交集求并集,很有子集卷积的感觉,先上占位多项式试一下。设 \(z_S(t) = \sum\limits_{i=0}^n t^i fwt(f_i)[s]\),也就是我们视 \(f_t = \sum\limits_{i=0}^{2^n-1} cnt_ix^i[\operatorname{popcount}(i) ==t]\),这个东西是一行,那么 z 就是一列并加上占位的出来的多项式。
考虑然后该怎么做,本来应该是直接对 \(f\) 这个拆成 \(n\) 位的东西直接做 \(k\) 次子集卷积得到,但是比较难算因为我们这里要的是无序对,需要除掉一个阶乘有点麻烦,同时我们没有必要每次 fwt 重新做,所以我们考虑先转成 \(z_S(t)\) 处理,因为 fwt 之后每一位就独立了,然后再 fwt 回来。
有这个 \(z_S(t)\) 还不够,我们还没算选 \(k\) 个集的事情,考虑直接枚举选了 \(i\) 个,那么就是 \(w_S(t) = \sum\limits_{i=1}^k\frac{z_S(t)^i}{i!}[t^n]\),感觉很像一个 exp 的 形式,但是这里被 \(k\) 限制了,但是我们仍然考虑对两侧求导。
我们这里为了方便,用 \(G=w_S(t),F=z_s(t)\),那么有
两侧求导,得到:
同一提出去一个 \(F'\):
注意到后面这个求和式和 \(G\) 相当像,只是差两位求和,直接写开得到:
直接对着这个柿子做就行,这里全都是多项式乘法,不要和我一样傻乎乎地想不明白还以为是位运算相关的。
这里还有一个问题是两侧都带有 \(G\),我们考虑每次先两侧模 \(x^t\),然后左侧求导所以可以算出来右侧 \(x^t\) 系数,往右加上再返回来推出 \(x^{t+1}\) 的系数,可以看看我的代码。
我因为比较懒,这里求 \(k\) 次直接用的是快速幂,但是也可以 \(O(n^2)\) 递推 ln, exp 但是因为这里 \(F\) 并不一定具有末尾为 1 的性质所以非常恶心。
但是这个做法不知道为什么我写的奇慢无比,本地能跑出 2min 的超快速度,这里把我的代码(也可以帮我改改怎么样卡过去 qwq)和一个我本地 30s 但是 loj 上 5s 的代码:
我的代码:
#include <bits/stdc++.h>
using namespace std;const int maxn = 22, N = (1 << 21), mod = 998244353;
int jc[N], revjc[N], inv[N];
int n, m, k, cnt[N], ans;
int add(int x, int y) {x = x + y;if(x > mod)return x - mod;return x;
}
int mul(int x, int y) {long long t = 1ll * x * y;return (t > mod ? t % mod : t);
}
struct Poly {int a[maxn * 2], n;int size() {return n;}Poly() {memset(a, 0, sizeof(a));}void resize(int N) {n = N;}int& operator[](int x) {return a[x];}friend Poly operator*(Poly f, Poly g) {Poly ans; ans.resize(f.size() + g.size() - 1);for (int i = 0; i < f.size(); i++) {if(!f[i])continue;for (int j = 0; j < g.size(); j++)ans[i + j] = add(ans[i + j], mul(f[i], g[j]));}return ans;}friend Poly operator+(Poly f, Poly g) {for (int i = 0; i < f.size(); i++)f[i] = add(f[i], g[i]);return f;}friend Poly operator*(Poly f, int v) {v = add(v, mod);for (int i = 0; i < f.size(); i++)f[i] = mul(f[i], v);return f;}friend Poly operator-(Poly f, Poly g) {return f + g * (-1);}
} f[N];
Poly qpow(Poly x, int k) {Poly res; res.resize(n + 1); res[0] = 1;while(k) {if(k & 1)res = res * x, res.resize(n + 1);x = x * x, x.resize(n + 1), k >>= 1;}return res;
}
Poly get_deriv(Poly f) {for (int i = 1; i <= f.size() - 1; i++)f[i - 1] = mul(i, f[i]);f[f.size() - 1] = 0;return f;
}
void fwt(int n, int v) {for (int h = 2; h <= n; h <<= 1)for (int i = 0; i < n; i += h)for (int j = i; j < i + h / 2; j++)f[j + h / 2] = (f[j + h / 2] + (f[j] * v));
}
int read() {int sum = 0;char c = getchar();while(!isdigit(c))c = getchar();while(isdigit(c))sum = sum * 10 + c - '0', c = getchar();return sum;
}
signed main() {n = read(), m = read(), k = read();for (int i = 0; i < (1 << n); i++) f[i].resize(n + 1), cnt[i] = cnt[i >> 1] + (i & 1);jc[0] = 1, inv[0] = inv[1] = 1, revjc[0] = 1;for (int i = 1; i <= n; i++)jc[i] = mul(jc[i - 1], i);for (int i = 2; i <= n; i++)inv[i] = mul((mod - mod / i), inv[mod % i]);for (int i = 1; i <= n; i++)revjc[i] = mul(revjc[i - 1], inv[i]);
// cout << 123 << endl;for (int i = 1; i <= m; i++) {int x = read();f[x][cnt[x]]++;}fwt(1 << n, 1);for (int i = 0; i < (1 << n); i++) {Poly g1 = qpow(f[i], k), g2 = get_deriv(f[i]);for (int i = 0; i < g1.size(); i++)g1[i] = mul(g1[i], revjc[k]), g1[i] = mod - g1[i], g1[i] %= mod;Poly nw = g2 * g1 + g2;for (int j = 1; j <= n - 1; j++) {int del = mul(inv[j], nw[j - 1]);for (int k = n; k >= j; k--)g2[k] = g2[k - 1];g2[j - 1] = 0;nw = nw + g2 * del;}if(cnt[(1 << n) - 1 - i] & 1)ans = add(ans, mod - mul(nw[n - 1], inv[n]));elseans = add(ans, mul(nw[n - 1], inv[n]));}cout << ans << endl;return 0;
}
AC 代码:
#include <cstdio>
#include <iostream>
#include <cstring>
#include <vector>
#include <algorithm>
using namespace std;
typedef long long int_;
# define rep(i,a,b) for(int i=(a); i<=(b); ++i)
# define drep(i,a,b) for(int i=(a); i>=(b); --i)
inline int readint() {int a = 0;char c = getchar(), f = 1;for (; c < '0' || c > '9'; c = getchar())if (c == '-')f = -f;for (; '0' <= c && c <= '9'; c = getchar())a = (a << 3) + (a << 1) + (c ^ 48);return a * f;
}const int Mod = 998244353;namespace Math {
const int MaxN = 262150;
int jc[MaxN], inv[MaxN];
int jc_inv[MaxN];
void import(int n = MaxN - 1) {jc[1] = inv[1] = 1;rep(i, 2, n) {jc[i] = 1ll * jc[i - 1] * i % Mod;inv[i] = (0ll + Mod - Mod / i) * inv[Mod % i] % Mod;}jc_inv[0] = jc[0] = 1;rep(i, 1, n)jc_inv[i] = 1ll * jc_inv[i - 1] * inv[i] % Mod;
}
inline int C(int n, int m) {if (n < m || m < 0)return 0;return 1ll * jc[n] * jc_inv[m] % Mod * jc_inv[n - m] % Mod;
}
}const int MaxN = 22;
struct Poly {int a[MaxN];int &operator[](const int &x) {return a[max(0, x)];}void operator += (const Poly &b) {rep(i, 0, MaxN - 1)a[i] = (a[i] + b.a[i]) % Mod;}
};
Poly operator * (const int &v, const Poly &f) {Poly g;rep(i, 0, MaxN - 1)g[i] = 1ll * f.a[i] * v % Mod;return g;
}
Poly operator * (const Poly &f, const Poly &g) {Poly h;memset(h.a, 0, MaxN << 2);rep(i, 0, MaxN - 1) if (f.a[i])rep(j, 0, MaxN - 1 - i)h[i + j] = (h[i + j] + 1ll * f.a[i] * g.a[j]) % Mod;return h;
}
Poly qkpow(Poly b, int q) {Poly a;a[0] = 1;rep(i, 1, MaxN - 1) a[i] = 0;for (; q; q >>= 1, b = b * b)if (q & 1)a = a * b;return a;
}void FWT(Poly a[], int n) {rep(i, 0, n - 1) rep(j, 0, (1 << n) - 1)if (j >> i & 1)a[j] += a[j ^ (1 << i)];
}Poly a[1 << MaxN];
int cnt[1 << MaxN];
int main() {int n = readint(), m = readint();int k = readint();Math::import();rep(i, 1, (1 << n) - 1) // bitcntcnt[i] = cnt[i >> 1] + (i & 1);rep(i, 1, m) {int x = readint();++ a[x][cnt[x]];}FWT(a, n);int ans = 0;Poly g_, now, gg;for (int i = 0; i < (1 << n); ++i) {rep(j, 1, n) g_[j - 1] = a[i][j] * j;g_[n] = 0;gg = qkpow(a[i], k);rep(j, 0, n) gg[j] = 1ll * gg[j] * Math::jc_inv[k] % Mod;now = gg * g_;rep(j, 0, n) now[j] = Mod - now[j];now += g_;rep(j, 1, n - 1) {int delta = 1ll * Math::inv[j] * now[j - 1] % Mod;drep(k, n, j) g_[k] = g_[k - 1]; // shiftg_[j - 1] = 0;now += delta * g_;}if (cnt[((1 << n) - 1)^i] & 1)ans = (ans + Mod - 1ll * now[n - 1] * Math::inv[n] % Mod) % Mod;elseans = (ans + 1ll * now[n - 1] * Math::inv[n]) % Mod;}printf("%d\n", ans);return 0;
}