본문 바로가기

Problem Solving/Baekjoon Online Judge (Diamond)

백준 22487 - Do use segment tree (C++)

문제

  • 문제 링크
  •   
       BOJ 22487 - Do use segment tree 
      
  • 문제 요약
  • $N$개의 정점으로 이루어진 트리와 두가지 유형으로 이루어진 $Q$개의 쿼리가 주어진다.
    각 쿼리를 알맞게 처리해보자.
  • 제한
  • TL : $2$ sec, ML : $512$ MB

  • $1 ≤ N ≤ 200,000$
  • $1 ≤ Q ≤ 100,000$
  • $-10,000 ≤ w_i ≤ 10,000$

알고리즘 분류

  • 구현 (implemantation)
  • 자료 구조 (data structures)
  • 트리 (trees)
  • heavy-light 분할 (heavy-light decomposition)
  • 세그먼트 트리 (segment tree)
  • 느리게 갱신되는 세그먼트 트리 (lazy propagation)

풀이

BOJ 16993 - 연속합과 쿼리
BOJ 15561 - 구간 합 최대? 2

간단히 말하면, 위와 같은 문제들을 트리 위에서 풀어야 한다는 것이다.
흔히 말하는 금광 세그, 구조체 세그와 같은 유형이다.

선형 구조가 아닌 트리 구조임에 따라, 경로 상 쿼리를 선형 구조처럼 풀 수 있게 해주는 $hld$를 이용하자.
또, 구간 단위 업데이트 쿼리를 처리해야 하므로 $lazy$를 적절히 뿌려주도록 하자.

앞서 언급한 문제들을 풀어봤다면 세그먼트 트리 상에서 관리해야 할 값을 알고 있을 것이다.
이부분에 관한 설명은 생략.



우리가 구하고자 하는건 '최대 연속합' 이다. 즉, 연속성을 유지해야 한다.

$hld$에서 경로 상 쿼리를 처리하는 방식을 떠올려보자.
두 정점이 같은 체인에 속할 때까지, 매 순간마다 깊이가 더 깊은 체인를 끌어올려 주며 진행한다.

이에 따라, 단순히 위 과정에서 $merge$를 수행한다면 연속성을 보장받지 못하게 된다.

따라서 두 정점 $a_i$, $b_i$에 대해 $a_i$부터 이어져오는 최대 연속합, $b_i$부터 이어져오는 최대 연속합을 각각 관리해줄 것을 생각하자.



위 과정에서 가장 중요한 것은 일관된 방향(순서)을 잡아줘야 한다는 것이다.
나는 이런 저런 시행 착오를 거치다 최종적으로 아래의 코드가 탄생했다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
ll HLDQuery(int p, int q)
{
    T x(seg[0]), y(seg[0]);
    while (top[p] ^ top[q])
    {
        if (D[top[p]] > D[top[q]])
            x = merge(SegQuery(11, n, in[top[p]], in[p]), x), p = P[top[p]];
        else
            y = merge(SegQuery(11, n, in[top[q]], in[q]), y), q = P[top[q]];
    }
    if (D[p] > D[q])
        swap(p, q), swap(x, y);
    swap(x.ls, x.rs);
    return merge(x, merge(SegQuery(11, n, in[p], in[q]), y)).ms;
}
cs

뭔가 복잡하게 결론이 난 것 처럼 보이긴 한데.. 좀 더 깔끔한 방식이 있을 것도 같긴 하다.



