回顾

     上一节我们构建了一个redis连接池,最终通过如下的方式发送命令


func main() {
	
	client := redis.NewClient(&redis.Options{
		Addr:     "localhost:6379",
	})
 
	client.SendCommand("SET hello world")
 

发送命令的方式是通过SendComman函数传递完成的命令字符串,但是作为一个客户端我们更习惯

Set(key string, value interface{}, expiration time.Duration) 

这种写法。

 

最终实现后的效果

 

func main() {
	
	client := redis.NewClient(&redis.Options{
		Addr:     "localhost:6379",
	})

	client.Set("hello","world",0)

}

实现流程图

 

 

代码结构

| internal

      |pool

              |conn.go                                       连接池每个连接的具体结构

              |pool.go                                        连接池

      |proto

               |writer.go                                       写文件

      |util

               |strconv.go                                      类型转化函数

      |log.go                                                      日志处理函数

|command.go                                                  命令参数处理文件

|commands.go                                                 客户端命令实现文件

|options.go                                                       配置文件

|redis.go                                                           redis客户端对外主文件,Client结构体存放的位置

 

实现代码

redis.go

package redis

import(
	"github.com/learn-go/redis/internal/pool"
	"github.com/learn-go/redis/internal/proto"
)

type Client struct {
	//配置参数
	opt      *Options
	//连接池
	connPool pool.Pooler
	//命令集
	cmdable
	//处理函数
	process           func(Cmder) error
}


func (c *Client) init() {
	c.process = c.defaultProcess
	c.cmdable.setProcessor(c.Process)
}


func (c *Client) Process(cmd Cmder) error {
	return c.process(cmd)
}


func (c *Client) defaultProcess(cmd Cmder) error {
	cn, err := c.getConn()
	if err != nil {
		cmd.setErr(err)  
		return err
	}
	//写命令到服务端
	err = cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error {
		return writeCmd(wr, cmd)
	})
	if err != nil {
		c.releaseConn(cn)
		cmd.setErr(err)
		return err
	}
	c.releaseConn(cn)

	return cmd.Err()
}

func (c *Client) releaseConn(cn *pool.Conn){
	c.connPool.Put(cn)
}

func (c *Client) getConn() (*pool.Conn, error) {
	cn, err := c.connPool.Get()
	if err != nil {
		return nil, err
	}
	return cn, nil
}


func NewClient(opt *Options) *Client {
	opt.init()
	c := Client{
		opt:      opt,
		connPool: newConnPool(opt),
	}
	c.init()

	return &c   
}

commands.go

package redis
import(
	"time"
	"github.com/learn-go/redis/internal"
)

func usePrecise(dur time.Duration) bool {
	return dur < time.Second || dur%time.Second != 0
}

func formatMs(dur time.Duration) int64 {
	if dur > 0 && dur < time.Millisecond {
		internal.Logf(
			"specified duration is %s, but minimal supported value is %s",
			dur, time.Millisecond,
		)
	}
	return int64(dur / time.Millisecond)
}

func formatSec(dur time.Duration) int64 {
	if dur > 0 && dur < time.Second {
		internal.Logf(
			"specified duration is %s, but minimal supported value is %s",
			dur, time.Second,
		)
	}
	return int64(dur / time.Second)
}

type Cmdable interface {
	Set(key string, value interface{}, expiration time.Duration) 
}


type cmdable struct {
	process func(cmd Cmder) error
}

func (c *cmdable) setProcessor(fn func(Cmder) error) {
	c.process = fn
}

func (c *cmdable) Set(key string, value interface{}, expiration time.Duration)  {
	args := make([]interface{}, 3, 4)
	args[0] = "set"
	args[1] = key
	args[2] = value
	if expiration > 0 {
		if usePrecise(expiration) {
			args = append(args, "px", formatMs(expiration))
		} else {
			args = append(args, "ex", formatSec(expiration))
		}
	}
	cmd := NewCmd(args...)
	c.process(cmd)
}

command.go

package redis

import(
	"fmt"
	"strings"
	"github.com/learn-go/redis/internal/proto"
)

type Cmder interface {
	Name() string
	Args() []interface{}
	stringArg(int) string
	setErr(error)


	Err() error
}


func writeCmd(wr *proto.Writer, cmds ...Cmder) error {
	for _, cmd := range cmds {
		err := wr.WriteArgs(cmd.Args())
		if err != nil {
			return err
		}
	}
	return nil
}

func cmdString(cmd Cmder, val interface{}) string {
	var ss []string
	for _, arg := range cmd.Args() {
		ss = append(ss, fmt.Sprint(arg))
	}
	s := strings.Join(ss, " ")
	if err := cmd.Err(); err != nil {
		return s + ": " + err.Error()
	}
	if val != nil {
		switch vv := val.(type) {
		case []byte:
			return s + ": " + string(vv)
		default:
			return s + ": " + fmt.Sprint(val)
		}
	}
	return s

}

//-----------------------------------------

type baseCmd struct {
	_args []interface{}
	err   error
}

func (cmd *baseCmd) Err() error {
	return cmd.err
}

func (cmd *baseCmd) Args() []interface{} {
	return cmd._args
}

func (cmd *baseCmd) stringArg(pos int) string {
	if pos < 0 || pos >= len(cmd._args) {
		return ""
	}
	s, _ := cmd._args[pos].(string)
	return s
}


func (cmd *baseCmd) setErr(e error) {
	cmd.err = e
}

func (cmd *baseCmd) Name() string {
	if len(cmd._args) > 0 {
		// Cmd name must be lower cased.
		s := strings.ToLower(cmd.stringArg(0))
		cmd._args[0] = s
		return s
	}
	return ""
}

//--------------------------------------------------------------

type Cmd struct {
	baseCmd
}

