一、背景
1、客户两个部门,A部门在平台上训练模型,发布预测服务;B部门为新成立业务部门,需要去调用A的预测服务。B部门对发出的请求、返回都有明确的接口规范。A的预测服务,不同模型请求、返回参数不同,没有确定的格式,不会根据B要求做修改。
二、方案
考虑在用户请求入口加一个proxy服务,功能:
①作为B的访问入口
②解析B发出的请求,筛选平台需要的字段
③将B的请求代理到后方预测服务上
④接收预测服务response,加上B需要字段
三、实现
使用ReverseProxy做代理,用ModifyResponse去修改resp。
省略部分代码,结构体仅作例子用
package apimodel
type SvcRequest struct {
UserTag string `json:"userTag"`//B用户要求的response必带字段,要在request里带着
Data map[string]interface{} `json:"data"` //B发出的请求,因为不同模型请求json不同,只能用map[string]interface{}
}
type AiProxyRequest struct {
Data map[string]interface{} `json:"data"` //实际发送到预测服务的请求,剥离用户自定义字段,userTag
}
type SvcResponse struct {
UserTag string `json:"userTag"` //B用户要求的response必带字段
Result map[string]interface{} `json:"result"` //预测服务返回的结果,因为不同预测输出的json不一样,目前在外面套了一层result
}
const CtxSvcRequest = "SvcRequest" //context中的k-v的key名
main.go
func main() {
flag.InitFlags()
err := config.GetConfig(flag.ConfigFile)
if err != nil {
log.Logger.Error("Get config failed,err msg [%s]", err)
exitError()
}
router := gin.New()
handler.RegisterRoutes(router)
log.Logger.Info("start server at %s", config.ServerConfig.InferenceConfig.ListenPort)
if err := router.Run(config.ServerConfig.InferenceConfig.ListenPort); err != nil {
log.Logger.Error("Start server failed,err:%s", err)
exitError()
}
}
handler.go
func RegisterRoutes(router *gin.Engine) {
//所有请求handler指定为ProxyRequest
router.Any("/*action", ProxyRequest)
}
func ProxyRequest(c *gin.Context) {
var svcRequest apimodel.SvcRequest
var err error
operator := service.GetOperator()
//这里因为svcRequest里用的map[string]interface{},不能用shouldBindJson,只能先把body读出来Unmarshal
//但是body是io.ReadCloser,需要rebuild一下,再set回去
buf, err := ioutil.ReadAll(c.Request.Body)
if err != nil {
AbortWithResponse(c, err)
return
}
c.Request.Body = ioutil.NopCloser(bytes.NewReader(buf))
json.Unmarshal(buf, &svcRequest)
log.Logger.Info("ProxyRequest %v", svcRequest)
err = operator.GenRequest(c.Request, svcRequest)
if err != nil {
AbortWithResponse(c, err)
return
}
proxy, err := operator.GetProxy()
if err != nil {
AbortWithResponse(c, err)
return
}
//这里将svcRequest写到了c.Request.Context中,注意Context不能直接编辑,需要用WithContext生成新的request
c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), apimodel.CtxSvcRequest, svcRequest))
proxy.ServeHTTP(c.Writer, c.Request)
}
func (p ProxyOperator) GetProxy() (httputil.ReverseProxy, error) {
var ServerProxy httputil.ReverseProxy
//转发请求配置
ServerProxy = httputil.ReverseProxy{
Director: func(req *http.Request) {
req.URL.Scheme = config.ServerConfig.SvcConfig.ProxyMethodScheme
req.URL.Host = config.ServerConfig.AiProxyConfig.ServerIp + ":" +
config.ServerConfig.AiProxyConfig.ServerPort
req.Host = config.ServerConfig.AiProxyConfig.ServerIp + ":" +
config.ServerConfig.AiProxyConfig.ServerPort
},
//配置rewriteBody方法,对response进行编辑
ModifyResponse: rewriteBody,
}
return ServerProxy, nil
}
func (p ProxyOperator) GenRequest(req *http.Request, svcRequest apimodel.SvcRequest) error {
log.Logger.Info("Marshal svcRequest.Data ")
//只用到data字段
content, err := json.Marshal(svcRequest.Data)
if err != nil {
log.Logger.Error("Marshal svcRequest.Data failed,err:%s", err)
return err
}
req.ContentLength = int64(len(content))
req.Body = ioutil.NopCloser(bytes.NewReader(content))
return nil
}
func rewriteBody(resp *http.Response) error {
var buffer bytes.Buffer
var svcResp apimodel.SvcResponse
var inferneceReq apimodel.SvcRequest
respBody, err := ioutil.ReadAll(resp.Body) //Read html
if err != nil {
log.Logger.Error("Read resp.Body failed,err:%s", err)
return err
}
err = resp.Body.Close()
if err != nil {
log.Logger.Error("Close resp.Body failed,err:%s", err)
return err
}
//拼接result
prefix := []byte("{\"result\":")
tail := []byte("}")
buffer.Write(prefix)
buffer.Write(respBody)
buffer.Write(tail)
if err := json.Unmarshal(buffer.Bytes(), &svcResp); err != nil {
log.Logger.Error("Unmarshal svcResp failed,err:%s", err)
return err
}
//从ctx中获取前面set进去的request,注意resp.Request.Body已经为空了,借助ctx传值
inferneceReq = resp.Request.Context().Value(apimodel.CtxSvcRequest).(apimodel.SvcRequest)
svcResp.UserTag = inferneceReq.UserTag
content, err := json.Marshal(svcResp)
if err != nil {
log.Logger.Error("Marshal svcResp failed,err:%s", err)
return err
}
resp.Body = ioutil.NopCloser(bytes.NewReader(content))
resp.ContentLength = int64(len(content))
resp.Header.Set("Content-Length", strconv.Itoa(len(content)))
return nil
}