[알고리즘] 할로윈의 양아치

난이도

Gold 3

문제

Trick or Treat!!

10월 31일 할로윈의 밤에는 거리의 여기저기서 아이들이 친구들과 모여 사탕을 받기 위해 돌아다닌다. 올해 할로윈에도 어김없이 많은 아이가 할로윈을 즐겼지만 단 한 사람, 일찍부터 잠에 빠진 스브러스는 할로윈 밤을 즐길 수가 없었다. 뒤늦게 일어나 사탕을 얻기 위해 혼자 돌아다녀 보지만 이미 사탕은 바닥나 하나도 얻을 수 없었다.

단단히 화가 난 스브러스는 거리를 돌아다니며 다른 아이들의 사탕을 빼앗기로 마음을 먹는다. 다른 아이들보다 몸집이 큰 스브러스에게 사탕을 빼앗는 건 어렵지 않다. 또한, 스브러스는 매우 공평한 사람이기 때문에 한 아이의 사탕을 뺏으면 그 아이 친구들의 사탕도 모조리 뺏어버린다. (친구의 친구는 친구다?!)

사탕을 빼앗긴 아이들은 거리에 주저앉아 울고 $K$명 이상의 아이들이 울기 시작하면 울음소리가 공명하여 온 집의 어른들이 거리로 나온다. 스브러스가 어른들에게 들키지 않고 최대로 뺏을 수 있는 사탕의 양을 구하여라.

스브러스는 혼자 모든 집을 돌아다녔기 때문에 다른 아이들이 받은 사탕의 양을 모두 알고 있다. 또한, 모든 아이는 스브러스를 피해 갈 수 없다.

입력 첫째 줄에 정수 $N$, $M$, $K$가 주어진다. $N$은 거리에 있는 아이들의 수, $M$은 아이들의 친구 관계 수, $K$는 울음소리가 공명하기 위한 최소 아이의 수이다. ($1 \leq N \leq 30\ 000$, $0 \leq M \leq 100\ 000$, $1 \leq K \leq \min $( N, 3000 )

둘째 줄에는 아이들이 받은 사탕의 수를 나타내는 정수 $c_1, c_2, \cdots, c_N$이 주어진다. ($1 \leq c_i \leq 10\ 000$)

셋째 줄부터 $M$개 줄에 갈쳐 각각의 줄에 정수 $a$, $b$가 주어진다. 이는 $a$와 $b$가 친구임을 의미한다. 같은 친구 관계가 두 번 주어지는 경우는 없다. ( $1 \leq a, b \leq N$, $a \neq b$)

출력

스브러스가 어른들에게 들키지 않고 아이들로부터 뺏을 수 있는 최대 사탕의 수를 출력한다.

예제 입력

10 6 6
9 15 4 4 1 5 19 14 20 5
1 3
2 5
4 9
6 2
7 8
6 10

예제 출력

57

해설 및 후기

union-find와 dp(냅색 문제)를 복합적으로 이용했어야 하는 문제이다. 꽤나 어려웠다. 우선 각 간선을 받는 대로 union하여 집합을 구성했고, 그렇게 되면 같은 집합은 같은 parent를 갖게 될 것이다.

그러면 다음으로 각 parent를 순회하며 루트 번호를 찾고, 각 루트에 합을 더한다. 또한 cnt도 증가시킨다. 그렇게 되면 [[집합의 합, 아이들의 수], …] 형태의 리스트를 완성시킬 수 있다.

다음으로는 이 리스트를 활용해 냅색 문제를 구현한다. 중복이 없어야 하므로, 각 집합을 먼저 순회하며, 해당 집합을 dp[i]가 선택할지 선택하지 않을 지를 결정한다. 기존의 방식은 집합 리스트 x k 의 이차원 dp였으나, 시간 제한을 통과하기 어려워 이와 같은 방법을 사용하게 되었다.

제출 코드

import sys

sys.setrecursionlimit(10**7)
n,m,k = map(int, sys.stdin.readline().rstrip().split())
c = list(map(int, sys.stdin.readline().rstrip().split()))
cLst = [[] for _ in range(n)]
isVisit = [False for _ in range(n)]
parent = [i for i in range(n)]


def find(x):
    if(x != parent[x]):
        parent[x] = find(parent[x])
    return parent[x]

def union(x, y):
    tX = find(x)
    tY = find(y)
    if(tX < tY):
        parent[tY] = tX
    else:
        parent[tX] = tY

for i in range(m):
    a,b = map(int, sys.stdin.readline().rstrip().split())
    union(a-1,b-1)


for i in range(n):
    find(i)
uniChild = [[0,0] for i in range(n)]

for i in range(n):
    uniChild[parent[i]][0] += c[i]
    uniChild[parent[i]][1] += 1

ns = []
for i in range(n):
    if(uniChild[i][1] != 0):
        ns.append([uniChild[i][0], uniChild[i][1]])

nsLen = len(ns)
dp = [0 for _ in range(k)] 
for i in range(nsLen):
    targetVal, targetWei = ns[i]
    for j in reversed(range(targetWei, k)):
        dp[j] = max(dp[j], dp[j - targetWei] + targetVal)

print(dp[k-1])