首页 > 技术文章 > 二分查找详解

beeblog72 2020-06-19 11:35 原文

看到一个大佬的博客详解二分查找算法,有一段内容让我深有感触:

我周围的人几乎都认为二分查找很简单,但事实真的如此吗?二分查找真的很简单吗?并不简单。看看 Knuth 大佬(发明 KMP 算法的那位)怎么说的:

Although the basic idea of binary search is comparatively straightforward, the details can be surprisingly tricky...

这句话可以这样理解:思路很简单,细节是魔鬼。

这两个月刷题遇到不少要用到二分查找的题。当年学数据结构的时候觉得这是一个相当直观且好理解的算法,但是真正刷题时觉得这个算法需要注意的坑还是挺多的。最普通的应用就是找某个元素的索引(数组有序且不重复),再复杂一些的还有找某个元素最左边或最右边的索引。更高端的有对数组的索引或者数组中整数的取值范围进行二分查找,不过这一块还是万变不离其宗,查找的范围依旧是[left, right],难点在于要怎么找到二分查找的对象。

二分查找基本框架

def binarySearch(arr: List[int], target: int):
    n = len(arr)
    left, right = 0, ...  # 左右边界的赋值可变
    while left ... right:  # 需要注意有没有等号
        mid = left + (right - left) // 2
        if arr[mid] == target:
            ...  # 要不要直接return
        elif arr[mid] < target:
            left = ...  # 要不要加一
        elif arr[mid] > target:
            right = ...  # 要不要减一
    return ...  # 有返回mid的,有left的各种

上面一段代码中的...部分是需要根据题目需要修改的地方,也就是二分查找的细节所在。另外,计算mid的公式也可以写成mid = (left + right) // 2,按上面那样写是为了防止溢出(虽然在Python里并不会有整型溢出的问题,不过最好养成这个习惯)。

找一个数的索引

这是二分查找最简单的一种应用,只要学习过数据结构肯定闭着眼睛都能写出来。

def binarySearch(arr: List[int], target: int):
    n = len(arr)
    left, right = 0, n - 1
    while left <= right:
        mid = left + (right - left) // 2
        if arr[mid] == target:
            return mid
        elif arr[mid] < target:
            left = mid + 1
        elif arr[mid] > target:
            right = mid - 1
    return -1  # 没有找到,返回 -1

这里有几个地方需要注意。

左右指针的赋值

左右指针的初始化决定了搜索区间是开区间还是闭区间

左右指针初始化为left, right = 0, n - 1,也就是说搜索区间是一个闭区间,即[0, n - 1]。而当mid处的值不是目标值时,就要把mid从搜索区间中去除,继续在两边的某个闭区间中搜索。因此,左右指针的更新规则为left = mid + 1right = mid - 1

终止条件

搜索区间为空时就应该跳出循环

上面的代码中,while循环的条件是left <= right,也就是说当left == right + 1时跳出while循环。实际上,这个终止条件与前面说的闭区间相对应,当left == right时,闭区间内仍有一个索引的位置需要搜索;当left == right + 1时,[right + 1, right]已经是一个空集,意味着已经没有索引需要搜索了,因此就跳出循环。

开区间

爷就是喜欢开区间,那咋办嘛

开区间情况下,左右指针应初始化为left, right = 0, n,对应的搜索区间是一个开区间,即[0, n)

  • arr[mid] < target时,同样地,要把mid从搜索区间中去除,注意此时是开区间。于是left = mid + 1,对应的搜索区间为[mid + 1, right);
  • arr[mid] > target时,right = mid, 对应的搜索区间为[left, mid)

此时,while循环的条件应为left < right,也就是说当left == right时跳出while循环。对应的搜索区间为[left, left),显然这个区间为空,即搜索完毕。完整代码如下。

def binarySearch(arr: List[int], target: int):
    n = len(arr)
    left, right = 0, n
    while left < right:
        mid = left + (right - left) // 2
        if arr[mid] == target:
            return mid
        elif arr[mid] < target:
            left = mid + 1
        elif arr[mid] > target:
            right = mid
    return -1  # 没有找到,返回 -1

局限性

这个版本的二分查找至少还有两个需求无法满足:

  • 如果target在数组中多次出现,我们想要找到它的左边界(即最早出现时的索引)或者右边界要怎么改呢?
  • 如果target在数组中不出现,我们想要找到它插入到该数组中应该在的位置要怎么改呢?

寻找左侧边界

基于开区间版本,修改几个位置即可。

def binarySearch(arr: List[int], target: int):
    n = len(arr)
    if n == 0:  # 特判
    	return -1
    left, right = 0, n
    while left < right:
        mid = left + (right - left) // 2
        if arr[mid] == target:
            right = mid  # 修改
        elif arr[mid] < target:
            left = mid + 1
        elif arr[mid] > target:
            right = mid
    return left  # 修改

