Binary Tree Maximum Path Sum

Problem

Given the root of a non-empty binary tree, return the maximum path sum of any non-empty path.

path in a binary tree is a sequence of nodes where each pair of adjacent nodes has an edge connecting them. A node can not appear in the sequence more than once. The path does not necessarily need to include the root.

The path sum of a path is the sum of the node’s values in the path.

Examples

Example 1:

Input: root = [1,2,3]

Output: 6

Explanation: The path is 2 -> 1 -> 3 with a sum of 2 + 1 + 3 = 6.

Example 2:

Input: root = [-15,10,20,null,null,15,5,-5]

Output: 40

Explanation: The path is 15 -> 20 -> 5 with a sum of 15 + 20 + 5 = 40.

Constraints

  • 1 <= The number of nodes in the tree <= 1000.
  • -1000 <= Node.val <= 1000

You should aim for a solution with O(n) time and O(n) space, where n is the number of nodes in the given tree.

Solution

The idea is that at any given node in the tree, we have a potential path that goes through the node by considering the path that goes through the left/right subtree. If the subtree sums are negative, they can only hurt, so we clip the value to 0 as a way of masking it out of consideration. This local root is responsible for checking if its part of the max sum, so we check if this is a new max. We then recurse out by sending the best path from this node (left or right) + the current node to the parent.

/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode() : val(0), left(nullptr), right(nullptr) {}
 *     TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
 *     TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
 * };
 */
 
class Solution {
public:
    int maxPathSum(TreeNode* root) {
 
        int max_path_sum = std::numeric_limits<int>::lowest();
        std::function<int(TreeNode *)> runner;
        runner = [&](TreeNode *node) -> int {
            // Base case, doesn't contribute to sum
            if (!node) { return 0; }
            // Find left and right max path sums
            // If the sum is negative, it can only hurt, so we set to 0
            //   as a way to pretend that we are ignoring it
            auto left_sum = std::max(runner(node->left), 0);
            auto right_sum = std::max(runner(node->right), 0);
            // Find best path that is rooted at current node
            auto curr_sum = left_sum + right_sum + node->val;
            max_path_sum = std::max(max_path_sum, curr_sum);
            // value passed up is the best left/right path that goes through current node,
            //   because caller will then consider extending path down opposite side
            return node->val + std::max(left_sum, right_sum);
        };
        runner(root);
        return max_path_sum;
    }
};