P3354 Riv 河流 题解

最近在练树形 DP,正好看到 这一道 虚标的紫题,但本蒟蒻不会写,想出来了便记录一下

题面 & 思考

先看题面,一颗有根树,选定 k 个节点作为 ”伐木场“,求运送木料最小费用。注意木料费用是 dis * wood

最开始想到简单的树形背包,状态转移方程:

1
f[i][k] = min(f[j][s] + cost, f[i][k]);

但是注意到,如果某个后代节点如果是伐木场,cost 不需要计算,所以状态中还需要存储后代的伐木场情况。

但是后代中锯木厂不止一个。考虑到树有唯一的父亲,且同一深度祖先唯一,可以记录最近的伐木场祖先,作为状态的一部分。

f[i][j][k] 即 i 节点 最近的伐木场祖先为 j,后代(不算自己)有 k 个是伐木场。

但是由于 f 自己也可能是伐木场,转移方程不同,需要分类讨论,于是改成:f[i][j][k][0/1] 其中 0 代表自己不是伐木场, 1 表示自己伐木场。

状态转移方程

因为我们需要枚举祖先,在 DFS 时需要记录 Fa 数组:

1
2
3
DFS 开始时:fa[++tot] = x;

结束时:tot--;

简单记录祖先 stack。

回溯后,对于当前节点 x 和 子节点 y 每个祖先 ff

1
2
3
4
先对每个 k 赋初值,对于每种伐木场个数 l:
f[x][ff][l][0] += f[y][ff][0][0]; --> 当前节点不是伐木场,子节点 y 最近伐木场为 l,对任意 f,赋最大值,即 子结点中没有伐木场。

f[x][ff][l][1] += f[y][x][0][0]; --> 当前节点是伐木场,子节点 y 最近伐木场为 当前节点,对任意 f,赋最大值,即 子结点中没有伐木场。

然后就是树形背包:

1
2
3
4
f[x][ff][l][0] = min(f[x][ff][l][0], f[x][ff][l-s][0] + f[y][ff][s][0]); --> 当前节点不是伐木场,对 y 节点分配 s 个伐木场个数
f[x][ff][l][1] = min(f[x][ff][l][0], f[x][ff][l-s][1] + f[y][x][s][0]); --> 当前节点是伐木场,对 y 节点分配 s 个伐木场个数

唯一的不同是 y 节点最近祖先为 x

注意在枚举 l 时,由于是01背包,需要倒过来枚举,不然状态计算会叠加。(很容易理解)

做完了?好像还差一个 cost!

考虑到 祖先节点 ff cost 的贡献,因为是 当前节点 W 乘以 到 ff 的总距离!

需要计算总距离,维护 dep 数组,即当前节点到根节点距离,即可!很容易理解。ffx 的距离就是 dep[x] - dep[ff]

对于当前节点 x 每一个祖先 ff

1
2
3
4
5
6
7
8
if(l>=1){ --> 因为 l 需要 -1 ,要分类讨论。
f[x][ff][l][0] = min(f[x][ff][l][0] + W[x] * (dep[x]-dep[ff]), f[x][ff][l][1]);
合并 01,因为回溯之后,1 的状态不再被使用,便于下一步计算。
-1 是因为上文 f[x][ff][l][1] = min(f[x][ff][l][0], f[x][ff][l-s][1] + f[y][x][s][0]); 时,s+l-s = l 但是当前节点也是伐木场,所以状态更新是 l+1 的,要 -1 获取正确结果
}else{
l = 0 时,当前节点不可能为伐木场,直接添加到 ff 增加的贡献
f[x][ff][l][0] += W[x] * (dep[x]-dep[ff]);
}

这里再来聊一聊合并。为了便于讨论,上文背包转移的时候,我们没有考虑 y 的 1 的情况,而是在每次 y 回溯时 合并 0 和 1。这有点类似于滚动数组?回溯后 0 不再表示之前的意义,而是我们最初设计的状态:i 节点 最近的伐木场祖先为 j,后代(或自己)有 k 个是伐木场!

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#include <cmath>
#include <cstdio>
#include <iostream>
#include <cstring>

using namespace std;

struct Node {
int to, next, d;
} NDS[201];

int head[201], vis[201];
int cnt = 0;
long long W[201];
int tot = 0;

void add(int a, int b, int d) {
NDS[cnt].to = b;
NDS[cnt].next = head[a];
NDS[cnt].d = d;
head[a] = cnt++;
}

int fa[201];
long long f[201][201][52][2];
long long dep[201];
long long n, k;

void dp(int x) {
fa[++tot] = x;
vis[x] = 1;
for (int i = head[x]; i != -1; i = NDS[i].next) {
int y = NDS[i].to;
if (vis[y]) continue;
dep[y] = dep[x] + NDS[i].d;
dp(y);
for (int j = tot; j >= 1; j--) {
int ff = fa[j];
for (int l = k; l >= 0; l--) {
f[x][ff][l][0] += f[y][ff][0][0];
f[x][ff][l][1] += f[y][x][0][0];
for(int s = l; s>=0; s--){
f[x][ff][l][0] = min(f[x][ff][l][0], f[x][ff][l-s][0] + f[y][ff][s][0]);
f[x][ff][l][1] = min(f[x][ff][l][1], f[x][ff][l-s][1] + f[y][x][s][0]);
}
}
}
}
for (int j = 1; j <= tot; j++) {
int ff = fa[j];
for (int l = k; l >= 0; l--) {
if(l>=1){
f[x][ff][l][0] = min(f[x][ff][l][0] + W[x] * (dep[x]-dep[ff]), f[x][ff][l-1][1]);
}else{
f[x][ff][l][0] += W[x] * (dep[x]-dep[ff]);
}
// cout << x << " "<<ff << " "<< l << " " << f[x][ff][l][0] << " " << endl;
}
}
// cout << dep[x] << endl;
tot--;
}

int main() {
// Type your code here
memset(head, -1, sizeof(head));
cin >> n >> k;
for (int i = 1; i <= n; i++) {
int w, v, d;
cin >> w >> v >> d;
W[i] = w;
add(i, v, d);
add(v, i, d); // 链式前向星
}
dp(0);
cout << f[0][0][k][0] << endl; // 输出结果,注意是 合并后,所以是 0
return 0;
}