记录编号 597391 评测结果 AAAAAAAAAAAAAAAAAAAAAAAAA
题目名称 魔法试练场 最终得分 100
用户昵称 GravatardarkMoon 是否通过 通过
代码语言 C++ 运行时间 7.132 s
提交时间 2024-11-27 12:11:11 内存使用 24.19 MiB
显示代码纯文本
#include<bits/stdc++.h>
#define int long long
#define fi first
#define se second
#define mp make_pair
using namespace std;
auto mread = [](){int x;scanf("%lld", &x);return x;};
const int N = 1e5 + 5, M = 1e7 + 5;
int n = mread(), P = mread(), siz[N], a[N];
vector<int> v[N];
bool del[N];
int mi = LONG_LONG_MAX, p = 0, sum, ans, s[M];
struct node{
    int sum, ma;
    bool friend operator <(node a, node b){
        return a.ma < b.ma;
    }
};
vector<node> vn, tmp;
void dfs1(int x, int fa){
    // dfs1 用来求当前子树的每个节点的 siz(任意点为根),以及整个子树的大小 sum
    siz[x] = 1;
    for(int y : v[x]){
        if(y == fa || del[y]){
            continue;
        }
        dfs1(y, x);
        siz[x] += siz[y];
    }
    return;
}
void dfs2(int x, int fa){
    // dfs2 用来求重心
    int ma = 0;
    for(int y : v[x]){
        if(y == fa || del[y]){
            continue;
        }
        dfs2(y, x);
        ma = max(ma, siz[y]);
    }
    ma = max(ma, sum - siz[x]);
    if(ma < mi){
        mi = ma;
        p = x;
    }
    return;
}
void add(int x, int fa, int sum, int ma){
    vn.push_back({sum, ma});
    tmp.push_back({sum, ma});
    for(int y : v[x]){
        if(y == fa || del[y]){
            continue;
        }
        add(y, x, sum + a[y], max(ma, a[y]));
    }
}
signed main(){
    for(int i = 1, x, y; i < n; i ++){
        cin >> x >> y;
        v[x].push_back(y);
        v[y].push_back(x);
    }
    for(int i = 1; i <= n; i ++){
        cin >> a[i];
    }
    int S = 0;
    while(S < n){
        for(int i = 1; i <= n; i ++){
            if(!del[i]){
                mi = LONG_LONG_MAX, p = 0;
                dfs1(i, 0);
                sum = siz[i];
                dfs2(i, 0);
                // p 是重心
                for(int y : v[p]){
                    if(del[y]){
                        continue;
                    }
                    add(y, p, a[p] + a[y], max(a[p], a[y]));
                    sort(tmp.begin(), tmp.end());
                    for(auto t : tmp){
                        ans -= s[(P - (t.sum - t.ma - a[p]) % P + P) % P];
                        s[t.sum % P] ++;
                    }
                    for(auto t : tmp){
                        s[t.sum % P] --;
                    }
                    tmp.clear();
                }
                sort(vn.begin(), vn.end());
                s[a[p] % P] ++;
                for(auto t : vn){
                    ans += s[(P - (t.sum - t.ma - a[p]) % P + P) % P];
                    s[t.sum % P] ++;
                }
                for(auto t : vn){
                    s[t.sum % P] --;
                }
                s[a[p] % P] --;
                vn.clear();
                del[p] = 1;
                S ++;
            }
        }
    }
    printf("%lld", ans + n);
    return 0;
}