본문 바로가기

Problem Solving/Baekjoon Online Judge (Platinum)

백준 23979 - 트리의 재구성 (C++)

문제

  • 문제 링크
  •   
       BOJ 23979 - 트리의 재구성 
      
  • 문제 요약
  • $N$개의 정점으로 이루어진 트리와, 아래와같은 $Q$개의 쿼리가 주어진다.
    각 쿼리의 답을 출력하자.

  • $u$ $v$ $w$ $a$ $b$ : $u$에서 $v$로 가는 비용이 $w$인 간선을 잇고, 지웠을 때 트리가 되는 간선 중 새로 추가한 간선을 제외한 비용이 가장 큰 간선 하나를 지운다. 재구성된 트리의 $a$에서 $b$로 가는 경로의 비용을 출력한다.
  • 제한
  • TL : $1$ sec, ML : $512$ MB

  • $2 ≤ N, Q ≤ 200,000$
  • $1 ≤ C_i, W_i ≤ 200,000$

알고리즘 분류

  • 자료 구조 (data_structures)
  • 트리 (trees)
  • 최소 공통 조상 (lowest common ancestor)
  • 희소 배열 (sparse table)

풀이

트리에서 두 정점을 잇는 경로 상의 최대 가중치를 가진 간선을 빠르게 찾아야 한다.
나아가 가중치뿐만 아니라 어떤 정점을 잇는 간선이었는지까지 알 수 있어야 할 듯 하다.

