本文代码已上传github,欢迎交流。

最近在学习go语言,正好有遇到需要使用缓存的地方,于是决定自己造个轮子。主要特性如下:

  • 线程安全;
  • 支持被动触发的过期时间;
  • 支持key和value任意类型;
  • 基于双向链表和hash表实现;

双向链表的插入、删除和元素移动效率非常高,LRU缓存通常都有大量的以上操作。使用hash表来存储每个key对应的元素的指针,避免每次查询缓存都需要遍历整个链表,提高效率。

被动的过期的时间表示并不会主动的删除缓存中已经过期的元素,而是在需要使用的时候才去检查是否过期,如果过期的话再去删除。

数据结构

每个缓存的元素至少包含两个:缓存的关键字key、缓存的数据data;为了支持过期时间,每个元素还要有一个值来表示其过期时间;另外基于双向链表实现,还需要指向前一个元素和后一个元素的指针;于是,每个缓存元素的结构定义:

type elem struct {
    key        interface{}
    data       interface{}
    expireTime int64
    next       *elem
    pre        *elem
}

那么对于整个缓存来说,事实上就是一个个元素组成的列表,但是为了更高效的查询,使用一个hash表来存放key对应的元素的指针,提升查询效率,于是cache的结构定义:

type lrucache struct {
    maxSize   int
    elemCount int
    elemList  map[interface{}]*elem
    first     *elem
    last      *elem
    mu        sync.Mutex
}

保存链表首尾元素的指针是为了在淘汰元素和插入元素的时候更高效。

基本方法

一个缓存基本的方法应该包括新建缓存、添加元素、删除元素、查询元素。

新建缓存

新建一个缓存实际上就是新建一个lrucache结构体,并对里面的元素进行初始化:

// New create a new lrucache
// size: max number of element
func New(size int) (*lrucache, error) {
    newCache := new(lrucache)
    newCache.maxSize = size
    newCache.elemCount = 0
    newCache.elemList = make(map[interface{}]*elem)
    return newCache, nil
}

入参表示这个缓存最多能存放的元素的个数,当到达最大个数的时候就开始淘汰最久没使用的元素。

添加元素

Set
// Set create or update an element using key
//      key:    The identity of an element
//      value:  new value of the element
//      ttl:    expire time, unit: second
func (c *lrucache) Set(key interface{}, value interface{}, ttl ...int) error {

    // Ensure ttl are correct
    if len(ttl) > 1 {
        return errors.New("wrong para number, 2 or 3 expected but more than 3 received")
    }
    var elemTTL int64
    if len(ttl) == 1 {
        elemTTL = int64(ttl[0])
    } else {
        elemTTL = -1
    }

    c.mu.Lock()
    defer c.mu.Unlock()

    if e, ok := c.elemList[key]; ok {
        e.data = value
        if elemTTL == -1 {
            e.expireTime = elemTTL
        } else {
            e.expireTime = time.Now().Unix() + elemTTL
        }
        c.mvKeyToFirst(key)
    } else {
        if c.elemCount+1 > c.maxSize {
            if c.checkExpired() <= 0 {
                c.eliminationOldest()
            }
        }
        newElem := &elem{
            key:        key,
            data:       value,
            expireTime: -1,
            pre:        nil,
            next:       c.first,
        }
        if elemTTL != -1 {
            newElem.expireTime = time.Now().Unix() + elemTTL
        }
        if c.first != nil {
            c.first.pre = newElem
        }
        c.first = newElem
        c.elemList[key] = newElem

        c.elemCount++
    }
    return nil
}

如果一个key已经存在就更新它所对应的值,并将这个key对应的元素移动到链表的最前面;如果key不存在就需要新建一个链表元素,流程如下:

由于采用的是过期时间是被动触发的方式,因此在元素满的时候并不能确定是否存在过期的元素,因此目前采用的方式是,当满了之后每次新增元素就去遍历的检查一次过期的元素,时间复杂度为O(n),感觉这种实现方式不太好,但是目前没想到更好的实现方式。

上面使用到的内部方法实现如下:

// updateKeyPtr 更新对应key的指针,放到链表的第一个
func (c *lrucache) mvKeyToFirst(key interface{}) {
    elem := c.elemList[key]
    if elem.pre == nil {
        // 当key是第一个元素时,不做动作
        return
    } else if elem.next == nil {
        // 当key不是第一个元素,但是是最后一个元素时,提到第一个元素去
        elem.pre.next = nil

        c.last = elem.pre

        elem.pre = nil
        elem.next = c.first
        c.first = elem

    } else {
        elem.pre.next = elem.next
        elem.next.pre = elem.pre

        elem.next = c.first
        elem.pre = nil
        c.first = elem
    }
}

func (c *lrucache) eliminationOldest() {
    if c.last == nil {
        return
    }
    if c.last.pre != nil {
        c.last.pre.next = nil
    }
    key := c.last.key
    c.last = c.last.pre
    delete(c.elemList, key)
}

func (c *lrucache) deleteByKey(key interface{}) {
    if v, ok := c.elemList[key]; ok {
        if v.pre == nil && v.next == nil {
            // 当key是第一个元素时,清空元素列表,充值指针和元素计数
            c.elemList = make(map[interface{}]*elem)
            c.elemCount = 0
            c.last = nil
            c.first = nil
            return
        } else if v.next == nil {
            // 当key不是第一个元素,但是是最后一个元素时,修改前一个元素的next指针并修改c.last指针
            v.pre.next = v.next
            c.last = v.pre
        } else if v.pre == nil {
            c.first = v.next
            c.first.pre = nil
        } else {
            // 中间元素,修改前后指针
            v.pre.next = v.next
            v.next.pre = v.pre
        }
        delete(c.elemList, key)
        c.elemCount--
    }
}

// 遍历链表,检查并删除已经过期的元素
func (c *lrucache) checkExpired() int {
    now := time.Now().Unix()
    tmp := c.first
    count := 0
    for tmp != nil {
        if tmp.expireTime != -1 && now > tmp.expireTime {
            c.deleteByKey(tmp.key)
            count++
        }
        tmp = tmp.next
    }
    return count
}

获取元素

Get
// Get Get the value of a cached element by key. If key do not exist, this function will return nil and a error msg
//      key:    The identity of an element
//      return:
//          value:  the cached value, nil if key do not exist
//          err:    error info, nil if value is not nil
func (c *lrucache) Get(key interface{}) (value interface{}, err error) {
    if v, ok := c.elemList[key]; ok {
        if v.expireTime != -1 && time.Now().Unix() > v.expireTime {
            // 如果过期了
            c.deleteByKey(key)
            return nil, errors.New("the key was expired")
        }
        c.mvKeyToFirst(key)
        return v.data, nil
    }
    return nil, errors.New("no value found")
}

删除元素

Delete
// Delete delete an element
func (c *lrucache) Delete(key interface{}) error {
    c.mu.Lock()
    defer c.mu.Unlock()
    if _, ok := c.elemList[key]; !ok {
        return errors.New(fmt.Sprintf("key %T do not exist", key))
    }
    c.deleteByKey(key)
    return nil
}
Set