이제 경로 단위 업데이트 쿼리만이 남았는데, 위 과정을 도출해내는 게 핵심이지 이건 사실상 별 게 없다.
유의 사항으로 다음을 꼽을 수 있겠다.

  • 존재할 수 있는 가중치의 범위. ($lazy$의 기저 세팅 값을 따로 지정)
  • $lazy$값을 노드에 뿌려줄 때, 구간 합을 제외한 나머지는 $max(lz[n], lz[n] * (e - s + 1))$ 의 값을 취하자.

  • 두번째 유의 사항은 당연한 것이,
    양수만 있다면 모르겠지만 음수의 경우 포함하는 구간의 개수를 최소화하는 것이 이득이기 때문이다.

    자세한 건 아래 코드를 참조하자.

    전체 코드


    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    #include<bits/stdc++.h>
    #define ll long long
    #define N 200001
    using namespace std;
     
    struct T { ll ls, rs, ps, ms; } seg[1 << 19];
    const int inf(-1e9);
     
    vector <int> Gr[N], G[N];
    vector <int> lz(1 << 19, inf);
     
    int top[N], in[N], C[N];
    int D[N], P[N], S[N];
    int n, q;
     
    T merge(T L, T R) { return { max({L.ls, L.ps + R.ls}), max({R.rs, R.ps + L.rs}), L.ps + R.ps, max({ L.ms, R.ms, L.rs + R.ls }) }; }
    void Input()
    {
        ios_base::sync_with_stdio(0); cin.tie(0);
        seg[0= { inf, inf, 0, inf };
        cin >> n >> q;
     
        for (int i{}; i < n; cin >> C[++i]);
        for (int i{}; ++< n;)
        {
            int s, e; cin >> s >> e;
            Gr[s].push_back(e);
            Gr[e].push_back(s);
        }
    }
    void HLDInit(int now)
    {
        S[now] = 1;
        for (int& next : Gr[now])
            if (!S[next])
            {
                D[next] = D[now] + 1;
                P[next] = now;
                HLDInit(next);
                S[now] += S[next];
                G[now].push_back(next);
                if (S[next] > S[G[now][0]]) swap(G[now][0], G[now].back());
            }
    }
    int cnt{};
    void ETT(int now)
    {
        in[now] = ++cnt;
        for (int& next : G[now])
        {
            top[next] = next == G[now][0] ? top[now] : next;
            ETT(next);
        }
    }
    void Propagation(int n, int s, int e)
    {
        if (lz[n] ^ inf)
        {
            if (s ^ e)
                lz[n << 1= lz[n], lz[n << 1 | 1= lz[n];
            seg[n].ps = lz[n] * (e - s + 1);
            seg[n].ls = seg[n].rs = seg[n].ms = max(seg[n].ps, (ll)lz[n]);
            lz[n] = inf;
        }
    }
    T SegUpdate(int n, int s, int e, int l, int r, int w)
    {
        Propagation(n, s, e);
        if (s > r || e < l) return seg[n];
        if (s >= l && e <= r)
        {
            lz[n] = w;
            Propagation(n, s, e);
            return seg[n];
        }
        return seg[n] = merge(SegUpdate(n << 1, s, (s + e) >> 1, l, r, w), SegUpdate(n << 1 | 1, ((s + e) >> 1+ 1, e, l, r, w));
    }
    T SegQuery(int n, int s, int e, int l, int r)
    {
        Propagation(n, s, e);
        if (s > r || e < l) return seg[0];
        if (s >= l && e <= r) return seg[n];
        return merge(SegQuery(n << 1, s, (s + e) >> 1, l, r), SegQuery(n << 1 | 1, ((s + e) >> 1+ 1, e, l, r));
    }
    void HLDUdate(int p, int q, int w)
    {
        while (top[p] ^ top[q])
        {
            if (D[top[p]] < D[top[q]]) swap(p, q);
            SegUpdate(11, n, in[top[p]], in[p], w);
            p = P[top[p]];
        }
        if (D[p] > D[q]) swap(p, q);
        SegUpdate(11, n, in[p], in[q], w);
    }
    ll HLDQuery(int p, int q)
    {
        T x(seg[0]), y(seg[0]);
        while (top[p] ^ top[q])
        {
            if (D[top[p]] > D[top[q]])
                x = merge(SegQuery(11, n, in[top[p]], in[p]), x), p = P[top[p]];
            else
                y = merge(SegQuery(11, n, in[top[q]], in[q]), y), q = P[top[q]];
        }
        if (D[p] > D[q])
            swap(p, q), swap(x, y);
        swap(x.ls, x.rs);
        return merge(x, merge(SegQuery(11, n, in[p], in[q]), y)).ms;
    }
    void Solve()
    {
        for (int i(1); i <= n; i++)
            SegUpdate(11, n, in[i], in[i], C[i]);
        for (int o, i, j, k; q--;)
        {
            cin >> o >> i >> j >> k;
            if (o & 1)
                HLDUdate(i, j, k);
            else
                cout << HLDQuery(i, j) << '\n';
        }
    }
    int main()
    {
        Input();
        HLDInit(1);
        ETT(1);
        Solve();
    }
    cs


    comment

    $AC$로의 길은 험난했지만.. 첫 $Diamond$ $3$ 문제 격파라 기억에 남을 것 같다.