记录编号 |
597392 |
评测结果 |
AAAAAAAAAAAAAAAAAAAAAAAAA |
题目名称 |
魔法试练场 |
最终得分 |
100 |
用户昵称 |
┭┮﹏┭┮ |
是否通过 |
通过 |
代码语言 |
C++ |
运行时间 |
5.212 s |
提交时间 |
2024-11-27 12:13:03 |
内存使用 |
20.36 MiB |
显示代码纯文本
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define pii pair<int,ll>
#define fi first
#define in inline
#define se second
#define mp make_pair
#define pb push_back
const int N = 2e5+10;
ll read(){
ll x = 0,f = 1;char c = getchar();
for(;c < '0' || c > '9';c = getchar())if(c == '-')f = -1;
for(;c >= '0' && c <= '9';c = getchar())x = (x<<1) + (x<<3) + c-'0';
return x * f;
}
int n,p;
vector<int>e[N];
int rt,sum,cnt,len;
int a[N],siz[N],mx[N],premx[N];//mx[0] = inf
ll dep[N];
int cn[10000010];
ll ans = 0;
struct node{
ll d,mx;
bool operator < (const node a)const{return mx < a.mx;}
}c[N],del[N];
bool v[N];
void calsize(int x,int fa){
siz[x] = mx[x] = 1;
for(int y : e[x]){
if(y == fa || v[y])continue;
calsize(y,x);
siz[x] += siz[y];
mx[x] = max(mx[x],siz[y]);
}
mx[x] = max(mx[x],sum - siz[x]);
if(mx[x] < mx[rt])rt = x;
}
void caldep(int x,int fa){
c[++cnt] = {dep[x] % p,premx[x]};
del[++len] = {dep[x] % p,premx[x]};
for(int y : e[x]){
if(y == fa || v[y])continue;
dep[y] = (dep[x] + a[y]) % p;
premx[y] = max(premx[x],a[y]);
caldep(y,x);
}
}
void cal(int x){
sort(c+1,c+1+cnt);
cn[0]++;ans++;
for(int i = 1;i <= cnt;i++){
ans += cn[((p - c[i].d + c[i].mx + p - a[x]) % p + p) % p];
cn[c[i].d]++;
//p | a[x] + cn + c[i].d - c[i].mx
}
cn[0]--;
for(int i = 1;i <= cnt;i++)cn[c[i].d]--;
}
int su = 0;
void dfs(int x){
++su;
assert(su <= n);
cnt = 0;
for(int y : e[x]){
if(v[y])continue;
len = 0;
premx[y] = max(a[x],a[y]),dep[y] = a[y];
caldep(y,0);
sort(del+1,del+1+len);
for(int i = 1;i <= len;i++){
ans -= cn[((p - del[i].d + del[i].mx - a[x]) % p + p) % p];
cn[del[i].d]++;
}
for(int i = 1;i <= len;i++)cn[del[i].d]--;
}
cal(x);
for(int y : e[x]){
if(v[y])continue;
sum = siz[y],rt = 0;
calsize(y,0),calsize(rt,0),v[rt] = 1;
dfs(rt);
}
}
int main(){
freopen("mplace.in","r",stdin);
freopen("mplace.out","w",stdout);
n = read(),p = read();
for(int i = 1;i < n;i++){
int x = read(),y = read();
e[x].pb(y),e[y].pb(x);
}
for(int i = 1;i <= n;i++)a[i] = read();
mx[0] = 1e8,sum = n;
calsize(1,0),calsize(rt,0),v[rt] = 1;
dfs(rt);
printf("%lld\n",ans);
return 0;
}