golang 实现 websocket

package websocket

import (
	"bytes"
	"github.com/gorilla/websocket"
	"log"
	"net/http"
	"time"
)

var (
	newline = []byte{'\n'}
	space   = []byte{' '}

	// Time allowed to write a message to the peer.
	writeWait = 10 * time.Second

	// Time allowed to read the next pong message from the peer.
	pongWait = 60 * time.Second

	// Send pings to peer with this period. Must be less than pongWait.
	pingPeriod = (pongWait * 9) / 10

	// Maximum message size allowed from peer.
	maxMessageSize = 512
)

var upgrader = websocket.Upgrader{
	ReadBufferSize:  1024,
	WriteBufferSize: 1024,
}

//定义-》 只有一个组里面的 client 才能互相聊天,
type WsClientGroup struct {
	broadcast   chan []byte //用于广播数据【广播协程监听这个通道】
	clientEnter chan *WsClient
	clientExit  chan *WsClient
	clients     map[*WsClient]bool
}

func (g *WsClientGroup) HandleRun() {
	//eventLoop:
	log.Printf("handle run")
	for {
		select {
		case client := <-g.clientEnter:
			//注册用户
			g.clients[client] = true
		case client := <-g.clientExit:
			//用户退出
			if _, ok := g.clients[client]; ok {
				delete(g.clients, client)
				close(client.send)
			}
		case broadcastMsg := <-g.broadcast:
			log.Printf("broadcastMsg, %s",broadcastMsg)
			for cli := range g.clients {
				select {
				case cli.send <- broadcastMsg:
				default:
					//这种情况下,只能说明 cli.send == nil
					close(cli.send)
					delete(g.clients, cli)

				}
			}
		}
	}
}

func NewClientGroup() *WsClientGroup {
	return &WsClientGroup{
		broadcast:   make(chan []byte),
		clientEnter: make(chan *WsClient),
		clientExit:  make(chan *WsClient),
		clients:     make(map[*WsClient]bool),
	}
}

type WsClient struct {
	send   chan []byte
	conn   *websocket.Conn
	Groups *WsClientGroup
}

//开启读协程
func (cli *WsClient) ReadLoopGroup() {
	defer func() {
		cli.Groups.clientExit <- cli
		// exit and close connection
		cli.conn.Close()
	}()
	cli.conn.SetReadLimit(int64(maxMessageSize))
	cli.conn.SetReadDeadline(time.Now().Add(pongWait))
	cli.conn.SetPongHandler(func(s string) error {
		//续期
		cli.conn.SetReadDeadline(time.Now().Add(pongWait))
		return nil
	})
	log.Printf("begin read loop ")
readEventLoop:
	for {
		_, msg, err := cli.conn.ReadMessage()
		log.Printf("begin read msg")
		if err != nil {
			if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
				log.Printf("websocket client error: %+v", err)
			}

			break readEventLoop
		}
		msg = bytes.TrimSpace(msg)
		log.Printf("receive client msg = [%s]", msg)
		// 这里可以广播给其他用户, 前端可以传个 type,后端 根据 type 判断是广播还是私聊
		//假设这里用广播
		cli.Groups.broadcast <- msg
	}

}
func (cli *WsClient) WriteLoopGroup() {
	ticker := time.NewTicker(pingPeriod)
	defer func() {
		log.Printf("exit writeLoopGroup")
		ticker.Stop()
		cli.conn.Close()
	}()
writeEventLoop:
	for {

		select {
		case msg, ok := <-cli.send:
			if !ok {
				cli.conn.WriteMessage(websocket.CloseMessage, []byte{})
				return
			}
			w, err := cli.conn.NextWriter(websocket.TextMessage)
			if err != nil {
				break writeEventLoop
			}
			w.Write(msg)
			n := len(cli.send)
			for i := 0; i < n; i++ {
				//继续发送
				w.Write(newline)
				w.Write(<-cli.send)
			}
			if err = w.Close(); err != nil {
				break writeEventLoop
			}
		case <-ticker.C:
			cli.conn.SetWriteDeadline(time.Now().Add(writeWait))
			if err := cli.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
				//write errror
				//超时 异常 ,主动退出
				break writeEventLoop
			}

		}

	}
}

