rpc/session.go
package rpc
import (
"encoding/binary"
"io"
"net"
)
// 编写数据会话中读写
// 会话连接的结构体
type Session struct {
conn net.Conn
}
// 创建新连接
func NewSession(conn net.Conn) *Session {
return &Session{conn: conn}
}
// 向连接中写数据
func (s Session) Write(data []byte) error {
// 4字节头+数据长度切片
buf := make([]byte, 4+len(data))
// 写入头部数据,记录数据长度
// binary 只认固定长度的类型,所以使用了uint32,而不是直接写入
binary.BigEndian.PutUint32(buf[:4], uint32(len(data)))
copy(buf[:4], data)
_, err := s.conn.Write(buf)
if err != nil {
return err
}
return nil
}
// 从连接中读数据
func (s Session) Read() ([]byte, error) {
// 读取头部长度
header := make([]byte, 4)
// 按头部长度, 读取头部数据
_, err := io.ReadFull(s.conn, header)
if err != nil {
return nil, err
}
// 读取数据长度
dataLen := binary.BigEndian.Uint32(header)
// 按照数据长度去读取数据
data := make([]byte, dataLen)
_, err = io.ReadFull(s.conn, data)
if err != nil {
return nil, err
}
return data, nil
}
rpc/session_test.go
package rpc
import (
"fmt"
"net"
"sync"
"testing"
)
func TestSession_ReadWrite(t *testing.T) {
// 定义监听IP和端口
addr := "127.0.0.1:8080"
// 定义传输的数据
my_data := "hello world"
// 等待组
wg := sync.WaitGroup{}
// 协程,一个读,一个写
wg.Add(2)
// 写数据协程
go func() {
defer wg.Done()
// 创建tcp连接
lis, err := net.Listen("tcp", addr)
if err != nil {
t.Fatal(err)
}
conn,_ := lis.Accept()
s := Session{conn: conn}
// 写数据
err = s.Write([]byte(my_data))
if err != nil {
t.Fatal(err)
}
}()
// 读数据协程
go func() {
defer wg.Done()
conn, err := net.Dial("tcp", addr)
if err != nil {
t.Fatal(err)
}
s := Session{conn: conn}
// 读数据
data, err := s.Read()
if err != nil {
t.Fatal(err)
}
if string(data) != my_data {
t.Fatal(err)
}
fmt.Println(string(data))
}()
wg.Wait()
}
rpc/codec.go
package rpc
import (
"bytes"
"encoding/gob"
)
// 定义数据格式和编解码
type RPCData struct {
// 访问的函数
Name string
// 访问时传的参数
Args []interface{}
}
// 编码
func encode(data RPCData) ([]byte, error) {
var buf bytes.Buffer
// 得到字节数组的编码器
bufEnc := gob.NewEncoder(&buf)
// 对数据进行编码
bufEnc.Encode(data)
if err := bufEnc.Encode(data); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// 解码
func decode(b []byte) (RPCData, error) {
buf := bytes.NewBuffer(b)
// 返回字节数组的解码器
bufDec := gob.NewDecoder(buf)
var data RPCData
// 对数据解码
if err := bufDec.Decode(&data); err != nil {
return data, nil
}
return data, nil
}
rpc/server.go
package rpc
import (
"fmt"
"net"
"reflect"
)
// 声明服务端
type Server struct {
// 地址
addr string
// 服务端维护的函数名到函数反射值的map
funcs map[string]reflect.Value
}
// 创建服务端对象
func NewServer(addr string) *Server {
return &Server{addr: addr, funcs:make(map[string]reflect.Value)}
}
// 服务端绑定注册方法
// 将函数名与函数真正实现对应起来
// 第一个参数为函数名, 第二个传入真正的函数
func (s *Server) Register(rpcName string, f interface{}) {
if _, ok := s.funcs[rpcName]; ok {
return
}
// map中没有值,则将映射添加进map,便于调用
fVal := reflect.ValueOf(f)
s.funcs[rpcName] = fVal
}
// 服务端等待调用
func (s *Server) Run() {
// 监听
lis, err := net.Listen("tcp", s.addr)
if err != nil {
fmt.Printf("监听%s err:%v", s.addr, err)
return
}
for {
// 拿到连接
conn, err := lis.Accept()
if err != nil {
fmt.Printf("accept err:%v", err)
return
}
// 创建会话
srvSession := NewSession(conn)
// RPC 读取数据
b, err := srvSession.Read()
if err != nil {
fmt.Printf("read err:%v", err)
return
}
// 对数据解码
rpcData, err := decode(b)
if err != nil {
fmt.Printf("decode err:%v", err)
return
}
// 根据读取到的数据的Name,得到调用的函数名
f, ok := s.funcs[rpcData.Name]
if !ok {
fmt.Printf("函数名%s不存在", rpcData.Name)
}
// 解析遍历客户端出来的参数, 放到一个数组中
inArgs := make([]reflect.Value, 0, len(rpcData.Args))
for _, arg := range rpcData.Args {
inArgs = append(inArgs, reflect.ValueOf(arg))
}
// 反射调用方法,传入参数
out := f.Call(inArgs)
// 解析遍历执行结果,放到一个数组中
outArgs := make([]interface{}, 0, len(out))
for _, o := range out {
outArgs = append(outArgs, o.Interface())
}
// 包装数据返回给客户端
respRPCData := RPCData{rpcData.Name, outArgs}
// 编码
respBytes, err := encode(respRPCData)
if err != nil {
fmt.Printf("encode err: %v", err)
return
}
// 使用RPC写出数据
err = srvSession.Write(respBytes)
if err != nil {
fmt.Printf("session write err:%v", err)
return
}
}
}
rpc/client.go
package rpc
import (
"net"
"reflect"
)
// 声明客户端
type Client struct {
conn net.Conn
}
// 创建客户端对象
func NewClient(conn net.Conn) *Client {
return &Client{conn: conn}
}
// 实现通用的RPC客户端
// 绑定RPC使用的方法
// 传入访问的函数名
// 函数具体实现在Server端, Client只有函数原型
// 使用MakeFunc() 完成原型到函数的调用
// fPtr指向函数原型
func (c *Client) callRPC(rpcName string, fPtr interface{}) {
// 通过反射,获取fPtr未初始化的函数原型
fn := reflect.ValueOf(fPtr).Elem()
// 另一个函数,是对第一个函数参数操作
f := func(args []reflect.Value) []reflect.Value {
// 处理输入的参数
inArgs := make([]interface{}, 0, len(args))
for _, arg := range args{
inArgs = append(inArgs, arg.Interface())
}
// 创建连接
cliSession := NewSession(c.conn)
// 编码数据
reqRPC := RPCData{Name: rpcName, Args: inArgs}
b, err := encode(reqRPC)
if err != nil {
panic(nil)
}
// 写出数据
err = cliSession.Write(b)
if err != nil {
panic(nil)
}
// 读响应数据
respBytes, err := cliSession.Read()
if err != nil {
panic(err)
}
// 解码数据
respRPC, err := decode(respBytes)
if err != nil {
panic(err)
}
// 处理服务端返回的数据
outArgs := make([]reflect.Value, 0, len(respRPC.Args))
for i, arg := range respRPC.Args {
// 必须进行nil转换
if arg == nil {
// 必须填充一个真正的类型,不能是nil
outArgs = append(outArgs, reflect.Zero(fn.Type().Out(i)))
continue
}
}
return outArgs
}
v := reflect.MakeFunc(fn.Type(), f)
// 为函数fPtr赋值
fn.Set(v)
}
rpc/simple_tpc_test.go
package rpc
import (
"encoding/gob"
"fmt"
"net"
"testing"
)
// 用户查询
// 用于测试的结构体
type User struct {
Name string
Age int
}
// 用于测试查询用户的方法
func queryUser(uid int) (User, error) {
user := make(map[int]User)
user[0] = User{"zs", 20}
user[1] = User{"ls", 21}
user[2] = User{"ww", 22}
// 模拟查询用户
if u, ok := user[uid]; ok {
return u, nil
}
return User{}, fmt.Errorf("id %d not in user db", uid)
}
func TestRPC(t *testing.T) {
// 需要对interface可能产生的类型进行注册
gob.Register(User{})
addr := "127.0.0.1:8080"
// 创建服务端
srv := NewServer(addr)
// 将方法注册到服务端
srv.Register("queryUser", queryUser)
// 服务端等待调用
go srv.Run()
// 客户端获取连接
conn , err := net.Dial("tcp", addr)
if err != nil {
t.Error(err)
}
// 创建客户端
cli := NewClient(conn)
// 声明函数原型
var query func(int) (User error)
cli.callRPC("queryUser", &query)
// 得到查询结果
u, err := query(1)
if err != nil {
t.Fatal(err)
}
fmt.Println(u)
}