[알고리즘] 할로윈의 양아치
난이도
Gold 3
문제
Trick or Treat!!
10월 31일 할로윈의 밤에는 거리의 여기저기서 아이들이 친구들과 모여 사탕을 받기 위해 돌아다닌다. 올해 할로윈에도 어김없이 많은 아이가 할로윈을 즐겼지만 단 한 사람, 일찍부터 잠에 빠진 스브러스는 할로윈 밤을 즐길 수가 없었다. 뒤늦게 일어나 사탕을 얻기 위해 혼자 돌아다녀 보지만 이미 사탕은 바닥나 하나도 얻을 수 없었다.
단단히 화가 난 스브러스는 거리를 돌아다니며 다른 아이들의 사탕을 빼앗기로 마음을 먹는다. 다른 아이들보다 몸집이 큰 스브러스에게 사탕을 빼앗는 건 어렵지 않다. 또한, 스브러스는 매우 공평한 사람이기 때문에 한 아이의 사탕을 뺏으면 그 아이 친구들의 사탕도 모조리 뺏어버린다. (친구의 친구는 친구다?!)
사탕을 빼앗긴 아이들은 거리에 주저앉아 울고 $K$명 이상의 아이들이 울기 시작하면 울음소리가 공명하여 온 집의 어른들이 거리로 나온다. 스브러스가 어른들에게 들키지 않고 최대로 뺏을 수 있는 사탕의 양을 구하여라.
스브러스는 혼자 모든 집을 돌아다녔기 때문에 다른 아이들이 받은 사탕의 양을 모두 알고 있다. 또한, 모든 아이는 스브러스를 피해 갈 수 없다.
입력 첫째 줄에 정수 $N$, $M$, $K$가 주어진다. $N$은 거리에 있는 아이들의 수, $M$은 아이들의 친구 관계 수, $K$는 울음소리가 공명하기 위한 최소 아이의 수이다. ($1 \leq N \leq 30\ 000$, $0 \leq M \leq 100\ 000$, $1 \leq K \leq \min $( N, 3000 )
둘째 줄에는 아이들이 받은 사탕의 수를 나타내는 정수 $c_1, c_2, \cdots, c_N$이 주어진다. ($1 \leq c_i \leq 10\ 000$)
셋째 줄부터 $M$개 줄에 갈쳐 각각의 줄에 정수 $a$, $b$가 주어진다. 이는 $a$와 $b$가 친구임을 의미한다. 같은 친구 관계가 두 번 주어지는 경우는 없다. ( $1 \leq a, b \leq N$, $a \neq b$)
출력
스브러스가 어른들에게 들키지 않고 아이들로부터 뺏을 수 있는 최대 사탕의 수를 출력한다.
예제 입력
10 6 6
9 15 4 4 1 5 19 14 20 5
1 3
2 5
4 9
6 2
7 8
6 10
예제 출력
57
해설 및 후기
union-find와 dp(냅색 문제)를 복합적으로 이용했어야 하는 문제이다. 꽤나 어려웠다. 우선 각 간선을 받는 대로 union하여 집합을 구성했고, 그렇게 되면 같은 집합은 같은 parent를 갖게 될 것이다.
그러면 다음으로 각 parent를 순회하며 루트 번호를 찾고, 각 루트에 합을 더한다. 또한 cnt도 증가시킨다. 그렇게 되면 [[집합의 합, 아이들의 수], …] 형태의 리스트를 완성시킬 수 있다.
다음으로는 이 리스트를 활용해 냅색 문제를 구현한다. 중복이 없어야 하므로, 각 집합을 먼저 순회하며, 해당 집합을 dp[i]가 선택할지 선택하지 않을 지를 결정한다. 기존의 방식은 집합 리스트 x k 의 이차원 dp였으나, 시간 제한을 통과하기 어려워 이와 같은 방법을 사용하게 되었다.
제출 코드
import sys
sys.setrecursionlimit(10**7)
n,m,k = map(int, sys.stdin.readline().rstrip().split())
c = list(map(int, sys.stdin.readline().rstrip().split()))
cLst = [[] for _ in range(n)]
isVisit = [False for _ in range(n)]
parent = [i for i in range(n)]
def find(x):
if(x != parent[x]):
parent[x] = find(parent[x])
return parent[x]
def union(x, y):
tX = find(x)
tY = find(y)
if(tX < tY):
parent[tY] = tX
else:
parent[tX] = tY
for i in range(m):
a,b = map(int, sys.stdin.readline().rstrip().split())
union(a-1,b-1)
for i in range(n):
find(i)
uniChild = [[0,0] for i in range(n)]
for i in range(n):
uniChild[parent[i]][0] += c[i]
uniChild[parent[i]][1] += 1
ns = []
for i in range(n):
if(uniChild[i][1] != 0):
ns.append([uniChild[i][0], uniChild[i][1]])
nsLen = len(ns)
dp = [0 for _ in range(k)]
for i in range(nsLen):
targetVal, targetWei = ns[i]
for j in reversed(range(targetWei, k)):
dp[j] = max(dp[j], dp[j - targetWei] + targetVal)
print(dp[k-1])