当我深入的学习和了解了 GORM,XORM 后,我还是觉得它们不够简洁和优雅,有些笨重,有很大的学习成本。本着学习和探索的目的,于是我自己实现了一个简单且优雅的 go 语言版本的 ORM。

如上面的导语所示,GORM 算是 Golang 里面 ORM 库的头牌,它功能虽然很强大,但是我觉得它有很深的学习成本,对于新人而言,纯使用没有啥问题,但是当遇到一些复杂的查询的时候,就会捉襟见肘了,因为它的内部实现太复杂了,以至于你很难摸透它。于是,本着一边学习一边探索的目的,我从基础原理开始讲起,到一步一步实现,继而完成整个简单且优雅的 MySQL ORM。

一、前置学习

1. 为什么要用 ORM

for next
ORMORM
提供更加方便快捷的curd方法去和数据库产生交互

2. Golang 里面是如何原生连接 MySQL 的

说完了啥是 ORM,以及为啥用 ORM 之后,我们再看下 Golang 里面是如何原生连接 MySQL 的,这对于我们开发一个 ORM 帮助很大,只有弄清楚了它们之间交互的原理,我们才能更好的开始造。

原生代码连接 MySQL,一般是如下步骤。

首先是导入 sql 引擎和 mysql 的驱动:

import (
"database/sql"
_ "github.com/go-sql-driver/mysql"
)

连接 MySQL :

db, err := sql.Open("mysql", "root:123456@tcp(127.0.0.1:3306)/ApiDB?charset=utf8") //第一个参数数驱动名
if err != nil {
    panic(err.Error())
}

然后,我们快速过一下,如何增删改查:

增:

//方式一:
result, err := db.Exec("INSERT INTO userinfo (username, departname, created) VALUES (?, ?, ?)","lisi","dev","2020-08-04")

//方式二:
stmt, err := db.Prepare("INSERT INTO userinfo (username, departname, created) VALUES (?, ?, ?)")

result2, err := stmt.Exec("zhangsan", "pro", time.Now().Format("2006-01-02"))

删:

//方式一:
result, err := db.Exec("delete from userinfo where uid=?", 10795)

//方式二:
stmt, err := db.Prepare("delete from userinfo where uid=?")

result3, err := stmt.Exec("10795")

改:

//方式一:
result, err := db.Exec("update userinfo set username=? where uid=?", "lisi", 2)

//方式二:
stmt, err := db.Prepare("update userinfo set username=? where uid=?")

result, err := stmt.Exec("lisi", 2)

查:

//单条
var username, departname, status string
err := db.QueryRow("select username, departname, status from userinfo where uid=?", 4).Scan(&username, &departname, &status)
if err != nil {
    fmt.Println("QueryRow error :", err.Error())
}
fmt.Println("username: ", username, "departname: ", departname, "status: ", status)

//多条:
rows, err := db.Query("select username, departname, status from userinfo where username=?", "yang")
if err != nil {
    fmt.Println("QueryRow error :", err.Error())
}

//定义一个结构体,存放数据模型
type UserInfo struct {
    Username   string `json:"username"`
    Departname string `json:"departname"`
    Status    string `json:"status"`
}

//初始化
var user []UserInfo

for rows.Next() {
    var username1, departname1, status1 string
    if err := rows.Scan(&username1, &departname1, &status1); err != nil {
        fmt.Println("Query error :", err.Error())
    }
    user = append(user, UserInfo{Username: username1, Departname: departname1, Status: status1})
}
ExecPrepareExec

所以我在想?我是不是可以基于原生代码库的这个优势,自己开发 1 个 ORM 呢,第一:它能提供了各式各样的方法来提高开发效率,第二:底层直接转换拼接成最终的 SQL,去调用这个原生的组件,来和 MySQL 去交互。这样岂不是一箭双雕,既能提高开发效率,又能保持足够的高效和简单。完美!

说干就干吧!

3. ORM 框架构想

PrepareEexc"github.com/go-sql-driver/mysql"
smallorm

然后,整个调用过程采用链式的方法,这样比较方便,比如这样子:

db.Where().Where().Order().Limit().Select()

其次,暴露的 CURD 方法,使用起来要简单,名字要清晰,无歧义,不要搞一大堆复杂的间接调用。

OK,我们梳理一下,sql 里面常用到的一些 curd 的方法,把他们整理成 ORM 的一个个方法,并按照这个一步一步来实现,如下:

ConnectTableInsert/ReplaceWhereDeleteUpdateSelectExec/QueryFieldLimitCount/Max/Min/Avg/SumOrderGroupHavingGetLastSqlBegin/Commit/Rollback/
Insert/Replace/Delete/Select/Update

所以,我们可以畅享一下,这个完成后的 ORM,是如何调用的:

增:

type User1 struct {
    Username   string `sql:"username"`
    Departname string `sql:"departname"`
    Status     int64  `sql:"status"`
}

user2 := User1{
    Username:   "EE",
    Departname: "22",
    Status:     1,
}

// insert into userinfo (username,departname,status) values ('EE', '22', 1)

id, err := e.Table("userinfo").Insert(user2)

删:

// delete from userinfo where (uid = 10805)

result1, err := e.Table("userinfo").Where("uid", "=", 10805).Delete()

改:

// update userinfo set departname=110 where (uid = 10805)

result1, err := e.Table("userinfo").Where("uid", "=", 10805).Update("departname", 110)

查:

// select uid, status from userinfo where (departname like '%2') or (status=1)  order by uid desc limit 1

result, err := e.Table("userinfo").Where("departname", "like", "%2").OrWhere("status", 1).Order("uid", "desc").Limit(1).Field("uid, status").Select()

//select uid, status from userinfo where (uid in (1,2,3,4,5)) or (status=1)  order by uid desc limit 1

result, err := e.Table("userinfo").Where("uid", "in", []int{1,2,3,4,5}).OrWhere("status", 1).Order("uid", "desc").Limit(1).Field("uid, status").SelectOne()


type User1 struct {
    Username   string `sql:"username"`
    Departname string `sql:"departname"`
    Status     int64  `sql:"status"`
}

user2 := User1{
    Username:   "EE",
    Departname: "22",
    Status:     1,
}

user3 := User1{
    Username:   "EE",
    Departname: "22",
    Status:     2,
}

// select * from userinfo where (Username='EE' and Departname='22' and Status=1) or (Username='EE' and Departname='22' and Status=2)  limit 1
id, err := e.Table("userinfo").Where(user2).OrWhere(user3).SelectOne()

二、开始造

Connect
sql.Open("mysql", dsn)
SmallormEngine
type SmallormEngine struct {
   Db           *sql.DB
   TableName    string
   Prepare      string
   AllExec      []interface{}
   Sql          string
   WhereParam   string
   LimitParam   string
   OrderParam   string
   OrWhereParam string
   WhereExec    []interface{}
   UpdateParam  string
   UpdateExec   []interface{}
   FieldParam   string
   TransStatus  int
   Tx           *sql.Tx
   GroupParam   string
   HavingParam  string
}

因为我们这 ORM 的底层本质是 SQL 拼接,所以,我们需要把各种操作方法生成的数据,都保存到这个结构体的各个变量上,方便最后一步生成 SQL。

Db*sql.DBTx*sql.Tx

接下来就可以写连接操作了:

//新建Mysql连接
func NewMysql(Username string, Password string, Address string, Dbname string) (*SmallormEngine, error) {
    dsn := Username + ":" + Password + "@tcp(" + Address + ")/" + Dbname + "?charset=utf8&timeout=5s&readTimeout=6s"
    db, err := sql.Open("mysql", dsn)
    if err != nil {
        return nil, err
    }

    //最大连接数等配置,先占个位
   //db.SetMaxOpenConns(3)
   //db.SetMaxIdleConns(3)

    return &SmallormEngine{
        Db:         db,
        FieldParam: "*",
    }, nil
}
NewMysql用户名密码ip和端口数据库名

其次,如何实现链式的方式调用呢?只需要在每个方法返回实例本身即可,比如:

func (e *SmallormEngine) Where (name string) *SmallormEngine {
   return e
}

func (e *SmallormEngine) Limit (name string) *SmallormEngine {
   return e
}

这样我们就可以链式的调用了:

e.Where().Where().Limit()
Table/GetTable

我们需要 1 个设置和读取数据库表名字的方法,因为我们所有的 CURD 都是基于某张表的:

//设置表名
func (e *SmallormEngine) Table(name string) *SmallormEngine {
   e.TableName = name

   //重置引擎
   e.resetSmallormEngine()
   return e
}

//获取表名
func (e *SmallormEngine) GetTable() string {
   return e.TableName
}
Table()SmallormEngine
Insert/Replace

2.1 单个数据插入

下面就是本 ORM 第一个重头戏和挑战点了,如何往数据库里插入数据?在如何用 ORM 实现本功能之前,我们先回忆下上面讲的原生的代码是如何插入的:

PrepareExec
stmt, err := db.Prepare("INSERT INTO userinfo (username, departname, created) VALUES (?, ?, ?)")

result2, err := stmt.Exec("zhangsan", "pro", time.Now().Format("2006-01-02"))

我们分析下它的做法:

Prepare??Exec?

ok,妥了,整明白了。那我们就按照这 2 部拆分数据即可。

Insert[field1:value1,field2:value2,field3:value3]MapStructMapStruct结构体

