好题好题。
我们先对一个结点 \(u\) 进行分析。
发现能对 \(u\) 产生贡献的所有结点可以构成一个联通分量。
只有经过 \(u\) 才会对 \(u\) 产生贡献。
而我们不可能将一条链上的所有点都扔到 \(u\) 上,这显然不现实,肯定是进行计算,算出来的,而且光一个结点都需要 \(O(n^2)\) 的内存,明显很没有前途。
考虑一条链的两端 \(l,r\) 挂在 \(u\) 上,类似一个虚树,考虑有 \(n\) 个叶子的时候我们是如何计算一颗虚树的大小的。
先对其的 dfs 序排序,\(siz = \sum dep_{a_i} - \sum dep_{LCA(a_i,a_i+1)} - dep _{LCA(a_1,a_n)}\)。dfs 序从小到大放,加入 dep 的时候会与 dfs 前面那一个产生交错, 这一部分就是 \(dep_{LCA(a_i,a_{i-1})}\) 。 最后我们因为钦定了 1 为根,去掉的话就是 \(-dep_{LCA_{i=1}^n a_i}\) 有dfs 序的性质可知其等于 \(-dep_{LCA(a_1,a_n)}\)。
好了那我们需要将所有的结点都算出来,明显可以树上差分。而我们发现如上的操作刚好可以使用线段树来维护,那树上的线段树肯定会想到线段树合并,不然这空间都得炸。
代码:
#include<bits/stdc++.h>
#define hnczy "language"
using namespace std;
const int N=5E5+5;
int st[20][N],dfn[N],lg[N];
int dep[N],fa[N],cnt,n,m;
vector<int>e[N],del[N];long long ans;
int calc(int x,int y){return dep[x]<dep[y]?x:y;}
int LCA(int x,int y){if(!x || !y)return 0;if(x==y)return x;x=dfn[x],y=dfn[y];if(x>y)swap(x,y);int tmp=lg[y-x++];return calc(st[tmp][x],st[tmp][y-(1<<tmp)+1]);
}
struct SEG{struct node{int ls,rs,L,R,val,c;}c[N<<4];#define ls(p) c[p].ls#define rs(p) c[p].rs#define mid ((l+r)>>1) int rt[N],tot;void pushup(int p){c[p].L= c[ls(p)].L?c[ls(p)].L:c[rs(p)].L ;c[p].R= c[rs(p)].R?c[rs(p)].R:c[ls(p)].R ;c[p].val =c[ls(p)].val +c[rs(p)].val - dep[LCA(c[ls(p)].R,c[rs(p)].L )] ;}int query(int u){return c[u].val - dep[LCA(c[u].L,c[u].R)] ;}void change(int &p,int l,int r,int x,int w){if(!p)p = ++tot; if(l==r){c[p].c +=w;c[p].val = (c[p].c ? dep[x]:0);c[p].L = c[p].R = (c[p].c ? x:0);return ;}if(dfn[x]<=mid)change(ls(p),l,mid,x,w);else change(rs(p),mid+1,r,x,w);pushup(p); }void merge(int &p,int q,int l,int r){if(!p || !q)return p|=q,void();if(l==r)return c[p].c += c[q].c , c[p].val |= c[q].val, c[p].L |= c[q].L , c[p].R |= c[q].R,void() ;merge(ls(p),ls(q),l,mid),merge(rs(p),rs(q),mid+1,r);pushup(p);}
}seg;void dfs1(int u,int F){dep[u] =dep[F] +1,fa[u] =F;st[0][dfn[u] = ++cnt]=F;for(int v:e[u])if(v!=F)dfs1(v,u);
}void dfs2(int u){for(int v:e[u]) if(v!=fa[u])dfs2(v);for(int v:del[u]) seg.change(seg.rt[u],1,cnt,v,-1);ans+=seg.query(seg.rt[u]);seg.merge(seg.rt[fa[u]],seg.rt[u],1,cnt);
}
int main(){freopen(hnczy ".in", "r", stdin);freopen(hnczy ".out", "w", stdout);scanf("%d%d",&n,&m);for(int i=1,u,v;i<n;i++){scanf("%d%d",&u,&v);e[u].push_back(v);e[v].push_back(u); }dfs1(1,0);for(int i=1;i<=19;i++)for(int j=1;j+(1<<i-1)<=n;j++)st[i][j] = calc(st[i-1][j],st[i-1][j+(1<<i-1)]);for(int i=2;i<=n;i++)lg[i] = lg[i>>1]+1; for(int i=1,u,v;i<=m;i++){scanf("%d%d",&u,&v);int lca= LCA(u,v);seg.change(seg.rt[u],1,cnt,u,1);seg.change(seg.rt[u],1,cnt,v,1);seg.change(seg.rt[v],1,cnt,u,1);seg.change(seg.rt[v],1,cnt,v,1);del[lca] .push_back(u),del[lca].push_back(v);del[fa[lca]] .push_back(u),del[fa[lca]].push_back(v);}dfs2(1);cout<<ans/2; return 0;
}