回顾
上一节我们构建了一个redis连接池,最终通过如下的方式发送命令
func main() {
client := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
})
client.SendCommand("SET hello world")
发送命令的方式是通过SendComman函数传递完成的命令字符串,但是作为一个客户端我们更习惯
Set(key string, value interface{}, expiration time.Duration)
这种写法。
最终实现后的效果
func main() {
client := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
})
client.Set("hello","world",0)
}
实现流程图
代码结构
| internal
|pool
|conn.go 连接池每个连接的具体结构
|pool.go 连接池
|proto
|writer.go 写文件
|util
|strconv.go 类型转化函数
|log.go 日志处理函数
|command.go 命令参数处理文件
|commands.go 客户端命令实现文件
|options.go 配置文件
|redis.go redis客户端对外主文件,Client结构体存放的位置
实现代码
redis.go
package redis
import(
"github.com/learn-go/redis/internal/pool"
"github.com/learn-go/redis/internal/proto"
)
type Client struct {
//配置参数
opt *Options
//连接池
connPool pool.Pooler
//命令集
cmdable
//处理函数
process func(Cmder) error
}
func (c *Client) init() {
c.process = c.defaultProcess
c.cmdable.setProcessor(c.Process)
}
func (c *Client) Process(cmd Cmder) error {
return c.process(cmd)
}
func (c *Client) defaultProcess(cmd Cmder) error {
cn, err := c.getConn()
if err != nil {
cmd.setErr(err)
return err
}
//写命令到服务端
err = cn.WithWriter(c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmd(wr, cmd)
})
if err != nil {
c.releaseConn(cn)
cmd.setErr(err)
return err
}
c.releaseConn(cn)
return cmd.Err()
}
func (c *Client) releaseConn(cn *pool.Conn){
c.connPool.Put(cn)
}
func (c *Client) getConn() (*pool.Conn, error) {
cn, err := c.connPool.Get()
if err != nil {
return nil, err
}
return cn, nil
}
func NewClient(opt *Options) *Client {
opt.init()
c := Client{
opt: opt,
connPool: newConnPool(opt),
}
c.init()
return &c
}
commands.go
package redis
import(
"time"
"github.com/learn-go/redis/internal"
)
func usePrecise(dur time.Duration) bool {
return dur < time.Second || dur%time.Second != 0
}
func formatMs(dur time.Duration) int64 {
if dur > 0 && dur < time.Millisecond {
internal.Logf(
"specified duration is %s, but minimal supported value is %s",
dur, time.Millisecond,
)
}
return int64(dur / time.Millisecond)
}
func formatSec(dur time.Duration) int64 {
if dur > 0 && dur < time.Second {
internal.Logf(
"specified duration is %s, but minimal supported value is %s",
dur, time.Second,
)
}
return int64(dur / time.Second)
}
type Cmdable interface {
Set(key string, value interface{}, expiration time.Duration)
}
type cmdable struct {
process func(cmd Cmder) error
}
func (c *cmdable) setProcessor(fn func(Cmder) error) {
c.process = fn
}
func (c *cmdable) Set(key string, value interface{}, expiration time.Duration) {
args := make([]interface{}, 3, 4)
args[0] = "set"
args[1] = key
args[2] = value
if expiration > 0 {
if usePrecise(expiration) {
args = append(args, "px", formatMs(expiration))
} else {
args = append(args, "ex", formatSec(expiration))
}
}
cmd := NewCmd(args...)
c.process(cmd)
}
command.go
package redis
import(
"fmt"
"strings"
"github.com/learn-go/redis/internal/proto"
)
type Cmder interface {
Name() string
Args() []interface{}
stringArg(int) string
setErr(error)
Err() error
}
func writeCmd(wr *proto.Writer, cmds ...Cmder) error {
for _, cmd := range cmds {
err := wr.WriteArgs(cmd.Args())
if err != nil {
return err
}
}
return nil
}
func cmdString(cmd Cmder, val interface{}) string {
var ss []string
for _, arg := range cmd.Args() {
ss = append(ss, fmt.Sprint(arg))
}
s := strings.Join(ss, " ")
if err := cmd.Err(); err != nil {
return s + ": " + err.Error()
}
if val != nil {
switch vv := val.(type) {
case []byte:
return s + ": " + string(vv)
default:
return s + ": " + fmt.Sprint(val)
}
}
return s
}
//-----------------------------------------
type baseCmd struct {
_args []interface{}
err error
}
func (cmd *baseCmd) Err() error {
return cmd.err
}
func (cmd *baseCmd) Args() []interface{} {
return cmd._args
}
func (cmd *baseCmd) stringArg(pos int) string {
if pos < 0 || pos >= len(cmd._args) {
return ""
}
s, _ := cmd._args[pos].(string)
return s
}
func (cmd *baseCmd) setErr(e error) {
cmd.err = e
}
func (cmd *baseCmd) Name() string {
if len(cmd._args) > 0 {
// Cmd name must be lower cased.
s := strings.ToLower(cmd.stringArg(0))
cmd._args[0] = s
return s
}
return ""
}
//--------------------------------------------------------------
type Cmd struct {
baseCmd
}
func NewCmd(args ...interface{}) *Cmd {
return &Cmd{
baseCmd: baseCmd{_args: args},
}
}
/internal/proto/writer.go
package proto
import (
"bufio"
"encoding"
"fmt"
"io"
"strconv"
"github.com/learn-go/redis/internal/util"
)
const (
ErrorReply = '-'
StatusReply = '+'
IntReply = ':'
StringReply = '$'
ArrayReply = '*'
)
type Writer struct {
wr *bufio.Writer
lenBuf []byte
numBuf []byte
}
func NewWriter(wr io.Writer) *Writer {
return &Writer{
wr: bufio.NewWriter(wr),
lenBuf: make([]byte, 64),
numBuf: make([]byte, 64),
}
}
func (w *Writer) WriteArgs(args []interface{}) error {
err := w.wr.WriteByte(ArrayReply)
if err != nil {
return err
}
err = w.writeLen(len(args))
if err != nil {
return err
}
for _, arg := range args {
err := w.writeArg(arg)
if err != nil {
return err
}
}
return nil
}
func (w *Writer) writeLen(n int) error {
w.lenBuf = strconv.AppendUint(w.lenBuf[:0], uint64(n), 10)
w.lenBuf = append(w.lenBuf, '\r', '\n')
_, err := w.wr.Write(w.lenBuf)
return err
}
func (w *Writer) writeArg(v interface{}) error {
switch v := v.(type) {
case nil:
return w.string("")
case string:
return w.string(v)
case []byte:
return w.bytes(v)
case int:
return w.int(int64(v))
case int8:
return w.int(int64(v))
case int16:
return w.int(int64(v))
case int32:
return w.int(int64(v))
case int64:
return w.int(v)
case uint:
return w.uint(uint64(v))
case uint8:
return w.uint(uint64(v))
case uint16:
return w.uint(uint64(v))
case uint32:
return w.uint(uint64(v))
case uint64:
return w.uint(v)
case float32:
return w.float(float64(v))
case float64:
return w.float(v)
case bool:
if v {
return w.int(1)
} else {
return w.int(0)
}
case encoding.BinaryMarshaler:
b, err := v.MarshalBinary()
if err != nil {
return err
}
return w.bytes(b)
default:
return fmt.Errorf(
"redis: can't marshal %T (implement encoding.BinaryMarshaler)", v)
}
}
func (w *Writer) bytes(b []byte) error {
err := w.wr.WriteByte(StringReply)
if err != nil {
return err
}
err = w.writeLen(len(b))
if err != nil {
return err
}
_, err = w.wr.Write(b)
if err != nil {
return err
}
return w.crlf()
}
func (w *Writer) string(s string) error {
return w.bytes(util.StringToBytes(s))
}
func (w *Writer) uint(n uint64) error {
w.numBuf = strconv.AppendUint(w.numBuf[:0], n, 10)
return w.bytes(w.numBuf)
}
func (w *Writer) int(n int64) error {
w.numBuf = strconv.AppendInt(w.numBuf[:0], n, 10)
return w.bytes(w.numBuf)
}
func (w *Writer) float(f float64) error {
w.numBuf = strconv.AppendFloat(w.numBuf[:0], f, 'f', -1, 64)
return w.bytes(w.numBuf)
}
func (w *Writer) crlf() error {
err := w.wr.WriteByte('\r')
if err != nil {
return err
}
return w.wr.WriteByte('\n')
}
func (w *Writer) Reset(wr io.Writer) {
w.wr.Reset(wr)
}
func (w *Writer) Flush() error {
return w.wr.Flush()
}
conn.go(增加部分)
package pool
import (
"net"
"sync/atomic"
"time"
"fmt"
"github.com/learn-go/redis/internal/proto"
)
var noDeadline = time.Time{}
type Conn struct {
netConn net.Conn
wr *proto.Writer
//是否属于连接池的连接
pooled bool
//连接创建时间
createdAt time.Time
//已使用时间(原子更新操作)
usedAt atomic.Value
}
func NewConn(netConn net.Conn) *Conn {
cn := &Conn{
netConn: netConn,
createdAt: time.Now(),
}
cn.wr = proto.NewWriter(netConn)
cn.SetUsedAt(time.Now())
return cn
}
func (cn *Conn) UsedAt() time.Time {
return cn.usedAt.Load().(time.Time)
}
func (cn *Conn) SetUsedAt(tm time.Time) {
cn.usedAt.Store(tm)
}
func (cn *Conn) RemoteAddr() net.Addr {
return cn.netConn.RemoteAddr()
}
func (cn *Conn) Close() error {
return cn.netConn.Close()
}
//设置写超时
func (cn *Conn) setWriteTimeout(timeout time.Duration) error {
now := time.Now()
cn.SetUsedAt(now)
if timeout > 0 {
return cn.netConn.SetWriteDeadline(now.Add(timeout))
}
return cn.netConn.SetWriteDeadline(noDeadline)
}
func (cn *Conn) WithWriter(timeout time.Duration, fn func(wr *proto.Writer) error) error {
_ = cn.setWriteTimeout(timeout)
//相当于执行conn.go内writeCmd命令,
//将参数转化成服务端协议
firstErr := fn(cn.wr)
//发送到服务端
err := cn.wr.Flush()
if err != nil && firstErr == nil {
fmt.Println(err)
firstErr = err
}
return firstErr
}