由于 go 里面的数据都得是先定义类型,再去初始化 1 个值,所以,大致的调用过程是这样的:

type User struct {
    Username   string `sql:"username"`
    Departname string `sql:"departname"`
    Status     int64  `sql:"status"`
}

user2 := User{
    Username:   "EE",
    Departname: "22",
    Status:     1,
}

id, err := e.Table("userinfo").Insert(user2)
sql:"xxx"Tag标签
user2
sql:"xxx"(username, departname, status)?
stmt, err := db.Prepare("INSERT INTO userinfo (username, departname, status) VALUES (?, ?, ?)")
user2Exec
result2, err := stmt.Exec("EE", "22", 1)
user2(username, departname, status)反射反射
reflect.TypeOfreflect.ValueOf
type User struct {
    Username   string `sql:"username"`
    Departname string `sql:"departname"`
    Status     int64  `sql:"status"`
}

user2 := User{
    Username:   "EE",
    Departname: "22",
    Status:     1,
}


//反射出这个结构体变量的类型
t := reflect.TypeOf(user2)

//反射出这个结构体变量的值
v := reflect.ValueOf(user2)

fmt.Printf("==== print type ====\n%+v\n", t)
fmt.Printf("==== print value ====\n%+v\n", v)

我们打印看看,结果是啥?

==== print type ====
main.User

==== print value ====
{Username:EE Departname:22 Status:1}
Usert.NumField()t.Field(i)
//反射type和value
t := reflect.TypeOf(user2)
v := reflect.ValueOf(user2)

//字段名
var fieldName []string

//问号?占位符
var placeholder []string

//循环判断
for i := 0; i < t.NumField(); i++ {

  //小写开头,无法反射,跳过
  if !v.Field(i).CanInterface() {
    continue
  }

  //解析tag,找出真实的sql字段名
  sqlTag := t.Field(i).Tag.Get("sql")
  if sqlTag != "" {
    //跳过自增字段
    if strings.Contains(strings.ToLower(sqlTag), "auto_increment") {
      continue
    } else {
      fieldName = append(fieldName, strings.Split(sqlTag, ",")[0])
      placeholder = append(placeholder, "?")
    }
  } else {
    fieldName = append(fieldName, t.Field(i).Name)
    placeholder = append(placeholder, "?")
  }

  //字段的值
  e.AllExec = append(e.AllExec, v.Field(i).Interface())
}

//拼接表,字段名,占位符
e.Prepare =  "insert into " + e.GetTable() + " (" + strings.Join(fieldName, ",") + ") values(" + strings.Join(placeholder, ",") + ")"
t.NumField()t.Field(i).Tag.Get("sql")sql:"xxx"t.Field(i).Namev.Field(i).Interface()e.GetTable()Db.Prepare
e.Prepare =  "INSERT INTO userinfo (username, departname, status) VALUES (?, ?, ?)"
stmt.Exece.AllExecinterface
//申明stmt类型
var stmt *sql.Stmt

//第一步:Db.prepare
stmt, err = e.Db.Prepare(e.Prepare)

//第二步:执行exec,注意这是stmt.Exec
result, err := stmt.Exec(e.AllExec...)
if err != nil {
  //TODO
}

//获取自增ID
id, _ := result.LastInsertId()
stmt.Exec(e.AllExec...)三个点
stmt.Exec(e.AllExec...)
↓
↓
↓
stmt.Exec("EE", "22", 1)

到此为止,我们成功通过反射和拼接的办法,将 1 个结构体变量,按照 2 步操作法成功的进行了拆分,实现了插入数据。

insertreplaceinsertreplaceinsertType
//用insertType抽象出来,它的值为:insert, replace
e.Prepare = insertType + " into " + e.GetTable() + " (" + strings.Join(fieldName, ",") + ") values(" + strings.Join(placeholder, ",") + ")"

完整的 insert 函数如下:

//插入
func (e *SmallormEngine) Insert(data interface{}) (int64, error) {
    return e.insertData(data, "insert")
}

//替换插入
func (e *SmallormEngine) Replace(data interface{}) (int64, error) {
    return e.insertData(data, "replace")
}

//插入数据子方法
func (e *SmallormEngine) insertData(data interface{}, insertType string) (int64, error) {

  //反射type和value
  t := reflect.TypeOf(data)
  v := reflect.ValueOf(data)

  //字段名
  var fieldName []string

  //问号?占位符
  var placeholder []string

  //循环判断
  for i := 0; i < t.NumField(); i++ {

    //小写开头,无法反射,跳过
    if !v.Field(i).CanInterface() {
      continue
    }

    //解析tag,找出真实的sql字段名
    sqlTag := t.Field(i).Tag.Get("sql")
    if sqlTag != "" {
      //跳过自增字段
      if strings.Contains(strings.ToLower(sqlTag), "auto_increment") {
        continue
      } else {
        fieldName = append(fieldName, strings.Split(sqlTag, ",")[0])
        placeholder = append(placeholder, "?")
      }
    } else {
      fieldName = append(fieldName, t.Field(i).Name)
      placeholder = append(placeholder, "?")
    }

    //字段值
    e.AllExec = append(e.AllExec, v.Field(i).Interface())
  }

  //拼接表,字段名,占位符
  e.Prepare = insertType + " into " + e.GetTable() + " (" + strings.Join(fieldName, ",") + ") values(" + strings.Join(placeholder, ",") + ")"

  //prepare
  var stmt *sql.Stmt
  stmt, err = e.Db.Prepare(e.Prepare)
  if err != nil {
    return 0, e.setErrorInfo(err)
  }

  //执行exec,注意这是stmt.Exec
  result, err := stmt.Exec(e.AllExec...)
  if err != nil {
    return 0, e.setErrorInfo(err)
  }

  //获取自增ID
  id, _ := result.LastInsertId()
  return id, nil
}

//自定义错误格式
func (e *SmallormEngine) setErrorInfo(err error) error {
  _, file, line, _ := runtime.Caller(1)
  return errors.New("File: " + file + ":" + strconv.Itoa(line) + ", " + err.Error())
}

2.2 多个数据批量插入

上面单个插入的原理已经弄的透透的了,接下来我们来看下:批量插入,sql 里面其实是支持批量插入的,这样效率会高很多的,我们先看下,原始的 sql 语句是怎么批量插入的:

INSERT INTO userinfo (username, departname, created) VALUES ("EE", "22", 1),("aa", "rd", 0),("bb", "ty", 1)
("xx", "yy", "zz")

OK, 那我们看下 go 原生的批量插入代码是怎么弄的:

stmt, err := db.Prepare("INSERT INTO userinfo (username, departname, status) VALUES (?, ?, ?),(?, ?, ?),(?, ?, ?)")

result2, err := stmt.Exec("a1", "1", 1, "a2", "b2", 1, "a3", "b3", 0)
(),
PrepareExec
1. 批量插入,传入的数据就是一个切片数组了,`[]struct` 这样的数据类型了。
2. 我们得先用反射算出,这个数组有多少个元素。这样好算出 VALUES 后面有几个`()`的占位符。
3. 搞2个for循环,外面的for循环,得出这个子元素的type和value。里面的第二个for循环,就和单个插入的反射操作一样了,就是算出每一个子元素有几个字段,反射出field名字,以及对应`()`里面有几个?问号占位符。
4. 2层for循环把切片里面的每个元素的每个字段的value放入到1个统一的AllExec中。

OK,直接上代码吧:

//批量插入
func (e *SmallormEngine) BatchInsert(data interface{}) (int64, error) {
    return e.batchInsertData(data, "insert")
}

//批量替换插入
func (e *SmallormEngine) BatchReplace(data interface{}) (int64, error) {
    return e.batchInsertData(data, "replace")
}


//批量插入
func (e *SmallormEngine) batchInsertData(batchData interface{}, insertType string) (int64, error) {

  //反射解析
  getValue := reflect.ValueOf(batchData)

  //切片大小
  l := getValue.Len()

  //字段名
  var fieldName []string

  //占位符
  var placeholderString []string

  //循环判断
  for i := 0; i < l; i++ {
    value := getValue.Index(i) // Value of item
    typed := value.Type()      // Type of item
    if typed.Kind() != reflect.Struct {
      panic("批量插入的子元素必须是结构体类型")
    }

    num := value.NumField()

    //子元素值
    var placeholder []string
    //循环遍历子元素
    for j := 0; j < num; j++ {

      //小写开头,无法反射,跳过
      if !value.Field(j).CanInterface() {
        continue
      }

      //解析tag,找出真实的sql字段名
      sqlTag := typed.Field(j).Tag.Get("sql")
      if sqlTag != "" {
        //跳过自增字段
        if strings.Contains(strings.ToLower(sqlTag), "auto_increment") {
          continue
        } else {
          //字段名只记录第一个的
          if i == 1 {
            fieldName = append(fieldName, strings.Split(sqlTag, ",")[0])
          }
          placeholder = append(placeholder, "?")
        }
      } else {
        //字段名只记录第一个的
        if i == 1 {
          fieldName = append(fieldName, typed.Field(j).Name)
        }
        placeholder = append(placeholder, "?")
      }

      //字段值
      e.AllExec = append(e.AllExec, value.Field(j).Interface())
    }

    //子元素拼接成多个()括号后的值
    placeholderString = append(placeholderString, "("+strings.Join(placeholder, ",")+")")
  }

  //拼接表,字段名,占位符
  e.Prepare = insertType + " into " + e.GetTable() + " (" + strings.Join(fieldName, ",") + ") values " + strings.Join(placeholderString, ",")

  //prepare
  var stmt *sql.Stmt
  var err error
  stmt, err = e.Db.Prepare(e.Prepare)
  if err != nil {
    return 0, e.setErrorInfo(err)
  }

  //执行exec,注意这是stmt.Exec
  result, err := stmt.Exec(e.AllExec...)
  if err != nil {
    return 0, e.setErrorInfo(err)
  }

  //获取自增ID
  id, _ := result.LastInsertId()
  return id, nil
}


