最近在练树形 DP,正好看到 这一道 虚标的紫题,但本蒟蒻不会写,想出来了便记录一下
题面 & 思考
先看题面,一颗有根树,选定 k 个节点作为 ”伐木场“,求运送木料最小费用。注意木料费用是 dis * wood
。
最开始想到简单的树形背包,状态转移方程:
1
| f[i][k] = min(f[j][s] + cost, f[i][k]);
|
但是注意到,如果某个后代节点如果是伐木场,cost 不需要计算,所以状态中还需要存储后代的伐木场情况。
但是后代中锯木厂不止一个。考虑到树有唯一的父亲,且同一深度祖先唯一,可以记录最近的伐木场祖先,作为状态的一部分。
有 f[i][j][k]
即 i 节点 最近的伐木场祖先为 j,后代(不算自己)有 k 个是伐木场。
但是由于 f 自己也可能是伐木场,转移方程不同,需要分类讨论,于是改成:f[i][j][k][0/1]
其中 0 代表自己不是伐木场, 1 表示自己是伐木场。
状态转移方程
因为我们需要枚举祖先,在 DFS 时需要记录 Fa 数组:
1 2 3
| DFS 开始时:fa[++tot] = x;
结束时:tot--;
|
简单记录祖先 stack。
回溯后,对于当前节点 x
和 子节点 y
每个祖先 ff
:
1 2 3 4
| 先对每个 k 赋初值,对于每种伐木场个数 l: f[x][ff][l][0] += f[y][ff][0][0]; --> 当前节点不是伐木场,子节点 y 最近伐木场为 l,对任意 f,赋最大值,即 子结点中没有伐木场。
f[x][ff][l][1] += f[y][x][0][0]; --> 当前节点是伐木场,子节点 y 最近伐木场为 当前节点,对任意 f,赋最大值,即 子结点中没有伐木场。
|
然后就是树形背包:
1 2 3 4
| f[x][ff][l][0] = min(f[x][ff][l][0], f[x][ff][l-s][0] + f[y][ff][s][0]); --> 当前节点不是伐木场,对 y 节点分配 s 个伐木场个数 f[x][ff][l][1] = min(f[x][ff][l][0], f[x][ff][l-s][1] + f[y][x][s][0]); --> 当前节点是伐木场,对 y 节点分配 s 个伐木场个数
唯一的不同是 y 节点最近祖先为 x
|
注意在枚举 l 时,由于是01背包,需要倒过来枚举,不然状态计算会叠加。(很容易理解)
做完了?好像还差一个 cost!
考虑到 祖先节点 ff
cost 的贡献,因为是 当前节点 W 乘以 到 ff
的总距离!
需要计算总距离,维护 dep
数组,即当前节点到根节点距离,即可!很容易理解。ff
到 x
的距离就是 dep[x] - dep[ff]
对于当前节点 x
每一个祖先 ff
:
1 2 3 4 5 6 7 8
| if(l>=1){ --> 因为 l 需要 -1 ,要分类讨论。 f[x][ff][l][0] = min(f[x][ff][l][0] + W[x] * (dep[x]-dep[ff]), f[x][ff][l][1]); 合并 0 和 1,因为回溯之后,1 的状态不再被使用,便于下一步计算。 -1 是因为上文 f[x][ff][l][1] = min(f[x][ff][l][0], f[x][ff][l-s][1] + f[y][x][s][0]); 时,s+l-s = l 但是当前节点也是伐木场,所以状态更新是 l+1 的,要 -1 获取正确结果 }else{ l = 0 时,当前节点不可能为伐木场,直接添加到 ff 增加的贡献 f[x][ff][l][0] += W[x] * (dep[x]-dep[ff]); }
|
这里再来聊一聊合并。为了便于讨论,上文背包转移的时候,我们没有考虑 y
的 1 的情况,而是在每次 y
回溯时 合并 0 和 1。这有点类似于滚动数组?回溯后 0
不再表示之前的意义,而是我们最初设计的状态:i 节点 最近的伐木场祖先为 j,后代(或自己)有 k 个是伐木场!
代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
| #include <cmath> #include <cstdio> #include <iostream> #include <cstring>
using namespace std;
struct Node { int to, next, d; } NDS[201];
int head[201], vis[201]; int cnt = 0; long long W[201]; int tot = 0;
void add(int a, int b, int d) { NDS[cnt].to = b; NDS[cnt].next = head[a]; NDS[cnt].d = d; head[a] = cnt++; }
int fa[201]; long long f[201][201][52][2]; long long dep[201]; long long n, k;
void dp(int x) { fa[++tot] = x; vis[x] = 1; for (int i = head[x]; i != -1; i = NDS[i].next) { int y = NDS[i].to; if (vis[y]) continue; dep[y] = dep[x] + NDS[i].d; dp(y); for (int j = tot; j >= 1; j--) { int ff = fa[j]; for (int l = k; l >= 0; l--) { f[x][ff][l][0] += f[y][ff][0][0]; f[x][ff][l][1] += f[y][x][0][0]; for(int s = l; s>=0; s--){ f[x][ff][l][0] = min(f[x][ff][l][0], f[x][ff][l-s][0] + f[y][ff][s][0]); f[x][ff][l][1] = min(f[x][ff][l][1], f[x][ff][l-s][1] + f[y][x][s][0]); } } } } for (int j = 1; j <= tot; j++) { int ff = fa[j]; for (int l = k; l >= 0; l--) { if(l>=1){ f[x][ff][l][0] = min(f[x][ff][l][0] + W[x] * (dep[x]-dep[ff]), f[x][ff][l-1][1]); }else{ f[x][ff][l][0] += W[x] * (dep[x]-dep[ff]); }
} }
tot--; }
int main() { memset(head, -1, sizeof(head)); cin >> n >> k; for (int i = 1; i <= n; i++) { int w, v, d; cin >> w >> v >> d; W[i] = w; add(i, v, d); add(v, i, d); } dp(0); cout << f[0][0][k][0] << endl; return 0; }
|