#include <algorithm>
#include <cmath>
#include <cstdio>
#include <iostream>
#include <vector>
using namespace std;
const int N = 500500;
int n;
int m;
vector<int> tree[N];
int l[N];
int r[N];
long long res;
void dfs (int now, int dad) {
vector<int> t;
for (auto son : tree[now]) {
if (son == dad) {
continue;
}
dfs (son, now);
t.push_back(l[son]);
t.push_back(r[son]);
}
if (now > m) {
sort (t.begin(), t.end());
int num = t.size() / 2;
l[now] = t[num - 1];
r[now] = t[num];
}
for (auto son : tree[now]) {
if (son == dad) {
continue;
}
if (r[son] < l[now]) {
res += l[now] - r[son];
}
else if (l[now] < l[son]) {
res += l[son] - l[now];
}
}
}
int main () {
freopen ("starria.in", "r", stdin);
freopen ("starria.out", "w", stdout);
cin >> n >> m;
int u, v;
for (int i = 1; i < n; i++) {
cin >> u >> v;
tree[u].push_back(v);
tree[v].push_back(u);
}
for (int i = 1; i <= m; i++) {
cin >> l[i];
r[i] = l[i];
}
dfs (n, 0);
cout << res << endl;
return 0;
}