代码还可以进一步简化,但是本文主要是为了弄清原理,这么写看起来更直观一些。基于这三处修改,我们来看看为什么这个版本可以返回左侧边界。

为什么需要特判

实际上也不是所有情况都需要特判,因为很多时候我们不会在空数组中搜索,那不是吃饱了撑的吗。主要是为了和返回的索引区分开来,避免返回值与空数组的情况发生混淆,具体往下看就明白了。

返回值的意义

这个版本的代码除了当target在数组中多次出现时返回它的左侧边界,它的返回值还有一个含义就是当前数组中比target小的元素个数。因此返回值的范围是[0, n],如果没有特判,数组为空时也返回0,会发生混淆。

为什么可以找到左边界

关键在于这段代码:

if arr[mid] == target:
    right = mid

当找到target时,不返回索引,而是继续向左搜索,即在[left, mid)中搜索。

那么问题来了,如果此时mid已经是左边界了,继续在[left, mid)中搜索不会出错吗?事实上,每次左指针的更新规则为left = mid + 1且while循环的条件是left < right。也就是说最终仍然会搜索到[left, mid),此时left == mid

如果target不在数组中怎么返回-1

上面说过其返回值的含义是当前数组中比target小的元素个数,其范围是[0, n]。如果返回值为n,那显然target比数组中的所有元素都大,target不在数组中。而返回值在[0, n -1]中时,target同样也可能不在数组中,需要判断arr[left] == target

return left if left < n and nums[left] == target else -1

寻找右侧边界

基于上面的代码,修改一下就行。

def binarySearch(arr: List[int], target: int):
    n = len(arr)
    if n == 0:  # 特判
    	return -1
    left, right = 0, n
    while left < right:
        mid = left + (right - left) // 2
        if arr[mid] == target:
            left = mid + 1  # 修改
        elif arr[mid] < target:
            left = mid + 1
        elif arr[mid] > target:
            right = mid
    return left - 1  # 修改

为什么要返回left - 1

注意到,当找到target时,不返回索引,而是继续向右搜索,left = mid + 1。因此,while循环结束时的左指针一定指向target右边的第一个元素。因此要返回left - 1

返回值的意义

返回值还有个意义就是返回小等于target的最大元素的索引,其取值范围为[-1, n - 1]。要注意的是,当target比数组中的所有元素都小时,返回的是-1

如果target不在数组中怎么返回-1

如果返回值为-1,那显然target比数组中的所有元素都小,target不在数组中。而返回值在[0, n - 1]中时,target同样也可能不在数组中,需要再判断arr[left - 1] == target

return left - 1 if left > 0 and nums[left - 1] == target else -1

寻找插入位置

上面提到,寻找左侧边界版本的返回值同样代表了数组中比target小的元素个数,索引不就来了吗?

def binarySearch(arr: List[int], target: int):
    n = len(arr)
    left, right = 0, n
    while left < right:
        mid = left + (right - left) // 2
        if arr[mid] == target:
            right = mid  # 修改
        elif arr[mid] < target:
            left = mid + 1
        elif arr[mid] > target:
            right = mid
    return left  # 修改

需要注意的是,这里不需要特判,数组为空时直接往里放就完事了。

例子

搜索左右边界

34. 在排序数组中查找元素的第一个和最后一个位置
题目要求:
给定一个按照升序排列的整数数组nums,和一个目标值target。找出给定目标值在数组中的开始位置和结束位置。

你的算法时间复杂度必须是\(O(\log n)\)级别。如果数组中不存在目标值,返回[-1, -1]

class Solution:
    def lower(self, nums: List[int], target: int):
        n = len(nums)
        left, right = 0, n
        while left < right:
            mid = left + (right - left) // 2
            if nums[mid] >= target:
                right = mid
            elif nums[mid] < target:
                left = mid + 1
        return left if left < n and nums[left] == target else -1
    
    def higher(self, nums: List[int], target: int):
        n = len(nums)
        left, right = 0, n
        while left < right:
            mid = left + (right - left) // 2
            if nums[mid] > target:
                right = mid
            elif nums[mid] <= target:
                left = mid + 1
        return left - 1

    def searchRange(self, nums: List[int], target: int) -> List[int]:
        if not nums:
            return [-1, -1]
        
        lo = self.lower(nums, target)
        # 二分查找的次数少一半,大概能快点?
        if lo == -1:
            return [-1, -1]
        hi = self.higher(nums, target)
        return [lo, hi]

1300. 转变数组后最接近目标值的数组和
题目要求:
给你一个整数数组 arr 和一个目标值 target ,请你返回一个整数 value ,使得将数组中所有大于 value 的值变成 value 后,数组的和最接近 target (最接近表示两者之差的绝对值最小)。

如果有多种使得和最接近 target 的方案,请你返回这些整数中的最小值。

请注意,答案不一定是 arr 中的数字。

