简单题。
题意:给出 \(n\) 个数 \(w_i\)。现在每轮删除一个数,假设现在有 \(i_1,i_2\cdots i_k\) 这些下标的数还在,那么对于 \(i_x\) 就有 \(\frac{w_{i_x}}{\sum w_{i_j}}\) 的概率删除。问 \(1\) 号元素最后一个被删除的概率。\(n\le 10^5,\sum w_i\le 10^5\)
做法:
首先因为有个 \(\sum w\) 在分母上,这个事情很烦,但是稍加考虑,我们其实可以认为每个数被删除的概率还是 \(\frac{w_i}{sum}\),这里 \(sum\) 是所有数总和,下文同。但是可以多次被删,执行无限轮,只是后面是无效删除而已,概率是不变的。
然后就是考虑最后一个被删这个条件了,显然很可以容斥,设 \(S\) 是我钦定他们必须在 \(1\) 后被删掉,那么答案应该就是 \(\sum\limits_S P(S)(-1)^{|S|}\),这里 \(P(S)\) 是 \(S\) 全在 \(1\) 前被删掉的概率。
考虑如何计算 \(P(S)\),那么就要求只要不碰到 \(1,S\) 这些元素随便选,记 \(s(S)\) 是集合 \(S\) 的 \(w\) 之和,那么概率为 \(\frac{sum-w_1-s(S)}{sum}\),枚举执行多少轮后 \(1\) 被删除,那么概率就是 \(\sum (\frac{sum-w_1-s(S)}{sum})^i\frac{w_1}{sum}\)。
把 \(\frac{w_1}{sum}\) 这个常量提出再用等比数列求和稍微化简,得到 \(P(S)=\frac{w_1}{w_1+s(S)}\),很漂亮的柿子。
注意到题目中有 \(\sum w_i \le 10^5\),考虑直接枚举 \(s(S)\) 然后计算容斥系数的贡献即可,这个直接用多项式分治乘去做就可以,复杂度 \(O(n\log^2 n)\)。
代码:
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int maxn = 5e5 + 5, mod = 998244353, gb = 3, gi = (mod + 1) / gb;
int rev[maxn];
void init(int n) {for (int i = 1; i < n; i++) {rev[i] = rev[i >> 1] >> 1;if(i & 1)rev[i] |= (n >> 1);}
}
int qpow(int x, int k, int p) {int res = 1;while(k) {if(k & 1)res = res * x % p;x = x * x % p, k >>= 1;}return res;
}
struct Poly {vector<int> a;int& operator[](int x) {return a[x];}void resize(int N) {a.resize(N);}int size() {return a.size();}Poly() {}Poly(int N) {a.resize(N);}void NTT(int f) {int n = size();for (int i = 0; i < n; i++)if(i < rev[i])swap(a[i], a[rev[i]]);for (int h = 2; h <= n; h <<= 1) {int d = qpow((f == 1 ? gb : gi), (mod - 1) / h, mod);for (int i = 0; i < n; i += h) {int nw = 1;for (int j = i; j < i + h / 2; j++) {int a0 = a[j], a1 = a[j + h / 2] * nw % mod;a[j] = (a0 + a1) % mod, a[j + h / 2] = (a0 - a1 + mod) % mod;nw = nw * d % mod;}}}if(f == -1) {int inv = qpow(n, mod - 2, mod);for (int i = 0; i < n; i++)a[i] = a[i] * inv % mod;}}friend Poly operator*(Poly f, Poly g) {int len = 1, t = f.size() + g.size() - 1;while(len < t)len <<= 1;init(len), f.resize(len), g.resize(len);f.NTT(1), g.NTT(1);for (int i = 0; i < len; i++)f[i] = f[i] * g[i] % mod;f.NTT(-1);f.resize(t);return f;}
};
int n, a[maxn];
Poly solve(int l, int r) {if(l == r) {Poly f(a[l] + 1); f[0] = 1, f[a[l]] = mod - 1;return f;}int mid = l + r >> 1;return solve(l, mid) * solve(mid + 1, r);
}
signed main() {cin >> n;for (int i = 1; i <= n; i++)cin >> a[i];Poly res = solve(2, n);int ans = 0;for (int i = 0; i < res.size(); i++)ans = (ans + qpow(a[1] + i, mod - 2, mod) * a[1] % mod * res[i]) % mod;cout << ans << endl;return 0;
}