이는 기존 $LCA$를 위한 전처리에서 추가적인 $Sparse$ $Table$ 을 정의해주어 $O(logN)$ 수준으로 찾아낼 수 있다.
구체적으로 아래와 같다.

  • $dp[x][y]$ : 정점 $x$$2^y$번째 부모 정점.
  • $mdp[x][y]$ : 정점 $x$$x$$2^y$번째 부모 정점을 잇는 경로 상에, 가장 큰 가중치를 가진 간선의 가중치와 그 간선이 잇는 두 정점 중 깊이가 더 깊은 정점.

  • 이들은 $O(NlogN)$의 복잡도로 전처리해 줄 수 있다.



    이제 쿼리를 어떻게 잘 케웍해 볼 생각을 하며 끄적여 보는데, 무슨 예외가 가지치기마냥 계속해서 나왔다.

    결국 계산을 편하게 할 수 있도록 $u, v, a, b$를 적절히 재배열해 해결할 수 있었다.


    임의의 쿼리에서, $u→v$의 $mdp[][]$ 결과$x$라고 하자. $u→v$는 정점 $u$에서 정점 $v$로의 경로라는 뜻이다.
    $x_w$는 $u→v$에서 가중치가 가장 큰 간선의 가중치, $x_u$는 그 간선이 잇는 두 정점 중 깊이가 더 깊은 정점이다.

    우선 $u→v$에서 간선이 어찌 되던간에 $a→b$에 영향이 없을 조건은 간단하다.
    바로 $x_u$가 $a→b$ 위에 존재하지 않는 것.

    이는 $lca(x_u,$ $a)$, $lca(x_u,$ $b)$, $lca(x_u,$ $lca(a,$ $b))$의 결과를 가지고 적절히 따져주어 판단할 수 있다.

    영향이 없다면 단순히 $dist[a] + dist[b] - 2 * dist[lca(a, b)]$의 식으로 답을 계산할 수 있다.



    이제 영향이 있을 때의 식을 정리하자.
    나는 앞서 말했듯 일관된 계산을 위해 다음의 일반성을 유지하였다.

  • $x_u$는 반드시 $u$→$lca(u, v)$ 위에 존재한다.
  • $x_u$는 반드시 $a$→$lca(a, b)$ 위에 존재한다.

  • 이로 인해 $u$는 반드시 $a$와의 관계만을, $v$는 반드시 $b$와의 관계만을 따져줄 수 있게 된다.

    즉, 이 경우 답은 $dist(u, a)$ $+$ $dist(v, b)$ $+$ $w$ 로 쉽게 정리된다.

    자세한 건 아래 코드를 참고하자.

    전체 코드


    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    #include<bits/stdc++.h>
    #define N 100001
    using namespace std;
     
    vector <pair<intint>> Gr[N];
    struct S { int u, w; } mdp[N][18];
    int n, q, dep[N], dp[N][18];
    long long dist[N];
     
    void Input()
    {
        ios_base::sync_with_stdio(0); cin.tie(0);
        cin >> n >> q;
     
        for (int u, v, w, i{}; ++< n;)
        {
            cin >> u >> v >> w;
            Gr[u].emplace_back(v, w);
            Gr[v].emplace_back(u, w);
        }
    }
    void LCAInit(int now, int depth)
    {
        dep[now] = depth;
        for(auto& [next, weight] : Gr[now])
            if (!dep[next])
            {
                dist[next] = dist[now] + weight;
                mdp[next][0= { next, weight };
                dp[next][0= now;
                
                LCAInit(next, depth + 1);
            }
    }
    void TableDP()
    {
        for(int j(1); j < 18; j++)
            for (int i(1); i <= n; i++)
            {
                dp[i][j] = dp[dp[i][j - 1]][j - 1];
                mdp[i][j] = mdp[dp[i][j - 1]][j - 1];
     
                if (mdp[i][j].w < mdp[i][j - 1].w)
                    mdp[i][j] = mdp[i][j - 1];
            }
    }
    pair <S, int> getLCA(int u, int v)
    {
        if (dep[u] < dep[v]) swap(u, v);
        S temp{};
     
        int diff(dep[u] - dep[v]);
        for (int i{}; diff; diff >>= 1, i++)
            if (diff & 1)
            {
                if (temp.w < mdp[u][i].w) temp = mdp[u][i];
                u = dp[u][i];
            }
        if (u ^ v)
        {
            for(int i(17); !!~i; i--)
                if (dp[u][i] ^ dp[v][i])
                {
                    if (temp.w < mdp[u][i].w) temp = mdp[u][i];
                    if (temp.w < mdp[v][i].w) temp = mdp[v][i];
                    u = dp[u][i], v = dp[v][i];
                }
            if (temp.w < mdp[u][0].w) temp = mdp[u][0];
            if (temp.w < mdp[v][0].w) temp = mdp[v][0];
            u = dp[u][0];
        }
        return { temp, u };
    }
    int Lifting(int u, int diff)
    {
        for (int i{}; diff; diff >>= 1, i++)
            if (diff & 1)
                u = dp[u][i];
        return u;
    }
    void Query()
    {
        while (q--)
        {
            int u, v, w, a, b;
            cin >> u >> v >> w >> a >> b;
     
            auto [x, lca_uv](getLCA(u, v));
            auto [y, lca_ab](getLCA(a, b));
     
            if (dep[v] >= dep[x.u] && Lifting(v, dep[v] - dep[x.u]) == x.u) swap(u, v);
     
            if ((getLCA(a, x.u).second == x.u || getLCA(b, x.u).second == x.u) && lca_ab ^ x.u && getLCA(lca_ab, x.u).second == lca_ab)
            {
                if (dep[b] >= dep[x.u] && Lifting(b, dep[b] - dep[x.u]) == x.u) swap(a, b);
                
                auto [p, lca_ua](getLCA(u, a));
                auto [q, lca_vb](getLCA(v, b));
     
                cout << dist[u] + dist[a] + dist[v] + dist[b] - 2 * (dist[lca_ua] + dist[lca_vb]) + w << '\n';
            }
            else
                cout << dist[a] + dist[b] - 2 * dist[lca_ab] << '\n';
        }
    }
    int main()
    {
        Input();
        LCAInit(11);
        TableDP();
        Query();
    }
    cs


    comment

    재밌었다.