| 比赛 |
2025.12.20 |
评测结果 |
EEEEEEEEEE |
| 题目名称 |
cogito的树 |
最终得分 |
0 |
| 用户昵称 |
彭欣越 |
运行时间 |
2.380 s |
| 代码语言 |
C++ |
内存使用 |
24.81 MiB |
| 提交时间 |
2025-12-20 12:26:56 |
显示代码纯文本
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=500010;
int n,m,mk[N],s[N],a[N];
ll ans;
vector<int>g[N];
int tot,head[N];
struct edge {
int v,nxt;
}e[N*2];
void add (int u,int v) {
e[++tot].v=v;
e[tot].nxt=head[u];
head[u]=tot;
}
void dfs (int u,int fa) {
//cout << u <<endl;
for (int i=head[u];i;i=e[i].nxt) {
int v=e[i].v,p=0;
if (v==fa||v<=m) continue;
//if (fa!=0) p=1;
if (g[v].size()==s[v]-1) {
//cout << g[v].size() <<' '<< g[u].size() <<endl;
if (g[v].size()%2==1) {
a[v]=g[v][g[v].size()/2];
g[u].push_back(a[v]);
sort(g[u].begin(),g[u].end());
}else{
int l=g[v][g[v].size()/2-1],r=g[v][g[v].size()/2];
if (g[u].size()%2==1) {
int t=g[u][g[u].size()/2];
if (t>r) a[v]=r;
else if (t<l) a[v]=l;
else a[v]=t;
g[u].push_back(a[v]);
sort(g[u].begin(),g[u].end());
}else{
int l1=g[u][g[u].size()/2-1],r1=g[u][g[u].size()/2];
if (r<l1) a[v]=r;
else if (l>r1) a[v]=l;
else a[v]=(l+r)/2;
}
}
}else{
dfs(v,u);
if (g[v].size()%2==1) {
a[v]=g[v][g[v].size()/2];
g[u].push_back(a[v]);
sort(g[u].begin(),g[u].end());
}else{
int l=g[v][g[v].size()/2-1],r=g[v][g[v].size()/2];
if (g[u].size()%2==1) {
int t=g[u][g[u].size()/2];
if (t>r) a[v]=r;
else if (t<l) a[v]=l;
else a[v]=t;
g[u].push_back(a[v]);
sort(g[u].begin(),g[u].end());
}else{
int l1=g[u][g[u].size()/2-1],r1=g[u][g[u].size()/2];
if (r<l1) a[v]=r;
else if (l>r1) a[v]=l;
else a[v]=(l+r)/2;
}
}
}
}
if (!a[u]) a[u]=g[u][g[u].size()/2];
}
void get_ans (int u,int fa) {
for (int i=head[u];i;i=e[i].nxt) {
int v=e[i].v;
if (v==fa) continue;
ans+=abs(a[u]-a[v]);
get_ans(v,u);
}
}
int main () {
freopen("starria.in","r",stdin);
freopen("starria.out","w",stdout);
ios::sync_with_stdio(0);
cin.tie(0),cout.tie(0);
cin >> n >> m;
for (int i=1;i<n;i++) {
int a,b;
cin >> a >> b;
add(a,b);
add(b,a);
s[a]++,s[b]++;
}
for (int i=1;i<=m;i++) {
cin >> a[i];
}
for (int i=m+1;i<=n;i++) {
//flag=0;
for (int j=head[i];j;j=e[j].nxt) {
int v=e[j].v;
if (v>m) {
continue;
}
g[i].push_back(a[v]);
}
sort(g[i].begin(),g[i].end());
}
dfs(m+1,0);
get_ans(m+1,0);
cout << ans <<endl;
return 0;
}