본문 바로가기

Problem Solving/Baekjoon Online Judge (Diamond)

백준 23836 - 어떤 우유의 배달목록 (Hard) (C++)

문제

  • 문제 링크
  •   
       BOJ 23836 - 어떤 우유의 배달목록 (Hard) 
      
  • 문제 요약
  • $N$개의 노드로 이루어진 트리가 주어진다.
    아래 두 쿼리를 수행하는 프로그램을 작성하자.

  • $1$ $u$ $v$ : $u$번 노드부터 $v$번 노드에 이르는 경로 상에서, $i$번째로 등장하는 노드에 $i$를 더한다.
  • $2$ $x$ : 지금까지 $x$번 노드에 더해진 값을 출력한다.
  • 제한
  • TL : $2$ sec, ML : $512$ MB

  • $1 ≤ N, Q ≤ 10^5$
  • $2$번 쿼리는 최소 한 개 이상 주어진다.

알고리즘 분류

  • 자료 구조 (data_structures)
  • 트리 (trees)
  • heavy-light 분할 (heavy-light decomposition)
  • 세그먼트 트리 (segment tree)
  • 스택 (stack)

풀이

Eazy 버전을 풀 때 한 번 언급했었는데, 잊고 있다가 지금에서야 풀었다..

구간에 등차수열을 더하는 문제를 트리 위에서 풀기 전에, 선형 구조에서 풀어본 적이 없다면 이 문제 를 먼저 풀고 오도록 하자.

이제 트리 위에서 풀어야 하는데, 곰곰이 생각해보면 생각보다 귀찮은 요소가 좀 있다.

대부분 경로 단위 쿼리를 처리해야 하므로 $hld$로 트리를 분할하고 체인 단위로 쿼리를 날릴 것이다.
  • 등차수열을 더할텐데, 체인 단위로 올릴 때 연속성을 어떻게 보장할 것인가 ?
  • 수열에 방향성이 있는데, 세그먼트 트리 상에서 어떻게 관리할 것인가 ?




  • 보통 $hld$에서 체인 단위로 올려주며 쿼리를 처리할 때, 매 순간 깊이가 더 낮은 체인을 끌어올려 주었다.
    이를 그대로 적용할 경우, 연속성을 반드시 보장해야 하는 이번 문제에서 연속성을 보장받지 못한다.

    그렇기 때문에 나는 $u$에서부터 이어져오는 구간, $v$에서부터 이어져오는 구간, $u, v$가 한 체인에 속했을 때의 구간으로 나누어 처리하였다.

    $v$에서부터 이어져오는 구간 처리는 처리 순서가 뒤집어져야 하므로 스택을 이용했다.

    하나 또 중요한 것이, 이전 체인에서 $1, 2, 3, 4$를 더했다면 그 다음 체인에서는 $5, 6, 7, ...$를 더해야 한다.
    이는 수열의 시작값을 보정해줄 변수 하나를 추가로 두어 해결할 수 있다.



    마지막으로 하나 또 신경 쓸 것이 있다.

    수열을 항상 $a, a + 1, a + 2, ...$의 형태로 더하면 참 좋겠지만 두 노드의 높이 관계에 따라 $a, a - 1, a - 2, ...$의 형태를 더해야 할 수도 있다.
    (물론 결과적으로 하나의 등차수열을 더하는 것이지만 $ETT$ $Ordering$ 입장에서 봤을 때 그렇다는 것)

    이는 세그먼트 트리에서 정방향, 역방향 수열을 각각 관리해주고, 합쿼리에서 하나로 더해주면 된다.

    자세한 건 아래 코드의 $update$ 함수들을 중점적으로 보면 되겠다.

    전체 코드


    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    #include<bits/stdc++.h>
    #define ll long long
    #define N 100001
    using namespace std;
     
    vector <int> edge[N], tree[N];
    int top[N], in[N], par[N], dep[N], sz[N];
     
    ll n, q, seg[1 << 18][4];
     
    void Input()
    {
        ios_base::sync_with_stdio(0); cin.tie(0);
        cin >> n;
        for (int i{}, u, v; ++< n;)
        {
            cin >> u >> v;
            edge[u].push_back(v);
            edge[v].push_back(u);
        }
    }
    void hldInit(int now)
    {
        sz[now] = 1;
        for (int& next : edge[now])
            if (!sz[next])
            {
                dep[next] = dep[now] + 1;
                par[next] = now;
                hldInit(next);
                sz[now] += sz[next];
                tree[now].push_back(next);
     
                if (sz[tree[now][0]] < sz[next]) swap(tree[now][0], tree[now].back());
            }
    }
    void hldETT(int now)
    {
        in[now] = ++q;
        for (int& next : tree[now])
        {
            top[next] = next ^ tree[now][0] ? next : top[now];
            hldETT(next);
        }
    }
    void segUpdate(int n, int s, int e, int l, int r, int op, int cur)
    {
        if (s > r || e < l) return;
        if (s >= l && e <= r)
        {
            if (op & 1)
            {
                seg[n][0+= cur + s - l + 1;
                seg[n][1]++;
            }
            else
            {
                seg[n][2+= cur + r - e + 1;
                seg[n][3]++;
            }
            return;
        }
        int m(s + e >> 1);
        segUpdate(n << 1, s, m, l, r, op, cur);
        segUpdate(n << 1 | 1, m + 1, e, l, r, op, cur);
    }
    ll segQuery(int n, int s, int e, int i)
    {
        if (s > i || e < i) return 0;
     
        ll val(seg[n][0+ seg[n][2]);
        if (s == e) return val;
        val += seg[n][1* (i - s) + seg[n][3* (e - i);
     
        int m(s + e >> 1);
        return segQuery(n << 1, s, m, i) + segQuery(n << 1 | 1, m + 1, e, i) + val;
    }
    int getLCA(int x, int y)
    {
        for (; top[x] ^ top[y]; x = par[top[x]])
            if (dep[top[x]] < dep[top[y]])
                swap(x, y);
        return dep[x] < dep[y] ? x : y;
    }
    void hldUpdate(int u, int v)
    {
        int lca(getLCA(u, v)), cur(-1);
        while (top[u] ^ top[lca])
        {
            segUpdate(11, n, in[top[u]], in[u], 2, cur);
            cur += in[u] - in[top[u]] + 1;
            u = par[top[u]];
        }
        stack <pair<intint>> S;
        while (top[v] ^ top[lca])
        {
            S.push({ top[v], v });
            v = par[top[v]];
        }
     
        if (dep[u] > dep[v])
        {
            segUpdate(11, n, in[v], in[u], 2, cur);
            cur += in[u] - in[v] + 1;
        }
        else
        {
            segUpdate(11, n, in[u], in[v], 1, cur);
            cur += in[v] - in[u] + 1;
        }
     
        while (S.size())
        {
            auto [_py, _y](S.top()); S.pop();
            segUpdate(11, n, in[_py], in[_y], 1, cur);
            cur += in[_y] - in[_py] + 1;
        }
    }
    void Query()
    {
        cin >> q;
        for (int op, u, v; q--;)
        {
            cin >> op >> u;
            if (op & 1)
            {
                cin >> v;
                hldUpdate(u, v);
            }
            else
                cout << segQuery(11, n, in[u]) << '\n';
        }
    }
    int main()
    {
        Input();
        hldInit(1);
        hldETT(1);
        Query();
    }
    cs


    comment

    풀이 방향은 금방 잡았는데 구현이 생각보다 까다로웠다.
    $hld$에 대한 높은 이해도를 요구하는 좋은 문제라고 생각한다.