Posted on 2023-01-15 21:24
Uriel 阅读(47)
评论(0) 编辑 收藏 引用 所属分类:
闲来无事重切Leet Code 、
并查集
给出一棵树中每个节点的值vals,以及各条边edges,问树中有多少条path,满足1.首尾节点vals相同;2.path中其他节点的值均小于首尾节点
参考了Discussion:https://leetcode.com/problems/number-of-good-paths/solutions/3053513/python3-union-find-explained/
1 #2421
2 #Runtime: 2295 ms (Beats 93.75%)
3 #Memory: 39 MB (Beats 81.25%)
4
5 class UnionFind:
6 def __init__(self, n):
7 self.parent = [i for i in range(n + 1)]
8 def find(self, x):
9 i = x
10 while x != self.parent[x]:
11 x = self.parent[x]
12 self.parent[i] = x
13 return x
14 def union(self, x, y):
15 self.parent[self.find(x)] = self.find(y)
16
17 class Solution(object):
18 def numberOfGoodPaths(self, vals, edges):
19 """
20 :type vals: List[int]
21 :type edges: List[List[int]]
22 :rtype: int
23 """
24 uf = UnionFind(len(vals))
25 node_val_set = defaultdict(set)
26 node_dict = defaultdict(list)
27 ans = len(vals)
28 for x, y in edges:
29 node_val_set[vals[x]].add(x)
30 node_val_set[vals[y]].add(y)
31 node_dict[x].append(y)
32 node_dict[y].append(x)
33 for i in sorted(node_val_set.keys()):
34 for j in node_val_set[i]:
35 for k in node_dict[j]:
36 if vals[k] <= i:
37 uf.union(j, k)
38 cnt = defaultdict(int)
39 for j in node_val_set[i]:
40 cnt[uf.find(j)] += 1
41 for r in cnt.keys():
42 ans += (cnt[r] - 1) * cnt[r] // 2
43 return ans