문제
- 문제 링크
BOJ 22487 - Do use segment tree
$N$개의 정점으로 이루어진 트리와 두가지 유형으로 이루어진 $Q$개의 쿼리가 주어진다.
각 쿼리를 알맞게 처리해보자.
TL : $2$ sec, ML : $512$ MB
$1 ≤ N ≤ 200,000$ $1 ≤ Q ≤ 100,000$ $-10,000 ≤ w_i ≤ 10,000$
알고리즘 분류
- 구현 (implemantation)
- 자료 구조 (data structures)
- 트리 (trees)
- heavy-light 분할 (heavy-light decomposition)
- 세그먼트 트리 (segment tree)
- 느리게 갱신되는 세그먼트 트리 (lazy propagation)
풀이
BOJ 16993 - 연속합과 쿼리
BOJ 15561 - 구간 합 최대? 2
간단히 말하면, 위와 같은 문제들을 트리 위에서 풀어야 한다는 것이다.
흔히 말하는 금광 세그, 구조체 세그와 같은 유형이다.
선형 구조가 아닌 트리 구조임에 따라, 경로 상 쿼리를 선형 구조처럼 풀 수 있게 해주는 $hld$를 이용하자.
또, 구간 단위 업데이트 쿼리를 처리해야 하므로 $lazy$를 적절히 뿌려주도록 하자.
앞서 언급한 문제들을 풀어봤다면 세그먼트 트리 상에서 관리해야 할 값을 알고 있을 것이다.
이부분에 관한 설명은 생략.
우리가 구하고자 하는건 '최대 연속합' 이다. 즉, 연속성을 유지해야 한다.
$hld$에서 경로 상 쿼리를 처리하는 방식을 떠올려보자.
두 정점이 같은 체인에 속할 때까지, 매 순간마다 깊이가 더 깊은 체인를 끌어올려 주며 진행한다.
이에 따라, 단순히 위 과정에서 $merge$를 수행한다면 연속성을 보장받지 못하게 된다.
따라서 두 정점 $a_i$, $b_i$에 대해 $a_i$부터 이어져오는 최대 연속합, $b_i$부터 이어져오는 최대 연속합을 각각 관리해줄 것을 생각하자.
위 과정에서 가장 중요한 것은 일관된 방향(순서)을 잡아줘야 한다는 것이다.
나는 이런 저런 시행 착오를 거치다 최종적으로 아래의 코드가 탄생했다.
123456789101112131415 ll HLDQuery(int p, int q){ T x(seg[0]), y(seg[0]); while (top[p] ^ top[q]) { if (D[top[p]] > D[top[q]]) x = merge(SegQuery(1, 1, n, in[top[p]], in[p]), x), p = P[top[p]]; else y = merge(SegQuery(1, 1, n, in[top[q]], in[q]), y), q = P[top[q]]; } if (D[p] > D[q]) swap(p, q), swap(x, y); swap(x.ls, x.rs); return merge(x, merge(SegQuery(1, 1, n, in[p], in[q]), y)).ms;} cs
뭔가 복잡하게 결론이 난 것 처럼 보이긴 한데.. 좀 더 깔끔한 방식이 있을 것도 같긴 하다.
이제 경로 단위 업데이트 쿼리만이 남았는데, 위 과정을 도출해내는 게 핵심이지 이건 사실상 별 게 없다.
유의 사항으로 다음을 꼽을 수 있겠다.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | ll HLDQuery(int p, int q) { T x(seg[0]), y(seg[0]); while (top[p] ^ top[q]) { if (D[top[p]] > D[top[q]]) x = merge(SegQuery(1, 1, n, in[top[p]], in[p]), x), p = P[top[p]]; else y = merge(SegQuery(1, 1, n, in[top[q]], in[q]), y), q = P[top[q]]; } if (D[p] > D[q]) swap(p, q), swap(x, y); swap(x.ls, x.rs); return merge(x, merge(SegQuery(1, 1, n, in[p], in[q]), y)).ms; } | cs |
두번째 유의 사항은 당연한 것이,
양수만 있다면 모르겠지만 음수의 경우 포함하는 구간의 개수를 최소화하는 것이 이득이기 때문이다.
자세한 건 아래 코드를 참조하자.
전체 코드
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | #include<bits/stdc++.h> #define ll long long #define N 200001 using namespace std; struct T { ll ls, rs, ps, ms; } seg[1 << 19]; const int inf(-1e9); vector <int> Gr[N], G[N]; vector <int> lz(1 << 19, inf); int top[N], in[N], C[N]; int D[N], P[N], S[N]; int n, q; T merge(T L, T R) { return { max({L.ls, L.ps + R.ls}), max({R.rs, R.ps + L.rs}), L.ps + R.ps, max({ L.ms, R.ms, L.rs + R.ls }) }; } void Input() { ios_base::sync_with_stdio(0); cin.tie(0); seg[0] = { inf, inf, 0, inf }; cin >> n >> q; for (int i{}; i < n; cin >> C[++i]); for (int i{}; ++i < n;) { int s, e; cin >> s >> e; Gr[s].push_back(e); Gr[e].push_back(s); } } void HLDInit(int now) { S[now] = 1; for (int& next : Gr[now]) if (!S[next]) { D[next] = D[now] + 1; P[next] = now; HLDInit(next); S[now] += S[next]; G[now].push_back(next); if (S[next] > S[G[now][0]]) swap(G[now][0], G[now].back()); } } int cnt{}; void ETT(int now) { in[now] = ++cnt; for (int& next : G[now]) { top[next] = next == G[now][0] ? top[now] : next; ETT(next); } } void Propagation(int n, int s, int e) { if (lz[n] ^ inf) { if (s ^ e) lz[n << 1] = lz[n], lz[n << 1 | 1] = lz[n]; seg[n].ps = lz[n] * (e - s + 1); seg[n].ls = seg[n].rs = seg[n].ms = max(seg[n].ps, (ll)lz[n]); lz[n] = inf; } } T SegUpdate(int n, int s, int e, int l, int r, int w) { Propagation(n, s, e); if (s > r || e < l) return seg[n]; if (s >= l && e <= r) { lz[n] = w; Propagation(n, s, e); return seg[n]; } return seg[n] = merge(SegUpdate(n << 1, s, (s + e) >> 1, l, r, w), SegUpdate(n << 1 | 1, ((s + e) >> 1) + 1, e, l, r, w)); } T SegQuery(int n, int s, int e, int l, int r) { Propagation(n, s, e); if (s > r || e < l) return seg[0]; if (s >= l && e <= r) return seg[n]; return merge(SegQuery(n << 1, s, (s + e) >> 1, l, r), SegQuery(n << 1 | 1, ((s + e) >> 1) + 1, e, l, r)); } void HLDUdate(int p, int q, int w) { while (top[p] ^ top[q]) { if (D[top[p]] < D[top[q]]) swap(p, q); SegUpdate(1, 1, n, in[top[p]], in[p], w); p = P[top[p]]; } if (D[p] > D[q]) swap(p, q); SegUpdate(1, 1, n, in[p], in[q], w); } ll HLDQuery(int p, int q) { T x(seg[0]), y(seg[0]); while (top[p] ^ top[q]) { if (D[top[p]] > D[top[q]]) x = merge(SegQuery(1, 1, n, in[top[p]], in[p]), x), p = P[top[p]]; else y = merge(SegQuery(1, 1, n, in[top[q]], in[q]), y), q = P[top[q]]; } if (D[p] > D[q]) swap(p, q), swap(x, y); swap(x.ls, x.rs); return merge(x, merge(SegQuery(1, 1, n, in[p], in[q]), y)).ms; } void Solve() { for (int i(1); i <= n; i++) SegUpdate(1, 1, n, in[i], in[i], C[i]); for (int o, i, j, k; q--;) { cin >> o >> i >> j >> k; if (o & 1) HLDUdate(i, j, k); else cout << HLDQuery(i, j) << '\n'; } } int main() { Input(); HLDInit(1); ETT(1); Solve(); } | cs |
comment
$AC$로의 길은 험난했지만.. 첫 $Diamond$ $3$ 문제 격파라 기억에 남을 것 같다.
'Problem Solving > Baekjoon Online Judge (Diamond)' 카테고리의 다른 글
백준 15249 - Building Bridges (C++) (0) | 2023.04.26 |
---|---|
백준 14180 - 배열의 특징 (C++) (0) | 2023.04.26 |
백준 12795 - 반평면 땅따먹기 (C++) (0) | 2023.04.26 |
백준 4008 - 특공대 (C++) (0) | 2023.04.23 |
백준 18277 - Bliski Brojevi (C++) (0) | 2023.04.23 |