rpc/session.go

package rpc

import (
	"encoding/binary"
	"io"
	"net"
)

// 编写数据会话中读写

// 会话连接的结构体
type Session struct {
	conn net.Conn
}
// 创建新连接
func NewSession(conn net.Conn) *Session {
	return &Session{conn: conn}
}
// 向连接中写数据
func (s Session) Write(data []byte) error {
	// 4字节头+数据长度切片
	buf := make([]byte, 4+len(data))
	// 写入头部数据,记录数据长度
	// binary 只认固定长度的类型,所以使用了uint32,而不是直接写入
	binary.BigEndian.PutUint32(buf[:4], uint32(len(data)))
	copy(buf[:4], data)
	_, err := s.conn.Write(buf)
	if err != nil {
		return err
	}
	return nil
}

// 从连接中读数据
func (s Session) Read() ([]byte, error) {
	// 读取头部长度
	header := make([]byte, 4)
	// 按头部长度, 读取头部数据
	_, err := io.ReadFull(s.conn, header)
	if err != nil {
		return nil, err
	}
	// 读取数据长度
	dataLen := binary.BigEndian.Uint32(header)
	// 按照数据长度去读取数据
	data := make([]byte, dataLen)
	_, err = io.ReadFull(s.conn, data)
	if err != nil {
		return nil, err
	}
	return data, nil
}

rpc/session_test.go

package rpc

import (
	"fmt"
	"net"
	"sync"
	"testing"
)

func TestSession_ReadWrite(t *testing.T) {
	// 定义监听IP和端口
	addr := "127.0.0.1:8080"
	// 定义传输的数据
	my_data := "hello world"
	// 等待组
	wg := sync.WaitGroup{}
	// 协程,一个读,一个写
	wg.Add(2)
	// 写数据协程
	go func() {
		defer wg.Done()
		// 创建tcp连接
		lis, err := net.Listen("tcp", addr)
		if err != nil {
			t.Fatal(err)
		}
		conn,_ := lis.Accept()
		s := Session{conn: conn}
		// 写数据
		err = s.Write([]byte(my_data))
		if err != nil {
			t.Fatal(err)
		}
	}()
	// 读数据协程
	go func() {
		defer wg.Done()
		conn, err := net.Dial("tcp", addr)
		if err != nil {
			t.Fatal(err)
		}
		s := Session{conn: conn}
		// 读数据
		data, err := s.Read()
		if err != nil {
			t.Fatal(err)
		}
		if string(data) != my_data {
			t.Fatal(err)
		}
		fmt.Println(string(data))
	}()
	wg.Wait()
}

 

rpc/codec.go

package rpc

import (
	"bytes"
	"encoding/gob"
)

// 定义数据格式和编解码
type RPCData struct {
	// 访问的函数
	Name string
	// 访问时传的参数
	Args []interface{}
}

// 编码
func encode(data RPCData) ([]byte, error) {
	var buf bytes.Buffer
	// 得到字节数组的编码器
	bufEnc := gob.NewEncoder(&buf)
	// 对数据进行编码
	bufEnc.Encode(data)
	if err := bufEnc.Encode(data); err != nil {
		return nil, err
	}
	return buf.Bytes(), nil
}

// 解码
func decode(b []byte) (RPCData, error) {
	buf := bytes.NewBuffer(b)
	// 返回字节数组的解码器
	bufDec := gob.NewDecoder(buf)
	var data RPCData
	// 对数据解码
	if err := bufDec.Decode(&data); err != nil {
		return data, nil
	}
	return data, nil
}

rpc/server.go

package rpc

import (
	"fmt"
	"net"
	"reflect"
)

// 声明服务端
type Server struct {
	// 地址
	addr string
	// 服务端维护的函数名到函数反射值的map
	funcs map[string]reflect.Value
}
// 创建服务端对象
func NewServer(addr string) *Server {
	return &Server{addr: addr, funcs:make(map[string]reflect.Value)}
}

// 服务端绑定注册方法
// 将函数名与函数真正实现对应起来
// 第一个参数为函数名, 第二个传入真正的函数
func (s *Server) Register(rpcName string, f interface{}) {
	if _, ok := s.funcs[rpcName]; ok {
		return
	}
	// map中没有值,则将映射添加进map,便于调用
	fVal := reflect.ValueOf(f)
	s.funcs[rpcName] = fVal
}

