回顾

    上一篇博客我们构建了一个简单的同步执行的能发送命令到redis服务器的例子,这次我们尝试构建一个连接池,可以让GO并发执行redis命令

最终的目录结构

其中 redis.go是供外部调用的主文件,options.go是可配置的相关参数,internal/pool下是跟连接池相关的

最终想要实现成的结果

package main

import (
	"github.com/learn-go/redis"

)

func main() {
	
	client := redis.NewClient(&redis.Options{
		Addr:     "localhost:6379",
	})

	client.SendCommand("SET hello world")

}

实现流程图

1. NewClient过程

2.SendCommand过程

 

具体实现代码

redis.go

package redis

import(
	"github.com/learn-go/redis/internal/pool"
)

type Client struct {
	opt      *Options
	connPool pool.Pooler

}

//发送命令
func (c *Client) SendCommand(cmd string) error {
	cn, err := c.getConn()
	if err != nil {
		return err
	}

	cn.WithWrite(cmd)
	c.releaseConn(cn)

	return err

}
//释放连接
func (c *Client) releaseConn(cn *pool.Conn){
	c.connPool.Put(cn)
}
//获取一个可用连接
func (c *Client) getConn() (*pool.Conn, error) {
	cn, err := c.connPool.Get()
	if err != nil {
		return nil, err
	}
	return cn, nil
}

//生成一个redis客户端
func NewClient(opt *Options) *Client {
	opt.init()
	c := Client{
		opt:      opt,
		connPool: newConnPool(opt),
	}

	return &c   
}

options.go

package redis

import(
	"time"
	"runtime"
	"net"
	"github.com/learn-go/redis/internal/pool"
)

type Options struct {
	// The network type, either tcp or unix.
	// Default is tcp.

	//network类型,
	Network string
	// 服务器地址 格式是host:port  localhost:6379
	Addr string

	//连接函数
	Dialer func() (net.Conn, error)

	// Dial timeout for establishing new connections.
	// Default is 5 seconds.
	// 连接超时时间 默认是5秒
	DialTimeout time.Duration

	//从连接池获取连接的等待时间
	PoolTimeout time.Duration

	// Maximum number of socket connections.
	// Default is 10 connections per every CPU as reported by runtime.NumCPU.
	//连接池存储连接的数量
	//默认为 10 * runtime.NumCPU
	PoolSize int

	//失败重连次数
	//默认不重连
	MaxRetries int


	//重连最小时间间隔 默认 8毫秒 
	//-1是禁止
	MinRetryBackoff time.Duration

	//重连最大时间间隔 默认 512毫秒 
	//-1是禁止
	MaxRetryBackoff time.Duration
}

func (opt *Options) init() {
	if opt.Network == "" {
		opt.Network = "tcp"
	}
	if opt.Addr == "" {
		opt.Addr = "localhost:6379"
	}
	if opt.Dialer == nil {
		opt.Dialer = func() (net.Conn, error) {
			netDialer := &net.Dialer{
				Timeout:   opt.DialTimeout,
				KeepAlive: 5 * time.Minute,
			}
			return netDialer.Dial(opt.Network, opt.Addr)
		}
	}
	if opt.PoolSize == 0 {
		opt.PoolSize = 10 * runtime.NumCPU()
	}
	if opt.DialTimeout == 0 {
		opt.DialTimeout = 5 * time.Second
	}
	if opt.PoolTimeout == 0 {
		opt.PoolTimeout = 5 * time.Second
	}
}

func newConnPool(opt *Options) *pool.ConnPool {
	return pool.NewConnPool(&pool.Options{
		Dialer:             opt.Dialer,
		PoolSize:           opt.PoolSize,
		PoolTimeout:        opt.PoolTimeout,
	})
}

pool.go

package pool

import(
	"sync"
	"net"
	"time"
	"errors"
	"sync/atomic"
)

var ErrClosed = errors.New("redis: client is closed")
var ErrPoolTimeout = errors.New("redis: connection pool timeout")

//定时触发事件,用来设置从连接池获取连接的超时时间
var timers = sync.Pool{
	New: func() interface{} {
		t := time.NewTimer(time.Hour)
		t.Stop()
		return t
	},
}

type Pooler interface {
	NewConn() (*Conn, error)
	CloseConn(*Conn) error

	Get() (*Conn, error)
	Put(*Conn)
	Remove(*Conn)

	Len() int
	IdleLen() int
}

type Options struct {
	//和服务器建立连接的方法
	Dialer  func() (net.Conn, error)

	//连接池上线
	PoolSize           int
	//最少的空闲连接数
	MinIdleConns       int

	//连接最大持续时间(连接使用超过这个时间,认为连接时一个太陈旧的连接)
	MaxConnAge         time.Duration
	//获取连接池的一个连接超时时间
	PoolTimeout        time.Duration

	//一个连接多久未使用认为该连接时一个空闲连接
	IdleTimeout        time.Duration

	
}

