본문 바로가기

study/알고리즘

[알고리즘/백준] 1967번 - 트리의 지름

728x90

카테고리 : Tree / DFS(그래프 탐색) 

1967번 트리의 지름 https://www.acmicpc.net/problem/1967

 

 

 


접근 방법

 

먼저 tree의 정보를 딕셔너리를 통해 저장해주었다. p는 부모 노드, c는 자식 노드, w는 가중치 정보이다. 

 

tree = defaultdict(list)

n = int(input())
for _ in range(n - 1):
    p, c, w = map(int, input().split())
    tree[p].append((c, w))

 

DFS를 통해 트리를 탐색한다. 루트 노드의 번호가 항상 1이라고 가정되어있기 때문에 항상 루트부터 시작할 수 있다.

tree는 아까 저장해두었던 딕셔너리이고 i는 노드 번호이다.

 

def DFS(tree, i):
    global result
    sum_list = []

    if not tree[i]:
        return 0

    for c, w in tree[i]:
        prev = DFS(tree, c)
        sum_list.append(prev + w)

    sum_list.sort()
    result = max(result, sum(sum_list[-2:]))
    return sum_list[-1]

 

루트 노드부터 시작해서 자식이 없는 leaf 노드까지 재귀를 통해 내려간다. 

아래의 조건에서 자식이 없는 노드는 0을 반환하고 끝나게 된다.

 

if not tree[i]:
        return 0

 

자식이 없는 노드 끝까지 왔다가 다시 루트 노드까지 올라갈 때 가중치를 더해주게 된다.

이전 노드까지 더해진 가중치의 합인 prev에 현재 내가 선택한 자식 노드의 가중치를 더해 sum_list에 저장한다.

 

for c, w in tree[i]:
    prev = DFS(tree, c)
    sum_list.append(prev + w)

 

그중 가중치가 가장 큰 값을 찾기 위해 정렬한 후 가장 큰 값을 반환한다.

 

처음에는 항상 루트 노드를 기준으로 트리의 지름이 나올 거라고 생각했었지만

예제에서도 나와있듯이 다른 노드에서 트리의 지름이 나올 수도 있다.

 

따라서 재귀를 돌면서 가장 큰 트리의 지름을 result에 저장해준다.

sum_list의 가장 큰 값(마지막 값)과 두 번째로 큰 값을 더한 것이 그 노드의 지름이 되기 때문에

그 값과 원래 result에 저장되어있는 값을 비교하여 result 값을 변경해 준다.

result는 함수 밖에서 선언된 전역 변수라 global로 전역 변수를 알려주었다.

 

sum_list.sort()
result = max(result, sum(sum_list[-2:]))

 

그리고 작성한 코드가 재귀로 진행되기 때문에 그냥 제출하면 런타임 에러가 발생한다.

파이썬의 재귀 최대 깊이의 기본 설정이 1000번이기 때문에 그렇다고 한다.

 

따라서 아래의 코드를 추가해 재귀 최대 깊이를 100000번으로 다시 설정해주었다.

 

sys.setrecursionlimit(100000)

 


최종 제출 코드

 

from collections import defaultdict
import sys
input = sys.stdin.readline
sys.setrecursionlimit(100000)


def DFS(tree, i):
    global result
    sum_list = []

    if not tree[i]:
        return 0

    for c, w in tree[i]:
        prev = DFS(tree, c)
        sum_list.append(prev + w)

    sum_list.sort()
    result = max(result, sum(sum_list[-2:]))
    return sum_list[-1]



tree = defaultdict(list)

n = int(input())
for _ in range(n - 1):
    p, c, w = map(int, input().split())
    tree[p].append((c, w))

result = 0
DFS(tree, 1)
print(result)