package main import ( "errors" "fmt" "net" "os" "sync" "time" ) type ( Config struct { //初始化连接数 InitCap int //最大连接数 MaxCap int InitConn func() (interface{}, error) Close func(interface{}) error Validate func(interface{}) error Timeout time.Duration } Pool struct { mu sync.Mutex conns chan *conn //防止一个连接被重复加入 poolMap map[interface{}]int initConn func() (interface{}, error) close func(interface{}) error validate func(interface{}) error timeout time.Duration } conn struct { connect interface{} createTime time.Time } ) var ( ErrClosed = errors.New("pool is closed") ErrInvalidCap = errors.New("invalid cap settings") ErrInvalidInit = errors.New("invalid init func settings") ErrInvalidClose = errors.New("invalid close func settings") ErrInvalidConn = errors.New("invalid conn") ) const addr string = "127.0.0.1:8230" func NewPool(config *Config) (*Pool, error) { if config.InitCap < 0 || config.MaxCap <= 0 || config.InitCap > config.MaxCap { return nil, ErrInvalidCap } if config.InitConn == nil { return nil, ErrInvalidInit } if config.Close == nil { return nil, ErrInvalidClose } pool := &Pool{ conns: make(chan *conn, config.MaxCap), poolMap: make(map[interface{}]int), initConn: config.InitConn, close: config.Close, timeout: config.Timeout, } if config.Validate != nil { pool.validate = config.Validate } for i := 0; i < config.InitCap; i++ { connect, err := pool.initConn() pool.poolMap[connect] = 1 if err != nil { pool.Release() return nil, ErrInvalidInit } pool.conns <- &conn{connect: connect, createTime: time.Now()} } return pool, nil } func (pool *Pool) Get() (interface{}, error) { if pool.conns == nil { return nil, ErrClosed } for { select { case wrapConn := <-pool.conns: if wrapConn == nil { return nil, ErrClosed } //判断是否超时,超时则丢弃 pool.mu.Lock() if timeout := pool.timeout; timeout > 0 { if wrapConn.createTime.Add(timeout).Before(time.Now()) { delete(pool.poolMap, wrapConn.connect) pool.close(wrapConn.connect) pool.mu.Unlock() continue } } if pool.validate != nil { if err := pool.validate(wrapConn.connect); err != nil { delete(pool.poolMap, wrapConn.connect) pool.close(wrapConn.connect) pool.mu.Unlock() continue } } delete(pool.poolMap, wrapConn.connect) pool.mu.Unlock() return wrapConn.connect, nil default: pool.mu.Lock() if pool.initConn == nil { pool.mu.Unlock() return nil, ErrInvalidInit } connect, err := pool.initConn() pool.mu.Unlock() if err != nil { return nil, err } return connect, nil } } } func (pool *Pool) Put(connect interface{}) error { if connect == nil { return ErrInvalidConn } if pool.validate != nil { if err := pool.validate(connect); err != nil { pool.close(connect) return ErrInvalidConn } } pool.mu.Lock() if pool.conns == nil { pool.mu.Unlock() err := pool.close(connect) return err } if _, ok := pool.poolMap[connect]; ok { pool.mu.Unlock() return nil } select { case pool.conns <- &conn{connect: connect, createTime: time.Now()}: pool.poolMap[connect] = 1 pool.mu.Unlock() return nil default: pool.mu.Unlock() //连接池已满,直接关闭该连接 return pool.close(connect) } } // Release 释放连接池中所有连接 func (pool *Pool) Release() { pool.mu.Lock() for wrapConn := range pool.conns { pool.close(wrapConn.connect) } close(pool.conns) pool.conns = nil pool.initConn = nil pool.validate = nil pool.close = nil pool.poolMap = nil pool.mu.Unlock() } func (pool *Pool) Len() int { return len(pool.conns) } //以下是测试代码 func main() { go server() //等待tcp server启动 time.Sleep(2 * time.Second) client() fmt.Println("服务退出") time.Sleep(20 * time.Second) } func client() { initConn := func() (interface{}, error) { return net.Dial("tcp", addr) } close := func(v interface{}) error { return v.(net.Conn).Close() } //创建一个连接池: 初始化5,最大连接30 poolConfig := &Config{ InitCap: 5, MaxCap: 30, InitConn: initConn, Close: close, Timeout: 15 * time.Second, } p, err := NewPool(poolConfig) if err != nil { fmt.Println("err=", err) } //从连接池中取得一个连接 v, _ := p.Get() _, _ = p.Get() //do something conn := v.(net.Conn) conn.Write([]byte("guoqiang")) //将连接放回连接池中 p.Put(v) p.Put(v) p.Put(v) //查看当前连接中的数量 current := p.Len() fmt.Println("len=", current) } func server() { l, err := net.Listen("tcp", addr) if err != nil { fmt.Println("Error listening: ", err) os.Exit(1) } defer l.Close() fmt.Println("Listening on ", addr) for { conn, err := l.Accept() if err != nil { fmt.Println("Error accepting: ", err) } buffer := make([]byte, 20480) conn.Read(buffer) fmt.Printf("Received message %s -> %s message: %s\n", conn.RemoteAddr(), conn.LocalAddr(), buffer) } }