酝酿情绪中······

package main

import (
    "io"
    "log"
    "net"
    "fmt"
    "bufio"
    "bytes"
    "strconv"
    "strings"
    "crypto/sha1"
    "encoding/json"
    "encoding/base64"
    "encoding/binary"
)

type WebSocket struct {
    Listener    net.Listener
    Clients     []*Client
}

type Client struct {
    Conn        net.Conn
    Nickname    string
    Shook       bool
    Server      *WebSocket
    Id          int
}

type Msg struct {
    Data        string
    Num         int
}


func (self *Client) Handle() {
    if !self.Handshake() {
        return
    }
    self.Read()
}

func (self *Client) Read() {
    var (
        buf     []byte
        err     error
        fin     byte
        opcode  byte
        mask    byte
        mKey    []byte
        length  uint64
        l       uint16
        payload byte
        msg     *Msg
    )
    for {
        buf = make([]byte, 2)
        _, err = io.ReadFull(self.Conn, buf)
        if err != nil {
            self.Close()
            break
        }
        fin = buf[0] >> 7
        if fin == 0 {

        }
        opcode = buf[0] & 0xf
        if opcode == 8 {
            log.Print("Connection closed")
            self.Close()
            break
        }
        mask = buf[1] >> 7
        payload = buf[1] & 0x7f

        switch {
            case payload < 126:
                length = uint64(payload)

            case payload == 126:
                buf = make([]byte, 2)
                io.ReadFull(self.Conn, buf)
                binary.Read(bytes.NewReader(buf), binary.BigEndian, &l)
                length = uint64(l)

            case payload == 127:
                buf = make([]byte, 8)
                io.ReadFull(self.Conn, buf)
                binary.Read(bytes.NewReader(buf), binary.BigEndian, &length)
        }
        if mask == 1 {
            mKey = make([]byte, 4)
            io.ReadFull(self.Conn, mKey)
        }
        fmt.Printf("fin: %d, opcode: %d, mask: %d, length: %d\n", fin, opcode, mask, length)
        buf = make([]byte, length)
        io.ReadFull(self.Conn, buf)
        if mask == 1 {
            for i, v := range buf {
                buf[i] = v ^ mKey[i % 4]
            }
        }
        if self.Nickname == "" {
            self.Nickname = string(buf)
            msg = &Msg{
                self.Nickname + ",加入",
                len(self.Server.Clients),
            }
        } else {
            msg = &Msg{
                string(buf),
                len(self.Server.Clients),
            }
        }
        buf, err = json.Marshal(msg)
        if err != nil {
            log.Fatal(err)
        }
        self.WriteAll(buf)
    }
    self.Conn.Close()
}

func (self *Client) WriteAll(data []byte) {
    for _, client := range self.Server.Clients {
        client.Write(data)
    }
}

func (self *Client) Close() {
    for i, client := range self.Server.Clients {
        if self == client {
            msg := &Msg{
                self.Nickname + ",离开",
                len(self.Server.Clients) - 1,
            }
            buf, err := json.Marshal(msg)
            if err != nil {
                log.Fatal(err)
            }
            self.Server.Clients = append(self.Server.Clients[:i], self.Server.Clients[i+1:]...)
            self.WriteAll(buf)
            break
        }
    }
}

func (self *Client) Write(data []byte) bool {
    length := len(data)
    frame := []byte{129}
    switch {
        case length < 126:
            frame = append(frame, byte(length))
        case length <= 0xffff:
            buf := make([]byte, 2)
            binary.BigEndian.PutUint16(buf, uint16(length))
            frame = append(frame, byte(126))
            frame = append(frame, buf...)
        case uint64(length) <= 0xffffffffffffffff:
            buf := make([]byte, 8)
            binary.BigEndian.PutUint64(buf, uint64(length))
            frame = append(frame, byte(127))
            frame = append(frame, buf...)
        default:
            log.Fatal("Data too large")
            return false
    }
    frame = append(frame, data...)
    self.Conn.Write(frame)
    return true
}

func (self *Client) Handshake() bool {
    if self.Shook {
        return true
    }
    reader := bufio.NewReader(self.Conn)
    key := ""
    str := ""
    for {
        line, _, err := reader.ReadLine()
        if err != nil {
            log.Fatal(err)
            return false
        }
        if len(line) == 0 {
            break
        }
        str = string(line)
        if strings.HasPrefix(str, "Sec-WebSocket-Key") {
            key = str[19:43]
        }
    }
    sha := sha1.New()
    io.WriteString(sha, key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
    key = base64.StdEncoding.EncodeToString(sha.Sum(nil))

    header := "HTTP/1.1 101 Switching Protocols\r\n" +
        "Connection: Upgrade\r\n" +
        "Sec-WebSocket-Version: 13\r\n" +
        "Sec-WebSocket-Accept: " + key + "\r\n" +
        "Upgrade: websocket\r\n\r\n"
    self.Conn.Write([]byte(header))
    self.Shook = true
    self.Server.Clients = append(self.Server.Clients, self)
    return true
}

func NewWebSocket(addr string) *WebSocket {
    l, err := net.Listen("tcp", addr)
    if err != nil {
        log.Fatal(err)
    }
    return &WebSocket{l, make([]*Client, 0)}
}


func (self *WebSocket) Loop() {
    for {
        conn, err := self.Listener.Accept()
        if err != nil {
            log.Fatal(err)
        }
        s := conn.RemoteAddr().String()
        i, _ := strconv.Atoi(strings.Split(s, ":")[1])
        client := &Client{conn, "", false, self, i}
        go client.Handle()
    }
}

func main() {
    ws := NewWebSocket("192.168.1.2:1993")
    ws.Loop()
}

  • 相关链接: