| 比赛 |
2025.12.20 |
评测结果 |
AAAAAAAAAA |
| 题目名称 |
cogito的树 |
最终得分 |
100 |
| 用户昵称 |
LikableP |
运行时间 |
1.089 s |
| 代码语言 |
C++ |
内存使用 |
13.78 MiB |
| 提交时间 |
2025-12-20 11:30:16 |
显示代码纯文本
#include <cstdio>
#include <cctype>
template <typename T> T read() {
T res = 0, f = 1;
char ch = getchar();
for (; !isdigit(ch); ch = getchar()) if (ch == '-') f = -1;
for (; isdigit(ch); ch = getchar()) res = (res << 3) + (res << 1) + (ch ^ 48);
return res * f;
}
#include <algorithm>
typedef long long ll;
const int MAXN = 5e5 + 10;
struct EDGE {
int u, v, next;
} edge[MAXN << 1];
int head[MAXN], edgeNum;
void AddEdge(int u, int v) {
edge[++edgeNum] = {u, v, head[u]};
head[u] = edgeNum;
}
struct Weight {
int IntervalLeft, IntervalRight;
} weight[MAXN];
ll ans;
int sortArray[MAXN << 1], scnt;
void dfs(int u, int fa) {
// fprintf(stderr, "%d %d\n", u, fa);
for (int i = head[u]; i; i = edge[i].next) {
int v = edge[i].v;
if (v == fa) continue;
dfs(v, u);
}
scnt = 0;
for (int i = head[u]; i; i = edge[i].next) {
int v = edge[i].v;
if (v == fa) continue;
sortArray[++scnt] = weight[v].IntervalLeft;
sortArray[++scnt] = weight[v].IntervalRight;
}
if (!scnt) return;
std::sort(sortArray + 1, sortArray + scnt + 1);
weight[u].IntervalLeft = sortArray[scnt / 2];
weight[u].IntervalRight = sortArray[scnt / 2 + 1];
// fprintf(stderr, "node[%d] updated: left=%d, right=%d\n", u, weight[u].IntervalLeft, weight[u].IntervalRight);
for (int i = head[u]; i; i = edge[i].next) {
int v = edge[i].v;
if (v == fa) continue;
// fprintf(stderr, "\tweight[%d].left=%d, .right=%d\n", v, weight[v].IntervalLeft, weight[v].IntervalRight);
if (weight[v].IntervalRight < weight[u].IntervalLeft) {
ans += weight[u].IntervalLeft - weight[v].IntervalRight;
} else if (weight[v].IntervalLeft > weight[u].IntervalLeft) {
ans += weight[v].IntervalLeft - weight[u].IntervalLeft;
}
}
}
int n, m;
int main() {
#ifdef LOCAL
freopen("!input.in", "r", stdin);
freopen("!output.out", "w", stdout);
#elif !defined(ONLINE_JUDGE)
freopen("starria.in", "r", stdin);
freopen("starria.out", "w", stdout);
#endif
n = read<int>(), m = read<int>();
for (int i = 1; i <= n - 1; ++i) {
int u = read<int>(), v = read<int>();
AddEdge(u, v);
AddEdge(v, u);
}
for (int i = 1; i <= m; ++i) {
weight[i].IntervalLeft = weight[i].IntervalRight = read<int>();
}
dfs(n, 0);
printf("%lld\n", ans);
return 0;
}