我的做法是,对于每条新边,记录树中与之对应的路径。然后对于每条树边,统计被对应的次数。最后记录每个点到树根的路径上,有多少个1(设为q[i])。对于新边(x,y),它对答案的贡献就是q[x] + q[y] - 2q[lca(x,y)]。除了这些,答案还应加上树中0边的数量 * m。


/*************************************************************************
Author: WHU_GCC
Created Time: 2007-10-6 12:45:53
File Name: pku3417.cpp
Description: 
***********************************************************************
*/

#include 
<iostream>
#include 
<cmath>
using namespace std;
#define out(x) (cout<<#x<<": "<<x<<endl)
const int maxint=0x7FFFFFFF;

const int maxn = 100010;

struct node_t
{
    
int father;
    
int value;
    
int pre;
}
;

struct tmp_t
{
    
int v;
    tmp_t 
*next;
}
;

int n, m;
node_t p[maxn];
tmp_t 
*u[maxn];
int edge_u[maxn];
int edge_v[maxn];
int edge_lca[maxn];
int *d[20][200010];
int euler[maxn * 2];
int a[maxn * 2];
int high[maxn];
int stt[maxn];
int tim;
int *mmin(int *a, int *b)
{
    
if (*< *b)
        
return a;
    
return b;
}

void make_rmq(int n)
{
    
int i, j;
    
for (i = 1; i <= n; i++)
         d[
0][i] = &a[i];
    
for (j = 1; j <= log((double)n) / log (2.0);j++)
        
for (i = 1; i + (1 << j) - 1 <= n; i++)
            d[j][i] 
= mmin (d[j - 1][i], d[j - 1][i + (1 << (j - 1))]);
}

int *rmq(int i,int j)
{
    
if (i > j)
        swap (i, j);
    
int k = (int)(log (j - i + 1.) / log (2.0));
    
return mmin (d[k][i], d[k][j - (1 << k) + 1]);
}

void mmd (int f, int v, int h)
{
    euler[
++tim] = v;
    high[v] 
= h;
    tmp_t 
*pt;
    
for (pt = u[v]; pt; pt = pt->next)
        
if (pt->!= f)
        
{
            mmd (v, pt
->v, h + 1);
            euler[
++tim] = v;
        }

}


void build ()
{
    tim 
= 0;
    mmd (
111);
    memset (stt, 
0sizeof (stt));
    
int i;
    
for (i = 1; i <= tim; ++i)
    
{
        
if (stt[euler[i]] == 0)
            stt[euler[i]] 
= i;
        a[i] 
= high[euler[i]];
    }

    make_rmq (tim);
}

int lca (int u, int v)
{
    
return euler[rmq (stt[u], stt[v]) - a];
}


int f[maxn];
int g[maxn];
int q[maxn];

int dfs(int father, int now)
{
    
int ret = 0;
    tmp_t 
*t;
    
for (t = u[now]; t; t = t->next)
        
if (t->!= father)
            ret 
+= dfs(now, t->v);
    
return g[now] = ret + f[now];
}


void dfs1(int father, int now, int sum)
{
    q[now] 
= sum + (g[now] == 1);
    tmp_t 
*t;
    
for (t = u[now]; t; t = t->next)
        
if (t->!= father)
            dfs1(now, t
->v, sum + (g[now] == 1));
}


int main()
{
    
while (scanf("%d%d"&n, &m) != EOF)
    
{
        memset(u, 
0sizeof(u));
        
int i;
        
for (i = 0; i < n - 1; i++)
        
{
            
int t1, t2;
            scanf(
"%d%d"&t1, &t2);
            tmp_t 
*= new tmp_t;
            p
->= t2;
            p
->next = u[t1];
            u[t1] 
= p;
            
            p 
= new tmp_t;
            p
->= t1;
            p
->next = u[t2];
            u[t2] 
= p;
        }

        memset(p, 
0sizeof(p));
        
for (i = 1; i <= n; i++)
            p[i].pre 
= i;
        build();
        memset(f, 
0sizeof(f));
        
for (i = 0; i < m; i++)
        
{
            
int t1, t2;
            scanf(
"%d%d"&t1, &t2);
            edge_u[i] 
= t1;
            edge_v[i] 
= t2;
            
int t = lca(t1, t2);
            edge_lca[i] 
= t;
            f[t1]
++;
            f[t2]
++;
            f[t] 
-= 2;
        }

        memset(g, 
0sizeof(g));
        dfs(
11);
        memset(q, 
0sizeof(q));
        dfs1(
110);
        
int sum = 0;
        
for (i = 2; i <= n; i++)
            
if (g[i] == 0)
                sum
++;
        
int ans = 0;
        
for (i = 0; i < m; i++)
            ans 
+= q[edge_u[i]] + q[edge_v[i]] - 2 * q[edge_lca[i]];
        printf(
"%d\n", ans + sum * m);
    }

    
return 0;
}
posted on 2007-10-06 20:53 Felicia 阅读(494) 评论(0)  编辑 收藏 引用 所属分类: 图论

只有注册用户登录后才能发表评论。
网站导航: 博客园   IT新闻   BlogJava   知识库   博问   管理