给定n个未排序的数,找到第k小的数。这里只讨论复杂度为O(n)的分治算法。

基于QuickSelect的分治算法

参考K’th Smallest/Largest Element in Unsorted Array | Set 1,直接看只知道大意,还是每行代码都分析一下吧。(感觉这里有点烦的是这个下标的意义,一开始搞懂然后这个看代码会舒服很多)。

代码如下所示:

# 这个算法的思想类似冒泡排序,把所有小于数组最后一个元素的数放在它左边,所有大于的放在右边,然后返回这个元素在数组中的位置
# 一眼看过去并不知道这个算法实现的思路

def partition(arr, l, r): 
  
    x = arr[r] 
    i = l 
    for j in range(l, r): 
        if (arr[j] <= x): 
            arr[i], arr[j] = arr[j], arr[i] 
            i += 1
    arr[i], arr[r] = arr[r], arr[i] 
    return i 

import sys 
  
def kthSmallest(arr, l, r, k): 
  
    # If k is smaller than number of  
    # elements in array 
    if (k > 0 and k <= r - l + 1): 
      
        # Partition the array around last  
        # element and get position of pivot 
        # element in sorted array 
        pos = partition(arr, l, r) 
  
        # If position is same as k 
        if (pos - l == k - 1): 
            return arr[pos] 
        if (pos - l > k - 1): # If position is more,  
                              # recur for left subarray 
            return kthSmallest(arr, l, pos - 1, k) 
  
        # Else recur for right subarray 
        return kthSmallest(arr, pos + 1, r, 
                            k - pos + l - 1) 
  
    # If k is more than number of 
    # elements in array 
    return sys.maxsize 

对partition函数的分析

看了挺久还是没看懂那个partition函数,就动手模拟了一下:
moni.jpg

总结一下动手模拟的经验:

  • 在循环体内设定好一个界限,和确定模拟所需要的变量
  • 每一步模拟都写好对应的一个编号

对比图片,不难发现,有这么几个规律:

  • 如果j一路扫描过去,指向的元素都比x大的话,i一直指向第一个元素,这样在最后一步的替换就可直接把最后一个元素放在最前面,算法正确。
  • 如果j一路扫描过去,指向的元素都比x小的话,那么i将一直等于j,这样最后一步的替换无效,也保证了算法的正确性
  • 如果j一路扫描过去,碰到某个比小于等于x的数,那么这个元素将向左移动,同时i会增加一位

根据第二个不难推测出,i的意义是记录小于等于最后一个元素的个数,这样最后一步的交换也就有了意义。

可是怎么保证小于等于x的数都放在最后一个元素的前面?

终于明白了!

你潜意识里把这个算法跟冒泡排序的思想等价,但并不是这样的!冒泡排序中j=i+1,然而这里j>=i+1。

所以这个算法的思想很简单啊!就是把小的数往i的前面扔,i则代表x的位置。最后返回的意义是原本数组最后一个元素在数组中的下标。

这个算法显然可以节省空间。。

太垃圾了,搞了这么久…

对kthSmallest的分析

这个比较好理解。

  • 注意是pos - l 不是pos - 1 即可
  • 最后k - pos + l - 1可以用方程 x - (pos+1) + pos - l = k - 1 来理解

算法复杂度分析

分析如下:

  • partition这个函数的复杂度为线性
  • kthSmallest这个函数看运气

结论是:

  • 期望复杂度为O(n)
  • 最差复杂度为O(n^2),即最后一个数每次都是最大的,如升序数组。

基于QuickSelect的分治算法

这个算法可以保证最差复杂度也是O(n)。

参考K’th Smallest/Largest Element in Unsorted Array | Set 3 (Worst Case Linear Time)

代码如下:

