Logo

Binary Tree Maximum Path Sum

LeetCode의 Binary Tree Maximum Path Sum 문제를 함께 풀어보도록 하겠습니다.

문제

이진 트리에서 경로란 연결이 되어 있는 다수의 노드로 구성되며 경로 상에는 꼭 최상위 노드가 있어야할 필요는 없다. 이진 트리가 주어졌을 때, 노드들의 값의 합이 최대가 되는 경로를 구하라.

예제

  • 입력
 -10
 / \
9  20
  /  \
 15   7

  • 결과

42


  • 설명
  20
 /  \
15   7

노드들의 값의 합이 최대가 되는 경로는 위와 같으며, 노드들의 값의 합은 20 + 15 + 7 = 42가 된다.

풀이

문제에서 주어진 트리에서 어떻게 노드의 합이 최대인 경로를 구할 수 있을지 생각해보겠습니다.

 -10
 / \
9  20
  /  \
 15   7

20은 총 3개의 노드와 연결이 되어 있는데요. 이 셋 중에 하나랑은 연결을 끊어야 갈라지지 않는 유효한 경로를 구성할 수 있습니다.

 -10
 / \
9  20
     \
      7

예를 들어, 15를 버리면 위와 같은 경로가 구성되고,

 -10
 / \
9  20
  /  
 15   

7을 버리면 위와 같은 경로가 구성되며

   20
  /  \
 15   7

-10을 버리면 위와 같은 경로가 구성됩니다.

이를 통해 어떤 트리에서 노드의 합이 최대인 경로를 구하려면 크게 2가지 선택 사항이 주어진다는 것을 알 수 있습니다.

첫 번째는, 좌측이나 우측 자식 트리와 연결을 끊는 것인데요. 둘 중에 최대 합이 적은 쪽을 포기하는 것이 유리할 것입니다.

두 번째는, 부모 노드와 연결을 끊는 것인데요. 부모 노드를 포함한 경로의 최대 합이 좌측이나 우측 자식 트리의 최대 합보다 작다면 부모 노드를 포기하는 것이 유리할 것입니다.

첫 번째 과정은 일반적인 트리 문제를 풀 때 처럼 재귀적으로 접근할 수 있습니다. 하위 트리를 상대로 재귀 함수를 호출한 결과를 상위 트리에서 비교해서 단순히 큰 쪽을 선택하면 됩니다.

   P
 /  \
L    R

즉, 위와 같은 트리가 있다면 부모 노드 P을 상대로 재귀 함수를 호출한 결과는 자식 노드 LR을 상대로 재귀 함수를 호출한 결과로 부터 도출할 수 있습니다.

F(P) = P.VAL + MAX(F(L) + F(R))

두 번째 과정은 첫 번째 과정 대비 생각하기가 좀 더 어렵게 느껴질 수 있는데요. 그 것은 노드의 합이 최대인 경로가 반드시 최상위 노드를 포함하라는 법이 없기 때문입니다. 즉, 노드의 합이 최대인 경로는 주어진 트리 내에서 어디든지 일어날 수 있습니다.

문제에서 주어진 예제가 이를 보여주는 좋은 사례인데요.

 -10
 / \
9  20
  /  
 15   

우측 노드와 연결을 끊으면, 경로 상의 노드의 합이 9 + -10 + 20 + 15 = 34이 되지만,

   20
  /  \
 15   7

부모 노드와 연결을 끊으면, 경로 상의 노드의 합이 15 + 20 + 7 = 42가 됩니다.

만약에 -10의 좌측 자식이 9 대신에 90이었으면 어땠을까요?

 -10
 / \
90 20
  /  
 15   

그러면 당연히 우측 노드와 연결을 끊는 편이 더 나았을 것입니다. (90 + -10 + 20 + 15 = 115)

따라서 주어진 트리 내에서 얻을 수 있는 가장 큰 노드의 합은 전체 트리를 상대로 연쇄적인 재귀 호출이 종료될 때 까지 단정을 지을 수가 없습니다.

이렇게 여러 번의 함수 내에서 갱신되야하는 값은 함수 외부에서 선언한 전역 변수에 저장하면 쉽게 처리가 가능합니다.

그럼 지금까지 설명한 2개의 과정을 처리해주는 코드를 작성해도록 할께요.

class Solution:
    def maxPathSum(self, root: Optional[TreeNode]) -> int:
        max_sum = root.val

        def dfs(node):
            if not node:
                return 0

            left_max = max(dfs(node.left), 0)
            right_max = max(dfs(node.right), 0)

            nonlocal max_sum
            max_sum = max(node.val + left_max + right_max, max_sum)

            return node.val + max(left_max, right_max)

        dfs(root)
        return max_sum

문제에서 주어진 트리를 상대로 재귀 함수가 어떻게 호출되는지 시각화하면 다음과 같습니다.

 -10
 / \
9  20
  /  \
 15   7
dfs(-10) => -10 + max(9, 35) = 25 (output["max"] = -10 + 9 + 35 = 34)
    dfs(9) => 9 + max(0, 0) = 9 (output["max"] = 9 + 0 + 0 = 9)
        dfs(null) => 0
        dfs(null) => 0
    dfs(20) => 20 + max(15, 7) = 35 (output["max"] = 20 + 15 + 7 = 42) 👉 최대 합
        dfs(15) => 15 + max(0, 0) = 15 (output["max"] = 15 + 0 + 0)
            dfs(null) => 0
            dfs(null) => 0
        dfs(7) => 7 + max(0, 0) = 7 (output["max"] = 7 + 0 + 0)
            dfs(null) => 0
            dfs(null) => 0

트리의 노드 개수를 n, 트리의 높이를 h라고 했을 때, 이 재귀 알고리즘의 시간 복잡도는 O(n), 공간 복잡도는 O(h) 입니다.

마치면서

코드 자체는 별로 복잡하지 않지만 알고리즘을 한 번에 생각해내기는 결코 쉽지 않은 트리 문제였습니다.