package middlewares
import (
"bytes"
"encoding/base64"
"encoding/json"
"io/ioutil"
"math/rand"
"net/http"
"public/encrypt"
"github.com/gin-gonic/gin"
)
type EncryptParam struct {
Key string `json:"key" form:"key"`
EncryptedData string `json:"encrypted_data" form:"encrypted_data"`
}
type EncryptResponseWriter struct {
gin.ResponseWriter
Buff *bytes.Buffer
}
func (e *EncryptResponseWriter) Write(p []byte) (int, error) {
return e.Buff.Write(p)
//return e.ResponseWriter.Write(p) // 不再写底层的这个write
}
func Encrypt() gin.HandlerFunc {
return func(c *gin.Context) {
encryptType := c.Request.Header.Get("bb-encrypt")
version := c.Request.Header.Get("bb-encrypt-ver")
if encryptType == "" || encryptType == "none" {
return
}
encryptWriter := &EncryptResponseWriter{c.Writer, bytes.NewBuffer(make([]byte, 0))}
c.Writer = encryptWriter
// 解密请求
if encryptType == "request" || encryptType == "all" {
param := EncryptParam{}
if err := c.Bind(¶m); err != nil {
c.AbortWithStatus(http.StatusBadRequest)
common.Log.Errorf("Bind EncryptParam err: %s", err)
return
}
if param.Key == "" || param.EncryptedData == "" {
c.AbortWithStatus(http.StatusBadRequest)
common.Log.Error("EncryptedData is empty")
return
}
cert, err := Dao.GetCert(version, "server") // 此处是从数据库读取certs, 也可以本地读取文件
if err != nil {
c.AbortWithStatus(http.StatusBadRequest)
common.Log.Errorf("GetCert err: %s", err)
return
}
key, err := RsaDecryptData(cert.PrivateKey, param.Key)
if err != nil {
c.AbortWithStatus(http.StatusBadRequest)
common.Log.Errorf("RsaDecryptData err: %s", err)
return
}
data, err := AesDecryptData(key, param.EncryptedData)
if err != nil {
c.AbortWithStatus(http.StatusBadRequest)
common.Log.Errorf("AesDecryptData err: %s", err)
return
}
if c.Request.Method == http.MethodGet {
c.Request.URL.RawQuery = data
} else {
c.Request.Body = ioutil.NopCloser(bytes.NewBuffer([]byte(data)))
}
common.Log.Infof("%v-middlewares-decrypt raw: %v", c.Request.URL.Path, data)
}
c.Next()
normalReturn := func() {
if _, err := encryptWriter.ResponseWriter.Write(encryptWriter.Buff.Bytes()); err != nil {
common.Log.Error(err.Error())
}
}
if encryptWriter.Status() != http.StatusOK { // 不成功, 直接返回
normalReturn()
return
}
encryptWriter.Header().Set("bb-encrypted", version)
encryptWriter.Header().Set("bb-encrypt-ver", "0")
// 加密返回
if encryptType == "response" || encryptType == "all" {
randomKey := RandStringRunes(16)
cert, err := Dao.GetCert(version, "client") // 此处是从数据库读取certs, 也可以本地读取文件
if err != nil {
common.Log.Errorf("GetCert err: %s", err)
return
}
key, err := RsaEncryptData(cert.PublicKey, randomKey)
if err != nil {
common.Log.Errorf("RsaEncryptData err: %s", err)
return
}
encryptedData, err := AesEncryptData(randomKey, encryptWriter.Buff.String())
if err != nil {
common.Log.Errorf("AesEncryptData err: %s", err)
return
}
data, err := json.Marshal(EncryptParam{Key: key, EncryptedData: encryptedData})
if err != nil {
common.Log.Error(err.Error())
} else {
common.Log.Infof("%v-middlewares-encrypt raw: %v", c.Request.URL.Path, encryptWriter.Buff.String())
encryptWriter.Header().Set("bb-encrypt-ver", "1")
if _, err := encryptWriter.ResponseWriter.Write(data); err != nil {
common.Log.Error(err.Error())
}
}
} else {
normalReturn()
return
}
}
}
// RSA加密
func RsaEncryptData(publicKey, data string) (res string, err error) {
pk, err := encrypt.BytesToPublicKey([]byte(publicKey))
if err != nil {
return
}
eData, err := encrypt.EncryptWithPublicKey([]byte(data), pk)
if err != nil {
return
}
res = base64.StdEncoding.EncodeToString(eData)
return
}
// RSA解密
func RsaDecryptData(privateKey, data string) (res string, err error) {
eData, err := base64.StdEncoding.DecodeString(data)
if err != nil {
return
}
pk, err := encrypt.BytesToPrivateKey([]byte(privateKey))
if err != nil {
return
}
context, err := encrypt.DecryptWithPrivateKey(eData, pk)
if err != nil {
return
}
res = string(context)
return
}
// AES加密
func AesEncryptData(key, data string) (res string, err error) {
eData := encrypt.AesEncryptECB([]byte(data), []byte(key))
res = base64.URLEncoding.EncodeToString(eData)
return
}
// AES解密
func AesDecryptData(key, data string) (res string, err error) {
eData, err := base64.URLEncoding.DecodeString(data)
if err != nil {
return
}
context := encrypt.AesDecryptECB(eData, []byte(key))
res = string(context)
return
}
var letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
// 随机字符串
func RandStringRunes(n int) string {
b := make([]rune, n)
for i := range b {
b[i] = letterRunes[rand.Intn(len(letterRunes))]
}
return string(b)
}