import sys
# Returns k'th smallest element in arr[l..r]  
# in worst case linear time.  
# ASSUMPTION: ALL ELEMENTS IN ARR[] ARE DISTINCT  
def kthSmallest(arr, l, r, k):  
      
    # If k is smaller than number of  
    # elements in array  
    if (k > 0 and k <= r - l + 1):  
          
        # Number of elements in arr[l..r]  
        n = r - l + 1
  
        # Divide arr[] in groups of size 5,  
        # calculate median of every group 
        # and store it in median[] array.  
        median = [] 

        
        i = 0
        while (i < n // 5): 
            median.append(findMedian(arr, l + i * 5, 5)) 
            i += 1
  
        # For last group with less than 5 elements  
        if (i * 5 < n): 
            median.append(findMedian(arr, l + i * 5,  
                                              n % 5)) 
            i += 1
  
        # Find median of all medians using recursive call.  
        # If median[] has only one element, then no need  
        # of recursive call 
        if i == 1: 
            medOfMed = median[i - 1] 
        else: 
            medOfMed = kthSmallest(median, 0,  
                                   i - 1, i // 2) 
  
        # Partition the array around a medOfMed 
        # element and get position of pivot  
        # element in sorted array  
        pos = partition(arr, l, r, medOfMed) 
  
        # If position is same as k  
        if (pos - l == k - 1):  
            return arr[pos]  
        if (pos - l > k - 1): # If position is more,  
                              # recur for left subarray  
            return kthSmallest(arr, l, pos - 1, k)  
  
        # Else recur for right subarray  
        return kthSmallest(arr, pos + 1, r,  
                           k - pos + l - 1)  
  
    # If k is more than the number of  
    # elements in the array  
    return sys.maxsize

def swap(arr, a, b):  
    temp = arr[a]  
    arr[a] = arr[b]  
    arr[b] = temp  
  
# It searches for x in arr[l..r],   
# and partitions the array around x.  
def partition(arr, l, r, x): 
    for i in range(l, r): 
        if arr[i] == x: 
            swap(arr, r, i) 
            break
  
    x = arr[r]  
    i = l  
    for j in range(l, r):  
        if (arr[j] <= x):  
            swap(arr, i, j)  
            i += 1
    swap(arr, i, r)  
    return i  
  
# A simple function to find  
# median of arr[] from index l to l+n 
def findMedian(arr, l, n): 
    lis = [] 
    for i in range(l, l + n): 
        lis.append(arr[i]) 
          
    # Sort the array  
    lis.sort() 
  
    # Return the middle element 
    return lis[n // 2] 
  
# Driver Code  
if __name__ == '__main__':  
  
    arr = [12, 3, 5, 7, 4, 19, 26]  
    n = len(arr)  
    k = 3
    print("K'th smallest element is",  
           kthSmallest(arr, 0, n - 1, k)) 

代码中的子函数都挺好理解的,直接分析kthSmallest。

kthSmallest 分析

  1. 按照5来划分,直接得到每组的中位数。
i = 0
while (i < n // 5): 
    median.append(findMedian(arr, l + i * 5, 5)) 
    i += 1
if (i * 5 < n): 
            median.append(findMedian(arr, l + i * 5,  
                                              n % 5)) 
            i += 1
  1. 找到中位数的中位数
if i == 1: 
            medOfMed = median[i - 1] 
        else: 
            medOfMed = kthSmallest(median, 0,  
                                   i - 1, i // 2) 
  1. 利用中位数作为pivot执行第一个算法。
pos = partition(arr, l, r, medOfMed) 
  
# If position is same as k  
if (pos - l == k - 1):  
    return arr[pos]  
if (pos - l > k - 1): # If position is more,  
                        # recur for left subarray  
    return kthSmallest(arr, l, pos - 1, k)  

# Else recur for right subarray  
return kthSmallest(arr, pos + 1, r,  
                    k - pos + l - 1)  

改进算法复杂度分析

该算法最差情况的复杂度也是O(n)。证明如下:

  1. 假设该算法复杂度为T(n)
  2. 第一步复杂度为O(n)
  3. 第二步复杂度为T(n/5)
  4. 第三步中partition复杂度为O(n),而选中的pivot可以保证比它大或比它小的数都有3n/10 – 6个,故下一次递归规模至多为 7n/10 + 6
  5. 综上可推得该算法最差情况复杂度为O(n)

结语

第二种算法也可以将5换作为其他数字

第二种算法中分组数只能是5或比5更大的数字,否则不能保证最差复杂度为线性,且容易证明分组数越大,算法性能越高。


我很好奇