首页 > 技术文章 > 【回溯】leetcode回溯算法

zjkstudy 2020-03-03 00:52 原文

回溯算法介绍

回溯算法实际上一个类似枚举的搜索尝试过程,主要是在搜索尝试过程中寻找问题的解,当发现已不满足求解条件时,就“回溯”返回,尝试别的路径。回溯法是一种选优搜索法,按选优条件向前搜索,以达到目标。但当探索到某一步时,发现原先选择并不优或达不到目标,就退回一步重新选择,这种走不通就退回再走的技术为回溯法,而满足回溯条件的某个状态的点称为“回溯点”。回溯法可以理解成递归的一种特殊形式。

最经典的回溯问题是八皇后问题

回溯法体现的是走不通路就换条路走的思想,有点类似枚举搜索。我们枚举所有的解,找到满足期望的解。为了有规律地枚举所有可能的解,避免遗漏和重复,我们把问题求解的过程分为多个阶段。每个阶段,我们都会面对一个岔路口,我们先随意选一条路走,当发现这条路走不通的时候(不符合期望的解),就回退到上一个岔路口,另选一种走法继续走。

回溯问题常用的解决方法是递归

leetcode例题

78题 子集

题目描述:78题

给定一组不含重复元素的整数数组 nums,返回该数组所有可能的子集(幂集)。

说明:解集不能包含重复的子集。

求解思路与代码:

在做这道题时,我首先想到的是迭代的方法。例如给定的例子是nums=[1,2,3],则它的子集有8个。先从空集开始,先加入1,找出1与空集的组合;再加入2,找出2与1和空集的组合,这样迭代下去。

class Solution:
    def subsets(self, nums: List[int]) -> List[List[int]]:
        res = [[]]
        for i in nums:
            tmp = [[i] + num for num in res]
            res = res + tmp
            # print("i:",i)
            # print("tmp:",tmp)
            # print("res:",res)
        return res

这个解法如果把中间过程输出就比较清楚了:

# 当输入为nums=[1,2,3]时
i: 1
tmp: [[1]]
res: [[], [1]]

i: 2
tmp: [[2], [2, 1]]
res: [[], [1], [2], [2, 1]]

i: 3
tmp: [[3], [3, 1], [3, 2], [3, 2, 1]]
res: [[], [1], [2], [2, 1], [3], [3, 1], [3, 2], [3, 2, 1]]

回溯算法

从幂级的定义可以看出,幂级是长度从0到n所有子集的组合。以nums=[1,2,3]为例,分别包括长度为0,1,2,3的子集。

长度为0的子集有:[]
长度为1的子集有:[1];[2];[3]
长度为2的子集有:[1,2];[1,3];[2,3]
长度为3的子集有:[1,2,3]

以生长长度为2的子集为例:

  1. 先取nums[0] = 1为第一个元素;
  2. 将nums[1] = 2添加到当前子集并保存;
  3. 弹出nums[1] ,并且添加nums[2] ;保存
  4. 将nums[1] = 2 作为第一个元素,添加nums[2]

回溯官方题解:

class Solution:
    def subsets(self, nums: List[int]) -> List[List[int]]:
        def backtrack(first = 0, curr = []):
            if len(curr) == k:  
                output.append(curr[:])
            else:
	            for i in range(first, n):
	                curr.append(nums[i])
	                backtrack(i + 1, curr)
	                curr.pop()
        output = []
        n = len(nums)
        for k in range(n + 1):
            backtrack()
        return output

官方的题解很好懂,找准递归三要素,终止条件、返回值、本层递归做的事即可。

这里本来自己写了,但是看到大神的解法还是跪了。leetcode题解

class Solution:
    def subsets(self, nums: List[int]) -> List[List[int]]:
        res = []
        n = len(nums)
        
        def helper(i, tmp):
            res.append(tmp)
            for j in range(i, n):
                helper(j + 1,tmp + [nums[j]] )
        helper(0, [])
        return res 

推荐阅读