题目描述。
很不错的题。Tag:虚树、树链剖分、换根相关。
本文不讲解上述前置知识。
看到树上颜色段覆盖、查询,容易联想到树链剖分。
树上颜色段数量是不难统计的。先用树链剖分拍在序列上,然后区间只要维护颜色段数、左右颜色即可。
看到每次询问大小为 \(m\) 的子集,并不难联想到虚树相关。对于每次询问建立点集 \(S\) 的虚树。
钦定 \(1\) 号点为根,通过树链剖分 \(O(m \log n)\) 加和求出 \(W_1\) 即 \(1\) 号点的答案。
但是题目要求点集 \(S\) 中每个点的答案?这启发我们做换根。
具体地,记录一下虚树上每个子树有多少点在点集 \(S\) 内,然后换根其实和求“每个点到 \(u\) 的路径长度之和”是一样的,只是路径长度改成了路上的颜色段数。
变量重名挂得早,封装就是好!
代码中含有少量的魔怔,可自行去除注释。
- 在建虚树的时候,需要查询父亲到当前节点路径上的颜色段数和两端点颜色,这里注意不要合并反了,因为合并颜色段是不满足交换律的。
#include <bits/stdc++.h>
//#include <windows.h>
using namespace std;
const int N = 1e5 + 5, M = N << 1;
int n, q, a[N];
int h[N], e[M], ne[M], idx;
inline void add(int a, int b) { e[idx] = b, ne[idx] = h[a], h[a] = idx++; }struct Path { int a, b, c; } ; // ×óÓÒÑÕÉ«¡¢ÑÕÉ«¶ÎÊý
Path merge(Path x, Path y) { return (Path){x.a, y.b, x.c + y.c - (x.b == y.a)}; }void change(int u, int l, int r, int d);
Path query(int u, int l, int r);//----------------------------------------------------------------------------int dep[N], fa[N], sz[N], son[N];
void dfs(int u, int father) {fa[u] = father, dep[u] = dep[father] + 1, sz[u] = 1;for (int i = h[u]; ~i; i = ne[i]) {int v = e[i]; if (v == father) continue;dfs(v, u), sz[u] += sz[v];if (!son[u] || sz[son[u]] < sz[v]) son[u] = v;}
}
int top[N], dfn[N], id[N], tot;
void dfs2(int u, int t) {top[u] = t, dfn[u] = ++tot, id[tot] = u;if (son[u]) dfs2(son[u], t);for (int i = h[u]; ~i; i = ne[i]) {int v = e[i]; if (v == fa[u] || v == son[u]) continue;dfs2(v, v);}
}inline int lca(int a, int b) {while (top[a] ^ top[b]) {if (dep[top[a]] < dep[top[b]]) swap(a, b);a = fa[top[a]];}if (dep[a] > dep[b]) swap(a, b);return a;
}void change_path(int u, int v, int d) {while (top[u] ^ top[v]) {if (dep[top[u]] < dep[top[v]]) swap(u, v);
// cout << "Change: "; for (int i = dfn[top[u]]; i <= dfn[u]; i++) cout << id[i] << ' '; puts("");change(1, dfn[top[u]], dfn[u], d), u = fa[top[u]];}
// cout << "Change: "; for (int i = dfn[u]; i <= dfn[v]; i++) cout << id[i] << ' '; puts("");if (dep[u] > dep[v]) swap(u, v);change(1, dfn[u], dfn[v], d);
}
Path query_path(int u, int v) { // ÇÕ¶¨ u ÊÇ v µÄ׿ÏÈ Path res = (Path){-1, -1, -1};while (top[u] ^ top[v]) {Path val = query(1, dfn[top[v]], dfn[v]);if (res.a == -1) res = val;else res = merge(val, res);v = fa[top[v]];}Path val = query(1, dfn[u], dfn[v]);if (res.a == -1) res = val;else res = merge(val, res);return res;
}
//----------------------------------------------------------------------------struct Tree {int l, r, cov;Path res;
} tr[N << 2];
inline void pushup(int u) { tr[u].res = merge(tr[u << 1].res, tr[u << 1 | 1].res); }
inline void update(int u, int d) { tr[u].cov = d, tr[u].res = (Path){d, d, 1}; }
inline void pushdown(int u) { if (tr[u].cov) update(u << 1, tr[u].cov), update(u << 1 | 1, tr[u].cov), tr[u].cov = 0; }
void build(int u, int l, int r) {tr[u].l = l, tr[u].r = r, tr[u].cov = 0;if (l == r) return tr[u].res = (Path){a[id[l]], a[id[l]], 1}, void();int mid = l + r >> 1;build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);pushup(u);
}
void change(int u, int l, int r, int d) {if (tr[u].l >= l && tr[u].r <= r) return update(u, d), void();pushdown(u);int mid = tr[u].l + tr[u].r >> 1;if (l <= mid) change(u << 1, l, r, d);if (r > mid) change(u << 1 | 1, l, r, d);pushup(u);
}
Path query(int u, int l, int r) {if (tr[u].l >= l && tr[u].r <= r) return tr[u].res;pushdown(u);int mid = tr[u].l + tr[u].r >> 1;if (r <= mid) return query(u << 1, l, r);if (l > mid) return query(u << 1 | 1, l, r);return merge(query(u << 1, l, r), query(u << 1 | 1, l, r));
}int qwq[N];
void print(int u) {if (tr[u].l == tr[u].r) return qwq[id[tr[u].l]] = tr[u].res.a, void();pushdown(u);print(u << 1), print(u << 1 | 1);
}//----------------------------------------------------------------------------struct Virtual_Tree {int m, h[N], e[M], ne[M], idx;Path w[M]; // fa -> soninline void add(int a, int b, Path c) { e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx++; }int cnt[N];long long Ans[N], ans;void dfs1(int u) { for (int i = h[u]; ~i; i = ne[i]) dfs1(e[i]), cnt[u] += cnt[e[i]]; }void dfs2(int u) {Ans[u] = ans;
// cout << "Ans: " << u << ' ' << ans << endl;long long pre = ans;for (int i = h[u]; ~i; i = ne[i]) {int v = e[i];int in = cnt[v], out = m - in, a = w[i].a, b = w[i].b, c = w[i].c;ans -= in * 1ll * (c - 1), ans += out * 1ll * (c - 1);dfs2(v), ans = pre;}}void clr(int u) {cnt[u] = Ans[u] = 0;for (int i = h[u]; ~i; i = ne[i]) clr(e[i]);h[u] = -1;}
} VT;int arr[N], qry[N], m;
int stk[N], tt;
inline bool cmp(int a, int b) { return dfn[a] < dfn[b]; }
inline void insert(int u, int v) { // ÇÕ¶¨ u ÊÇ v µÄ׿ÏÈ Path w = query_path(u, v);
// cout << "Virtual Tree Add edge: " << u << ' ' << v << '\t' << w.a << ' ' << w.b << ' ' << w.c << endl;VT.add(u, v, w);
}
void build_VTree() {VT.ans = 0ll, VT.idx = 0;sort(arr + 1, arr + 1 + m, cmp);for (int i = 1; i <= m; i++) VT.cnt[arr[i]] = 1, VT.ans += query_path(1, arr[i]).c/*, cout << "Init: " << arr[i] << ' ' << query_path(1, arr[i]).c << endl*/;stk[tt = 1] = 1;for (int i = 1; i <= m; i++) {if (arr[i] == 1) continue;int u = arr[i], l = lca(arr[i], stk[tt]);while (dfn[l] < dfn[stk[tt - 1]]) insert(stk[tt - 1], stk[tt]), tt--;if (l == stk[tt]) stk[++tt] = u;else if (dfn[stk[tt - 1]] < dfn[l] && dfn[l] < dfn[stk[tt]]) insert(l, stk[tt]), tt--, stk[++tt] = l, stk[++tt] = u;else if (l == stk[tt - 1]) insert(stk[tt - 1], stk[tt]), tt--, stk[++tt] = u;}while (tt > 1) insert(stk[tt - 1], stk[tt]), tt--;
}int main() {
// freopen("rubbish.out", "w", stdout); scanf("%d%d", &n, &q);for (int i = 1; i <= n; i++) h[i] = VT.h[i] = -1; idx = VT.idx = 0;for (int i = 1; i <= n; i++) scanf("%d", &a[i]);for (int i = 1, a, b; i < n; i++) scanf("%d%d", &a, &b), add(a, b), add(b, a);dfs(1, 0), dfs2(1, 1), build(1, 1, n);// cout << "Dfn: "; for (int i = 1; i <= n; i++) cout << dfn[i] << ' '; puts("");
// cout << "Top: "; for (int i = 1; i <= n; i++) cout << top[i] << ' '; puts("");while (q--) {int op; scanf("%d", &op);if (op == 1) {int u, v, d; scanf("%d%d%d", &u, &v, &d);change_path(u, v, d);} else {scanf("%d", &m), VT.m = m;for (int i = 1; i <= m; i++) scanf("%d", &arr[i]), qry[i] = arr[i];build_VTree();VT.dfs1(1), VT.dfs2(1);
// SetConsoleTextAttribute(GetStdHandle(STD_OUTPUT_HANDLE), FOREGROUND_INTENSITY | FOREGROUND_BLUE);//????for (int i = 1; i <= m; i++) printf("%lld ", VT.Ans[qry[i]]); puts("");
// SetConsoleTextAttribute(GetStdHandle(STD_OUTPUT_HANDLE), FOREGROUND_INTENSITY |FOREGROUND_RED |FOREGROUND_GREEN | FOREGROUND_BLUE);//??????VT.clr(1);}
// cout << "Print: "; print(1); for (int i = 1; i <= n; i++) printf("%d ", qwq[i]); puts("");}return 0;
}