func Register(groups *WsClientGroup, w http.ResponseWriter, r *http.Request) {
	conn, err := upgrader.Upgrade(w, r, nil)
	if err != nil {
		log.Printf("error info %+v", err)
		return
	}

	client := &WsClient{
		Groups: groups,
		send:   make(chan []byte, 256),
		conn:   conn,
	}
	groups.clientEnter <- client
	log.Printf("enter client ")
	//groups.HandleRun()
	go client.ReadLoopGroup()
	go client.WriteLoopGroup()

}


main.go 测试代码

package main

import (
	"log"
	"net/http"
	"websocket/websocket"
)

func serverHome(w http.ResponseWriter, r *http.Request) {

	log.Println(r.URL)
	if r.URL.Path != "/" {
		http.Error(w, "Not found", http.StatusNotFound)
		return
	}
	if r.Method != http.MethodGet {
		http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
		return
	}
	http.ServeFile(w, r, "index.html")
}

func main() {

	http.HandleFunc("/", serverHome)
	var groups = websocket.NewClientGroup()
	go groups.HandleRun()
	http.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) {
		//这里要区分不同的聊天室的话, 可能还需要加一个 hashMap, hashMap.put("chat_room_id",groups)
		websocket.Register(groups, w, r)
	})
	http.ListenAndServe(":8080", nil)
}


index.html

<!DOCTYPE html>
<html lang="en">
<head>
    <title>Chat Example</title>
    <script type="text/javascript">
        window.onload = function () {
            var conn;
            var msg = document.getElementById("msg");
            var log = document.getElementById("log");

            function appendLog(item) {
                var doScroll = log.scrollTop > log.scrollHeight - log.clientHeight - 1;
                log.appendChild(item);
                if (doScroll) {
                    log.scrollTop = log.scrollHeight - log.clientHeight;
                }
            }

            document.getElementById("form").onsubmit = function () {
                if (!conn) {
                    return false;
                }
                if (!msg.value) {
                    return false;
                }
                conn.send(msg.value);
                msg.value = "";
                return false;
            };

            if (window["WebSocket"]) {
                conn = new WebSocket("ws://" + document.location.host + "/ws");
                conn.onclose = function (evt) {
                    var item = document.createElement("div");
                    item.innerHTML = "<b>Connection closed.</b>";
                    appendLog(item);
                };
                conn.onmessage = function (evt) {
                    var messages = evt.data.split('\n');
                    for (var i = 0; i < messages.length; i++) {
                        var item = document.createElement("div");
                        item.innerText = messages[i];
                        appendLog(item);
                    }
                };
            } else {
                var item = document.createElement("div");
                item.innerHTML = "<b>Your browser does not support WebSockets.</b>";
                appendLog(item);
            }
        };
    </script>
    <style type="text/css">
        html {
            overflow: hidden;
        }

        body {
            overflow: hidden;
            padding: 0;
            margin: 0;
            width: 100%;
            height: 100%;
            background: gray;
        }

        #log {
            background: white;
            margin: 0;
            padding: 0.5em 0.5em 0.5em 0.5em;
            position: absolute;
            top: 0.5em;
            left: 0.5em;
            right: 0.5em;
            bottom: 3em;
            overflow: auto;
        }

        #form {
            padding: 0 0.5em 0 0.5em;
            margin: 0;
            position: absolute;
            bottom: 1em;
            left: 0px;
            width: 100%;
            overflow: hidden;
        }

    </style>
</head>
<body>
<div id="log"></div>
<form id="form">
    <input type="submit" value="Send" />
    <input type="text" id="msg" size="64" autofocus />
</form>
</body>
</html>