본문 바로가기

Problem Solving/Baekjoon Online Judge (Platinum)

백준 13038 - Tree (C++)

문제

  • 문제 링크
  •   
       BOJ 13038 - Tree 
      
  • 문제 요약
  • $1$ $x$ $y$ : 정점 $x$부터 정점 $y$까지 이르는 경로의 길이를 출력한다.
  • $2$ $x$ : 정점 $x$를 제거한다. 이때, 정점 $x$의 자식들은 $x$의 부모 정점으로 전이된다.

  • $N$개의 정점으로 이루어진 트리와 위와 같은 $Q$개의 쿼리가 주어진다.
    $1$번 쿼리가 주어질 때마다 그에 맞는 답을 출력하자.
  • 제한
  • TL : $2$ sec, ML : $512$ MB

  • $1 ≤ N, Q ≤ 10^6$
  • 반드시 올바른 트리가 입력으로 주어지며, 트리의 루트$(1)$는 제거되지 않음이 보장된다.

알고리즘 분류

  • 자료 구조(data structures)
  • 트리(trees)
  • heavy-light 분할 (heavy-light decomposition)
  • 세그먼트 트리(segment tree)

풀이

임의의 두 정점 사이 거리는, 간선의 가중치가 없어 경로 상 등장하는 정점의 개수를 보고 계산할 수 있다.
또, 어떤 정점이 제거될 때 서브 트리를 부모 정점으로 전이시켜주어 트리의 모양을 유지한다.

이에 따라 두 쿼리를 아래와 같이 생각해 볼 수 있다.

  • 최초 각 정점마다 $1$의 가중치를 부여하자.
  • 그럼 어떤 정점을 제거한다는 것은, 해당 정점의 가중치를 $0$으로 바꿔주는 것과 같다.
  • 각 정점의 상태 및 구간을 세그먼트 트리로 관리한다면, $1$번 쿼리는 단순히 합쿼리를 날리는 것과 같다.

  • 결과적으로 $hld$ + $segment$ $tree$를 이용해 문제를 약 $O(Qlog^2N)$에 풀 수 있다.

    전체 코드


    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    #include<bits/stdc++.h>
    #define N 100001
    using namespace std;
     
    vector <int> Gr[N], G[N];
    int par[N], dep[N], sz[N];
    int top[N], in[N];
    int n, seg[1 << 18];
     
    void Input()
    {
        ios_base::sync_with_stdio(0); cin.tie(0);
        cin >> n;
        for (int i(2); i <= n; i++)
        {
            int j; cin >> j;
            Gr[j].push_back(i);
        }
    }
    void hldInit(int now)
    {
        sz[now] = 1;
        for (int& next : Gr[now])
        {
            dep[next] = dep[now] + 1;
            par[next] = now;
            hldInit(next);
            sz[now] += sz[next];
            G[now].push_back(next);
            if (sz[G[now][0]] < sz[next]) swap(G[now][0], G[now].back());
        }
    }
    int cur;
    void hldETT(int now)
    {
        in[now] = ++cur;
        for (int& next : G[now])
        {
            top[next] = next ^ G[now][0] ? next : top[now];
            hldETT(next);
        }
    }
    int segUpdate(int n, int l, int r, int x, int v)
    {
        if (l > x || r < x) return seg[n];
        if (l == r) return seg[n] = v;
        int m(l + r >> 1);
        return seg[n] = segUpdate(n << 1, l, m, x, v) + segUpdate(n << 1 | 1, m + 1, r, x, v);
    }
    int segQuery(int n, int l, int r, int x, int y)
    {
        if (l > y || r < x) return 0;
        if (l >= x && r <= y) return seg[n];
        int m(l + r >> 1);
        return segQuery(n << 1, l, m, x, y) + segQuery(n << 1 | 1, m + 1, r, x, y);
    }
    int hldQuery(int x, int y, int res = 0)
    {
        while (top[x] ^ top[y])
        {
            if (dep[top[x]] < dep[top[y]]) swap(x, y);
            res += segQuery(11, n, in[top[x]], in[x]);
            x = par[top[x]];
        }
        if (dep[x] > dep[y]) swap(x, y);
        return res + segQuery(11, n, in[x] + 1, in[y]);
    }
    void Solve()
    {
        for (int i(1); i <= n; i++)
            segUpdate(11, n, in[i], 1);
     
        int q; cin >> q;
        for (int o, x, y; q--;)
        {
            cin >> o;
            if (o & 1)
            {
                cin >> x >> y;
                cout << hldQuery(x, y) << '\n';
            }
            else
            {
                cin >> x;
                segUpdate(11, n, in[x], 0);
            }
        }
    }
    int main()
    {
        Input();
        hldInit(1);
        hldETT(1);
        Solve();
    }
    cs


    comment