POJ 1741 Tree
思路
男人八题中的一题,写完这题算是18\frac{1}{8}81个男人了!
这题是树上距离的计数问题,能够通过巧妙地排序加双指针来解决,
统计距离应该大家都会地,我就来说明一下如何计数吧。
假设我们已经求得距离并且排好序之后是这样的:
1335671\ 3\ 3\ 5\ 6\ 71 3 3 5 6 7,我们要求的距离是小于等于777的,
初始设置i=1,j=6i = 1, j= 6i=1,j=6,也就是一头一尾,
接下来我们判断如果dis[i]+dis[j]>7dis[i] + dis[j] > 7dis[i]+dis[j]>7就不断让j−−j--j−−这个时候答案的贡献就是j−ij - ij−i了,
简单模拟一下:
i=1,j=6,value[1]+value[6]=8>7,j−−value[1]+value[5]=7i = 1, j = 6, value[1] + value[6] = 8 > 7, j--value[1] + value[5] = 7i=1,j=6,value[1]+value[6]=8>7,j−−value[1]+value[5]=7,答案贡献j−ij - ij−i。
然后就是i++i++i++,接下来就是上面的重复步骤了。直到i==ji == ji==j。
然后这题就这样水过了。
代码
/*Author : lifehappy
*/
#pragma GCC optimize(2)
#pragma GCC optimize(3)
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <cstring>
#include <vector>
#include <stdlib.h>
#include <map>#define mp make_pair
#define pb push_back
#define endl '\n'
#define mid (l + r >> 1)
#define lson rt << 1, l, mid
#define rson rt << 1 | 1, mid + 1, r
#define ls rt << 1
#define rs rt << 1 | 1using namespace std;typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;const double pi = acos(-1.0);
const double eps = 1e-7;
const int inf = 0x3f3f3f3f;inline ll read() {ll f = 1, x = 0;char c = getchar();while(c < '0' || c > '9') {if(c == '-') f = -1;c = getchar();}while(c >= '0' && c <= '9') {x = (x << 1) + (x << 3) + (c ^ 48);c = getchar();}return f * x;
}const int N = 1e4 + 10;int head[N], to[N << 1], nex[N << 1], value[N << 1], cnt;int sz[N], visit[N], msz[N], dis[N], pre[N], now[N], tot, root, n, m, sum, ans;void add(int x, int y, int w) {to[cnt] = y;nex[cnt] = head[x];value[cnt] = w;head[x] = cnt++;
}void get_root(int rt, int fa) {sz[rt] = 1, msz[rt] = 0;for(int i = head[rt]; i; i = nex[i]) {if(to[i] == fa || visit[to[i]]) continue;get_root(to[i], rt);sz[rt] += sz[to[i]];msz[rt] = max(msz[rt], sz[to[i]]);}msz[rt] = max(msz[rt], sum - sz[rt]);if(msz[rt] < msz[root]) root = rt;
}void get_dis(int rt, int fa) {now[++tot] = dis[rt];for(int i = head[rt]; i; i = nex[i]) {if(to[i] == fa || visit[to[i]]) continue;dis[to[i]] = dis[rt] + value[i];get_dis(to[i], rt);}
}int calc(int rt) {tot = 0;now[++tot] = 0;int ans = 0;for(int i = head[rt]; i; i = nex[i]) {if(visit[to[i]]) continue;dis[to[i]] = value[i];int st = tot + 1;get_dis(to[i], rt);sort(now + st, now + tot + 1);for(int j = st, k = tot; j <= tot && j <= k; j++) {while(now[j] + now[k] > m && j < k) {k--;}ans += k - j;}}ans = -ans;sort(now + 1, now + tot + 1);for(int i = 1, j = tot; i <= tot && i <= j; i++) {while(now[j] + now[i] > m && j > i) {j--;}ans += j - i;}return ans;
}void solve(int rt) {visit[rt] = 1;ans += calc(rt);for(int i = head[rt]; i; i = nex[i]) {if(visit[to[i]]) continue;sum = sz[to[i]];root = 0, msz[0] = inf;get_root(to[i], rt);solve(root);}
}int main() {// freopen("in.txt", "r", stdin);// freopen("out.txt", "w", stdout);// ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);while(scanf("%d %d", &n, &m) && (n + m)) {memset(head, 0, sizeof head), cnt = 1, ans = 0;memset(visit, 0, sizeof visit);for(int i = 1; i < n; i++) {int x, y, w;scanf("%d %d %d", &x, &y, &w);add(x, y, w);add(y, x, w);}root = 0, msz[0] = inf, sum = n;get_root(1, 0);solve(root);printf("%d\n", ans);}return 0;
}