//自定义错误格式
func (e *SmallormEngine) setErrorInfo(err error) error {
  _, file, line, _ := runtime.Caller(1)
  return errors.New("File: " + file + ":" + strconv.Itoa(line) + ", " + err.Error())
}

开始总结一下上面这一坨关键的地方。

首先是获取这个切片的大小,用于第一个 for 循环。可以通过下面的 2 行代码:

//反射解析
getValue := reflect.ValueOf(batchData)

//切片大小
l := getValue.Len()
value := getValue.Index(i)v := reflect.ValueOf(data)
typed := value.Type()t := reflect.TypeOf(data)
i==1
placeholderString多个()

这样,批量插入,批量替换插入的逻辑就完成了。

2.3 单个和批量合二为一

为了使我们的 ORM 足够的优雅和简单,我们可以把单个插入和批量插入,搞成 1 个方法暴露出去。那怎么识别出传入的数据是单个结构体,还是切片结构体呢?还是得用反射:

reflect.ValueOf(data).Kind()
StructSliceArray
//插入
func (e *SmallormEngine) Insert(data interface{}) (int64, error) {

  //判断是批量还是单个插入
  getValue := reflect.ValueOf(data).Kind()
  if getValue == reflect.Struct {
    return e.insertData(data, "insert")
  } else if getValue == reflect.Slice || getValue == reflect.Array {
    return e.batchInsertData(data, "insert")
  } else {
    return 0, errors.New("插入的数据格式不正确,单个插入格式为: struct,批量插入格式为: []struct")
  }
}


//替换插入
func (e *SmallormEngine) Replace(data interface{}) (int64, error) {
  //判断是批量还是单个插入
  getValue := reflect.ValueOf(data).Kind()
  if getValue == reflect.Struct {
    return e.insertData(data, "replace")
  } else if getValue == reflect.Slice || getValue == reflect.Array {
    return e.batchInsertData(data, "replace")
  } else {
    return 0, errors.New("插入的数据格式不正确,单个插入格式为: struct,批量插入格式为: []struct")
  }
}

OK,完成。

Where

3.1 结构体参数调用

Where
select * from userinfo where status = 1
delete from userinfo where status = 1 or departname != "aa"
update userinfo set departname = "bb" where status = 1 and departname = "aa"
Where
=, !=, like, <, >andor
PrepareExce
stmt, err := db.Prepare("delete from userinfo where uid=?")
result3, err := stmt.Exec("10795")

stmt, err := db.Prepare("update userinfo set username=? where uid=?")
result, err := stmt.Exec("lisi", 2)

所以,where 部分的拆分,其实也是分 2 部来走。和插入的 2 步走的逻辑是一样的。大致的调用过程如下:

type User struct {
    Username   string `sql:"username"`
    Departname string `sql:"departname"`
    Status     int64  `sql:"status"`
}

user2 := User{
    Username:   "EE",
    Departname: "22",
    Status:     1,
}

result1, err1 := e.Table("userinfo").Where(user2).Delete()
result2, err2 := e.Table("userinfo").Where(user2).Select()
WhereParamWhereExec
Insert
func (e *SmallormEngine) Where(data interface{}) *SmallormEngine {

    //反射type和value
    t := reflect.TypeOf(data)
    v := reflect.ValueOf(data)

    //字段名
    var fieldNameArray []string

    //循环解析
    for i := 0; i < t.NumField(); i++ {

      //首字母小写,不可反射
      if !v.Field(i).CanInterface() {
        continue
      }

      //解析tag,找出真实的sql字段名
      sqlTag := t.Field(i).Tag.Get("sql")
      if sqlTag != "" {
        fieldNameArray = append(fieldNameArray, strings.Split(sqlTag, ",")[0]+"=?")
      } else {
        fieldNameArray = append(fieldNameArray, t.Field(i).Name+"=?")
      }

      //反射出Exec的值。
      e.WhereExec = append(e.WhereExec, v.Field(i).Interface())
    }

    //拼接
    e.WhereParam += strings.Join(fieldNameArray, " and ")
    return e
}

这样,我们就可以调用 Where()反复,转换成生成了 2 个暂存变量。我们打印下这 2 个值看看:

WhereParam = "username=? and departname=? and Status=?"
WhereExec = []interface{"EE", "22", 1}
and
e.Table("userinfo").Where(user2).Where(user3).XXX
e.WhereParam

先判断理一下,是否为空,如果不为空,则说明这是第二次调用了,我们用 "and (" 来做隔离。

//多次调用判断
if e.WhereParam != "" {
  e.WhereParam += " and ("
} else {
  e.WhereParam += "("
}

//结束拼接的时候,加上结束括号") "。

e.WhereParam += strings.Join(fieldNameArray, " and ") + ") "

这样,就达到了我们的目的了。我们看下多次调用后的打印结果:

WhereParam = "(username=? and departname=? and status=?) and (username=? and departname=? and status=?)"
WhereExec = []interface{"EE", "22", 1, "FF", "33", 0}
=

3.2 单个字符串参数的调用

WhereInsert
Where("uid", "=", 1234)
Where("uid", ">=", 1234)
Where("uid", "in", []int{2, 3, 4})
非and!=likenot inin

OK,那我们开始写一下,这种方式怎么判断呢?对比传入结构体的方式更简单:

方法有 3 个参数,第一个是需要查询的字段,第 2 个是比较符,第三个是查询的值。

func (e *SmallormEngine) Where(fieldName string, opt string, fieldValue interface{}) *SmallormEngine {

    //区分是操作符in的情况
    data2 := strings.Trim(strings.ToLower(fieldName.(string)), " ")
    if data2 == "in" || data2 == "not in" {
      //判断传入的是切片
      reType := reflect.TypeOf(fieldValue).Kind()
      if reType != reflect.Slice && reType != reflect.Array {
        panic("in/not in 操作传入的数据必须是切片或者数组")
      }

      //反射值
      v := reflect.ValueOf(fieldValue)
      //数组/切片长度
      dataNum := v.Len()
      //占位符
      ps := make([]string, dataNum)
      for i := 0; i < dataNum; i++ {
        ps[i] = "?"
        e.WhereExec = append(e.WhereExec, v.Index(i).Interface())
      }

      //拼接
      e.WhereParam += fieldName.(string) + " " + fieldValue + " (" + strings.Join(ps, ",") + ")) "

    } else {
      e.WhereParam += fieldName.(string) + " " + fieldValue.(string) + " ?) "
      e.WhereExec = append(e.WhereExec, fieldValue)
    }

    return e
}
inin (?,?,?)

所以,我们把这 2 种方式,拼接一下,融合成 1 种方式,智能的去判断即可,下面是完整的代码:

//传入and条件
func (e *SmallormEngine) Where(data ...interface{}) *SmallormEngine {

  //判断是结构体还是多个字符串
  var dataType int
  if len(data) == 1 {
    dataType = 1
  } else if len(data) == 2 {
    dataType = 2
  } else if len(data) == 3 {
    dataType = 3
  } else {
    panic("参数个数错误")
  }

  //多次调用判断
  if e.WhereParam != "" {
    e.WhereParam += " and ("
  } else {
    e.WhereParam += "("
  }

  //如果是结构体
  if dataType == 1 {
    t := reflect.TypeOf(data[0])
    v := reflect.ValueOf(data[0])

    //字段名
    var fieldNameArray []string

    //循环解析
    for i := 0; i < t.NumField(); i++ {

      //首字母小写,不可反射
      if !v.Field(i).CanInterface() {
        continue
      }

      //解析tag,找出真实的sql字段名
      sqlTag := t.Field(i).Tag.Get("sql")
      if sqlTag != "" {
        fieldNameArray = append(fieldNameArray, strings.Split(sqlTag, ",")[0]+"=?")
      } else {
        fieldNameArray = append(fieldNameArray, t.Field(i).Name+"=?")
      }

      e.WhereExec = append(e.WhereExec, v.Field(i).Interface())
    }

    //拼接
    e.WhereParam += strings.Join(fieldNameArray, " and ") + ") "

  } else if dataType == 2 {
    //直接=的情况
    e.WhereParam += data[0].(string) + "=?) "
    e.WhereExec = append(e.WhereExec, data[1])
  } else if dataType == 3 {
    //3个参数的情况

    //区分是操作符in的情况
    data2 := strings.Trim(strings.ToLower(data[1].(string)), " ")
    if data2 == "in" || data2 == "not in" {
      //判断传入的是切片
      reType := reflect.TypeOf(data[2]).Kind()
      if reType != reflect.Slice && reType != reflect.Array {
        panic("in/not in 操作传入的数据必须是切片或者数组")
      }

      //反射值
      v := reflect.ValueOf(data[2])
      //数组/切片长度
      dataNum := v.Len()
      //占位符
      ps := make([]string, dataNum)
      for i := 0; i < dataNum; i++ {
        ps[i] = "?"
        e.WhereExec = append(e.WhereExec, v.Index(i).Interface())
      }

      //拼接
      e.WhereParam += data[0].(string) + " " + data2 + " (" + strings.Join(ps, ",") + ")) "

    } else {
      e.WhereParam += data[0].(string) + " " + data[1].(string) + " ?) "
      e.WhereExec = append(e.WhereExec, data[2])
    }
  }

  return e
}
..interface{}len()data[0]
Where
// where uid = 123
e.Table("userinfo").Where("uid", 123)