type ConnPool struct {
	opt *Options

	//建立连接错误次数(原子操作) 如果建立连接错误次数超过配置的opt.Poolsize则会单独生成协程无限重试
	dialErrorsNum uint32 // atomic

	lastDialErrorMu sync.RWMutex
	lastDialError   error

	//控制同时连接数达到上限,就需要等待
	queue chan struct{}

	//锁
	connsMu      sync.Mutex

	//当前连接池内连接的个数
	poolSize     int

	//所有连接存放数组
	conns        []*Conn
	//空闲连接存放数组
	idleConns    []*Conn
	//空闲连接个数
	idleConnsLen int

	_closed uint32 // atomic

}
//连接池存在的连接数
func (p *ConnPool) Len() int {
	p.connsMu.Lock()
	n := len(p.conns)
	p.connsMu.Unlock()
	return n
}
//连接池现空闲的连接数
func (p *ConnPool) IdleLen() int {
	p.connsMu.Lock()
	n := p.idleConnsLen
	p.connsMu.Unlock()
	return n
}
//新建一个不在连接池内的连接
func (p *ConnPool) NewConn() (*Conn, error) {
	return p._NewConn(false)
}
//从连接池 获取一个可用连接
func (p *ConnPool) Get() (*Conn, error) {
	if p.closed() {
		return nil, ErrClosed
	}

	err := p.waitTurn()
	if err != nil {
		return nil, err
	}
	for {
		p.connsMu.Lock()
		//从连接池获取一个连接
		cn := p.popIdle()
		p.connsMu.Unlock()

		if cn == nil {
			break
		}

		if p.isStaleConn(cn) {
			_ = p.CloseConn(cn)
			continue
		}

		return cn, nil
	}
	newcn, err := p._NewConn(true)
	if err != nil {
		p.freeTurn()
		return nil, err
	}

	return newcn, nil
}
//释放一个再用连接到连接池
func (p *ConnPool) Put(cn *Conn) {
    //如果该连接不属于连接池直接remove掉
	if !cn.pooled {
		p.Remove(cn)
		return
	}

	p.connsMu.Lock()
	p.idleConns = append(p.idleConns, cn)
    //空闲连接数+1
	p.idleConnsLen++
	p.connsMu.Unlock()
    //占用的服务已结束
	p.freeTurn()
}
//移除一个连接
func (p *ConnPool) Remove(cn *Conn) {
	p.removeConn(cn)
	p.freeTurn()
	_ = p.closeConn(cn)
}

func (p *ConnPool) _NewConn(pooled bool) (*Conn, error) {
	cn, err := p.newConn(pooled)
	if err != nil {
		return nil, err
	}

	p.connsMu.Lock()
	p.conns = append(p.conns, cn)
	if pooled {
		if p.poolSize < p.opt.PoolSize {
			p.poolSize++
		} else {
			cn.pooled = false
		}
	}
	p.connsMu.Unlock()
	return cn, nil
}

func (p *ConnPool) getTurn() {
	p.queue <- struct{}{}
}

func (p *ConnPool) waitTurn() error {
	select {
	case p.queue <- struct{}{}:
		return nil
	default:
		timer := timers.Get().(*time.Timer)
		//设置超时时间
		timer.Reset(p.opt.PoolTimeout)

		select {
		case p.queue <- struct{}{}:
			if !timer.Stop() {
				<-timer.C
			}
			timers.Put(timer)
			return nil
		case <-timer.C:
			timers.Put(timer)
			return ErrPoolTimeout
		}
	}
}

func (p *ConnPool) freeTurn() {
	<-p.queue
}
//弹出一个空闲连接
func (p *ConnPool) popIdle() *Conn {
	if len(p.idleConns) == 0 {
		return nil
	}

	idx := len(p.idleConns) - 1
	cn := p.idleConns[idx]
	p.idleConns = p.idleConns[:idx]
	p.idleConnsLen--
	return cn
}
//关闭连接池
func (p *ConnPool) closed() bool {
	return atomic.LoadUint32(&p._closed) == 1
}

func NewConnPool(opt *Options) *ConnPool {
	p := &ConnPool{
		opt: opt,

		queue:     make(chan struct{}, opt.PoolSize),
		conns:     make([]*Conn, 0, opt.PoolSize),
		idleConns: make([]*Conn, 0, opt.PoolSize),
	}
	for i := 0; i < opt.MinIdleConns; i++ {
		p.checkMinIdleConns()
	}
	return p
}

//判断改连接是否时一个陈旧连接
func (p *ConnPool) isStaleConn(cn *Conn) bool {
	if p.opt.IdleTimeout == 0 && p.opt.MaxConnAge == 0 {
		return false
	}

	now := time.Now()
    //最近一次使用时间已经超过配置的空闲时间
	if p.opt.IdleTimeout > 0 && now.Sub(cn.UsedAt()) >= p.opt.IdleTimeout {
		return true
	}
    //从创建到现在的时间已经超过最大接受时间
	if p.opt.MaxConnAge > 0 && now.Sub(cn.createdAt) >= p.opt.MaxConnAge {
		return true
	}

	return false
}

