「BZOJ-4860」「BJOI-2017」树的难题

给你一棵 $n$ 个点的无根树。

树上的每条边具有颜色。一共有 $m$ 种颜色,编号为 $1$ 到 $m$。第 $i$ 种颜色的权值为 $c_i$。

对于一条树上的简单路径,路径上经过的所有边按顺序组成一个颜色序列,序列可以划

分成若干个相同颜色段。定义路径权值为颜色序列上每个同颜色段的颜色权值之和。

请你计算,经过边数在 $l$ 到 $r$ 之间的所有简单路径中,路径权值的最大值。

输入格式

输入文件名为 journey.in。

第一行,四个整数 $n, m, l, r$。

第二行,$n$ 个整数 $c_1, c_2, \cdots , c_m$,由空格隔开。依次表示每个颜色的权值。

接下来 $n-1$ 行,每行三个整数 $u, v, c$,表示点 $u$ 和点 $v$ 之间有一条颜色为 $c$ 的边。

输出格式

输出文件名为 journey.out。

输出一行,一个整数,表示答案。

样例

样例输入 1

1
2
3
4
5
6
5 3 1 4
-1 -5 -2
1 2 1
1 3 1
2 4 2
2 5 3

样例输出 1

1
-1

样例解释 1

颜色权值均为负,最优路径为 $(1, 2)$ 或 $(1, 3)$。

样例输入 2

1
2
3
4
5
6
7
8
9
8 4 3 4
-7 9 6 1
1 2 1
1 3 2
1 4 1
2 5 1
5 6 2
3 7 1
3 8 3

样例输出 2

1
11

样例解释 2

最优路径为 $(3, 1, 2, 5, 6)$,其颜色序列为 $(2, 1, 1, 2)$。

数据规模与约定

对于 $100 \%$ 的数据,$1 \leq n, m \leq 2 * 10^5, 1 \leq l \leq r \leq n, |c_i| \leq 10^4$。

保证树上至少存在一条经过边数在 $l$ 到 $r$ 之间的路径

链接

BZOJ-4860

题解

点分,每次统计经过根的路径,删去根,分别处理每个子树,先将每个子树不同深度内的最优路径找出,然后合并不同子树的路径就好了。

先同种颜色合并,再不同颜色合并,合并的时候,将所有子树按深度从小到大排序,记录之前子树的最大值和次大值,对当前子树处理,用线段树(单调队列)优化就可以了。

我并没有想出可以用单调队列做,于是写了个线段树,时间复杂度为 $O(n \text{ log}^2n)$,若用单调队列则为 $O(n \text{ log }n)$。

我 zz 写的线段树这里就不贴了,贴一下 wuvin 的单调队列,%wuvin

代码

