【题意】:给出一棵树,每个节点有一个权值,要你对这棵树cut k次,每cut一次在两棵树中选择任意一棵继续操作,问cut完k次后,求最后得到的树的最小/最大权值和。
【题解】:很明显的树dp。刚开始状态没设好,把问题搞得好复杂。
设状态dp[i][j]表示以i为根的这棵子树cut j次的最小/最大权值和(包括i这个节点)。
那么对于i的每棵子树都有两种选择,就是保留这棵子树或者丢弃这棵子树。
这里的转移就是一个简单的背包。
关键是最后统计答案,如果答案包括根,那么就是dp[root][k];否则,就需要枚举各个节点的情况,取最优值。
【代码】:
1 #include "iostream"
2 #include "cstdio"
3 #include "cstring"
4 #include "algorithm"
5 #include "vector"
6 #include "queue"
7 #include "cmath"
8 #include "string"
9 #include "cctype"
10 #include "map"
11 #include "iomanip"
12 using namespace std;
13 #define pb push_back
14 #define mp make_pair
15 #define fi first
16 #define se second
17 #define sof(x) sizeof(x)
18 #define lc(x) (x << 1)
19 #define rc(x) (x << 1 | 1)
20 #define lowbit(x) (x & (-x))
21 #define ll long long
22 #define maxn 1010
23 const int inf = 1 << 26;
24 vector<int> vec[maxn];
25 int val[maxn];
26 int dp[maxn][25][2];
27 int n, k;
28 int cnt[maxn];
29 int dfs(int u, int fa) {
30 cnt[u] = 1;
31 int size = vec[u].size();
32 dp[u][0][0] = dp[u][0][1] = 0;
33 for(int i = 1; i <= k; i++) {
34 dp[u][i][0] = inf;
35 dp[u][i][1] = -inf;
36 }
37 for(int i = 0; i < size; i++) {
38 int v = vec[u][i];
39 if(v == fa) continue;
40 cnt[u] += dfs(v, u);
41 for(int j = k; j >= 0; j--) {
42 dp[u][j][0] += dp[v][0][0];
43 dp[u][j][1] += dp[v][0][1];
44 for(int jj = 1; jj <= j; jj++) {
45 dp[u][j][0] = min(dp[u][j][0], dp[u][j-jj][0] + dp[v][jj][0]);
46 dp[u][j][1] = max(dp[u][j][1], dp[u][j-jj][1] + dp[v][jj][1]);
47 }
48 for(int jj = 0; jj < cnt[v] && jj < j; jj++) {
49 dp[u][j][0] = min(dp[u][j][0], dp[u][j-jj-1][0]);
50 dp[u][j][1] = max(dp[u][j][1], dp[u][j-jj-1][1]);
51 }
52 }
53 }
54 for(int i = 0; i <= k; i++) {
55 dp[u][i][0] += val[u];
56 dp[u][i][1] += val[u];
57 }
58 return cnt[u];
59 }
60
61 void solve() {
62 memset(cnt, 0, sof(cnt));
63 dfs(1, -1);
64 int ans1 = dp[1][k][0], ans2 = dp[1][k][1];
65 for(int i = 2; i <= n; i++) {
66 for(int j = 1; j <= k && n - cnt[i] >= j ; j++) {
67 ans1 = min(ans1, dp[i][k-j][0]);
68 ans2 = max(ans2, dp[i][k-j][1]);
69 }
70 }
71 cout << ans1 << " " << ans2 << endl;
72 }
73
74 int main() {
75 while(cin >> n >> k) {
76 for(int i = 1; i <= n; i++) {
77 cin >> val[i];
78 vec[i].clear();
79 }
80 int a, b;
81 for(int i = 1; i < n; i++) {
82 cin >> a >> b;
83 vec[a].pb(b);
84 vec[b].pb(a);
85 }
86 solve();
87 }
88 return 0;
89 }
90