//每当空闲连接数小于配置的最小空闲连接数就会创立新的连接,直到不小于配置数
func (p *ConnPool) checkMinIdleConns() {
	if p.opt.MinIdleConns == 0 {
		return
	}
	if p.poolSize < p.opt.PoolSize && p.idleConnsLen < p.opt.MinIdleConns {
		//这里需要注意代码顺序
		//由于创建过程是异步的,因此需要先++
		p.poolSize++
		p.idleConnsLen++
		//异步创建
		go p.addIdleConn()
	}
}


//创建一个空闲连接
func (p *ConnPool) addIdleConn() {
	cn, err := p.newConn(true)
	if err != nil {
		return
	}

	p.connsMu.Lock()
	p.conns = append(p.conns, cn)
	p.idleConns = append(p.idleConns, cn)
	p.connsMu.Unlock()
}

func (p *ConnPool) newConn(pooled bool) (*Conn, error) {
	if p.closed() {
		return nil, ErrClosed
	}

	if atomic.LoadUint32(&p.dialErrorsNum) >= uint32(p.opt.PoolSize) {
		return nil, p.getLastDialError()
	}

	netConn, err := p.opt.Dialer()
	if err != nil {
		p.setLastDialError(err)
		//已达错误上限,单独起一个协程一直重试
		if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.opt.PoolSize) {
			go p.tryDial()
		}
		return nil, err
	}

	cn := NewConn(netConn)
	cn.pooled = pooled
	return cn, nil
}
//无限重试连接,如果能连接成功将错误次数重置为0,可以让下次的newConn方法重新尝试连接
func (p *ConnPool) tryDial() {
	for {
		if p.closed() {
			return
		}

		conn, err := p.opt.Dialer()
		if err != nil {
			p.setLastDialError(err)
			time.Sleep(time.Second)
			continue
		}

		atomic.StoreUint32(&p.dialErrorsNum, 0)
		_ = conn.Close()
		return
	}
}

func (p *ConnPool) setLastDialError(err error) {
	p.lastDialErrorMu.Lock()
	p.lastDialError = err
	p.lastDialErrorMu.Unlock()
}

func (p *ConnPool) getLastDialError() error {
	p.lastDialErrorMu.RLock()
	err := p.lastDialError
	p.lastDialErrorMu.RUnlock()
	return err
}

func (p *ConnPool) CloseConn(cn *Conn) error {
	p.removeConn(cn)
	return p.closeConn(cn)
}

func (p *ConnPool) removeConn(cn *Conn) {
	p.connsMu.Lock()
	for i, c := range p.conns {
		if c == cn {
			p.conns = append(p.conns[:i], p.conns[i+1:]...)
			if cn.pooled {
				p.poolSize--
				p.checkMinIdleConns()
			}
			break
		}
	}
	p.connsMu.Unlock()
}

func (p *ConnPool) closeConn(cn *Conn) error {
	return cn.Close()
}

conn.go

package pool

import (
	"net"
	"sync/atomic"
	"time"
	"fmt"
	"strings"
)


type Conn struct {
	netConn net.Conn


	//是否属于连接池的连接
	pooled    bool

	//连接创建时间
	createdAt time.Time

	//已使用时间(原子更新操作)
	usedAt    atomic.Value

	
}

func NewConn(netConn net.Conn) *Conn {
	cn := &Conn{
		netConn:   netConn,
		createdAt: time.Now(),
	}
	cn.SetUsedAt(time.Now())
	return cn
}

func (cn *Conn) UsedAt() time.Time {
	return cn.usedAt.Load().(time.Time)
}

func (cn *Conn) SetUsedAt(tm time.Time) {
	cn.usedAt.Store(tm)
}

func (cn *Conn) Write(b []byte) (int, error) {
	now := time.Now()
	//更新使用时间
	cn.SetUsedAt(now)
	return cn.netConn.Write(b)
}

func (cn *Conn) RemoteAddr() net.Addr {
	return cn.netConn.RemoteAddr()
}


func (cn *Conn) Close() error {
	return cn.netConn.Close()
}


func (cn *Conn) WithWrite(cmd string) {
	cmd_argv := strings.Fields(cmd)
	protocl_cmd := fmt.Sprintf("*%d\r\n", len(cmd_argv))
	for _, arg := range cmd_argv {
			protocl_cmd += fmt.Sprintf("$%d\r\n", len(arg))
			protocl_cmd += arg
			protocl_cmd += "\r\n"
	}
	//redis服务器接受的命令协议
	fmt.Printf("%q\n", protocl_cmd)
	cn.netConn.Write([]byte(protocl_cmd))
}