我的做法是,对于每条新边,记录树中与之对应的路径。然后对于每条树边,统计被对应的次数。最后记录每个点到树根的路径上,有多少个1(设为q[i])。对于新边(x,y),它对答案的贡献就是q[x] + q[y] - 2q[lca(x,y)]。除了这些,答案还应加上树中0边的数量 * m。
/**//*************************************************************************
Author: WHU_GCC
Created Time: 2007-10-6 12:45:53
File Name: pku3417.cpp
Description:
************************************************************************/
#include <iostream>
#include <cmath>
using namespace std;
#define out(x) (cout<<#x<<": "<<x<<endl)
const int maxint=0x7FFFFFFF;
const int maxn = 100010;
struct node_t
{
int father;
int value;
int pre;
};
struct tmp_t
{
int v;
tmp_t *next;
};
int n, m;
node_t p[maxn];
tmp_t *u[maxn];
int edge_u[maxn];
int edge_v[maxn];
int edge_lca[maxn];
int *d[20][200010];
int euler[maxn * 2];
int a[maxn * 2];
int high[maxn];
int stt[maxn];
int tim;
int *mmin(int *a, int *b)
{
if (*a < *b)
return a;
return b;
}
void make_rmq(int n)
{
int i, j;
for (i = 1; i <= n; i++)
d[0][i] = &a[i];
for (j = 1; j <= log((double)n) / log (2.0);j++)
for (i = 1; i + (1 << j) - 1 <= n; i++)
d[j][i] = mmin (d[j - 1][i], d[j - 1][i + (1 << (j - 1))]);
}
int *rmq(int i,int j)
{
if (i > j)
swap (i, j);
int k = (int)(log (j - i + 1.) / log (2.0));
return mmin (d[k][i], d[k][j - (1 << k) + 1]);
}
void mmd (int f, int v, int h)
{
euler[++tim] = v;
high[v] = h;
tmp_t *pt;
for (pt = u[v]; pt; pt = pt->next)
if (pt->v != f)
{
mmd (v, pt->v, h + 1);
euler[++tim] = v;
}
}
void build ()
{
tim = 0;
mmd (1, 1, 1);
memset (stt, 0, sizeof (stt));
int i;
for (i = 1; i <= tim; ++i)
{
if (stt[euler[i]] == 0)
stt[euler[i]] = i;
a[i] = high[euler[i]];
}
make_rmq (tim);
}
int lca (int u, int v)
{
return euler[rmq (stt[u], stt[v]) - a];
}
int f[maxn];
int g[maxn];
int q[maxn];
int dfs(int father, int now)
{
int ret = 0;
tmp_t *t;
for (t = u[now]; t; t = t->next)
if (t->v != father)
ret += dfs(now, t->v);
return g[now] = ret + f[now];
}
void dfs1(int father, int now, int sum)
{
q[now] = sum + (g[now] == 1);
tmp_t *t;
for (t = u[now]; t; t = t->next)
if (t->v != father)
dfs1(now, t->v, sum + (g[now] == 1));
}
int main()
{
while (scanf("%d%d", &n, &m) != EOF)
{
memset(u, 0, sizeof(u));
int i;
for (i = 0; i < n - 1; i++)
{
int t1, t2;
scanf("%d%d", &t1, &t2);
tmp_t *p = new tmp_t;
p->v = t2;
p->next = u[t1];
u[t1] = p;
p = new tmp_t;
p->v = t1;
p->next = u[t2];
u[t2] = p;
}
memset(p, 0, sizeof(p));
for (i = 1; i <= n; i++)
p[i].pre = i;
build();
memset(f, 0, sizeof(f));
for (i = 0; i < m; i++)
{
int t1, t2;
scanf("%d%d", &t1, &t2);
edge_u[i] = t1;
edge_v[i] = t2;
int t = lca(t1, t2);
edge_lca[i] = t;
f[t1]++;
f[t2]++;
f[t] -= 2;
}
memset(g, 0, sizeof(g));
dfs(1, 1);
memset(q, 0, sizeof(q));
dfs1(1, 1, 0);
int sum = 0;
for (i = 2; i <= n; i++)
if (g[i] == 0)
sum++;
int ans = 0;
for (i = 0; i < m; i++)
ans += q[edge_u[i]] + q[edge_v[i]] - 2 * q[edge_lca[i]];
printf("%d\n", ans + sum * m);
}
return 0;
}