class Solution:
    def biSrch(self, arr, tgt, length):
        left, right = 0, length
        while left < right:
            mid = (left + right) // 2
            val = arr[mid]
            if val < tgt:
                left = mid + 1
            elif val >= tgt:
                right = mid
        return left

    def findBestValue(self, arr: List[int], target: int) -> int:
        n = len(arr)
        arr.sort()
        preSum = [0] * (n + 1)
        for i in range(n):
            preSum[i + 1] = preSum[i] + arr[i]

        ans = 0
        diff = target
        for i in range(1, arr[-1] + 1):
            index = self.biSrch(arr, i, n)
            curSum = preSum[index] + (n - index) * i
            if abs(curSum - target) < diff:
                ans, diff = i, abs(curSum - target)

        return ans

搜索插入位置

35. 搜索插入位置
题目要求:
给定一个排序数组和一个目标值,在数组中找到目标值,并返回其索引。如果目标值不存在于数组中,返回它将会被按顺序插入的位置。

你可以假设数组中无重复元素。

根据上面的模板,很容易可以写出下面的代码。

class Solution:
    def searchInsert(self, nums: List[int], target: int) -> int:
        n = len(nums)
        left, right = 0, n
        while left < right:
            mid = left + (right - left) // 2
            if nums[mid] == target:
                return mid
            elif nums[mid] > target:
                right = mid
            elif nums[mid] < target:
                left = mid + 1
            
        return left

二分查找的其他应用

287. 寻找重复数
题目要求:
给定一个包含 n + 1 个整数的数组 nums,其数字都在 1 到 n 之间(包括 1 和 n),可知至少存在一个重复的整数。假设只有一个重复的整数,找出这个重复的数。

思路:
对数组中整数的范围进行二分查找,如果mid之前的小于mid的元素数量大于mid,则重复元素一定在mid之前。

class Solution:
    def findDuplicate(self, nums: List[int]) -> int:
        length = len(nums)
        # 每个数都在[left, right]中
        # 即[1, n]中
        left = 1
        right = length - 1

        while left < right:
            mid = left + (right - left) // 2
            cnt = 0
            for num in nums:
                if num <= mid:
                    cnt += 1

            if cnt > mid:
                right = mid
            else:
                left = mid + 1
        return left

1095. 山脉数组中查找目标值
题目要求:
给你一个 山脉数组mountainArr,请你返回能够使得mountainArr.get(index)等于target最小的下标index值。如果不存在这样的下标index,就请返回-1

何为山脉数组?如果数组A是一个山脉数组的话,那它满足如下条件:

首先,A.length >= 3

其次,在0 < i < A.length - 1条件下,存在i使得:

  • A[0] < A[1] < ... A[i-1] < A[i]
  • A[i] > A[i+1] > ... > A[A.length - 1]

你将不能直接访问该山脉数组,必须通过MountainArray接口来获取数据:

  • MountainArray.get(k) - 会返回数组中索引为k的元素(下标从 0 开始)
  • MountainArray.length() - 会返回该数组的长度

注意:
MountainArray.get发起超过100次调用的提交将被视为错误答案。

提示:

  • 3 <= mountain_arr.length() <= 10000
  • 0 <= target <= 10^9
  • 0 <= mountain_arr.get(index) <= 10^9

思路:
题中要求不能超过100次调用,数组长度不超过10000,疯狂明示要用二分查找。而山脉数组以山顶为界限,两边都是有序的,因此要用二分查找关键在于怎么找到山顶。首先将数组二分:

  • 如果mid处的元素比它右边的元素小,那它一定不是山顶,left = mid + 1
  • 否则mid有可能是山顶,right = mid

循环结束后right就是山顶。找到山顶后就可以在两边有序的数组中用二分查找搜索目标元素。

# """
# This is MountainArray's API interface.
# You should not implement it, or speculate about its implementation
# """
#class MountainArray:
#    def get(self, index: int) -> int:
#    def length(self) -> int:

class Solution:
    def findInMountainArray(self, target: int, mountain_arr: 'MountainArray') -> int:
        n = mountain_arr.length()
        left, right = 1, n - 2
        # 循环结束后left就是山顶
        while left <= right:
            top = left + (right - left) // 2
            if mountain_arr.get(top) > mountain_arr.get(top + 1):
                right = top - 1
            else:
                left = top + 1
        top = left
        left, right = 0, top
        while left <= right:
            mid = left + (right - left) // 2
            val = mountain_arr.get(mid)
            if val == target:
                return mid
            elif val > target:
                right = mid - 1
            else:
                left = mid + 1

        left, right = top + 1, n - 1
        while left <= right:
            mid = left + (right - left) // 2
            val = mountain_arr.get(mid)
            if val == target:
                return mid
            elif val < target:
                right = mid - 1
            else:
                left = mid + 1

        return -1

这段代码中的三处二分查找都是基于闭区间的版本。一开始将左右指针初始化为left, right = 1, n - 2是因为山顶只有可能在索引[1, n - 2]中。

推荐阅读