// 服务端等待调用
func (s *Server) Run() {
	// 监听
	lis, err := net.Listen("tcp", s.addr)
	if err != nil {
		fmt.Printf("监听%s err:%v", s.addr, err)
		return
	}
	for {
		// 拿到连接
		conn, err := lis.Accept()
		if err != nil {
			fmt.Printf("accept err:%v", err)
			return
		}
		// 创建会话
		srvSession := NewSession(conn)
		// RPC 读取数据
		b, err := srvSession.Read()
		if err != nil {
			fmt.Printf("read err:%v", err)
			return
		}
		// 对数据解码
		rpcData, err := decode(b)
		if err != nil {
			fmt.Printf("decode err:%v", err)
			return
		}
		// 根据读取到的数据的Name,得到调用的函数名
		f, ok := s.funcs[rpcData.Name]
		if !ok {
			fmt.Printf("函数名%s不存在", rpcData.Name)
		}
		// 解析遍历客户端出来的参数, 放到一个数组中
		inArgs := make([]reflect.Value, 0, len(rpcData.Args))
		for _, arg := range rpcData.Args {
			inArgs = append(inArgs, reflect.ValueOf(arg))
		}
		// 反射调用方法,传入参数
		out := f.Call(inArgs)
		// 解析遍历执行结果,放到一个数组中
		outArgs := make([]interface{}, 0, len(out))
		for _, o := range out {
			outArgs = append(outArgs, o.Interface())
		}
		// 包装数据返回给客户端
		respRPCData := RPCData{rpcData.Name, outArgs}
		// 编码
		respBytes, err := encode(respRPCData)
		if err != nil {
			fmt.Printf("encode err: %v", err)
			return
		}
		// 使用RPC写出数据
		err = srvSession.Write(respBytes)
		if err != nil {
			fmt.Printf("session write err:%v", err)
			return
		}
	}
}

rpc/client.go

package rpc

import (
	"net"
	"reflect"
)

// 声明客户端
type Client struct {
	conn net.Conn
}
// 创建客户端对象
func NewClient(conn net.Conn) *Client {
	return &Client{conn: conn}
}
// 实现通用的RPC客户端
// 绑定RPC使用的方法
// 传入访问的函数名

// 函数具体实现在Server端, Client只有函数原型
// 使用MakeFunc() 完成原型到函数的调用
// fPtr指向函数原型
func (c *Client) callRPC(rpcName string, fPtr interface{}) {
	// 通过反射,获取fPtr未初始化的函数原型
	fn := reflect.ValueOf(fPtr).Elem()
	// 另一个函数,是对第一个函数参数操作
	f := func(args []reflect.Value) []reflect.Value {
		// 处理输入的参数
		inArgs := make([]interface{}, 0, len(args))
		for _, arg := range args{
			inArgs = append(inArgs, arg.Interface())
		}
		// 创建连接
		cliSession := NewSession(c.conn)
		// 编码数据
		reqRPC := RPCData{Name: rpcName, Args: inArgs}
		b, err := encode(reqRPC)
		if err != nil {
			panic(nil)
		}
		// 写出数据
		err = cliSession.Write(b)
		if err != nil {
			panic(nil)
		}
		// 读响应数据
		respBytes, err := cliSession.Read()
		if err != nil {
			panic(err)
		}
		// 解码数据
		respRPC, err := decode(respBytes)
		if err != nil {
			panic(err)
		}
		// 处理服务端返回的数据
		outArgs := make([]reflect.Value, 0, len(respRPC.Args))
		for i, arg := range respRPC.Args {
			// 必须进行nil转换
			if arg == nil {
				// 必须填充一个真正的类型,不能是nil
				outArgs = append(outArgs, reflect.Zero(fn.Type().Out(i)))
				continue
			}
		}
		return outArgs
	}

	v := reflect.MakeFunc(fn.Type(), f)
	// 为函数fPtr赋值
	fn.Set(v)
}

rpc/simple_tpc_test.go

package rpc

import (
	"encoding/gob"
	"fmt"
	"net"
	"testing"
)

// 用户查询
// 用于测试的结构体
type User struct {
	Name string
	Age int
}

// 用于测试查询用户的方法
func queryUser(uid int) (User, error) {
	user := make(map[int]User)
	user[0] = User{"zs", 20}
	user[1] = User{"ls", 21}
	user[2] = User{"ww", 22}
	// 模拟查询用户
	if u, ok := user[uid]; ok {
		return u, nil
	}
	return User{}, fmt.Errorf("id %d not in user db", uid)
}

func TestRPC(t *testing.T) {
	// 需要对interface可能产生的类型进行注册
	gob.Register(User{})
	addr := "127.0.0.1:8080"
	// 创建服务端
	srv := NewServer(addr)
	// 将方法注册到服务端
	srv.Register("queryUser", queryUser)
	// 服务端等待调用
	go srv.Run()
	// 客户端获取连接
	conn , err := net.Dial("tcp", addr)
	if err != nil {
		t.Error(err)
	}
	// 创建客户端
	cli := NewClient(conn)
	// 声明函数原型
	var query func(int) (User error)
	cli.callRPC("queryUser", &query)
	// 得到查询结果
	u, err := query(1)
	if err != nil {
		t.Fatal(err)
	}
	fmt.Println(u)
}