golang200行实现连接池
package main
import (
"errors"
"fmt"
"net"
"os"
"sync"
"time"
)
type (
Config struct {
//初始化连接数
InitCap int
//最大连接数
MaxCap int
InitConn func() (interface{}, error)
Close func(interface{}) error
Validate func(interface{}) error
Timeout time.Duration
}
Pool struct {
mu sync.Mutex
conns chan *conn
//防止一个连接被重复加入
poolMap map[interface{}]int
initConn func() (interface{}, error)
close func(interface{}) error
validate func(interface{}) error
timeout time.Duration
}
conn struct {
connect interface{}
createTime time.Time
}
)
var (
ErrClosed = errors.New("pool is closed")
ErrInvalidCap = errors.New("invalid cap settings")
ErrInvalidInit = errors.New("invalid init func settings")
ErrInvalidClose = errors.New("invalid close func settings")
ErrInvalidConn = errors.New("invalid conn")
)
const addr string = "127.0.0.1:8230"
func NewPool(config *Config) (*Pool, error) {
if config.InitCap < 0 || config.MaxCap <= 0 || config.InitCap > config.MaxCap {
return nil, ErrInvalidCap
}
if config.InitConn == nil {
return nil, ErrInvalidInit
}
if config.Close == nil {
return nil, ErrInvalidClose
}
pool := &Pool{
conns: make(chan *conn, config.MaxCap),
poolMap: make(map[interface{}]int),
initConn: config.InitConn,
close: config.Close,
timeout: config.Timeout,
}
if config.Validate != nil {
pool.validate = config.Validate
}
for i := 0; i < config.InitCap; i++ {
connect, err := pool.initConn()
pool.poolMap[connect] = 1
if err != nil {
pool.Release()
return nil, ErrInvalidInit
}
pool.conns <- &conn{connect: connect, createTime: time.Now()}
}
return pool, nil
}
func (pool *Pool) Get() (interface{}, error) {
if pool.conns == nil {
return nil, ErrClosed
}
for {
select {
case wrapConn := <-pool.conns:
if wrapConn == nil {
return nil, ErrClosed
}
//判断是否超时,超时则丢弃
pool.mu.Lock()
if timeout := pool.timeout; timeout > 0 {
if wrapConn.createTime.Add(timeout).Before(time.Now()) {
delete(pool.poolMap, wrapConn.connect)
pool.close(wrapConn.connect)
pool.mu.Unlock()
continue
}
}
if pool.validate != nil {
if err := pool.validate(wrapConn.connect); err != nil {
delete(pool.poolMap, wrapConn.connect)
pool.close(wrapConn.connect)
pool.mu.Unlock()
continue
}
}
delete(pool.poolMap, wrapConn.connect)
pool.mu.Unlock()
return wrapConn.connect, nil
default:
pool.mu.Lock()
if pool.initConn == nil {
pool.mu.Unlock()
return nil, ErrInvalidInit
}
connect, err := pool.initConn()
pool.mu.Unlock()
if err != nil {
return nil, err
}
return connect, nil
}
}
}
func (pool *Pool) Put(connect interface{}) error {
if connect == nil {
return ErrInvalidConn
}
if pool.validate != nil {
if err := pool.validate(connect); err != nil {
pool.close(connect)
return ErrInvalidConn
}
}
pool.mu.Lock()
if pool.conns == nil {
pool.mu.Unlock()
err := pool.close(connect)
return err
}
if _, ok := pool.poolMap[connect]; ok {
pool.mu.Unlock()
return nil
}
select {
case pool.conns <- &conn{connect: connect, createTime: time.Now()}:
pool.poolMap[connect] = 1
pool.mu.Unlock()
return nil
default:
pool.mu.Unlock()
//连接池已满,直接关闭该连接
return pool.close(connect)
}
}
// Release 释放连接池中所有连接
func (pool *Pool) Release() {
pool.mu.Lock()
for wrapConn := range pool.conns {
pool.close(wrapConn.connect)
}
close(pool.conns)
pool.conns = nil
pool.initConn = nil
pool.validate = nil
pool.close = nil
pool.poolMap = nil
pool.mu.Unlock()
}
func (pool *Pool) Len() int {
return len(pool.conns)
}
//以下是测试代码
func main() {
go server()
//等待tcp server启动
time.Sleep(2 * time.Second)
client()
fmt.Println("服务退出")
time.Sleep(20 * time.Second)
}
func client() {
initConn := func() (interface{}, error) { return net.Dial("tcp", addr) }
close := func(v interface{}) error { return v.(net.Conn).Close() }
//创建一个连接池: 初始化5,最大连接30
poolConfig := &Config{
InitCap: 5,
MaxCap: 30,
InitConn: initConn,
Close: close,
Timeout: 15 * time.Second,
}
p, err := NewPool(poolConfig)
if err != nil {
fmt.Println("err=", err)
}
//从连接池中取得一个连接
v, _ := p.Get()
_, _ = p.Get()
//do something
conn := v.(net.Conn)
conn.Write([]byte("guoqiang"))
//将连接放回连接池中
p.Put(v)
p.Put(v)
p.Put(v)
//查看当前连接中的数量
current := p.Len()
fmt.Println("len=", current)
}
func server() {
l, err := net.Listen("tcp", addr)
if err != nil {
fmt.Println("Error listening: ", err)
os.Exit(1)
}
defer l.Close()
fmt.Println("Listening on ", addr)
for {
conn, err := l.Accept()
if err != nil {
fmt.Println("Error accepting: ", err)
}
buffer := make([]byte, 20480)
conn.Read(buffer)
fmt.Printf("Received message %s -> %s message: %s\n", conn.RemoteAddr(), conn.LocalAddr(), buffer)
}
}