复习完回溯法后,继续复习下递归。
明确函数定义(最关键):先想清楚参数代表什么、返回值代表什么含义。换句话说,先确定「这个函数对外承诺做什么」,后面所有逻辑都围绕这个承诺展开。
明确终止条件:什么时候不再往下递归,直接返回结果。
明确递推关系:假设子问题已经被正确解决,如何利用它推出当前问题的解。
一般而言,可以这样思考问题:
定义子问题:把规模从n变成更小的n-1、n/2等。
定义递归边界(base case):递归什么时候停?停的时候返回什么?
返回值含义:递归函数承诺返回什么?(不要陷入无尽递归,要相信函数能做到)
复杂度:是否需要记忆化或者剪枝来进行优化?
一定要相信递归,详细他他可以做到,只关心当前这一层,不要试图在大脑中展开递归的每一级调用。假设递归函数对更小的输入已经能正确工作,然后专注思考「拿到子问题的解之后,我这一层该怎么处理」
线性递归:单分支,规模每次减去k
典型题目:
阶乘、斐波那契
反转链表、遍历链表
数组、字符串递归处理
基本都长这个样子:
def f(x):
if base_case(x):
return base_value
return combine(x, f(smaller(x)))例1:反转链表
class ListNode:
def __init__(self, val=0, next=None):
self.val = val
self.next = next
def reverse_list(head: ListNode) -> ListNode:
# 终止条件:空 或 只剩一个节点
if head is None or head.next is None:
return head
# 相信递归:先把 head.next 开始的子链表反转好
new_head = reverse_list(head.next)
# 此时 head.next 是反转后子链表的「尾巴」,让它指回 head
head.next.next = head
head.next = None # 断开原指向,否则成环
return new_head # 新头结点一路向上返回二分递归(分治/两分支),拆成左右两半
把大问题拆成同类的小问题、解决小问题、把小问题的解合并成大问题的解
典型题目:
归并排序、快速排序
二叉树
数组区间问题
通用模版:
def solve(l, r):
if l == r:
return base
mid = (l + r) // 2
left = solve(l, mid)
right = solve(mid+1, r)
return merge(left, right)例题2:归并排序:
def merge_sort(nums):
if len(nums) <= 1:
return nums[:] # 0 或 1 个元素,天然有序
mid = len(nums) // 2
left = merge_sort(nums[:mid]) #拆出左半,并相信他会排好
right = merge_sort(nums[mid:]) #拆出右半,并相信他会排好
return merge(left, right) #combine 合并两个有序数组
def merge(left, right):
out = []
i = j = 0
while i < len(left) and j < len(right):
if left[i] <= right[j]: # <= 保证稳定性
out.append(left[i])
i += 1
else:
out.append(right[j])
j += 1
out.extend(left[i:])
out.extend(right[j:])
return out树形递归
对每个字树做同样的事,搞明白递归函数要返回什么。
常见的有:
返回高度/最大值/是否满足条件
返回节点指针
典型题目:
二叉树:最大深度、是否平衡、路径和、LCA、序列化/反序列化
N叉树同理
例3:二叉树的最大深度
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def maxDepth(root: TreeNode) -> int:
if root is None:
return 0
return 1 + max(maxDepth(root.left), maxDepth(root.right))回溯
回溯也是应用了递归:递归+试错+撤销选择
典型题目:
全排列/组合/子集
括号生成
N 皇后、数独
路径搜索(DFS)
模板:
res = []
path = []
def dfs(state):
if done(state):
res.append(path.copy())
return
for choice in choices(state):
path.append(choice) # 做选择
dfs(next_state(state, choice))
path.pop() # 撤销选择例题4:全排列
def permute(nums):
res = []
used = [False] * len(nums)
path = []
def dfs():
if len(path) == len(nums):
res.append(path.copy())
return
for i in range(len(nums)):
if used[i]:
continue
used[i] = True
path.append(nums[i])
dfs()
path.pop()
used[i] = False
dfs()
return res记忆化递归
中间有大量重叠的计算,可以用缓存暂存。
典型题目:
斐波那契
爬楼梯、打家劫舍
背包、划分、区间 DP
例题5:爬楼梯
from functools import lru_cache
def climbStairs(n: int) -> int:
@lru_cache(None)
def f(i):
if i == 0:
return 1
if i < 0:
return 0
return f(i - 1) + f(i - 2)
return f(n)区间递归
区间递归(Interval Recursion)本质上就是区间 DP 的记忆化搜索形式。
归并排序你已经知道从中间切最好;而区间递归里你不知道哪个 k 最优,所以得全试一遍。这个「在区间内枚举一个点」的循环,就是区间递归的标志。
两种分解方式
① 分割型(枚举切分点 k): 把 [l,r] 在 k 处切成 [l,k] 和 [k+1,r],两段独立求解再合并。 思维:「最后一次操作把哪两段合起来?」 代表:石子合并、矩阵链乘。
② 端点/关键点型:
看两端: 比较
s[l]和s[r],收缩到[l+1,r-1]/[l+1,r]/[l,r-1]。代表:最长回文子序列。枚举最后处理的点 k: 在
[l,r]内枚举「最后一个被处理的元素」。代表:戳气球。 思维:「哪个元素最后处理?它处理时左右边界是固定的。」
模板
from functools import lru_cache
@lru_cache(maxsize=None)
def solve(l, r):
# 1. 终止条件(区间小到可直接回答)
if l >= r: # 具体是 l>r / l==r / l+1==r,视题目而定
return 基本值
# 2. 枚举区间内的分割点 / 关键点 k,合并子区间答案
best = 初始值
for k in range(l, r): # k 的范围视分解方式而定
best = 更优(best, 合并(solve(l, k), solve(k+1, r), 当前代价))
return best例:端点收缩型,最长回文子序列
solve(l, r) = s[l..r] 的最长回文子序列长度。两端相等就同时收缩并 +2,否则丢一端取较优。
def longestPalindromeSubseq(s: str) -> int:
from functools import lru_cache
@lru_cache(maxsize=None)
def solve(l, r):
if l > r:
return 0
if l == r: # 单字符本身是长度 1 的回文
return 1
if s[l] == s[r]: # 两端相等,同时收缩
return 2 + solve(l + 1, r - 1)
# 两端不等,丢掉一端,取较优
return max(solve(l + 1, r), solve(l, r - 1))
return solve(0, len(s) - 1)例:分割型,石子合并
n 堆石子排成一排,每次只能合并相邻两堆,代价 = 两堆之和。合并到只剩一堆,求最小总代价。 solve(l, r) = 合并 [l,r] 所有石子的最小代价。枚举切分点 k,左右各自合好再合并;最后一次合并的代价恒等于整个区间之和(用前缀和 O(1) 求)。
def stone_merge(stones: list[int]) -> int:
n = len(stones)
prefix = [0] * (n + 1)
for i in range(n):
prefix[i + 1] = prefix[i] + stones[i] # 前缀和
from functools import lru_cache
@lru_cache(maxsize=None)
def solve(l, r):
if l == r: # 单堆无需合并
return 0
best = float('inf')
for k in range(l, r): # 枚举切分点
best = min(best, solve(l, k) + solve(k + 1, r))
return best + (prefix[r + 1] - prefix[l]) # 加上本层合并代价
return solve(0, n - 1)