本文首发于“雨夜随笔”公众号,欢迎关注。

最近在看Golang官方库中的排序算法,不得不说官方有很多优化的点非常有意思,也很值得思考和学习,那么话不多少,让我们直接开始学习Golang的排序算法:

在Golang中引入排序是在 sort 包下面,使用方法是:

import "sort"

func FunctionWithSort() {
    sort.Sort(data)
}

那么我们进入sort包看一下 Sort() 的逻辑:

// Sort sorts data.
// It makes one call to data.Len to determine n, and O(n*log(n)) calls to
// data.Less and data.Swap. The sort is not guaranteed to be stable.
func Sort(data Interface) {
  n := data.Len()
  quickSort(data, 0, n, maxDepth(n))
}

Golang中注释也说了目前使用的排序算法不是稳定的,也就是说对于相同的元素,并不能保证排序后的顺序和排序前一样。

我们同时也看到Sort()调用了一个名为“快速排序”的函数。那么我们在看这个函数之前,先关心一下他参数中有一个maxDepth()的函数,这个是有什么用呢?我们先来看一下:

// maxDepth returns a threshold at which quicksort should switch
// to heapsort. It returns 2*ceil(lg(n+1)).
func maxDepth(n int) int {
  var depth int
  for i := n; i > 0; i >>= 1 {
    depth++
  }
  return depth * 2
}

从注释中,我们看到这个是用来判断是否可以采用堆排序。这个depth的值是用来决定是否应该采用堆排序。而如何不乘以2,即使构成完全二叉树的深度,但是乘以2,这个是什么意思呢?这个意思就是构成最深的完全二叉树的深度,不得不说这段很秒。

堆排序

那么我们就来看看堆排序,了解这个的可以跳过这一段。我喜欢先抓最核心的来讲,在我看来,堆排序的关键就是建立一个大顶堆,然后将根节点交换到堆数据的最后,然后用剩下的元素继续建立大顶堆,然后不断重复上述步骤,直至最后将最后一个元素换到数据的开始,排序也就完成了。

func heapSort(data Interface, a, b int) {
  first := a
  lo := 0
  hi := b - a

  // Build heap with greatest element at top.
  for i := (hi - 1) / 2; i >= 0; i-- {
    siftDown(data, i, hi, first)
  }

  // Pop elements, largest first, into end of data.
  for i := hi - 1; i >= 0; i-- {
    data.Swap(first, first+i)
    siftDown(data, lo, i, first)
  }
}

而代码中调用的shiftDown()就是建立大顶堆的部分。结果上面代码我们可以看出大顶堆是从后往前建立了,也就是从树的底部往上开始逐步建立大顶堆。

// siftDown implements the heap property on data[lo, hi).
// first is an offset into the array where the root of the heap lies.
func siftDown(data Interface, lo, hi, first int) {
  root := lo
  for {
    child := 2*root + 1
    if child >= hi {
      break
    }
    if child+1 < hi && data.Less(first+child, first+child+1) { // 保证交换的是最大的子节点
      child++
    }
    if !data.Less(first+root, first+child) { // 父节点已经是最大了,不需要交换
      return
    }
    data.Swap(first+root, first+child) // 交换父子节点
    root = child
  }
}

上述就是Golang中进行堆排序的过程,Golang为什么要用堆排序呢,因为堆排序的时间复杂度非常稳定,平均情况就是O(nlgn),但是因为堆排序需要保存全部的数据,对于空间需求更大,所以只在数据量不大的情况下使用即可。那我们再来继续看一下Golang中的“快速排序”。

快速排序

func quickSort(data Interface, a, b, maxDepth int) {
  for b-a > 12 { // Use ShellSort for slices <= 12 elements
    ...
  }
  if b-a > 1 {
    // Do ShellSort pass with gap 6
    // It could be written in this simplified form cause b-a <= 12
    for i := a + 6; i < b; i++ {
      if data.Less(i, i-6) {
        data.Swap(i, i-6)
      }
    }
    insertionSort(data, a, b)
  }
}

可以看到“快速排序”分为两个部分,一个是长度大于12的,一个是长度小于等于12的,我们先来看一下小于等于12的这一部分,从注释中我们看出这一部分使用了希尔排序,那我们来看一下这个希尔排序。

希尔排序