// where uid not in (2,3,4)
e.Table("userinfo").Where("uid", "not in", []int{2, 3, 4})

// where uid in (2,3,4)
e.Table("userinfo").Where("uid", "in", []int{2, 3, 4})

// where uid like '%2%'
e.Table("userinfo").Where("uid", "like", "%2%")

// where uid >= 123
e.Table("userinfo").Where("uid", ">=", 123)

// where (uid >= 123) and (name = 'vv')
e.Table("userinfo").Where("uid", ">=", 123).Where("name", "vv")
OrWhere
Whereand
where (uid >= 123) or (name = 'vv')
where (uid = 123 and name = 'vv') or (uid = 456 and name = 'bb')
OrWhereParamWherewhereParamWhereExecor
func (e *SmallormEngine) OrWhere(data ...interface{}) *SmallormEngine {

  ...

  //判断使用顺序
  if e.WhereParam == "" {
    panic("WhereOr必须在Where后面调用")
  }

  //WhereOr条件
  e.OrWhereParam += " or ("

  ...

  return e
}
OrWhereWhere

也是一样,有三种调用方式:

OrWhere("uid", 1234) //默认是等于
OrWhere("uid", ">=", 1234)
OrWhere(uidStruct) //传入1个结构体,结构体之间用and连接

看下使用效果:

// where (uid = 123) or (name = "vv")
e.Table("userinfo").Where("uid", 123).OrWhere("name", "vv")

// where (uid not in (2,3,4)) or (uid not in (5,6,7))
e.Table("userinfo").Where("uid", "not in", []int{2, 3, 4}).OrWhere("uid", "not in", []int{5, 6, 7})

// where (uid like '%2') or (uid like '%5%')
e.Table("userinfo").Where("uid", "like", "%2").OrWhere("uid", "like", "%5%")

// where (uid >= 123) or (uid <= 454)
e.Table("userinfo").Where("uid", ">=", 123).OrWhere("uid", "<=", 454)

// where (username = "EE" and departname = "22" and status = 1) or (name = 'vv') or (status = 1)

type User struct {
    Username   string `sql:"username"`
    Departname string `sql:"departname"`
    Status     int64  `sql:"status"`
}

user2 := User{
    Username:   "EE",
    Departname: "22",
    Status:     1,
}

e.Table("userinfo").Where(user2).OrWhere("name", "vv").OrWhere("status", 1)

为了使这个方法更简单的被使用,不搞复杂,这种方式的 or 关系,实质上是针对于多次调用 where 之间的,是不支持同一个 where 里面的数据是 or 关系的。那如果需要的话,可以这样调用:

// where (username = "EE") or (departname = "22") or (status = 1)

e.Table("userinfo").Where(username, "EE").OrWhere("departname", "22").OrWhere("status", 1)
Delete
WhereOrWhereDeleteDeleteWherePrepareExec

我们看下具体是怎么写:

//删除
func (e *SmallormEngine) Delete() (int64, error) {

  //拼接delete sql
  e.Prepare = "delete from " + e.GetTable()

  //如果where不为空
  if e.WhereParam != "" || e.OrWhereParam != "" {
    e.Prepare += " where " + e.WhereParam + e.OrWhereParam
  }

  //limit不为空
  if e.LimitParam != "" {
    e.Prepare += "limit " + e.LimitParam
  }

  //第一步:Prepare
  var stmt *sql.Stmt
  var err error
  stmt, err = e.Db.Prepare(e.Prepare)
  if err != nil {
    return 0, err
  }

  e.AllExec = e.WhereExec

  //第二步:执行exec,注意这是stmt.Exec
  result, err := stmt.Exec(e.AllExec...)
  if err != nil {
    return 0, e.setErrorInfo(err)
  }

  //影响的行数
  rowsAffected, err := result.RowsAffected()
  if err != nil {
    return 0, e.setErrorInfo(err)
  }

  return rowsAffected, nil
}
Inserte.Prepare

这样看下调用方式和结果:

// delete from userinfo where (uid >= 123) or (uid <= 454)
rowsAffected, err := e.Table("userinfo").Where("uid", ">=", 123).OrWhere("uid", "<=", 454).Delete()

nice!

Update
DeleteWhere
update userinfo set status = 1 where (uid >= 123) or (uid <= 454)
status=1
e.Table("userinfo").Where("uid", 123).Update("status", 1)

e.Table("userinfo").Where("uid", 123).Update(user2)
Where
UpdateInsert

直接上代码吧:

//更新
func (e *SmallormEngine) Update(data ...interface{}) (int64, error) {

  //判断是结构体还是多个字符串
  var dataType int
  if len(data) == 1 {
    dataType = 1
  } else if len(data) == 2 {
    dataType = 2
  } else {
    return 0, errors.New("参数个数错误")
  }

  //如果是结构体
  if dataType == 1 {
    t := reflect.TypeOf(data[0])
    v := reflect.ValueOf(data[0])

    var fieldNameArray []string
    for i := 0; i < t.NumField(); i++ {

      //首字母小写,不可反射
      if !v.Field(i).CanInterface() {
        continue
      }

      //解析tag,找出真实的sql字段名
      sqlTag := t.Field(i).Tag.Get("sql")
      if sqlTag != "" {
        fieldNameArray = append(fieldNameArray, strings.Split(sqlTag, ",")[0]+"=?")
      } else {
        fieldNameArray = append(fieldNameArray, t.Field(i).Name+"=?")
      }

      e.UpdateExec = append(e.UpdateExec, v.Field(i).Interface())
    }
    e.UpdateParam += strings.Join(fieldNameArray, ",")

  } else if dataType == 2 {
    //直接=的情况
    e.UpdateParam += data[0].(string) + "=?"
    e.UpdateExec = append(e.UpdateExec, data[1])
  }

  //拼接sql
  e.Prepare = "update " + e.GetTable() + " set " + e.UpdateParam

  //如果where不为空
  if e.WhereParam != "" || e.OrWhereParam != "" {
    e.Prepare += " where " + e.WhereParam + e.OrWhereParam
  }

  //limit不为空
  if e.LimitParam != "" {
    e.Prepare += "limit " + e.LimitParam
  }

  //prepare
  var stmt *sql.Stmt
  var err error
  stmt, err = e.Db.Prepare(e.Prepare)
  if err != nil {
    return 0, e.setErrorInfo(err)
  }

  //合并UpdateExec和WhereExec
  if e.WhereExec != nil {
    e.AllExec = append(e.UpdateExec, e.WhereExec...)
  }

  //执行exec,注意这是stmt.Exec
  result, err := stmt.Exec(e.AllExec...)
  if err != nil {
    return 0, e.setErrorInfo(err)
  }

  //影响的行数
  id, _ := result.RowsAffected()
  return id, nil
}
e.WhereExec...UpdateExec
cannot use []interface{} literal (type []interface{}) as type interface{} in append
array_merge
$a1=array("red","green");
$a2=array("blue","yellow");
print_r(array_merge($a1,$a2));   // Array ( [0] => red [1] => green [2] => blue [3] => yellow )

6. 查询

PrepareExecQueryRowQuery
Scan()
//单条
var username, departname, status string
err := db.QueryRow("select username, departname, status from userinfo where uid=?", 4).Scan(&username, &departname, &status)
if err != nil {
    fmt.Println("QueryRow error :", err.Error())
}
fmt.Println("username: ", username, "departname: ", departname, "status: ", status)

再看多条的查询,第一步,得先把查询的数据结构先定义出来,再实例化 1 个多维的数组,再通过 for 循环去给这个数组赋值,值得注意的是这个数据结构的字段数得和 select 出来的字段数保持一致,不然就会丢失。PHP 程序员再次直呼好家伙。

//多条:
rows, err := db.Query("select username, departname, created from userinfo where username=?", "yang")
if err != nil {
    fmt.Println("QueryRow error :", err.Error())
}

//定义一个结构体,存放数据模型
type UserInfo struct {
    Username   string `json:"username"`
    Departname string `json:"departname"`
    Created    string `json:"created"`
}

//初始化
var user []UserInfo

for rows.Next() {
    var username1, departname1, created1 string
    if err := rows.Scan(&username1, &departname1, &created1); err != nil {
        fmt.Println("Query error :", err.Error())
    }
    user = append(user, UserInfo{Username: username1, Departname: departname1, Created: created1})
}

麻烦归麻烦,我们还是需要抽丝剥茧,我们还是得找出规律,用我们自定义的方法,去生成符合这样格式的数据。所以,查询又会是另一个难点和挑战点。

QueryRowQueryfor nextlimit 1

下面开始吧。

Select()

考虑到要提前定义 1 个数据结构,再初始化成 1 个数组,真的是太麻烦了,我想着能不能啥都不传呢?直接按照数据表里的字段名,直接给我输出 1 个同名字的 map 切片呢?试一试吧。

