记录编号 597392 评测结果 AAAAAAAAAAAAAAAAAAAAAAAAA
题目名称 魔法试练场 最终得分 100
用户昵称 Gravatar┭┮﹏┭┮ 是否通过 通过
代码语言 C++ 运行时间 5.212 s
提交时间 2024-11-27 12:13:03 内存使用 20.36 MiB
显示代码纯文本
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define pii pair<int,ll>
#define fi first
#define in inline
#define se second
#define mp make_pair
#define pb push_back
const int N = 2e5+10;

ll read(){
	ll x = 0,f = 1;char c = getchar();
	for(;c < '0' || c > '9';c = getchar())if(c == '-')f = -1;
	for(;c >= '0' && c <= '9';c = getchar())x = (x<<1) + (x<<3) + c-'0';
	return x * f;
}


int n,p;
vector<int>e[N];
int rt,sum,cnt,len;
int a[N],siz[N],mx[N],premx[N];//mx[0] = inf
ll dep[N];
int cn[10000010];
ll ans = 0;
struct node{
	ll d,mx;
	bool operator < (const node a)const{return mx < a.mx;}
}c[N],del[N];
bool v[N];

void calsize(int x,int fa){
	siz[x] = mx[x] = 1;
	for(int y : e[x]){
		if(y == fa || v[y])continue;
		calsize(y,x);
		siz[x] += siz[y];
		mx[x] = max(mx[x],siz[y]);
	}
	mx[x] = max(mx[x],sum - siz[x]);
	if(mx[x] < mx[rt])rt = x;
}
void caldep(int x,int fa){
	c[++cnt] = {dep[x] % p,premx[x]};
	del[++len] = {dep[x] % p,premx[x]};
	for(int y : e[x]){
		if(y == fa || v[y])continue;
		dep[y] = (dep[x] + a[y]) % p;
		premx[y] = max(premx[x],a[y]);
		caldep(y,x);
	}
}
void cal(int x){
	sort(c+1,c+1+cnt);
	cn[0]++;ans++;
	for(int i = 1;i <= cnt;i++){
		ans += cn[((p - c[i].d + c[i].mx + p - a[x]) % p + p) % p];
		cn[c[i].d]++;
		//p | a[x] + cn + c[i].d - c[i].mx
	}
	cn[0]--;
	for(int i = 1;i <= cnt;i++)cn[c[i].d]--;
}
int su = 0;
void dfs(int x){
    ++su;
    assert(su <= n);
	cnt = 0;
	for(int y : e[x]){
		if(v[y])continue;
		len = 0;
		premx[y] = max(a[x],a[y]),dep[y] = a[y];
		caldep(y,0);
		sort(del+1,del+1+len);
		for(int i = 1;i <= len;i++){
			ans -= cn[((p - del[i].d + del[i].mx - a[x]) % p + p) % p];
			cn[del[i].d]++;
		}
		for(int i = 1;i <= len;i++)cn[del[i].d]--;
	}
	cal(x);
	for(int y : e[x]){
		if(v[y])continue;
		sum = siz[y],rt = 0;
		calsize(y,0),calsize(rt,0),v[rt] = 1;
		dfs(rt);
	}
}


int main(){
	freopen("mplace.in","r",stdin);
	freopen("mplace.out","w",stdout);
	n = read(),p = read();
	for(int i = 1;i < n;i++){
		int x = read(),y = read();
		e[x].pb(y),e[y].pb(x);
	}
	for(int i = 1;i <= n;i++)a[i] = read();
	mx[0] = 1e8,sum = n;
	calsize(1,0),calsize(rt,0),v[rt] = 1;
	dfs(rt);
	printf("%lld\n",ans);

	return 0;

}