题目大意
给定一棵树,每个节点有一个括号。对于每个节点 \(i\),定义 \(s_i\) 为从根节点到 \(i\) 的路径上所有括号按顺序组成的字符串。求每个 \(s_i\) 中互不相同的合法括号子串的个数 \(k_i\)。
思路
首先,\(k_i\) 可以从父节点递推得到,\(k_i=k_{f_i}+a_i\)。其中 \(a_i\) 为以节点 \(i\) 结尾的合法括号序列数量。因此只要求出每个节点的 \(a\)。
以 ( 为 \(1\) ) 为 \(−1\) 做树上前缀和,设点 \(u\)
的前缀和为 \(sum_u\)。则以 \(u\) 结尾的合法括号子串的开头 \(v\) 需要满足:
- \(sum_{f_v}=sum_u\)。
- 对于 \(v\to u\) 这条链上的所有点 \(x\),有 \(sum_x\ge sum_u\)。
在 DFS 过程中开一棵值域线段树维护 \(1\to u\) 这条链上每个 \(sum\) 值对应的最大节点深度。这样就能找到 \(sum_p<sum_u\) 且深度最大的节点 \(p\)。
设 \(ask(x,y)\) 表示 \(1\to x\) 链上 \(sum=y\) 的节点数量。则 \(a_u=ask(f_u,k)-ask(p,k)\)。
第一遍 DFS 求出所有询问并离线下来。
第二遍 DFS 求出所有点的 \(a\)。
第三遍 DFS 对 \(a\) 做树上前缀和得到所有点的 \(k\) 即可。
Code
#include <bits/stdc++.h>
#define rept(i,a,b) for(int i(a);i<=b;++i)
#define ls(p) ((p)<<1)
#define rs(p) ((p)<<1|1)
#define eb emplace_back
#define int long long
using namespace std;
constexpr int N=5e5+5;
struct Query{int k,coef,id;// k:目标值// coef:贡献系数,1/-1// id:贡献给到的节点Query(int _k,int _coef,int _id):k(_k),coef(_coef),id(_id){}
};
struct SegTree{int t[N<<3];void update(int p,int pl,int pr,int pos,int x){ // 单点修改if(pl==pr) return void(t[p]=x);int mid=pl+pr>>1;if(pos<=mid) update(ls(p),pl,mid,pos,x);else update(rs(p),mid+1,pr,pos,x);t[p]=max(t[ls(p)],t[rs(p)]);}int query(int p,int pl,int pr,int l,int r){ // 区间maxif(l>r) return 0;if(l<=pl&&pr<=r) return t[p];int mid=pl+pr>>1,a=0;if(l<=mid) a=max(a,query(ls(p),pl,mid,l,r));if(mid<r) a=max(a,query(rs(p),mid+1,pr,l,r));return a;}
}sgt;
char s[N];
int sum[N],dep[N],cnt[N<<1],a[N],st[N];
int n,m,ans;
vector<int> g[N];
vector<Query> q[N];
void dfs1(int u){int lst=sgt.query(1,1,m,sum[u],sum[u]);sgt.update(1,1,m,sum[u],dep[u]); st[dep[u]]=u;for(int v:g[u]){sum[v]=sum[u]+(s[v]=='('?1:-1);dep[v]=dep[u]+1;if(s[v]==')'){int bound=sgt.query(1,1,m,1,sum[v]-1);q[u].eb(sum[v],1,v);if(bound) q[st[bound]].eb(sum[v],-1,v);}dfs1(v);}sgt.update(1,1,m,sum[u],lst);
}
void dfs2(int u){++cnt[sum[u]];for(Query x:q[u]){a[x.id]+=x.coef*cnt[x.k];}for(int v:g[u]) dfs2(v);--cnt[sum[u]];
}
void dfs3(int u){for(int v:g[u]){a[v]+=a[u];dfs3(v);}ans^=u*a[u];
}
signed main(){cin.tie(0)->sync_with_stdio(0);cin>>n>>s+1;m=n<<1;rept(i,2,n){int x;cin>>x;g[x].eb(i);}g[0].eb(1);sum[0]=n,dep[0]=1; // 为了不出负数,sum统一加上ndfs1(0),dfs2(0),dfs3(0);cout<<ans;return 0;
}