比如这样,userinfo 表里面有 4 个字段:"uid, username, departname, status",我们像下面这样查询,然后就可以返回 1 个 map 的数组切片,岂不是美滋滋?

result, err := e.Table("userinfo").Where("status", 1).Select()

返回为:

//type:

[]map[string]string

//value:

[map[departname:v status:1 uid:123 username:yang] map[departname:n status:0 uid:456 username:small]]
Db.QueryColumns()

比如:

rows, err := db.Query("select uid, username, departname, status from userinfo where username=?", "yang")
if err != nil {
    fmt.Println("Query error :", err.Error())
}

column, err := rows.Columns()
if err != nil {
   fmt.Println("rows.Columns error :", err.Error())
}

fmt.Println(column)

我们看下返回值:

[uid username departname stauts]
rows.Scan()Scan()
Scan()
for rows.Next() {
    var uid1, username1, departname1, status1 string
    rows.Scan(&uid1, &username1, &departname1, &status1)
    fmt.Println(uid1,username1,departname1,status1)
}

这样我们打印这 4 个变量,他们就都有值了:

1 yang v 0
12 yi b 1
....
Columns
//读出查询出的列字段名
column, err := rows.Columns()
if err != nil {
  return nil, e.setErrorInfo(err)
}

//values是每个列的值,这里获取到byte里
values := make([][]byte, len(column))

//因为每次查询出来的列是不定长的,用len(column)定住当次查询的长度
scans := make([]interface{}, len(column))

for i := range values {
  scans[i] = &values[i]
}
valuesscans

一一对应:

// 打印column的值
[uid username departname stauts]

// 打印values的值
[[] [] [] []]

//打印scans的值
[0xc000056180 0xc000056198 0xc0000561b0 0xc0000561c8]

Scan()
for rows.Next() {

  rows.Scan(scans[0], scans[1],scans[2], scans[3])

}
scans[0]uid1scans[3]status1scans[0]values[0]values[0]
// 打印column的值
[uid username departname stauts]

// 打印scans的值
[0xc000056180 0xc000056198 0xc0000561b0 0xc0000561c8]

// 打印values的值
[1 yang v 0]

然后,我们再通过这 3 个切片的下标的映射,就能将表字段和值对应起来,拼接成 1 个 map。

scansscans[0],scans[1].....scans[n]
results := make([]map[string]string, 0)
for rows.Next() {
  if err := rows.Scan(scans...); err != nil {
    return nil, e.setErrorInfo(err)
  }

  //每行数据
  row := make(map[string]string)

  //循环values数据,通过相同的下标,取column里面对应的列名,生成1个新的map
  for k, v := range values {
    key := column[k]
    row[key] = string(v)
  }

  //添加到map切片中
  results = append(results, row)
}
rows.Scan(scans...)
rows.Scan(scans[0], scans[1],scans[2], scans[3])
↓↓↓
↓↓↓
rows.Scan(scans...)
scan

好了,我们看下这个方法,完整的代码:

//查询多条,返回值为map切片
func (e *SmallormEngine) Select() ([]map[string]string, error) {

  //拼接sql
  e.Prepare = "select * from " + e.GetTable()

  //如果where不为空
  if e.WhereParam != "" || e.OrWhereParam != "" {
    e.Prepare += " where " + e.WhereParam + e.OrWhereParam
  }

  e.AllExec = e.WhereExec


  //query
  rows, err := e.Db.Query(e.Prepare, e.AllExec...)
  if err != nil {
    return nil, e.setErrorInfo(err)
  }

  //读出查询出的列字段名
  column, err := rows.Columns()
  if err != nil {
    return nil, e.setErrorInfo(err)
  }

  //values是每个列的值,这里获取到byte里
  values := make([][]byte, len(column))

  //因为每次查询出来的列是不定长的,用len(column)定住当次查询的长度
  scans := make([]interface{}, len(column))

  for i := range values {
    scans[i] = &values[i]
  }

  results := make([]map[string]string, 0)
  for rows.Next() {
    if err := rows.Scan(scans...); err != nil {
      //query.Scan查询出来的不定长值放到scans[i] = &values[i],也就是每行都放在values里
      return nil, e.setErrorInfo(err)
    }

    //每行数据
    row := make(map[string]string)

    //循环values数据,通过相同的下标,取column里面对应的列名,生成1个新的map
    for k, v := range values {
      key := column[k]
      row[key] = string(v)
    }

    //添加到map切片中
    results = append(results, row)
  }

  return results, nil
}

这样,我们就能非常方便的查询数据了,但是这个方法,有 2 个小的影响的地方,1. 就是最后返回的 map 切片,里面的 key 名都是数据库的字段名(可能都是小字母头),如果要映射成首字母大写的结构,需要我们自己去写方法。2. 他会把数据库表的所有字段的类型都会转换成字符串类型的,理论上影响也不大。

SelectOne()
limit 1
//查询1条
func (e *SmallormEngine) SelectOne() (map[string]string, error) {

  //limit 1 单个查询
  results, err := e.Limit(1).Select()
  if err != nil {
    return nil, e.setErrorInfo(err)
  }

  //判断是否为空
  if len(results) == 0 {
    return nil, nil
  } else {
    return results[0], nil
  }
}
Limit()limit 1SelectOne

这样,我们就可以很方便的查询单条数据了:

result, err := e.Table("userinfo").Where("status", 1).SelectOne()

返回为:

//type:

map[string]string

//value:

map[departname:v status:1 uid:123 username:yang]
Find()

这个方法其实是对原生 go 查询的一个简单包装,毕竟还是有很多人是喜欢先定义好数据结构,然后通过引用赋值的,当然在大分部的 go 的 ORM 里面,也是这么实现查询操作的。

//定义好结构体
type User struct {
    Uid        int    `sql:"uid,auto_increment"`
    Username   string `sql:"username"`
    Departname string `sql:"departname"`
    Status     int64  `sql:"status"`
}

//实例化切片
var user1 []User

// select * from userinfo where status=1
err := e.Table("userinfo").Where("status", 2).Find(&user1)

if err != nil {
    fmt.Println(err.Error())
} else {
    fmt.Printf("%#v", user1)
}

看下打印的数据

[]smallorm.User{smallorm.User2{Uid:131733, Username:"EE2", Departname:"223", Status:2}, smallorm.User{Uid:131734, Username:"EE2", Departname:"223", Status:2}, smallorm.User{Uid:131735, Username:"EE2", Departname:"223", Status:2}}

我们先在脑海中理一下大致的一个调用和逻辑处理过程:

&Find()Find()

这么看来,第 3 步是最复杂的,它需要获取传入的结构体切片里面的每一个值,并且还得把查询出来的结果给它全部赋上,Word 天,感觉好难啊!!!这题不会做啊。

