题目概述
给出一棵树,边权为 \(1\),现在让你求一个完全图的最大生成树的边权和,其中 \((i,j)\) 的边权为两个点在树上的路径。
动态加叶子节点。
分析
许多 trick 的题目。
考虑不加点怎么做,可以想到 Tree MST,用了点分治进行边优化,使得边的数量为 \(\mathcal{O}(n\log n)\),然后直接跑。
但是显然此题的限制更小。
由于是要求最大生成树,我们可以贪心地对于每一个点 \(i\) 连向他最远的一个点。
由此不难想到树的直径,这样每个点最远的点只可能是树的直径的两个端点。
但是现在有 \(n\) 条边,其实是有一条重复的直径,减去就行了。
然后静态的事情我们就考虑完了。
考虑动态的情况。
因为我们发现动态维护每个点的最远的点或者说是直接的距离似乎很难,不难想到可以拆路径。
有两个 \(trick\):
- \(p\) 连到最远的点 \(x\in\{s,t\}\)。
- \(p\rightarrow x\) 一定经过 \(s\rightarrow t\) 的中点 \(mid\)。
这样子我们就转化为了 \(p\rightarrow mid\rightarrow x\),那显然地,后面的我们是知道的,只需要维护 \(mid\) 就行了。
而这个 \(mid\) 有可能在边上,我们可以在每条边上加一个虚点就行了,最后答案除以 \(2\)。
注意到:
- 每一次加叶子节点所构成的新直径一定有一个端点是之前的直径就有的。
- \(mid\) 每次最多移动一条边。
因为移动一条边,我们考虑在原 \(mid\) 子树内外的情况距离一定是 \(+1\) 或者 \(-1\),所以说只需维护子树大小就行了,这个可以用树状数组简单维护。
真的很妙。
代码
时间复杂度 \(\mathcal{O}(nlogn+q\log n)\)。
#include <iostream>
#include <cstdio>
#include <cstring>
#include <stdlib.h>
#include <algorithm>
#include <vector>
#define int long long
#define N 800005
using namespace std;
int tr[N];
void update(int x,int val,int n) {for (;x <= n;x += x & -x) tr[x] += val;
}
int query(int x) {int res = 0;for (;x;x -= x & -x) res += tr[x];return res;
}
int querysum(int l,int r) {return query(r) - query(l - 1);
}
int n,q;
vector<int> g[N];
int dep[N],fa[N][25],yz[N],L[N],R[N];
int s,t,mid,ds,ans;
int cnt;
void dfs0(int cur,int father){fa[cur][0] = father;dep[cur] = dep[father] + 1;L[cur] = ++cnt;for (auto i : g[cur])if (i != father) dfs0(i,cur);R[cur] = cnt;
}
int LCA(int x,int y) {if (dep[x] < dep[y]) x ^= y ^= x ^= y;for (int j = 20;j >= 0;j --)if (dep[fa[x][j]] >= dep[y]) x = fa[x][j];if (x == y) return x;for (int j = 20;j >= 0;j --)if (fa[x][j] != fa[y][j]) x = fa[x][j],y = fa[y][j];return fa[x][0];
}
int dis(int x,int y) {return dep[x] + dep[y] - 2 * dep[LCA(x,y)];}
int get(int x,int step) {for (;step;step -= step & -step) x = fa[x][__lg(step & -step)];return x;
}
int middle(int x,int y) {if (dep[x] < dep[y]) x ^= y ^= x ^= y;return get(x,dis(x,y) / 2);
}
int diss[N];
void dfs(int cur,int father) {diss[cur] = diss[father] + 1;for (auto i : g[cur])if (i != father) dfs(i,cur);
}
int tot;
void add(int x,int y) {g[x].push_back(++tot);g[tot].push_back(y);g[y].push_back(tot);g[tot].push_back(x);
}
signed main(){cin >> n >> q;tot = n + q;for (int i = 1;i < n;i ++) {int u,v;scanf("%lld%lld",&u,&v);add(u,v);}for (int i = 1;i <= q;i ++) {scanf("%lld",&yz[i]);add(yz[i],n + i);}dfs0(1,0);for (int j = 1;j <= 20;j ++)for (int i = 1;i <= tot;i ++)fa[i][j] = fa[fa[i][j - 1]][j - 1];diss[0] = -1;s = 1;for (int i = 1;i <= n;i ++)if (dis(1, i) > dis(1, s)) s = i;t = s;for (int i = 1;i <= n;i ++)if (dis(s, i) > dis(s, t)) t = i;ds = dis(s, t);mid = middle(s,t);for (int i = 1;i <= n;i ++) update(L[i],1,tot),ans += dis(i,mid);// cout << ans << ' ' << ds << ' ' << s << ' ' << t << ' ' << mid << '\n';printf("%lld\n",(ans + ds / 2 * n - ds) / 2);for (int i = 1;i <= q;i ++) {n ++;update(L[n],1,tot);ans += dis(n,mid);pair<int,int> dis1({dis(s,n),s}),dis2({dis(t,n),t});auto nw = max(dis1,dis2);int nmid = mid;if (nw.first > ds) {ds = nw.first;s = nw.second;t = n;nmid = middle(s,t);}if (nmid != mid) {if (fa[mid][0] == nmid) {int sum = querysum(L[mid],R[mid]);ans += sum - (n - sum);}else {int sum = querysum(L[nmid],R[nmid]);ans += -sum + (n - sum);}}mid = nmid;printf("%lld\n",(ans + ds / 2 * n - ds) / 2);}return 0;
}