题意
对于 \(01\) 串 \(S\),我们不断删除 \(S\) 中的 \(\texttt{01}\) 子串,直至 \(S\) 中不含 \(\texttt{01}\) 子串,最终得到的串设为 \(f(S)\)。定义 \(V(S)=\sum\limits_{|T|=|S|}[f(S)=f(T)]\)。
给定长度为 \(n\) 的 \(01\) 串 \(S\),你需要支持:
- 区间 flip。
- 查询 \(V(S_{l\sim r})\)。
\(1\leq n,q\leq 5\times 10^5\)。
题解
没打比赛,看完题 \(5\,\text{min}\) 击杀了,貌似是不带脑子做法。
下面将 \(01\) 串视作括号串。
容易发现,\(f(S)\) 一定是 \(a\) 个 \(\texttt{)}\) 加上 \(b\) 个 \(\texttt{(}\) 的形式,否则一定能找到可消除的 \(\texttt{01}\) 子串。我们将其表示为 \(f(S)=(a,b)\)。
考虑如果给定 \(|S|\) 和 \(f(S)=(a,b)\) 怎么做。如果去掉这 \(a\) 个 \(\texttt{)}\) 和 \(b\) 个 \(\texttt{(}\),则 \(S\) 剩下的部分一定是合法的括号串。设 \(n=\dfrac{|S|-a-b}{2}\),考虑枚举 \(S\) 剩下的长度为 \(2n\) 的括号串可以被分割成 \(k\) 段极短合法括号串。
首先需要指出一个结论:长度为 \(2n\) 且可以被分割成 \(k\) 段极短合法括号串的括号串个数是 \(\dfrac{k}{n}\dbinom{2n-k-1}{n-1}\)。
证明
对于一个长度为 \(2n\) 的合法括号串 \(S\),我们将它每一段的末尾的右括号删去,得到一个长度为 \(2n-k\) 的括号串 \(T\)。将 \(\texttt{(}\) 视作 \(+1\),\(\texttt{)}\) 视作 \(-1\),则 \(T\) 对应的和为 \(k\)。而容易看出 \(T\) 中前缀和为 \(1,2,\cdots,k\) 的位置恰好各出现一次,对应于 \(S\) 中前缀和为 \(0\) 的位置。因此我们构造出了双射。
于是转而对前缀和始终 \(>0\)、和为 \(k\)、长度为 \(2n-k\) 的 \(\pm 1\) 序列计数。显然序列的第一项一定是 \(+1\),因此可以转化为对前缀和始终 \(\geq 0\)、和为 \(k-1\)、长度为 \(2n-k-1\) 的 \(\pm 1\) 序列计数。这就是是反射容斥能解决的形式了,容易得到答案为
还需要乘上将 \(k\) 段括号串塞进 \(a+b+1\) 个空隙里的方案数,容易插板法得出方案数为 \(\dbinom{k+a+b}{a+b}\)。因此答案为
来推一下式子吧!这个 \(k\) 放在这里很难受,容易想到把 \(a+b+1\) 换出来:
于是答案变为
显然后面的和式可以用上指标卷积化简,因此答案就是
\(\mathcal{O}(n)\) 预处理逆元和阶乘及其逆元,这样即可 \(\mathcal{O}(1)\) 计算答案。需要特判 \(n=0\),显然此时答案为 \(1\)。
现在我们只需维护出 \(S_{l\sim r}\) 对应的二元组 \((a,b)\) 即可。考虑线段树,在节点上维护区间对应子串的二元组,信息是好合并的,讨论一下左节点的右括号和右节点的左括号的合并即可。但是还要支持区间 flip,好像没法快速更改信息怎么办?使用经典技巧,同时维护当前节点原本的信息和 flip 之后的信息,信息依然是可并的,区间 flip 只需要打 tag 之后交换一下这两个信息即可。时间复杂度 \(\mathcal{O}((n+q)\log{n})\)。
代码
#include <bits/stdc++.h>using namespace std;using ll = long long;
using ull = unsigned long long;
using ld = long double;
using pii = pair<int, int>;
const int N = 5e5 + 5;
const int mod = 998244353;template<typename T> T lowbit(T x) { return x & -x; }
template<typename T> void chk_min(T &x, T y) { x = min(x, y); }
template<typename T> void chk_max(T &x, T y) { x = max(x, y); }
template<typename T> T add(T x, T y) { return x += y - mod, x += x >> 31 & mod; }
template<typename T> T sub(T x, T y) { return x -= y, x += x >> 31 & mod; }
template<typename T> void cadd(T &x, T y) { x += y - mod, x += x >> 31 & mod; }
template<typename T> void csub(T &x, T y) { x -= y, x += x >> 31 & mod; }int n, q, inv[N];
int fac[N], ifac[N];
string s;int qpow(int a, int b) {int res = 1;for (; b; b >>= 1) {if (b & 1) res = (ll)res * a % mod;a = (ll)a * a % mod;}return res;
}
void init(int n) {fac[0] = 1;for (int i = 1; i <= n; ++i) fac[i] = (ll)fac[i - 1] * i % mod;ifac[n] = qpow(fac[n], mod - 2);for (int i = n - 1; ~i; --i) ifac[i] = (ll)ifac[i + 1] * (i + 1) % mod;inv[1] = 1;for (int i = 2; i <= n; ++i) inv[i] = (ll)(mod - mod / i) * inv[mod % i] % mod;
}
int C(int n, int m) {return n < 0 || m < 0 || n < m ? 0 : (ll)fac[n] * ifac[m] % mod * ifac[n - m] % mod;
}struct SegTree {
#define ls(p) (p << 1)
#define rs(p) (p << 1 | 1)struct Node {int a, b, ra, rb;Node operator+(const Node &x) const {return {a + max(x.a - b, 0), max(b - x.a, 0) + x.b,ra + max(x.ra - rb, 0), max(rb - x.ra, 0) + x.rb};}Node &operator+=(const Node &x) { return *this = *this + x; }} nodes[N << 2];int tg[N << 2];void push_up(int p) { nodes[p] = nodes[ls(p)] + nodes[rs(p)]; }void make_rev(int p) { tg[p] ^= 1, swap(nodes[p].a, nodes[p].ra), swap(nodes[p].b, nodes[p].rb); }void push_down(int p) { if (tg[p]) make_rev(ls(p)), make_rev(rs(p)), tg[p] = 0; }void build(int p, int l, int r) {if (l == r) return nodes[p] = s[l] == '0' ? Node{0, 1, 1, 0} : Node{1, 0, 0, 1}, void();int mid = l + r >> 1;build(ls(p), l, mid), build(rs(p), mid + 1, r);push_up(p);}void rev(int p, int l, int r, int x, int y) {if (x <= l && y >= r) return make_rev(p);push_down(p);int mid = l + r >> 1;if (x <= mid) rev(ls(p), l, mid, x, y);if (y > mid) rev(rs(p), mid + 1, r, x, y);push_up(p);}Node query(int p, int l, int r, int x, int y) {if (x <= l && y >= r) return nodes[p];push_down(p);int mid = l + r >> 1;Node res = {0, 0, 0, 0};if (x <= mid) res = query(ls(p), l, mid, x, y);if (y > mid) res += query(rs(p), mid + 1, r, x, y);return res;}
#undef ls
#undef rs
} sgt;int main() {ios::sync_with_stdio(0), cin.tie(0);cin >> n >> s, s = '#' + s;init(n), sgt.build(1, 1, n);cin >> q;while (q--) {int tp, l, r; cin >> tp >> l >> r;if (tp == 1) sgt.rev(1, 1, n, l, r);else {auto [a, b, ra, rb] = sgt.query(1, 1, n, l, r);int n = r - l + 1 - a - b >> 1;if (!n) cout << "1\n";else cout << (ll)C(n * 2 + a + b, n + a + b + 1) * (a + b + 1) % mod * inv[n] % mod << '\n';}}return 0;
}