后来在我大量翻阅 GORM 的源码以及查看 go 反射的文档后,我渐渐的有了头绪,这题也太简单了吧(逃

Selecttag:sql:"xx"
//读出查询出的列字段名
column, err := rows.Columns()
if err != nil {
  return e.setErrorInfo(err)
}

//values是每个列的值,这里获取到byte里
values := make([][]byte, len(column))

//因为每次查询出来的列是不定长的,用len(column)定住当次查询的长度
scans := make([]interface{}, len(column))

for i := range values {
  scans[i] = &values[i]
}
values
//原始struct的切片值
destSlice := reflect.ValueOf(result).Elem()

//原始单个struct的类型
destType := destSlice.Type().Elem()
User
fmt.Printf("%+v\n", destSlice)
fmt.Printf("%+v", destType)

[]
main.User

ok,我们就成功解析出了传入的结构体是长啥样的了,然后就可以再根据一系列 for 循环和各种神奇的 go 反射方法来继续:

destType.NumField(); //获取到User结构体的字段数,这里返回:4

destType.Field(i).Tag.Get("sql")  //获取到User结构体的第i个字段的tag值,比如返回:`username`

destType.Field(i).Name  // //获取到User结构体的第i个字段的名字,比如返回:`Username`

再通过这几个反射给赋值:

dest := reflect.New(destType).Elem()  // 根据类型生成1个新的值,返回:{Uid:0 Username: Departname: Status:0}

dest.Field(i).SetString(value) //给第i个元素,附值,类型是string类型

reflect.Append(destSlice, dest) // 将dest值添加到destSlice切片中。

destSlice.Set(reflect.Append(destSlice, dest)) //将最后得到的切片完全赋值给本身。

或许这一顿反射操作已经把你搞晕了,说实话,我也晕了。现在看下完整的函数:

//查询多条,返回值为struct切片
func (e *SmallormEngine) Find(result interface{}) error {

  if reflect.ValueOf(result).Kind() != reflect.Ptr {
    return e.setErrorInfo(errors.New("参数请传指针变量!"))
  }

  if reflect.ValueOf(result).IsNil() {
    return e.setErrorInfo(errors.New("参数不能是空指针!"))
  }

  //拼接sql
  e.Prepare = "select * from " + e.GetTable()


  e.AllExec = e.WhereExec

  //query
  rows, err := e.Db.Query(e.Prepare, e.AllExec...)
  if err != nil {
    return e.setErrorInfo(err)
  }

  //读出查询出的列字段名
  column, err := rows.Columns()
  if err != nil {
    return e.setErrorInfo(err)
  }

  //values是每个列的值,这里获取到byte里
  values := make([][]byte, len(column))

  //因为每次查询出来的列是不定长的,用len(column)定住当次查询的长度
  scans := make([]interface{}, len(column))

  //原始struct的切片值
  destSlice := reflect.ValueOf(result).Elem()

  //原始单个struct的类型
  destType := destSlice.Type().Elem()

  for i := range values {
    scans[i] = &values[i]
  }

  //循环遍历
  for rows.Next() {

    dest := reflect.New(destType).Elem()

    if err := rows.Scan(scans...); err != nil {
      //query.Scan查询出来的不定长值放到scans[i] = &values[i],也就是每行都放在values里
      return e.setErrorInfo(err)
    }

    //遍历一行数据的各个字段
    for k, v := range values {
      //每行数据是放在values里面,现在把它挪到row里
      key := column[k]
      value := string(v)

      //遍历结构体
      for i := 0; i < destType.NumField(); i++ {

        //看下是否有sql别名
        sqlTag := destType.Field(i).Tag.Get("sql")
        var fieldName string
        if sqlTag != "" {
          fieldName = strings.Split(sqlTag, ",")[0]
        } else {
          fieldName = destType.Field(i).Name
        }

        //struct里没这个key
        if key != fieldName {
          continue
        }

        //反射赋值
        if err := e.reflectSet(dest, i, value); err != nil {
          return err
        }
      }
    }
    //赋值
    destSlice.Set(reflect.Append(destSlice, dest))
  }

  return nil
}
reflectSet
//反射赋值
func (e *SmallormEngine) reflectSet(dest reflect.Value, i int, value string) error {
  switch dest.Field(i).Kind() {
  case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
    res, err := strconv.ParseInt(value, 10, 64)
    if err != nil {
      return e.setErrorInfo(err)
    }
    dest.Field(i).SetInt(res)
  case reflect.String:
    dest.Field(i).SetString(value)
  case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
    res, err := strconv.ParseUint(value, 10, 64)
    if err != nil {
      return e.setErrorInfo(err)
    }
    dest.Field(i).SetUint(res)
  case reflect.Float32:
    res, err := strconv.ParseFloat(value, 32)
    if err != nil {
      return e.setErrorInfo(err)
    }
    dest.Field(i).SetFloat(res)
  case reflect.Float64:
    res, err := strconv.ParseFloat(value, 64)
    if err != nil {
      return e.setErrorInfo(err)
    }
    dest.Field(i).SetFloat(res)
  case reflect.Bool:
    res, err := strconv.ParseBool(value)
    if err != nil {
      return e.setErrorInfo(err)
    }
    dest.Field(i).SetBool(res)
  }
  return nil
}
switch dest.Field(i).Kind() casestrconv.xxx()SetXXX()
FindOne()
Limit 1
//查询单条,返回值为struct切片
func (e *SmallormEngine) FindOne(result interface{}) error {

  //取的原始值
  dest := reflect.Indirect(reflect.ValueOf(result))

  //new一个类型的切片
  destSlice := reflect.New(reflect.SliceOf(dest.Type())).Elem()

  //调用
  if err := e.Limit(1).Find(destSlice.Addr().Interface()); err != nil {
    return err
  }

  //判断返回值长度
  if destSlice.Len() == 0 {
    return e.setErrorInfo(errors.New("NOT FOUND"))
  }

  //取切片里的第0个数据,并复制给原始值结构体指针
  dest.Set(destSlice.Index(0))
  return nil
}
Find()FindOne()
Find()

OK,我们调用试一下:

//定义好结构体
type User struct {
    Uid        int    `sql:"uid,auto_increment"`
    Username   string `sql:"username"`
    Departname string `sql:"departname"`
    Status     int64  `sql:"status"`
}

//实例化数据
var user1 User

// select * from userinfo where status=1
err := e.Table("userinfo").Where("status", 2).FindOne(&user1)

if err != nil {
    fmt.Println(err.Error())
} else {
    fmt.Printf("%#v", user1)
}

看下打印的数据

smallorm.User{Uid:131733, Username:"EE2", Departname:"223", Status:2}
Field
select *
e.Table("userinfo").Where("status", 2).Field("uid,status").Select()

由于是采用链式的调用方式,而且它本身也没有数据属性,所以是可以放在中间部分的任何位置的:

e.Table("userinfo").Field("uid,status").Where("status", 2).Select()
SmallormEngineFieldParam
//设置查询字段
func (e *SmallormEngine) Field(field string) *SmallormEngine {
  e.FieldParam = field
  return e
}
Select/Find
e.Prepare = "select " + e.FieldParam + " from " + e.GetTable()
e.FieldParamNewMysqlField()select *

值得注意的是,我们是直接裸传的,并没有对传入的字段做检验和判断,这个优化将在第二版本中展开。

Limit
limit 1limit 0,9limit 10,19limit
e.Table("userinfo").Where("status", 2).Limit(1).Select()
e.Table("userinfo").Where("status", 2).Limit(0, 9).Select()

我们来看下怎么实现这 2 种方式的调用:

//limit分页
func (e *SmallormEngine) Limit(limit ...int64) *SmallormEngine {
  if len(limit) == 1 {
    e.LimitParam = strconv.Itoa(int(limit[0]))
  } else if len(limit) == 2 {
    e.LimitParam = strconv.Itoa(int(limit[0])) + "," + strconv.Itoa(int(limit[1]))
  } else {
    panic("参数个数错误")
  }
  return e
}
LimitParamFind/Select
//limit不为空
if e.LimitParam != "" {
  e.Prepare += " limit " + e.LimitParam
}

这样我们就往 prepare 中增加好了 limit 的语句。

Count/Max/Min/Avg/Sum
  1. Count() //获取总数
  2. Max() //获取最大值
  3. Min() //获取最小值
  4. Avg() //获取平均值
  5. Sum() //获取总和
select *select Xxxx(*)db.QueryRow()

我们来看下怎么写,首先第一步,设置 2 个参数,分别对应于具体的聚合函数,以及需要聚合的字段名。

name 对应于具体的聚合函数,param 则对应于具体的字段:

func (e *SmallormEngine) aggregateQuery(name, param string) (interface{}, error) {

  e.Prepare = "select " + name + "(" + param + ") as cnt from " + e.GetTable()

}

这样,我们这个通用方法的主体给完成了,我们想实现对应的聚合查询功能,只需要传递 2 个参数即可。

接下来,我们看下查询部分:

//执行绑定
var cnt interface{}

//queryRows
err := e.Db.QueryRow(e.Prepare, e.AllExec...).Scan(&cnt)
if err != nil {
  return nil, e.setErrorInfo(err)
}
cnt

下面是完整的代码:

//聚合查询
func (e *SmallormEngine) aggregateQuery(name, param string) (interface{}, error) {

  //拼接sql
  e.Prepare = "select " + name + "(" + param + ") as cnt from " + e.GetTable()

  //如果where不为空
  if e.WhereParam != "" || e.OrWhereParam != "" {
    e.Prepare += " where " + e.WhereParam + e.OrWhereParam
  }

  //limit不为空
  if e.LimitParam != "" {
    e.Prepare += " limit " + e.LimitParam
  }

  e.AllExec = e.WhereExec

  //生成sql
  e.generateSql()

  //执行绑定
  var cnt interface{}

  //queryRows
  err := e.Db.QueryRow(e.Prepare, e.AllExec...).Scan(&cnt)
  if err != nil {
    return nil, e.setErrorInfo(err)
  }

  return cnt, err
}

OK,这样,我们就完成了聚合函数的通用主体部分,接下来就是各自的差异部分了。

Count
Count()count
//总数
func (e *SmallormEngine) Count() (int64, error) {
  count, err := e.aggregateQuery("count", "*")
  if err != nil {
    return 0, e.setErrorInfo(err)
  }
  return count.(int64), err
}
count.(xxx)
Max
Max()max
//最大值
func (e *SmallormEngine) Max(param string) (string, error) {
  max, err := e.aggregateQuery("max", param)
  if err != nil {
    return "0", e.setErrorInfo(err)
  }
  return string(max.([]byte)), nil
}

之所以返回值用 string 类型,是因为取最大值,有时候不限制在 int 类型的表字段取最大值,有时候也会有时间最大值等,所以返回 string 是最合适的。

Min
Min()min
//最小值
func (e *SmallormEngine) Min(param string) (string, error) {
  min, err := e.aggregateQuery("min", param)
  if err != nil {
    return "0", e.setErrorInfo(err)
  }

  return string(min.([]byte)), nil
}
Avg
Avg()avg
//平均值
func (e *SmallormEngine) Avg(param string) (string, error) {
  avg, err := e.aggregateQuery("avg", param)
  if err != nil {
    return "0", e.setErrorInfo(err)
  }

  return string(avg.([]byte)), nil
}
Sum
Sum()sum
//总和
func (e *SmallormEngine) Sum(param string) (string, error) {
  sum, err := e.aggregateQuery("sum", param)
  if err != nil {
    return "0", e.setErrorInfo(err)
  }
  return string(sum.([]byte)), nil
}

接下来,来快速的调用看看:

//select count(*) as cnt from userinfo where (uid >= 10805)
cnt, err := e.Table("userinfo").Where("uid", ">=", 10805).Count()


//select max(uid) as cnt from userinfo where (uid >= 10805)
max, err := e.Table("userinfo").Where("uid", ">=", 10805).Max('uid')


//select min(uid) as cnt from userinfo where (uid >= 10805)
min, err := e.Table("userinfo").Where("uid", ">=", 10805).Count()


//select avg(uid) as cnt from userinfo where (uid >= 10805)
avg, err := e.Table("userinfo").Where("uid", ">=", 10805).Avg("uid")


// select sum(uid) as cnt from userinfo where (uid >= 10805)
sum, err := e.Table("userinfo").Where("uid", ">=", 10805).Sum("uid")

Order
ascdesc
//查询结果按照uid倒序
select * from userinfo where (uid >= 10805) order by uid desc

//查询结果按照uid正序
select * from userinfo where (uid >= 10805) order by uid asc

//查询结果,先按照uid正序,再按照status倒序
select * from userinfo where (uid >= 10805) order by uid asc,status desc

所以,我们也把这个操作,用一个单独的方法给暴露出来,方便排序,调用方式如下:

sum, err := e.Table("userinfo").Where("uid", ">=", 10805).Order("uid", "desc").Select()
sum, err := e.Table("userinfo").Where("uid", ">=", 10805).Order("uid","asc", "status", "desc").Select()
order xxx xxx,xx,xxe.OrderParamFind/Select

看下,具体是怎么实现的:

//order排序
func (e *SmallormEngine) Order(order ...string) *SmallormEngine {
  orderLen := len(order)
  if orderLen%2 != 0 {
    panic("order by参数错误,请保证个数为偶数个")
  }

  //排序的个数
  orderNum := orderLen / 2

  //多次调用的情况
  if e.OrderParam != "" {
    e.OrderParam += ","
  }

  for i := 0; i < orderNum; i++ {
    keyString := strings.ToLower(order[i*2+1])
    if keyString != "desc" && keyString != "asc" {
      panic("排序关键字为:desc和asc")
    }
    if i < orderNum-1 {
      e.OrderParam += order[i*2] + " " + order[i*2+1] + ","
    } else {
      e.OrderParam += order[i*2] + " " + order[i*2+1]
    }
  }

  return e
}

唯一复杂的地方,就是判断参数是偶数个数的,然后,按照二分查找法,进行多个排序规则的拼接,这个地方也是有其他的算法进行拼接。

Find/Selecte.Prepare
//order by不为空
if e.OrderParam != "" {
  e.Prepare += " order by " + e.OrderParam
}
Group

分组也是我们平时用的非常多的,它用于我们对某 1 个或者几个字段进行分组,然后查询这个分组后的数据,写法很简单,直接上代码:

//group分组
func (e *SmallormEngine) Group(group ...string) *SmallormEngine {
  if len(group) != 0 {
    e.GroupParam = strings.Join(group, ",")
  }
  return e
}
Field(count(*) as c)
result,err := e.Table("userinfo").Where("departname", "like", "2%").Field("departname, count(*) as c").Group("departname", "status").Select()

Find/Selecte.Prepare
//group 不为空
if e.GroupParam != "" {
  e.Prepare += " group by " + e.GroupParam
}
Having

Having 用于在使用 Group 分组后的过滤查询,它的作用和 where 其实是一模一样的,都是过滤,只不过 Having 只能用于 group 之后,对 select 后面的参数进行过滤,比如这个 sql:

我们想查询出按照 status 分组后,uid 的总数大于 5 的数据:

select status, count(uid) as c from userinfo where (uid >= 10805) group by status having c >= 5

所以,既然绑定的方式和 where 是一模一样的,我们可以看下怎么调用的:

result,err := e.Table("userinfo").Where("", "like", "2%").Field("status, count(uid) as c ").Group(status").Having("c",">=", 5).Select()
result,err := e.Table("userinfo").Where("departname", "like", "2%").Field("status, count(uid) as c ").Group(status").Having("c", 5).Select()


type User struct {
    Status     int64  `sql:"status"`
}

user2 := User1{
    Status:     1,
}
result,err := e.Table("userinfo").Where("departname", "like", "2%").Field("status, count(uid) as c ").Group(status").Having(user2).Select()
Where
//having过滤
func (e *SmallormEngine) Having(having ...interface{}) *SmallormEngine {

  //判断是结构体还是多个字符串
  var dataType int
  if len(having) == 1 {
    dataType = 1
  } else if len(having) == 2 {
    dataType = 2
  } else if len(having) == 3 {
    dataType = 3
  } else {
    panic("having个数错误")
  }

  //多次调用判断
  if e.HavingParam != "" {
    e.HavingParam += "and ("
  } else {
    e.HavingParam += "("
  }

  //如果是结构体
  if dataType == 1 {
    t := reflect.TypeOf(having[0])
    v := reflect.ValueOf(having[0])

    var fieldNameArray []string
    for i := 0; i < t.NumField(); i++ {

      //小写开头,无法反射,跳过
      if !v.Field(i).CanInterface() {
        continue
      }

      //解析tag,找出真实的sql字段名
      sqlTag := t.Field(i).Tag.Get("sql")
      if sqlTag != "" {
        fieldNameArray = append(fieldNameArray, strings.Split(sqlTag, ",")[0]+"=?")
      } else {
        fieldNameArray = append(fieldNameArray, t.Field(i).Name+"=?")
      }

      e.WhereExec = append(e.WhereExec, v.Field(i).Interface())
    }
    e.HavingParam += strings.Join(fieldNameArray, " and ") + ") "

  } else if dataType == 2 {
    //直接=的情况
    e.HavingParam += having[0].(string) + "=?) "
    e.WhereExec = append(e.WhereExec, having[1])
  } else if dataType == 3 {
    //3个参数的情况
    e.HavingParam += having[0].(string) + " " + having[1].(string) + " ?) "
    e.WhereExec = append(e.WhereExec, having[2])
  }

  return e
}

HavingParamWhereExec
Find/Selecte.Prepare
//having
if e.HavingParam != "" {
  e.Prepare += " having " + e.HavingParam
}

OK,我们来试一下怎么调用:

//select uid, status, count(uid) as b from userinfo where (departname like '2%')  group by uid,status having (status=1)  order by uid desc,status asc

result,err := e.Table("userinfo").Where("departname", "like", "2%").Order("uid", "desc", "status", "asc").Field("uid, status, count(uid) as b").Group("uid", "status").Having("status",1).Select()
if err != nil {
    fmt.Println(err.Error())
    return
}
fmt.Println("result is :", result)
GetLastSql

我们上面的所有的方法,其实本质上都是组装成原生 sql 语法的拼装,有时候,我们其实是想知道最后生成的 sql 到底是啥,或者查询报错了,想看下最后生成的 sql 是否有语法错误,我们 ORM 也提供了这个方法,用于查询本次执行最后生成的 sql 语句。

e.Preparee.AllExece.Prepare
//生成完成的sql语句
func (e *SmallormEngine) generateSql() {
  e.Sql = e.Prepare
  for _, i2 := range e.AllExec {
    switch i2.(type) {
    case int:
      e.Sql = strings.Replace(e.Sql, "?", strconv.Itoa(i2.(int)), 1)
    case int64:
      e.Sql = strings.Replace(e.Sql, "?", strconv.FormatInt(i2.(int64), 10), 1)
    case bool:
      e.Sql = strings.Replace(e.Sql, "?", strconv.FormatBool(i2.(bool)), 1)
    default:
      e.Sql = strings.Replace(e.Sql, "?", "'"+i2.(string)+"'", 1)
    }
  }
}

这个替换做的比较简陋,只对基础的 int 和 bool 型做了类型转换,其他类型都当做 sql 里的字符串处理,需要加单引号。

e.SqlGetLastSql
//获取最后执行生成的sql
func (e *SmallormEngine) GetLastSql() string {
  return e.Sql
}

值得注意的是,这个是打印最后一次生成的 sql,如果你有多次 CURD 操作,记得每次去调用:

sum, err := e.Table("userinfo").Where("uid", ">=", 10805).Order("uid", "desc").Select()

fmt.Println(e.GetLastSql()) //select * from userinfo where (uid >= 10805) order by uid asc


sum, err := e.Table("userinfo").Where("uid", ">=", 10805).Order("uid","asc", "status", "desc").Select()

fmt.Println(e.GetLastSql()) //select * from userinfo where (uid >= 10805) order by uid asc,status desc
Exec/Query

本次 ORM 也提供了裸调 sql 的方法,虽然不是推荐使用,但是有时候确实是有这样的需求的使用场景的。

Exec
Exec
result, err := db.Exec("INSERT INTO userinfo (username, departname, created) VALUES (?, ?, ?)","lisi","dev","2020-08-04")

其实,你是可以不传后面的几个参数,不使用问号占位符的,第一个参数直接传完整的 sql 即可,像这样:

result, err := db.Exec("INSERT INTO userinfo (username, departname, created) VALUES ('lisi', 'dev', '2021-11-04')")
Exec
//直接执行增删改sql
func (e *SmallormEngine) Exec(sql string) (id int64, err error) {
  result, err := e.Db.Exec(sql)
  e.Sql = sql
  if err != nil {
    return 0, e.setErrorInfo(err)
  }

  //区分是insert还是其他(update,delete)
  if strings.Contains(sql, "insert") {
    lastInsertId, _ := result.LastInsertId()
    return lastInsertId, nil
  } else {
    rowsAffected, _ := result.RowsAffected()
    return rowsAffected, nil
  }
}

我们通过判断 sql 是的语句是新增还是其他,因为新增的话一般情况是要返回自增 ID 的,而其他情况需要返回影响的行数。

这样,我们就可以很方便的调用原生的 sql 语句了:

//result, err:= e.Exec("insert into userinfo(username,departname,created,status) values('dd', '31','2020-10-02',1)");

//result, err := e.Exec("delete from userinfo where username='dd'")

result, err := e.Exec("update userinfo set username='dd' where uid = 132733")

fmt.Println(err)
fmt.Println(result)
fmt.Println(e.GetLastSql())
Query
Query
result, err := db.Query("SELECT * FROM userinfo limit 1")
Select
//直接执行查sql
func (e *SmallormEngine) Query(sql string) ([]map[string]string, error) {
  rows, err := e.Db.Query(sql)
  e.Sql = sql
  if err != nil {
    return nil, e.setErrorInfo(err)
  }

  //读出查询出的列字段名
  column, err := rows.Columns()
  if err != nil {
    return nil, e.setErrorInfo(err)
  }

  //values是每个列的值,这里获取到byte里
  values := make([][]byte, len(column))

  //因为每次查询出来的列是不定长的,用len(column)定住当次查询的长度
  scans := make([]interface{}, len(column))

  for i := range values {
    scans[i] = &values[i]
  }

  //最后得到的map
  results := make([]map[string]string, 0)
  for rows.Next() {
    if err := rows.Scan(scans...); err != nil {
      //query.Scan查询出来的不定长值放到scans[i] = &values[i],也就是每行都放在values里
      return nil, e.setErrorInfo(err)
    }

    row := make(map[string]string) //每行数据
    for k, v := range values {
      //每行数据是放在values里面,现在把它挪到row里
      key := column[k]
      row[key] = string(v)
    }
    results = append(results, row)
  }

  return results, nil
}

OK,我们就可以这样调用了:

result, err := e.Query("SELECT * FROM userinfo limit 1")

fmt.Println(err)
fmt.Println(result)
fmt.Println(e.GetLastSql())
Begin/Commit/Rollback

sql 里的事务操作也是平时业务中用的非常多的,它用于在多次执行增删改的操作的时候,如果其中 1 个出现问题,可以一起回滚数据,确保了数据的一致性。本 ORM 也提供了相应的方法。事务也是通过封装来调用原生 go 代码里面的事务方法。

一共有 3 个方法配合调用:

  1. Begin() // 开启事物
  2. Rollback() // 回滚
  3. Commit() //确认提交执行
Begin

开启事务功能相对简单,只是设置一个标志符即可:

//开启事务
func (e *SmallormEngine) Begin() error {

  //调用原生的开启事务方法
  tx, err := e.Db.Begin()
  if err != nil {
    return e.setErrorInfo(err)
  }
  e.TransStatus = 1
  e.Tx = tx
  return nil
}
Db.Begin()txe.TransStatus = 1

接下来,我们在具体的增删改查的方法里,通过这个标记去判断现在是不是事务状态:

//判断是否是事务
var stmt *sql.Stmt
var err error
if e.TransStatus == 1 {
  stmt, err = e.Tx.Prepare(e.Prepare)
} else {
  stmt, err = e.Db.Prepare(e.Prepare)
}

...

result, err := stmt.Exec(e.AllExec...)
stmt
Rollback

回滚操作表示我们执行出现了问题后,向 mysql 服务器提供回滚指令,它会将这句 sql 执行的结果给还原。实现原来更简单了,直接调用原生的即可:

//事务回滚
func (e *SmallormEngine) Rollback() error {
  e.TransStatus = 0
  return e.Tx.Rollback()
}
Commit

确认提交表示我们所有的执行都是 OK 的,这个时候我们需要向 mysql 服务器发出确认提交指令,它才会真正意义上将 sql 给执行。如果不执行这个指令,实际上数据并不会执行,所以,我们最后一定不要忘记执行确认提交操作。实现原来也很简单了,直接调用原生的即可:

//事务提交
func (e *SmallormEngine) Commit() error {
  e.TransStatus = 0
  return e.Tx.Commit()
}

我们看下一个完整的事务的调用例子:

err0 := e.Begin()

isCommit := true
if err0 != nil {
    fmt.Println(err0.Error())
    os.Exit(1)
}

result1, err1 := e.Table("userinfo").Where("uid", "=", 10803).Update("departname", 110)
if err1 != nil {
    isCommit = false
    fmt.Println(err1.Error())
}

//没找到,删除失败
if result1 <= 0 {
    isCommit = false
    fmt.Println("update 0")
}

fmt.Println("result1 is :", result1)
fmt.Println("sql is :", e.GetLastSql())

result2, err2 := e.Table("userinfo").Where("uid", "=", 10802).Delete()
if err2 != nil {
    isCommit = false
    fmt.Println(err2.Error())
}

if result2 <= 0 {
    isCommit = false
    fmt.Println("delete 0")
}

fmt.Println("result2 is :", result2)
fmt.Println("sql is :", e.GetLastSql())

user1 := User{
    Username:   "EE",
    Departname: "22",
    Created:    "2012-12-12",
    Status:     1,
}

id, err3 := e.Table("userinfo").Insert(user1)
if err3 != nil {
    isCommit = false
    fmt.Println(err3.Error())
}

fmt.Println("id is :", id)
fmt.Println("sql is :", e.GetLastSql())

if isCommit {
    _ = e.Commit()
    fmt.Println("ok")
} else {
    _ = e.Rollback()
    fmt.Println("error")
}
isCommit

到此为止,我们把 ORM 该有的功能基本上实现了 90%以上,也算是一个小而美、优雅且简单的 ORM 框架了。

三、功能测试和性能测试

功能测试必不可少,而且 go 也给我们提供了很简单就可以完成的测试功能,这个可以逐步完善,我们先看下性能测试,我们和 GORM 跑个分测试一下。

数据库的结构如下,表里面有 209w 数据:

CREATE DATABASE `ApiDB`;

USE ApiDB;

CREATE TABLE `userinfo` (
    `uid` int NOT NULL AUTO_INCREMENT,
    `username` varchar(64) DEFAULT NULL,
    `departname` varchar(64) DEFAULT NULL,
    `created` date DEFAULT NULL,
    `status` int NOT NULL,
    PRIMARY KEY (`uid`)
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4;
SelectUpdate
package smallorm

import (
  "gorm.io/driver/mysql"
  "gorm.io/gorm"
  "testing"
)

func BenchmarkSmallormSelect(b *testing.B) {
  e, _ := NewMysql("root", "123456", "127.0.0.1:3306", "ApiDB")

  type User struct {
    Username   string `gorm:"username"`
    Departname string `gorm:"departname"`
    Created    string `gorm:"created"`
    Status     int64  `gorm:"status"`
  }
  var users[] User

  b.ResetTimer()
  for i := 0; i < b.N; i++ {
    _ = e.Table("userinfo").Where("id", ">=", 50).Limit(100).Find(&users)
  }
  b.StopTimer()
}

func BenchmarkGormSelect(b *testing.B) {
  dsn := "root:123456@tcp(127.0.0.1:3306)/ApiDB?charset=utf8mb4&parseTime=True&loc=Local"
  db, _ := gorm.Open(mysql.Open(dsn), &gorm.Config{})

  type User struct {
    Username   string `gorm:"username"`
    Departname string `gorm:"departname"`
    Created    string `gorm:"created"`
    Status     int64  `gorm:"status"`
  }
  var users[] User

  b.ResetTimer()
  for i := 0; i < b.N; i++ {
    db.Table("userinfo").Where("uid >= ?", "50").Limit(50).Find(&users)
  }
  b.StopTimer()
}

func BenchmarkSmallormUpdate(b *testing.B) {
  e, _ := NewMysql("root", "123456", "127.0.0.1:3306", "ApiDB")

  b.ResetTimer()
  for i := 0; i < b.N; i++ {
    _,_ = e.Table("userinfo").Where("id", "=", 15).Update("status", 0)
  }
  b.StopTimer()
}

func BenchmarkGormUpdate(b *testing.B) {
  dsn := "root:123456@tcp(127.0.0.1:3306)/ApiDB?charset=utf8mb4&parseTime=True&loc=Local"
  db, _ := gorm.Open(mysql.Open(dsn), &gorm.Config{})

  b.ResetTimer()
  for i := 0; i < b.N; i++ {
    db.Table("userinfo").Where("uid = ?", "15").Update("status", 1)
  }
  b.StopTimer()
}

运行下,看下跑分数据:

go test -bench=. -benchmem

goos: darwin
goarch: amd64
pkg: smallorm
cpu: Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz
BenchmarkSmallormSelect-12  1296   843769 ns/op    911 B/op    25 allocs/op
BenchmarkGormSelect-12      598    1998827 ns/op   29250 B/op  1058 allocs/op
BenchmarkSmallormUpdate-12  1197   864404 ns/op    727 B/op    21 allocs/op
BenchmarkGormUpdate-12      314    4216470 ns/op   6246 B/op   76 allocs/op
PASS
ok      smallorm        6.880s

这个跑分,大家可以看下。

四、待实现功能

  • [ ] 1. 多表联合查询
  • [ ] 2. 快捷 hash 分表
  • [ ] 3. 其他 sql 引擎的支持(sqlite3,PostgreSQL 等)
  • [ ] 4. 日志、性能、结构、安全的优化