본문 바로가기

Problem Solving/Baekjoon Online Judge (Gold)

백준 2213 - 트리의 독립집합 (C++)

문제

  • 문제 링크
  •   
       BOJ 2213 - 트리의 독립집합 
      
  • 문제 요약
  • $N$개의 정점으로 이루어진 트리가 주어진다.
    이 트리의 '최대 독립집합'의 크기와 이를 구성하는 정점들을 구해보자.
  • 제한
  • TL : $2$ sec, ML : $128$ MB

  • $1 ≤ N, C_i ≤ 10,000$

알고리즘 분류

  • 다이나믹 프로그래밍(dp)
  • 트리(trees)
  • 트리에서의 다이나믹 프로그래밍(dp_tree)

풀이

트리$dp$ 역추적의 바이블같은 문제이다.

임의의 정점이 가질 수 있는 상태는 결국 둘 중 하나다. 최대 독립집합에 포함 되었을 때( $1$ ) ? 아니었을 때( $0$ ) ?
이에 따라 다음과 같이 점화식을 정의해보자.

$dp[i][j] :$ $i$번 정점을 루트로 하는 서브 트리에서, $i$번 정점의 포함 상태가 $j$일 때의 답.

  • $j$ == $1$ 이라면, $i$와 이어진 하위 정점 $k$에 대해선 $dp[k][0]$ 으로 밖에 분기할 수 없다.
  • $j$ == $0$ 이라면, $i$와 이어진 하위 정점 $k$에 대해서 $dp[k][0]$ 은 물론 $dp[k][1]$ 로도 분기할 수 있다.


  • 트리 위 역추적이 처음이라면 살짝 구현이 어색할 수 있는데, 기록된 $dp$ 값들을 $dfs$ 로 차근차근 따라가 주면 된다.

  • $j$ == $1$ 이라면, $i$와 이어진 하위 정점 $k$에 대해선 $dp[k][0]$ 으로 밖에 분기할 수 없다.
  • $j$ == $0$ 이면서, $i$와 이어진 하위 정점 $k$가 $dp[k][1] > dp[k][0]$ 도 만족 한다면 정점 $k$는 최대 독립집합에 포함 된다.

  • 위 두번째 내용은 아래 코드를 보는 것이 이해가 빠를 듯 하다.

    전체 코드


    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    #include<bits/stdc++.h>
    #define N 10001
    using namespace std;
     
    vector <int> an, Gr[N];
    int n, r, c[N], dp[N][2];
     
    int f(int p, int q, int o)
    {
        int& t(dp[p][o]);
        if (!~t)
        {
            t = o ? c[p] : 0;
            for (int i : Gr[p])
                if (i ^ q)
                    t += max(f(i, p, 0), o ? 0 : f(i, p, 1));
        }
        return t;
    }
    void Solve(int p, int q, int o)
    {
        if (o) an.push_back(p);
        for (int i : Gr[p])
            if (i ^ q)
                Solve(i, p, o ? 0 : dp[i][1> dp[i][0]);
    }
    int main()
    {
        ios_base::sync_with_stdio(0); cin.tie(0);
        memset(dp, -1sizeof dp); cin >> n;
        for (int i(1); i <= n; cin >> c[i++]);
        for (int i, j, o{}; ++< n;)
            cin >> i >> j, Gr[i].push_back(j), Gr[j].push_back(i);
        cout << max(f(100), f(101)) << '\n';
     
        Solve(10, dp[1][1> dp[1][0]);
        sort(an.begin(), an.end());
     
        for (int i : an) cout << i << ' ';
    }
    cs


    comment