wuvin 的单调队列

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
/* @author wuvin */
#include <bits/stdc++.h>
#define N 200005
using namespace std;
struct edge {
int t, col, next;
} e[N * 2];
int ecnt, head[N];
void addedge(int f, int t, int c) {
e[++ecnt] = (edge){t, c, head[f]};
head[f] = ecnt;
e[++ecnt] = (edge){f, c, head[t]};
head[t] = ecnt;
}
int n, m, l, r;
int val[N];
int sz[N], vis[N];
int cen, ans;
void findc(int u, int fr, int tot) {
sz[u] = 1;
int mx = 0;
for (int i = head[u]; i; i = e[i].next) {
if (e[i].t != fr && vis[e[i].t] != 1) {
findc(e[i].t, u, tot);
sz[u] += sz[e[i].t];
mx = max(mx, sz[e[i].t]);
}
}
if (max(mx, tot - sz[u]) <= tot / 2) cen = u;
}
int deep[N];
void getd(int u, int fr) {
deep[u] = 1;
sz[u] = 1;
for (int i = head[u]; i; i = e[i].next) {
if (vis[e[i].t] == 0 && e[i].t != fr) {
getd(e[i].t, u);
deep[u] = max(deep[e[i].t] + 1, deep[u]);
sz[u] += sz[e[i].t];
}
}
}
edge bd[N]; /* best edge */
vector<edge> v[N];
int sor[N], scnt;
bool cmp(int a, int b) { return deep[bd[a].t] < deep[bd[b].t]; }
bool cmpe(edge a, edge b) { return deep[a.t] < deep[b.t]; }
int best1[N], best2[N], b1sz, b2sz; /*-2e9*/
int buf[N], bsz;
void DFS(int u, int fr, int frc, int v, int num) {
bsz = max(bsz, num);
buf[num] = max(buf[num], v);
if (num >= l && num <= r) ans = max(ans, v);
for (int i = head[u]; i; i = e[i].next) {
if (vis[e[i].t] == 0 && e[i].t != fr) {
if (e[i].col == frc)
DFS(e[i].t, u, frc, v, num + 1);
else
DFS(e[i].t, u, e[i].col, v + val[e[i].col], num + 1);
}
}
}
/*dddl*/
int vi[N], ti[N];
int headx, tail;
void add(int vv, int tnow, int tt) {
while (headx > tail && vi[headx] <= vv) headx--;
while (headx > tail && ti[tail + 1] <= tnow) tail++;
vi[++headx] = vv;
ti[headx] = tt;
}
long long query(int tnow) {
while (headx > tail && ti[tail + 1] <= tnow) tail++;
return headx > tail ? vi[tail + 1] : -2e9;
}
void solve(int u, int totsz) {
vis[u] = 1;
if (totsz < l) return;
/*--solving---*/
for (int i = head[u]; i; i = e[i].next)
if (vis[e[i].t] != 1) getd(e[i].t, u);
for (int i = head[u]; i; i = e[i].next)
if (vis[e[i].t] != 1) {
v[e[i].col].push_back(e[i]);
if (v[e[i].col].size() == 1) sor[++scnt] = e[i].col;
if (bd[e[i].col].t == 0 || deep[e[i].t] >= deep[bd[e[i].col].t])
bd[e[i].col] = e[i];
}
sort(sor + 1, sor + scnt + 1, cmp);
for (int i = 1; i <= scnt; i++) {
int c = bd[sor[i]].col;
sort(v[c].begin(), v[c].end(), cmpe);
for (int j = 0; j < v[c].size(); j++) {
DFS(v[c][j].t, u, v[c][j].col, val[v[c][j].col], 1);
/*solve small*/
int l1 = l - 1, r1 = r - 1;
headx = tail = 0;
for (int k = min(b2sz, r1); k >= max(l1, 1); k--)
add(best2[k], -k, -(k - (r - l + 1)));
for (int k = 1; k <= bsz && r1 >= 1; k++, l1--, r1--) {
if (l1 >= 1 && l1 <= b2sz)
add(best2[l1], -l1, -l1 + (r - l + 1));
ans = max(ans * 1ll, query(-l1) + buf[k] - val[c]);
}
for (int k = 1; k <= bsz; k++)
best2[k] = max(best2[k], buf[k]), buf[k] = -2e9;
b2sz = max(b2sz, bsz);
bsz = 0;
}
/*--solve big*/
int l1 = l - 1, r1 = r - 1;
headx = tail = 0;
for (int k = min(b1sz, r1); k >= max(l1, 1); k--)
add(best1[k], -k, -(k - (r - l + 1)));
for (int k = 1; k <= b2sz && r1 >= 1; k++, l1--, r1--) {
if (l1 >= 1 && l1 <= b1sz) add(best1[l1], -l1, -l1 + (r - l + 1));
ans = max(ans * 1LL, query(-l1) + best2[k]);
}
for (int k = 1; k <= b2sz; k++)
best1[k] = max(best1[k], best2[k]), best2[k] = -2e9;
b1sz = max(b1sz, b2sz);
b2sz = 0;
}
/*-----clear--*/
for (int i = head[u]; i; i = e[i].next)
if (vis[e[i].t] != 1)
v[e[i].col].clear(), bd[e[i].col] = (edge){0, 0, 0};
scnt = 0;
for (int i = 1; i <= b1sz; i++) best1[i] = -2e9;
for (int i = 1; i <= b2sz; i++) best2[i] = -2e9;
b1sz = b2sz = 0;
/*------------*/
for (int i = head[u]; i; i = e[i].next) {
if (vis[e[i].t] == 0) {
int tot = sz[e[i].t];
findc(e[i].t, u, tot);
solve(cen, tot);
}
}
}
void readin() {
scanf("%d%d%d%d", &n, &m, &l, &r);
for (int i = 1; i <= m; i++) scanf("%d", &val[i]);
for (int i = 2; i <= n; i++) {
int f, t, c;
scanf("%d%d%d", &f, &t, &c);
addedge(f, t, c);
}
for (int i = 1; i <= n; i++) best1[i] = best2[i] = buf[i] = -2e9;
best1[0] = best2[0] = buf[0] = 0;
ans = -2e9;
}
int main() {
freopen("journey.in", "r", stdin);
freopen("journey.out", "w", stdout);
readin();
findc(1, 0, n);
solve(cen, n);
cout << ans << endl;
return 0;
}

分享到