func NewCmd(args ...interface{}) *Cmd {
	return &Cmd{
		baseCmd: baseCmd{_args: args},
	}
}

/internal/proto/writer.go

package proto

import (
	"bufio"
	"encoding"
	"fmt"
	"io"
	"strconv"

	"github.com/learn-go/redis/internal/util"
)

const (
	ErrorReply  = '-'
	StatusReply = '+'
	IntReply    = ':'
	StringReply = '$'
	ArrayReply  = '*'
)

type Writer struct {
	wr *bufio.Writer

	lenBuf []byte
	numBuf []byte
}

func NewWriter(wr io.Writer) *Writer {
	return &Writer{
		wr: bufio.NewWriter(wr),

		lenBuf: make([]byte, 64),
		numBuf: make([]byte, 64),
	}
}

func (w *Writer) WriteArgs(args []interface{}) error {
	err := w.wr.WriteByte(ArrayReply)
	if err != nil {
		return err
	}

	err = w.writeLen(len(args))
	if err != nil {
		return err
	}

	for _, arg := range args {
		err := w.writeArg(arg)
		if err != nil {
			return err
		}
	}

	return nil
}

func (w *Writer) writeLen(n int) error {
	w.lenBuf = strconv.AppendUint(w.lenBuf[:0], uint64(n), 10)
	w.lenBuf = append(w.lenBuf, '\r', '\n')
	_, err := w.wr.Write(w.lenBuf)
	return err
}

func (w *Writer) writeArg(v interface{}) error {
	switch v := v.(type) {
	case nil:
		return w.string("")
	case string:
		return w.string(v)
	case []byte:
		return w.bytes(v)
	case int:
		return w.int(int64(v))
	case int8:
		return w.int(int64(v))
	case int16:
		return w.int(int64(v))
	case int32:
		return w.int(int64(v))
	case int64:
		return w.int(v)
	case uint:
		return w.uint(uint64(v))
	case uint8:
		return w.uint(uint64(v))
	case uint16:
		return w.uint(uint64(v))
	case uint32:
		return w.uint(uint64(v))
	case uint64:
		return w.uint(v)
	case float32:
		return w.float(float64(v))
	case float64:
		return w.float(v)
	case bool:
		if v {
			return w.int(1)
		} else {
			return w.int(0)
		}
	case encoding.BinaryMarshaler:
		b, err := v.MarshalBinary()
		if err != nil {
			return err
		}
		return w.bytes(b)
	default:
		return fmt.Errorf(
			"redis: can't marshal %T (implement encoding.BinaryMarshaler)", v)
	}
}

func (w *Writer) bytes(b []byte) error {
	err := w.wr.WriteByte(StringReply)
	if err != nil {
		return err
	}

	err = w.writeLen(len(b))
	if err != nil {
		return err
	}

	_, err = w.wr.Write(b)
	if err != nil {
		return err
	}

	return w.crlf()
}

func (w *Writer) string(s string) error {
	return w.bytes(util.StringToBytes(s))
}

func (w *Writer) uint(n uint64) error {
	w.numBuf = strconv.AppendUint(w.numBuf[:0], n, 10)
	return w.bytes(w.numBuf)
}

func (w *Writer) int(n int64) error {
	w.numBuf = strconv.AppendInt(w.numBuf[:0], n, 10)
	return w.bytes(w.numBuf)
}

func (w *Writer) float(f float64) error {
	w.numBuf = strconv.AppendFloat(w.numBuf[:0], f, 'f', -1, 64)
	return w.bytes(w.numBuf)
}

func (w *Writer) crlf() error {
	err := w.wr.WriteByte('\r')
	if err != nil {
		return err
	}
	return w.wr.WriteByte('\n')
}

func (w *Writer) Reset(wr io.Writer) {
	w.wr.Reset(wr)
}

func (w *Writer) Flush() error {
	return w.wr.Flush()
}

conn.go(增加部分)

package pool

import (
	"net"
	"sync/atomic"
	"time"
	"fmt"
	"github.com/learn-go/redis/internal/proto"
)
var noDeadline = time.Time{}

type Conn struct {
	netConn net.Conn

	wr       *proto.Writer

	//是否属于连接池的连接
	pooled    bool

	//连接创建时间
	createdAt time.Time

	//已使用时间(原子更新操作)
	usedAt    atomic.Value

	
}

func NewConn(netConn net.Conn) *Conn {
	cn := &Conn{
		netConn:   netConn,
		createdAt: time.Now(),
	}
	cn.wr = proto.NewWriter(netConn)
	cn.SetUsedAt(time.Now())
	return cn
}

func (cn *Conn) UsedAt() time.Time {
	return cn.usedAt.Load().(time.Time)
}

func (cn *Conn) SetUsedAt(tm time.Time) {
	cn.usedAt.Store(tm)
}
func (cn *Conn) RemoteAddr() net.Addr {
	return cn.netConn.RemoteAddr()
}

func (cn *Conn) Close() error {
	return cn.netConn.Close()
}
//设置写超时
func (cn *Conn) setWriteTimeout(timeout time.Duration) error {
	now := time.Now()
	cn.SetUsedAt(now)
	if timeout > 0 {
		return cn.netConn.SetWriteDeadline(now.Add(timeout))
	}
	return cn.netConn.SetWriteDeadline(noDeadline)
}

func (cn *Conn) WithWriter(timeout time.Duration, fn func(wr *proto.Writer) error) error {
	_ = cn.setWriteTimeout(timeout)
	//相当于执行conn.go内writeCmd命令,
	//将参数转化成服务端协议
	firstErr := fn(cn.wr)
	//发送到服务端
	err := cn.wr.Flush()
	if err != nil && firstErr == nil {
		fmt.Println(err)
		firstErr = err
	}
	return firstErr
}