希尔排序时一种改进后的插入排序,因为插入排序对已经排好序的数据操作时更为有效,所以希尔排序先通过一定的间隔将元素划分成几个区域来先进行排序,然后逐步缩小间隔进行排序,最后采用插入排序,因为已经基本都排好了,所以插入排序的效率就很高。

而Golang中希尔排序中使用的Gap值是6,也就是间隔6位的为一组,先进行排序,然后不同于一般的希尔排序将Gap值减半,而是直接进行插入排序。我们来看一下Golang中的插入排序。

插入排序

插入排序的关键在于将未排序的元素和已排序元素从后往前依次比较,找到相应的位置进行插入。

// Insertion sort
func insertionSort(data Interface, a, b int) {
  for i := a + 1; i < b; i++ {
    for j := i; j > a && data.Less(j, j-1); j-- {
      data.Swap(j, j-1)
    }
  }
}

从上面的代码我们可以很轻易的看到插入排序的过程。那么我们再看看长度大于12的排序算法。

func quickSort(data Interface, a, b, maxDepth int) {
  for b-a > 12 { // Use ShellSort for slices <= 12 elements
    if maxDepth == 0 {
      heapSort(data, a, b)
      return
    }
    maxDepth--
    mlo, mhi := doPivot(data, a, b)
    // Avoiding recursion on the larger subproblem guarantees
    // a stack depth of at most lg(b-a).
    if mlo-a < b-mhi {
      quickSort(data, a, mlo, maxDepth)
      a = mhi // i.e., quickSort(data, mhi, b)
    } else {
      quickSort(data, mhi, b, maxDepth)
      b = mlo // i.e., quickSort(data, a, mlo)
    }
  }
  if b-a > 1 {
    // Do ShellSort pass with gap 6
    // It could be written in this simplified form cause b-a <= 12
    ...
}

首先在maxDepth为0的情况下,使用堆排序,这个是什么意思呢,就是当递归到最大深度的时候,使用堆排序。那么在不为零的时候我们可以看出使用的就是快速排序,不过在快速排序中,又进行了一步优化,也就是找中位数 doPivot() 这个方法,我们来看一下Golang是如何做的。

寻找中位数

因为快速排序的关键就是找到一个合适的分界值,最好的当然就是中位数,这样可以将元素平均的分为两个部分,整体分割的次数显然也会减少,最后能够降低整体耗费的时间。所以快速排序如果可以很快捷的找到中位数,那么能够大大的增加排序的效率,我们来看看这一段代码:

func doPivot(data Interface, lo, hi int) (midlo, midhi int) {
  m := int(uint(lo+hi) >> 1) // Written like this to avoid integer overflow.
  if hi-lo > 40 {
    // Tukey's ``Ninther,'' median of three medians of three.
    s := (hi - lo) / 8
    medianOfThree(data, lo, lo+s, lo+2*s)
    medianOfThree(data, m, m-s, m+s)
    medianOfThree(data, hi-1, hi-1-s, hi-1-2*s)
  }
  medianOfThree(data, lo, m, hi-1)

  // Invariants are:
  // data[lo] = pivot (set up by ChoosePivot)
  // data[lo < i < a] < pivot
  // data[a <= i < b] <= pivot
  // data[b <= i < c] unexamined
  // data[c <= i < hi-1] > pivot
  // data[hi-1] >= pivot
  pivot := lo
  a, c := lo+1, hi-1

  for ; a < c && data.Less(a, pivot); a++ {
  }
  b := a
  for {
    for ; b < c && !data.Less(pivot, b); b++ { // data[b] <= pivot
    }
    for ; b < c && data.Less(pivot, c-1); c-- { // data[c-1] > pivot
    }
    if b >= c {
      break
    }
    // data[b] > pivot; data[c-1] <= pivot
    data.Swap(b, c-1)
    b++
    c--
  }
  // If hi-c<3 then there are duplicates (by property of median of nine).
  // Let's be a bit more conservative, and set border to 5.
  protect := hi-c < 5
  if !protect && hi-c < (hi-lo)/4 {
    // Lets test some points for equality to pivot
    dups := 0
    if !data.Less(pivot, hi-1) { // data[hi-1] = pivot
      data.Swap(c, hi-1)
      c++
      dups++
    }
    if !data.Less(b-1, pivot) { // data[b-1] = pivot
      b--
      dups++
    }
    // m-lo = (hi-lo)/2 > 6
    // b-lo > (hi-lo)*3/4-1 > 8
    // ==> m < b ==> data[m] <= pivot
    if !data.Less(m, pivot) { // data[m] = pivot
      data.Swap(m, b-1)
      b--
      dups++
    }
    // if at least 2 points are equal to pivot, assume skewed distribution
    protect = dups > 1
  }
  if protect {
    // Protect against a lot of duplicates
    // Add invariant:
    // data[a <= i < b] unexamined
    // data[b <= i < c] = pivot
    for {
      for ; a < b && !data.Less(b-1, pivot); b-- { // data[b] == pivot
      }
      for ; a < b && data.Less(a, pivot); a++ { // data[a] < pivot
      }
      if a >= b {
        break
      }
      // data[a] == pivot; data[b-1] < pivot
      data.Swap(a, b-1)
      a++
      b--
    }
  }
  // Swap pivot into middle
  data.Swap(pivot, b-1)
  return b - 1, c
}

这一段代码看起来很长,但是我们可以分部分来看,首先来看第一部分,也就是长度大于40的情况,这一种情况在Golang看来就是数据比较多的情况了,那么它调用了三次medianOfThree(),这个是什么呢?我们来看一下:

John Tukey’s median of medians

Tukey的这个解决办法理解起来很简单,我们来举个例子,比如我们想要求一下序列的中位数: y_1,y_2,y_3,...,y_9 ,那么我们可以将数组分成三个部分: y_a 是前三个数的中位数,

y_b 是中间三个数的中位数, y_c 是最后三个数的中位数。那么再找到这三个数的中位数,则认为这个数是整个序列的中位数。看到这里很多人肯定有疑问了,因为这个方法肯定不对,而且很容易就举到例子了,比如3, 1, 4, 4, 5, 9, 9, 8, 2. 然后

yA = median( 3, 1, 4 ) = 3

yB = median( 4, 5, 9 ) = 5

yC = median( 9, 8, 2 ) = 8

然后median( 3, 5, 8 ) = 5,但是实际上这个序列的中位数是4。

这就涉及到Tukey所要解决的问题究竟是什么了,Tukey并不是为了解决寻找中位数这个问题,而是如何在大数据中找到一个接近中位数的数字,这个需求其实有很大的用处,因为很多时候我们可以牺牲一部分精确值,而节省我们很多的工作,比如这次的快速排序,一个精确的中位数是会节省我们很多的工作,但是寻找这个精确的中位数在数据量很大时则会浪费很多工作,那么我们可以采用Tukey的方式,第一会节省很多内存,因为我们不需要一下子帮所有的数据都读入进来,而是按照我们设定的步长来读取。第二是节省很多的时间,因为数据长度的减少,我们排序的时间也会缩短很多。而且如果需要更为精确的中位数,我们可以将步长继续缩短。所以Tukey的这个方案就很适合快速排序分界值的查找。

那么我们来看一下medianOfThree()这个代码,其实很简单, 就是将中位数放在m1上。:

// medianOfThree moves the median of the three values data[m0], data[m1], data[m2] into data[m1].
func medianOfThree(data Interface, m1, m0, m2 int) {
  // sort 3 elements
  if data.Less(m1, m0) {
    data.Swap(m1, m0)
  }
  // data[m0] <= data[m1]
  if data.Less(m2, m1) {
    data.Swap(m2, m1)
    // data[m0] <= data[m2] && data[m1] < data[m2]
    if data.Less(m1, m0) {
      data.Swap(m1, m0)
    }
  }
  // now data[m0] <= data[m1] <= data[m2]
}

那么经过上面的筛选,我们将“中位数”放到了lo上,我们用pivot标记起来。那么接下来这个函数干了什么呢,就是将数据不是分成两部分,而是三部分,也就是得到小于pivot,等于pivot和大于pivot这三个部分,同时将两个分界值返回。而这个寻找过程是一个逐渐逼近的过程,我们来分别看一下:

for ; a < c && data.Less(a, pivot); a++ {
  }
  b := a
  for {
    for ; b < c && !data.Less(pivot, b); b++ { // data[b] <= pivot
    }
    for ; b < c && data.Less(pivot, c-1); c-- { // data[c-1] > pivot
    }
    if b >= c {
      break
    }
    // data[b] > pivot; data[c-1] <= pivot
    data.Swap(b, c-1)
    b++
    c--
  }

这个过程就是保证:

(lo,a]: < pivot

(a,b): <= pivot

[c-1,hi-1): > pivot

而data[b] > pivot; data[c-1] <= pivot

这个只是第一步,简单的将数据进行筛选,我们注意到现在[b,c-1]还没有筛选过,那么接下来我们就是将所有等于pivot的移到这个区间来。

// If hi-c<3 then there are duplicates (by property of median of nine).
  // Let's be a bit more conservative, and set border to 5.
  protect := hi-c < 5 //其实如果大于pivot的少于3个就认为等于pivot的不止一个了,这里Golang将这个数放大到5个,是为了剪枝
  if !protect && hi-c < (hi-lo)/4 { // 如果上面大于5,并不意味着没有重复,要看一下大于pivot的这段是不是占到整个序列的1/4,如果没有,还是可能存在pivot重复
    // Lets test some points for equality to pivot
    dups := 0
    if !data.Less(pivot, hi-1) { // 因为data[hi-1]肯定大于等于pivot,这个是之前medianOfThree决定的,所以data[hi-1] = pivot
      data.Swap(c, hi-1)
      c++
      dups++
    }
    if !data.Less(b-1, pivot) { //因为data[b-1]肯定小于等于pivot,所以这里data[b-1] = pivot
      b--
      dups++
    }
    // m-lo = (hi-lo)/2 > 6
    // b-lo > (hi-lo)*3/4-1 > 8
    // ==> m < b ==> data[m] <= pivot
    if !data.Less(m, pivot) { //这里可以看一下上面的解释,得出data[m] = pivot
      data.Swap(m, b-1)
      b--
      dups++
    }
    // if at least 2 points are equal to pivot, assume skewed distribution
    protect = dups > 1 // 如果大于1,则认为存在等于pivot的数
  }
  if protect { // 存在pivot重复的话要移动这些数据,也就是进一步细分[a,b)这个区间
    // Protect against a lot of duplicates
    // Add invariant:
    // data[a <= i < b] unexamined
    // data[b <= i < c] = pivot
    for {
      for ; a < b && !data.Less(b-1, pivot); b-- { // data[b] == pivot
      }
      for ; a < b && data.Less(a, pivot); a++ { // data[a] < pivot
      }
      if a >= b {
        break
      }
      // data[a] == pivot; data[b-1] < pivot
      data.Swap(a, b-1)
      a++
      b--
    }
  }
  // Swap pivot into middle
  data.Swap(pivot, b-1)

具体的可以参考上面的注释。那么这些都明白了之后,最后就剩下一个标准的快速排序。

func quickSort(data Interface, a, b, maxDepth int) {
  ...
    // Avoiding recursion on the larger subproblem guarantees
    // a stack depth of at most lg(b-a).
    if mlo-a < b-mhi {
      quickSort(data, a, mlo, maxDepth)
      a = mhi // i.e., quickSort(data, mhi, b)
    } else {
      quickSort(data, mhi, b, maxDepth)
      b = mlo // i.e., quickSort(data, a, mlo)
    }
  }
  ...
}

即使到了这一步,Golang还是希望继续优化,也就是先排长度小的,因为中位数只是个近似值,所以先排长度小的可以减少快速排序的时间复杂度。

总结

Golang的排序源码看完之后,其实有很大的启示,我们简单总结一下。

1:就是边界条件的判定,这个是非常重要的,因为没有判断边界条件的代码是不安全的,很容易出现崩溃和异常。而下面的这个代码就是防止大数溢出的典型:

m := int(uint(lo+hi) >> 1) // Written like this to avoid integer overflow.

2:就是不断的优化,这里面不断的变换排序方式,就是在根据数据量的特点,保证内存和时间都能做到最少。而且即使是最后的快速排序,也在不断的进行优化。这个真的是让我感到非常的受益。

3:就是子问题拆分,这里面不同的数据量进行不同排序方式,就是将一个复杂的情况拆成合理的情况,然后选用最为合理的方式。

4:代码的风格,无论是代码的简洁性和封装性,都在源码中得到的一定的展现。使得源码的可读性非常强。

看完源码,不得不说学到了很多。也对自己以后写代码有了一定的思考,源码的作者带来的其实不止上面这么多,即使写这篇文章的时候阅读代码,收益又有很多,真的是需要多读源码了。