1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
| #include "closing.h" #include <bits/stdc++.h>
#ifdef ORZXKR #include "grader.cpp" #endif
using i64 = int64_t; using i128 = __int128_t;
const int kMaxN = 2e5 + 5; const i64 kInf = 1e18;
int n, x, y; int p[kMaxN]; i64 k, dis1[kMaxN], dis2[kMaxN]; std::vector<std::pair<int, int>> G[kMaxN];
struct SGT { static const int kMaxT = kMaxN * 60; int tot, rt, ls[kMaxT], rs[kMaxT], cnt[kMaxT]; i128 sum[kMaxT];
void clear() { for (int i = 1; i <= tot; ++i) ls[i] = rs[i] = sum[i] = cnt[i] = 0; tot = rt = 0; } void pushup(int x) { sum[x] = sum[ls[x]] + sum[rs[x]]; cnt[x] = cnt[ls[x]] + cnt[rs[x]]; } void update(int &x, i64 l, i64 r, i64 ql, int v) { if (!x) x = ++tot; if (l == r) { sum[x] += v * ql, cnt[x] += v; return; } i64 mid = (l + r) >> 1; if (ql <= mid) update(ls[x], l, mid, ql, v); else update(rs[x], mid + 1, r, ql, v); pushup(x); } int query(int x, i64 l, i64 r, i64 k) { if (!x) return 0; else if (l == r) return !l ? cnt[x] : std::min<int>(cnt[x], k / l); else if (k >= sum[x]) return cnt[x]; i64 mid = (l + r) >> 1; if (k <= sum[ls[x]]) return query(ls[x], l, mid, k); else return cnt[ls[x]] + query(rs[x], mid + 1, r, k - sum[ls[x]]); } } sgt;
void dfs(int u, int fa, i64 *dis) { p[u] = fa; for (auto [v, w] : G[u]) { if (v == fa) continue; dis[v] = dis[u] + w; dfs(v, u, dis); } }
int solve1() { std::vector<i64> v1 = {0}, v2 = {0}; for (int i = 1; i <= n; ++i) v1.emplace_back(dis1[i]), v2.emplace_back(dis2[i]); std::sort(v1.begin(), v1.end()), std::sort(v2.begin(), v2.end()); for (int i = 1; i <= n; ++i) v1[i] += v1[i - 1], v2[i] += v2[i - 1]; int ret = 0; for (int i = 0, j = n; i <= n; ++i) { for (; ~j && v1[i] + v2[j] > k; --j) {} if (j >= 0) ret = std::max(ret, i + j); } return ret; }
int solve2() { static int id[kMaxN]; static i64 a[kMaxN], b[kMaxN]; int ret = 0, cnt = 0; i64 k = ::k; for (int i = 1; i <= n; ++i) { a[i] = std::min(dis1[i], dis2[i]); b[i] = std::max(dis1[i], dis2[i]); id[i] = i; } for (int i = x; i; i = p[i]) { ++cnt, k -= std::min(dis1[i], dis2[i]); a[i] = std::max(dis1[i], dis2[i]) - std::min(dis1[i], dis2[i]), b[i] = kInf; } if (k < 0) return 0; sgt.clear(); std::sort(id + 1, id + 1 + n, [&] (int i, int j) { return b[i] < b[j]; }); for (int i = 1; i <= n; ++i) sgt.update(sgt.rt, 0, kInf, a[id[i]], 1); ret = cnt + sgt.query(sgt.rt, 0, kInf, k); for (int i = 1; i <= n; ++i) { sgt.update(sgt.rt, 0, kInf, a[id[i]], -1); sgt.update(sgt.rt, 0, kInf, b[id[i]] - a[id[i]], 1); ++cnt, k -= a[id[i]]; if (k >= 0) ret = std::max(ret, cnt + sgt.query(sgt.rt, 0, kInf, k)); else break; } return ret; }
int max_score(int N, int X, int Y, long long K, std::vector<int> U, std::vector<int> V, std::vector<int> W) { n = N, x = X + 1, y = Y + 1, k = K; for (int i = 1; i <= n; ++i) G[i].clear(); for (int i = 0; i < n - 1; ++i) { ++U[i], ++V[i]; G[U[i]].emplace_back(V[i], W[i]); G[V[i]].emplace_back(U[i], W[i]); } dis1[x] = dis2[y] = 0; dfs(x, 0, dis1), dfs(y, 0, dis2); return std::max(solve1(), solve2()); }
|