比赛 2025.12.20 评测结果 EEEEEEEEEE
题目名称 cogito的树 最终得分 0
用户昵称 彭欣越 运行时间 2.380 s
代码语言 C++ 内存使用 24.81 MiB
提交时间 2025-12-20 12:26:56
显示代码纯文本
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=500010;
int n,m,mk[N],s[N],a[N];
ll ans;
vector<int>g[N];
int tot,head[N];
struct edge {
	int v,nxt;
}e[N*2];
void add (int u,int v) {
	e[++tot].v=v;
	e[tot].nxt=head[u];
	head[u]=tot;
}
void dfs (int u,int fa) {
	//cout << u <<endl;
	for (int i=head[u];i;i=e[i].nxt) {
		int v=e[i].v,p=0;
		if (v==fa||v<=m) continue;
		//if (fa!=0) p=1;
	    if (g[v].size()==s[v]-1) {
	    	//cout << g[v].size() <<' '<< g[u].size() <<endl;
	    	if (g[v].size()%2==1) {
	    		a[v]=g[v][g[v].size()/2];
	    		g[u].push_back(a[v]);
	    		sort(g[u].begin(),g[u].end());
			}else{
				int l=g[v][g[v].size()/2-1],r=g[v][g[v].size()/2];
				if (g[u].size()%2==1) {
	    			int t=g[u][g[u].size()/2];
	    			if (t>r) a[v]=r;
	    			else if (t<l) a[v]=l;
	    			else a[v]=t;
	    			g[u].push_back(a[v]);
	    			sort(g[u].begin(),g[u].end());
				}else{
					int l1=g[u][g[u].size()/2-1],r1=g[u][g[u].size()/2];
					if (r<l1) a[v]=r;
					else if (l>r1) a[v]=l;
					else a[v]=(l+r)/2;
				}
			}
			
		}else{
			dfs(v,u);
			if (g[v].size()%2==1) {
	    		a[v]=g[v][g[v].size()/2];
	    		g[u].push_back(a[v]);
	    		sort(g[u].begin(),g[u].end());
			}else{
				int l=g[v][g[v].size()/2-1],r=g[v][g[v].size()/2];
				if (g[u].size()%2==1) {
	    			int t=g[u][g[u].size()/2];
	    			if (t>r) a[v]=r;
	    			else if (t<l) a[v]=l;
	    			else a[v]=t;
	    			g[u].push_back(a[v]);
	    			sort(g[u].begin(),g[u].end());
				}else{
					int l1=g[u][g[u].size()/2-1],r1=g[u][g[u].size()/2];
					if (r<l1) a[v]=r;
					else if (l>r1) a[v]=l;
					else a[v]=(l+r)/2;
				}
			}
		}
	}
	if (!a[u]) a[u]=g[u][g[u].size()/2];
}
void get_ans (int u,int fa) {
	for (int i=head[u];i;i=e[i].nxt) {
		int v=e[i].v;
		if (v==fa) continue;
		ans+=abs(a[u]-a[v]);
		get_ans(v,u);
	}
}
int main () {
	freopen("starria.in","r",stdin);
	freopen("starria.out","w",stdout); 
    ios::sync_with_stdio(0);
    cin.tie(0),cout.tie(0);
    cin >> n >> m;
    for (int i=1;i<n;i++) {
    	int a,b;
    	cin >> a >> b;
    	add(a,b);
    	add(b,a);
    	s[a]++,s[b]++;
	}
	for (int i=1;i<=m;i++) {
		cin >> a[i];
	}
	for (int i=m+1;i<=n;i++) {
		//flag=0;
		for (int j=head[i];j;j=e[j].nxt) {
			int v=e[j].v;
			if (v>m) {
				continue;
			}
			g[i].push_back(a[v]);
		}
		sort(g[i].begin(),g[i].end());
	}
	dfs(m+1,0);
	get_ans(m+1,0);
	cout << ans <<endl;
    return 0;
}