配对
巨大困难题,赛时只会了启发式合并的部分分。
观察匹配的路径,显然有一个结论:所有匹配的路径不会经过同一条边。
所以可以先考虑不带修、\(n\) 为偶数的做法:对整棵树进行遍历,遍历到 \(u\) 的子树时如果遇到多个儿子中有多个未被匹配的儿子,则将他们自行进行匹配。如果未被匹配的节点有奇数个,则最后剩下哪一个都是一样的。因为此时匹配的贡献是他们到 \(u\) 的距离。
由上面的部分分做法,可以得到另外两个结论:
- 对于同一颗子树内的节点,将他们之间自行匹配掉一定比与其他子树内的节点匹配更优。
- 同一颗子树内,最多只存在一个未被匹配的点与其余子树内的点匹配。
和 P3177 相似的思想,我们可以从每条边的贡献的角度思考问题。发现如果 \(u\) 的子树内存在一个未被匹配的点,那么 \(u\to fa\) 的那条边就一定会被某条匹配的路径覆盖。
再将交换的操作转化,注意到它等效于将一个 \(c = 1\) 的点变为 \(c = 0\) 的点,同时又把一个 \(c = 0\) 的点变为 \(c = 1\) 的点,并且当匹配点为奇数时剩下的一个点也可以转化为将一个 \(c = 1\) 的点变为 \(c = 0\) 的点的操作。所以可以从反转颜色的角度去思考交换操作。
贡献会算了,由此就可以设计树形 DP 了。定义 \(dp_{u, i, j}\) 表示 \(u\) 子树内进行了 \(i\) 次 \(1\to 0\) 的操作,\(j\) 次 \(0\to 1\) 的操作。其中 \(i\in\{0, 1, 2\}, j\in \{0, 1\}\)。
转移是简单的,不需要大量的分类讨论,类似树上背包对两个节点的 DP 值进行 merge。然后在出某个节点的时候,对子树内存在未匹配点的情况特判,增加 \(u\to fa\) 这条边的贡献即可。
当 \(n\) 为偶数的时候,答案就是 \(\min\{dp_{1, 0, 0}, dp_{1, 1, 1}\}\);当 \(n\) 为奇数的时候,答案就是 \(\min\{dp_{1, 1, 0}, dp_{1, 2, 1}\}\),因为有一个始终无法被匹配的点。
时间复杂度 \(O(n)\)。
#include <bits/stdc++.h>
#define fi first
#define se second
#define eb(x) emplace_back(x)
#define pb(x) push_back(x)
#define lc(x) (tr[x].ls)
#define rc(x) (tr[x].rs)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double ldb;
using pi = pair<int, int>;
const int N = 1000005;
ll n, oric[N], c[N], dp[N][3][2], f[3][2];
vector<pi> g[N];
void merge(int u, int v)
{memcpy(f, dp[u], sizeof(f));memset(dp[u], 0x3f, sizeof(dp[u]));for(int a = 0; a <= 2; a++)for(int b = 0; a + b <= 2; b++)for(int c = 0; c <= 1; c++)for(int d = 0; c + d <= 1; d++)dp[u][a + b][c + d] = min(dp[u][a + b][c + d], f[a][c] + dp[v][b][d]);
}
void dfs(int u, int fa, ll wf)
{if(oric[u] == 0) dp[u][0][0] = dp[u][0][1] = 0;else dp[u][1][0] = dp[u][0][0] = 0;for(auto eg : g[u]){int v = eg.fi, w = eg.se;if(v == fa) continue;dfs(v, u, w);c[u] ^= c[v];merge(u, v);}if(u != 1){if(c[u]){dp[u][0][0] += wf;dp[u][1][1] += wf;dp[u][2][0] += wf;}else{dp[u][0][1] += wf;dp[u][1][0] += wf;dp[u][2][1] += wf;}}
}
int main()
{ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);cin >> n;int cnt = 0;for(int i = 1; i <= n; i++){cin >> c[i];oric[i] = c[i];cnt += c[i];}for(int i = 1; i < n; i++){int u, v, w;cin >> u >> v >> w;g[u].push_back({v, w});g[v].push_back({u, w});}memset(dp, 0x3f, sizeof(dp));dfs(1, 0, 0);if(cnt & 1) cout << min(dp[1][1][0], dp[1][2][1]);else cout << min(dp[1][0][0], dp[1][1][1]);return 0;
}