给出一棵树的所有边的node pair,树中一共n个节点,计算一个ans数组,ans[i]表示从节点i出发到树中其他所有节点的路径只和
有参考Discussion的思路->https://leetcode.com/problems/sum-of-distances-in-tree/solutions/130567/two-traversals-o-n-python-solution-with-explanation/
step 1. 用dict预处理与每个节点相连的边
step 2. 第一次DFS,预处理以节点i为根的子树的所有节点数n_node[i]以及从i出发,到达其所有子节点的距离和dis[i]
step 3. 第二次DFS,从节点0(或者任意其他节点)开始搜索,ans[0] = dis[0], ans[i] =
ans[j] + n - 2 * n_node[i] (j为i的父节点,DFS时顺便记录)
1 #834
2 #Runtime: 998 ms (Beats 85.71%)
3 #Memory: 70.3 MB (Beats 61.90%)
4
5 class Solution(object):
6 def sumOfDistancesInTree(self, n, edges):
7 """
8 :type n: int
9 :type edges: List[List[int]]
10 :rtype: List[int]
11 """
12 graph_dict = defaultdict(set)
13 for x,y in edges:
14 graph_dict[x].add(y)
15 graph_dict[y].add(x)
16 self.dis = [0] * n
17 self.n_node = [0] * n
18
19 def DFS(node, vis):
20 vis.add(node)
21 t_dis, t_n = 0, 0
22 for i in graph_dict[node]:
23 if i not in vis:
24 tp = DFS(i, vis)
25 t_n += tp[0]
26 t_dis += tp[1] + tp[0]
27 self.dis[node] = t_dis
28 self.n_node[node] = t_n + 1
29 return t_n + 1, t_dis
30
31 DFS(0, set())
32
33 ans = [0] * n
34
35 def SumofTree(node, pre_node, vis):
36 vis.add(node)
37 if pre_node == -1:
38 ans[node] = self.dis[node]
39 else:
40 ans[node] = ans[pre_node] + n - 2 * self.n_node[node]
41 for i in graph_dict[node]:
42 if i not in vis:
43 SumofTree(i, node, vis)
44
45 SumofTree(0, -1, set())
46 return ans
47