Golang之搭建ORM框架

背景

用过GORM后,我个人觉得不够简洁有些笨重,有较大的学习成本。本着学习和探索的目的,自己实现了个较为简洁优雅的go版本ORM。本文主要从基础原理开始将,再到实现步骤,最终完成这个简洁版的ORM(MySQL版),目的不是说一定要实现这个ORM来用,而是重点在于学习实现的思路,以及可改进方案。

ORM是什么?为什么要用ORM?

不管是用PHP、Go等语言,我们在做需求的时候,应该都使用过ORM。特别在新项目中,我们会用ORM框架去连接数据库,而不是直接用原生去写SQL连接,原因有很多,比如有安全、性能等的考虑,但我觉得更多的还是因为开发效率和懒,因为有些SQL写起来很复杂很累,特别是查询的时候,又是分页,又是结果集,还需要自己去判断和遍历,开发效率真的低。如果有个ORM,数据库config配置完,几个链式调用,结果就出来了,方便多了。

通俗来讲就是:ORM是提供与数据库交互的中间件。ORM提供更加方便的curd方法与数据库产生交互。

Go原生如何去连接MySQL

了解了ORM后,再来看看如何原生连接MySQL,只有清楚他们之间交互的原理,才能更好的搭建ORM(造轮子)。

第一步:导入数据库包和MySQL驱动包


import (
	"database/sql"
	_ "github.com/go-sql-driver/mysql"
)

第二步:连接MySQL


db, err := sql.Open("mysql", "root:123456@tcp(127.0.0.1:3306)/test?charset=utf8mb4")
if err != nil {
	panic(err)
}

第三步:快速过一遍增删改查


func main() {
	db, err := sql.Open("mysql", "root:123456@tcp(127.0.0.1:3306)/test?charset=utf8mb4")
	if err != nil {
		panic(err)
	}

	// 增
	// 方式一:
	result, err := db.Exec("INSERT INTO users (username, jobname, created_at) VALUES (?, ?, ?)", "zhangsan", "dev", "2022-02-23")

	// 方式二:
	stmt, err := db.Prepare("INSERT INTO users (username, jobname, created_at) VALUES (?, ?, ?)")

	result2, err := stmt.Exec("zhangsan", "dev", time.Now().Format("2006-01-02"))

	// 删
	// 方式一:
	result3, err := db.Exec("delete from users where id=?", 1)

	// 方式二:
	stmt, err := db.Prepare("delete from users where id=?")

	result4, err := stmt.Exec(1)

	// 改
	// 方式一:
	result, err := db.Exec("update users set username=? where id=?", "lisi", 2)

	// 方式二:
	stmt, err := db.Prepare("update users set username=? where id=?")

	result, err := stmt.Exec("lisi", 2)

	// 查
	// 单条
	var username, jobname, createdAt string
	err := db.QueryRow("select username, jobname, created_at from users where id=?", 4).Scan(&username, &jobname, &createdAt)
	if err != nil {
		fmt.Println("QueryRow error :", err.Error())
	}
	fmt.Println("username: ", username, "jobname: ", jobname, "created_at: ", createdAt)

	// 多条
	rows, err := db.Query("select username, jobname, created_at from users where username=?", "wangwu")
	if err != nil {
		fmt.Println("QueryRow error :", err.Error())
	}

	// 结构体
	type Users struct {
		Username  string `json:"username"`
		JobName   string `json:"jobname"`
		CreatedAt string `json:"created_at"`
	}

	// 初始化
	var user []Users

	// 判断及遍历
	for rows.Next() {
		var username1, jobname1, createdAt1 string
		if err := rows.Scan(&username1, &jobname1, &createdAt1); err != nil {
			fmt.Println("Query error :", err.Error())
		}
		user = append(user, Users{Username: username1, JobName: jobname1, CreatedAt: createdAt1})
	}
}

总结下:Go里面原生连接MySQL的方法,就是直接写sql,简单粗暴点(方式一)就直接Exec执行,想性能高一点就可以麻烦点(方式二)就先Prepare再Exec(此处原因自行进行Benchmark实验可以得出,不做过多赘述)。学习成本是较低,但问题还是麻烦和开发效率低。

所以我在想基于这个原生代码的优势,撸一个ORM:1、能提供各种方法提高开发效率;2、底层直接转换拼接成完整SQL,然后调用原生的组件和MySQL进行交互。这样就可以降低学习其他ORM的成本,也可以提高开发效率,也用得顺手方便。

搭建ORM的构思

原理:简单的SQL拼接,暴露出各种CURD方法,并在底层逻辑拼接成Prepare和Eexc占位符的部分,然后调用原生mysql驱动包的方法来实现和数据库交互。

调用过程采取链式调用的方法。比如:


db.Where().Order().Limit().Select()

注意:暴露的CURD方法,名字要清晰,无歧义,不要搞一大堆复杂的间接调用,使用起来要简单。

搭建前先整理列举出mysql常用的curd的方法,按照列表一个个实现:

  • 连接 Connect
  • 设置表名 Table
  • 新增 Insert
  • 条件 Where
  • 删除 Delete
  • 更新 Update
  • 查询 Select
  • 设置查询字段 Field
  • 设置条数 Limit
  • 排序 Order
  • 聚合查询 Count/Sum/Max/Min/Avg
  • 分组 Group
  • 分组后判断 Having
  • 获取执行的SQL GetLastSql
  • 执行原生sql Exec/Query
  • 事务 Begin/Commit/Rollback

注意:Insert/Delete/Select/Update是整个链式调用的最后一步。是真正的和MySQL交互的方法,后面不能再接其他方法。

在开始搭建之前先设想下,自己开发的这个ORM是什么样的,如何去使用它:


// 注意:Users结构体的每一个元素后面都有一个sql:"xxx",叫Tag标签。因为go里面首字母大写表示是可见的属性,所以如果是可见的属性都是大写字母开头,而数据表里面的字段首字母名一般是小写,所以为了照顾这个特殊关系而进行转换和匹配,才用了这个标签。如果你的表的字段类型也是大小字母开头,那可以不需要这个标签,下面会具体说到如何转换匹配。

type Users struct {
    Username  string `sql:"username"`
    JobName   string `sql:"jobname"`
    CreatedAt int64  `sql:"created_at"`
}

// 增

user1 := Users{
    Username:  "admin",
    JobName:   "dev", 
    CreatedAt: time.Now().Unix(),
}

// insert into users (username,jobname,created_at) values ('admin', 'dev', 1645597211)

id, err := db.Table("users").Insert(user1)

// 删

// delete from users where id = 1

result1, err := db.Table("users").Where("id", "=", 1).Delete()

// 改

// update users set jobname = 'test' where id = 1)

result2, err := db.Table("users").Where("id", "=", 1).Update("jobname", "test")

// 查


// select id,username,jobname from users where (username like '%san') or (jobname = 'dev') order by id desc limit 1

result3, err := db.Table("users").Where("username", "like", "%san").OrWhere("jobname", "=", "dev").Field("id, username, jobname").Order("id", "desc").Limit(1).Select()


user2 := Users{
    Username:  "admin",
    JobName:   "dev", 
}

// select * from users where username='admin' and jobname='dev' limit 1
result4, err := db.Table("users").Where(user2).SelectOne()

开发阶段

一、连接 Connect

连接MySQL直接把原生的sql.Open方法套上函数即可。第一个版本以实现功能为主,就暂时先不考虑协程和长连接的保持。

1、先构建一个结构体,用来存储各种数据。因为这个ORM的底层本质是SQL拼接,所以我们需要把各种操作方法生成的数据,都保存到这个结构体的各个属性上,方便最后一步生成SQL。


type DbEngine struct {
   Db           *sql.DB  		// 用于直接进行CURD操作
   TableName    string
   WhereParam   string
   OrWhereParam string
   WhereExec    []interface{}
   UpdateParam  string
   UpdateExec   []interface{}
   FieldParam   string
   LimitParam   string
   OrderParam   string
   GroupParam   string
   HavingParam  string
   Prepare      string
   AllExec      []interface{}
   Sql          string
   TransStatus  int
   Tx           *sql.Tx   	 	// 数据库的事务操作,用于回滚和提交
}

2、编写连接方法


// 新建Mysql连接
func NewMysql(user, pwd, address, dbname string) (*DbEngine, error) {
    dsn := user + ":" + pwd + "@tcp(" + address + ")/" + dbname + "?charset=utf8mb4&timeout=5s&readTimeout=5s"
    db, err := sql.Open("mysql", dsn)
    if err != nil {
        return nil, err
    }

   // 配置,先占个位
   db.SetMaxOpenConns(100)
   db.SetMaxIdleConns(100)

   return &DbEngine{
       Db:         db,
       FieldParam: "*",
   }, nil
}

二、设置/读取表名 Table/GetTable

每次调用Table()方法,就给本次执行设置一个表名。并且清空DbEngine的其他数据。


// 设置表名
func (e *DbEngine) Table(tableName string) *DbEngine {
   e.TableName = tableName

   e.resetDbEngine()
   return e
}


// 获取表名
func (e *DbEngine) GetTable() string {
   return e.TableName
}

// 重置引擎
func (e *DbEngine) resetDbEngine() {
	*e = DbEngine{
		Db:        e.Db,
		TableName: e.TableName,
	}
}

三、新增 Insert

这里涉及到的反射知识点会比较多,可自行了解

  • 单条插入

回顾下原生的单条插入,采用Prepare再Exec的方式,高效且安全:


stmt, err := db.Prepare("INSERT INTO users (username, jobname, created_at) VALUES (?, ?, ?)")

result, err := stmt.Exec("zhangsan", "dev", time.Now().Unix())

先在Prepare里把插入的数据的value值用占位符代替,有几个value就用几个占位符,在Exec里面,把value值一一对应上。

思路:为了方便,我们插入数据的参数要传kv键值对集合,比如[field1:val1,field2:val2,field3:val3]。在go语言中可以选择Map或者Struct,由于数据表不同字段可能是不同类型,所以选择结构体。


type Users struct {
    Username  string `sql:"username"`
    JobName   string `sql:"jobname"`
    CreatedAt int64  `sql:"created_at"`
}

// 增

user1 := Users{
    Username:  "admin",
    JobName:   "dev", 
    CreatedAt: time.Now().Unix(),
}

// insert into users (username,jobname,created_at) values ('admin', 'dev', 1645597211)

id, err := e.Table("users").Insert(user1)

第一步:将结构体中sql:"xxx"标签进行解析,解析成(username, jobname, created_at)。

第二步:将user1的元素的值都取出来,与字段一一对应的放入到Exec中。

如何从结构体中解析出对应的kv键值呢?可以通过反射来推导出结构体属性,属性、属性值、类型、tag标签。


// 反射获取结构体变量的类型
t := reflect.TypeOf(user1)
fmt.Printf("%+v\n", t) // main.User

// 反射获取结构体属性对应值
v := reflect.ValueOf(user1)
fmt.Printf("%+v\n", v) // {Username:admin JobName:dev CreatedAt:1645597211}

// 其他反射方法自行了解
t.NumField() // 获取结构体字段总数
t.Field(i)	 // 获取结构体属性,属性名、包括类型、tag标签等的值
v.Field(i)	 // 获取结构体属性值
...

单条插入,代码如下:


func (e *DbEngine) insertData(data interface{}) (int64, error) {
	// 反射type和value
	t := reflect.TypeOf(data)
	v := reflect.ValueOf(data)

	// 字段名
	var fieldName []string
	// 占位符
	var placeholder []string
	// 字段总数
	num := t.NumField()

	for i := 0; i < num; i++ {
	  // 判断字段值
	  if !v.Field(i).CanInterface() {
	    continue
	  }

	  // 解析tag
	  sqlTag := t.Field(i).Tag.Get("sql")
	  if sqlTag != "" {
	    // 跳过自增字段
	    if strings.Contains(strings.ToLower(sqlTag), "auto_increment") {
	      continue
	    } else {
	      fieldName = append(fieldName, strings.Split(sqlTag, ",")[0])
	      placeholder = append(placeholder, "?")
	    }
	  } else {
	    fieldName = append(fieldName, t.Field(i).Name)
	    placeholder = append(placeholder, "?")
	  }

	  // 组装字段值
	  e.AllExec = append(e.AllExec, v.Field(i).Interface())
	}

	// sql拼接
	e.Prepare =  "insert into " + e.GetTable() + " (" + strings.Join(fieldName, ",") + ") values(" + strings.Join(placeholder, ",") + ")"

	var stmt *sql.Stmt
	var err error

	// 预处理
	stmt, err = e.Db.Prepare(e.Prepare)
	if err != nil {
    	return 0, e.setError(err)
  	}

	// 生成sql
	e.generateSql()
	// 执行exec
	result, err := stmt.Exec(e.AllExec...)
	if err != nil {
	  return 0, e.setError(err)
	}

	id, _ := result.LastInsertId()

	return id, nil
}

// 设置错误内容
func (e *DbEngine) setError(err error) error {
  _, file, line, _ := runtime.Caller(1)
  return errors.New("File: " + file + ":" + strconv.Itoa(line) + ", " + err.Error())
}
  • 批量插入

不描述那么多了,直接说下思路:
1、传入参数为一个切片数组结构体,[]struct
2、通过反射获取切片数据,计算该数组元素个数,进行循环处理
3、2个for循环,最外层for循环获取切片数组中每个子元素的值和类型,也即结构体的值和结构体类型,然后进行判断
4、里面的for循环则和单条插入的操作一样,反射获取字段名,计算字段总数,解析tag,组装占位符、字段值
5、最终拼接sql进行操作

批量插入,代码如下:


func (e *DbEngine) batchInsertData(batchData interface{}) (int64, error) {
	// 反射解析
	v := reflect.ValueOf(batchData)
	// 切片长度
	l := v.Len()
	// 字段名
	var fieldName []string
	// 总占位符
	var placeholderString []string

	for i := 0; i < l; i++ {
		value := v.Index(i) 	// 子元素值,即结构体
		typed := value.Type()   // 子元素值类型

		if typed.Kind() != reflect.Struct {
		  panic("参数必须是切片结构体类型")
		}

		num := value.NumField()

		// 子占位符
		var placeholder []string

		for j := 0; j < num; j++ {
		  // 判断字段值
		  if !value.Field(j).CanInterface() {
		    continue
		  }

		  // 解析tag
		  sqlTag := typed.Field(j).Tag.Get("sql")
		  if sqlTag != "" {
		    // 跳过自增字段
		    if strings.Contains(strings.ToLower(sqlTag), "auto_increment") {
		      continue
		    } else {
		      // 字段名只记录一个循环的即可
		      if i == 1 {
		        fieldName = append(fieldName, strings.Split(sqlTag, ",")[0])
		      }
		      placeholder = append(placeholder, "?")
		    }
		  } else {
		    if i == 1 {
		      fieldName = append(fieldName, typed.Field(j).Name)
		    }
		    placeholder = append(placeholder, "?")
		  }

		  // 组装字段值
		  e.AllExec = append(e.AllExec, value.Field(j).Interface())
		}

		// 组装成多个()括号的字段值
		placeholderString = append(placeholderString, "("+strings.Join(placeholder, ",")+")")
	}

	// 拼接sql
	e.Prepare = "insert into " + e.GetTable() + " (" + strings.Join(fieldName, ",") + ") values " + strings.Join(placeholderString, ",")

	var stmt *sql.Stmt
	var err error

	// 预处理
	stmt, err = e.Db.Prepare(e.Prepare)
	if err != nil {
		return 0, e.setError(err)
	}

	// 生成sql
	e.generateSql()
	// 执行exec
	result, err := stmt.Exec(e.AllExec...)
	if err != nil {
	return 0, e.setError(err)
	}

	// 获取自增ID
	id, _ := result.LastInsertId()
	return id, nil
}
  • 合二为一,整合单条和批量为一个方法

为了让ORM简洁和优雅一点,把单条插入和批量插入整合一起,只对外暴露一个方法。通过反射判断传入的参数是单结构体还是切片结构体。


func (e *DbEngine) Insert(data interface{}) (int64, error) {
	// 参数类型
	typed := reflect.ValueOf(data).Kind()

	// 判断是批量还是单个插入
	if typed == reflect.Struct {
		return e.insertData(data)
	} else if typed == reflect.Slice || typed == reflect.Array {
		return e.batchInsertData(data)
	} else {
		return 0, errors.New("参数有误")
	}
}

四、条件 Where

Where方法主要是为了替换sql中where后面的条件语句,结合各种比较符号:=、!=、>、<、like、in或者用and、or去隔开多个条件。
思路:同样是先占位符进行预处理,然后再执行exec替换值的方式操作,所以Where方法的逻辑也是和Insert一样分两步走

  • 方式一:结构体参数

type Users struct {
    Username  string `sql:"username"`
    JobName   string `sql:"jobname"`
    CreatedAt int64  `sql:"created_at"`
}

user1 := Users{
    Username:  "admin",
    JobName:   "dev", 
}

// select * from users where username = 'admin' and jobname = 'dev';

result, err := e.Table("users").Where(user1).Select()

where部分是中间层,不会去执行结果,只是将数据拆分出来,存入到DB引擎结构体的WhereParam和WhereExec属性中,由最后的增删改查操作使用。

结构体参数Where方法,代码如下:


func (e *DbEngine) Where(data interface{}) *DbEngine {

    // 反射type和value
    t := reflect.TypeOf(data)
    v := reflect.ValueOf(data)

    // 字段名
    var fieldNameArray []string
    // 字段总数
    num := t.NumField()

    for i := 0; i < num; i++ {
      // 判断字段值
      if !v.Field(i).CanInterface() {
        continue
      }

      // 解析tag
      sqlTag := t.Field(i).Tag.Get("sql")
      if sqlTag != "" {
      	 fieldNameArray = append(fieldNameArray, strings.Split(sqlTag, ",")[0]+"=?")
      } else {
      	 fieldNameArray = append(fieldNameArray, t.Field(i).Name+"=?")
      }

      // 组装值
      e.WhereExec = append(e.WhereExec, v.Field(i).Interface())
    }

    // 多次调用判断
	if e.WhereParam != "" {
		// 如果不为空,则说明这是第二次调用了
		e.WhereParam += " and ("
	} else {
		e.WhereParam += "("
	}

    // 拼接条件语句
    e.WhereParam += strings.Join(fieldNameArray, " and ") + ") "

    return e 
}

打印下WhereParam和WhereExec的值


// fmt.Println
WhereParam = "(username=? and jobname=?) "
WhereExec = []interface{"admin", "dev"}

// 多次调用Where方法,每次都是一个and连接
result, err := e.Table("users").Where(user1).Where(user2).Select()

// fmt.Println
WhereParam = "(username=? and jobname=?) and (created_at=?)"
WhereExec = []interface{"admin", "dev", 1645597211}

注意:结构体参数的方式每个条件都是等于“=”的关系,如果需要其他条件关系的话,可以使用下面的字符串参数的方式

  • 方式二:字符串参数

结构体参数方式除了上面说的只能是使用“=”的条件外,他每次使用还得写个结构体然后赋值才能使用,有时只有一个参数的话使用结构体参数方式不够便捷。所以提供了字符串参数的方式,较为简单便捷。

此方式的Where方法有三个参数,(查询的字段,条件符号,值)如:


// 也可以用其他条件符号,如:=、!=、>、<、>=、<=、not in、in、like
Where("id", "=", 1)
Where("id", ">", 2)
Where("id", ">=", 3)
Where("id", "in", []int{1, 2, 3, 4})
Where("name", "like", "%tim")

字符串参数方法,代码如下:


func (e *DbEngine) Where(field string, opt string, value interface{}) *DbEngine {
    // 获取符号
    data2 := strings.Trim(strings.ToLower(opt), " ")

    // 多次调用判断
	if e.WhereParam != "" {
		e.WhereParam += " and ("
	} else {
		e.WhereParam += "("
	}

    // 判断是否是in
    if data2 == "in" || data2 == "not in" {
      typed := reflect.TypeOf(value).Kind()
      // 判断传入的是切片或数组
      if typed != reflect.Slice && typed != reflect.Array {
		 panic("in类的操作参数必须是切片或数组")
      }

      // 反射值
      v := reflect.ValueOf(value)
      // 切片长度
      l := v.Len()
      // 占位符
      ps := make([]string, l)

      for i := 0; i < l; i++ {
         ps[i] = "?"
         e.WhereExec = append(e.WhereExec, v.Index(i).Interface())
      }

      // 拼接
      e.WhereParam += field + " " + data2 + " (" + strings.Join(ps, ",") + ")) "

    } else {
      e.WhereParam += field + " " + opt + " ?) "
      e.WhereExec = append(e.WhereExec, value)
    }

    return e 
}
  • 合二为一,整合为一个方法,并增加了"="等于号的情况下,可以省略等于号,只传字段和值两个参数

// Where and条件方法
func (e *DbEngine) Where(data ...interface{}) *DbEngine {
	// 判断是结构体参数还是字符串参数
	var dataType int
	if len(data) == 1 {
		dataType = 1 // 结构体参数
	} else if len(data) == 2 {
		dataType = 2 // 直接等于"="的情况,省去等于号,只传字段和值两个参数
	} else if len(data) == 3 {
		dataType = 3 // 字符串参数
	} else {
		panic("参数错误")
	}

	// 多次调用判断
	if e.WhereParam != "" {
		e.WhereParam += " and ("
	} else {
		e.WhereParam += "("
	}

	// 结构体
	if dataType == 1 {
		t := reflect.TypeOf(data[0])
		v := reflect.ValueOf(data[0])

		// 字段名
		var fieldNameArr []string
		// 字段总数
		num := t.NumField()

		for i := 0; i < num; i++ {
			// 判断字段值
			if !v.Field(i).CanInterface() {
				continue
			}

			// 解析tag
			sqlTag := t.Field(i).Tag.Get("sql")
			if sqlTag != "" {
				fieldNameArr = append(fieldNameArr, strings.Split(sqlTag, ",")[0]+"=?")
			} else {
				fieldNameArr = append(fieldNameArr, t.Field(i).Name+"=?")
			}
			e.WhereExec = append(e.WhereExec, v.Field(i).Interface())
		}
		// 拼接
		e.WhereParam += strings.Join(fieldNameArr, " and ") + ") "

	} else if dataType == 2 {
		// 直接等于"="
		e.WhereParam += data[0].(string) + "=?) "
		e.WhereExec = append(e.WhereExec, data[1])
	} else if dataType == 3 {
		// 字符串参数
		// 获取符号
		data2 := strings.Trim(strings.ToLower(data[1].(string)), " ")
		// 判断是否是in操作
		if data2 == "in" || data2 == "not in" {
			// 判断传入的是切片
			reType := reflect.TypeOf(data[2]).Kind()
			if reType != reflect.Slice && reType != reflect.Array {
				panic("in类操作传入参数必须是切片或数组")
			}

			// 反射切片值
			v := reflect.ValueOf(data[2])
			// 数组/切片长度
			l := v.Len()
			// 占位符
			ps := make([]string, l)

			for i := 0; i < l; i++ {
				ps[i] = "?"
				e.WhereExec = append(e.WhereExec, v.Index(i).Interface())
			}
			// 拼接
			e.WhereParam += data[0].(string) + " " + data2 + " (" + strings.Join(ps, ",") + ")) "
		} else {
			e.WhereParam += data[0].(string) + " " + data[1].(string) + " ?) "
			e.WhereExec = append(e.WhereExec, data[2])
		}
	}
	return e
}

五、条件 OrWhere

思路:与Where方法是一样,只需要将whereParam替换成OrWhereParam,WhereExec还是维持不变,然后把and拼接改为or即可

同样支持三种方式,直接上代码:


// OrWhere or条件方法
func (e *DbEngine) OrWhere(data ...interface{}) *DbEngine {
	// 判断是结构体参数还是字符串参数
	var dataType int
	if len(data) == 1 {
		dataType = 1 // 结构体参数
	} else if len(data) == 2 {
		dataType = 2 // 直接等于"="的情况,省去等于号,只传字段和值两个参数
	} else if len(data) == 3 {
		dataType = 3 // 字符串参数
	} else {
		panic("参数错误")
	}

	// 判断方法使用顺序
	if e.WhereParam == "" {
		panic("OrWhere方法必须在Where方式之后调用")
	}

	e.OrWhereParam += " or ("

	// 结构体
	if dataType == 1 {
		t := reflect.TypeOf(data[0])
		v := reflect.ValueOf(data[0])

		// 字段名
		var fieldNameArr []string
		// 字段总数
		num := t.NumField()

		for i := 0; i < num; i++ {
			// 判断字段值
			if !v.Field(i).CanInterface() {
				continue
			}

			// 解析tag
			sqlTag := t.Field(i).Tag.Get("sql")
			if sqlTag != "" {
				fieldNameArr = append(fieldNameArr, strings.Split(sqlTag, ",")[0]+"=?")
			} else {
				fieldNameArr = append(fieldNameArr, t.Field(i).Name+"=?")
			}
			e.WhereExec = append(e.WhereExec, v.Field(i).Interface())
		}
		// 拼接
		e.OrWhereParam += strings.Join(fieldNameArr, " and ") + ") "

	} else if dataType == 2 {
		// 直接等于"="
		e.OrWhereParam += data[0].(string) + "=?) "
		e.WhereExec = append(e.WhereExec, data[1])
	} else if dataType == 3 {
		// 字符串参数
		// 获取符号
		data2 := strings.Trim(strings.ToLower(data[1].(string)), " ")
		// 判断是否是in操作
		if data2 == "in" || data2 == "not in" {
			// 判断传入的是切片
			reType := reflect.TypeOf(data[2]).Kind()
			if reType != reflect.Slice && reType != reflect.Array {
				panic("in类操作传入参数必须是切片或数组")
			}

			// 反射切片值
			v := reflect.ValueOf(data[2])
			// 数组/切片长度
			l := v.Len()
			// 占位符
			ps := make([]string, l)

			for i := 0; i < l; i++ {
				ps[i] = "?"
				e.WhereExec = append(e.WhereExec, v.Index(i).Interface())
			}
			// 拼接
			e.OrWhereParam += data[0].(string) + " " + data2 + " (" + strings.Join(ps, ",") + ")) "
		} else {
			e.OrWhereParam += data[0].(string) + " " + data[1].(string) + " ?) "
			e.WhereExec = append(e.WhereExec, data[2])
		}
	}
	return e
}

注意:必须先调用Where后才能调用OrWhere方法。因为一般用到了or,前面肯定也有前置where条件。
为了使这个方法更简单使用,不搞得复杂,这种方式的or实质上是针对多次调用where之间的去使用的,是不支持同一个where里面的数据是or关系的,一般也不会用到这种,不然性能也是极差。

六、删除 Delete

思路:Delete方法是直接与数据库交互的最后一步,不需要再去处理各种数据,只需要把Where或OrWhere方法中组装好的参数和值写入预处理和执行方法中即可

代码如下:


// Delete 删除方法
func (e *DbEngine) Delete() (int64, error) {
	// 拼接sql
	e.Prepare = "delete from " + e.GetTable()

	// 若where参数不为空
	if e.WhereParam != "" || e.OrWhereParam != "" {
		e.Prepare += " where " + e.WhereParam + e.OrWhereParam
	}
	// limit不为空
	if e.LimitParam != "" {
		e.Prepare += "limit " + e.LimitParam
	}

	var stmt *sql.Stmt
	var err error
	// 预处理
	stmt, err = e.Db.Prepare(e.Prepare)
	if err != nil {
		return 0, err
	}

	e.AllExec = e.WhereExec

	// 生成sql
	e.generateSql()
	// 执行
	result, err := stmt.Exec(e.AllExec...)
	if err != nil {
		return 0, e.setError(err)
	}

	// 影响行数
	rowsAffected, err := result.RowsAffected()
	if err != nil {
		return 0, e.setError(err)
	}
	return rowsAffected, nil
}

调用方式


// delete from users where (id = 1) or (username = 'admin');
rowsAffected, err := e.Table("users").Where("id", 1).OrWhere("username", "admin").Delete()

7、更新 Update


update users set jobname = 'web' where (id = 1) or (username = 'admin');

思路:Update和Delete都是与数据库交互的最后一步,但是Delete不同的是他不仅需要Where或OrWhere方法中组装好的参数和值,还需要处理要更新的数据,也就是set jobname='web’这一步,我们可以参考下Where方法,搞一个字符串参数传值和结构体传值的形式,用起来灵活一点,具体实现方式跟实现Insert几乎一样。

例如:


// 字符串参数
e.Table("users").Where("id", 1).Update("jobname", "web")

// 结构体参数
e.Table("users").Where("id", 1).Update(user1)

代码如下:


// Update 更新
func (e *DbEngine) Update(data ...interface{}) (int64, error) {
	// 判断结构体参数还是字符串参数
	var dataType int
	if len(data) == 1 {
		dataType = 1 // 结构体参数
	} else if len(data) == 2 {
		dataType = 2 // 直接等于"="的情况,省去等于号,只传字段和值两个参数
	} else {
		return 0, errors.New("参数错误")
	}

	// 结构体参数
	if dataType == 1 {
		t := reflect.TypeOf(data[0])
		v := reflect.ValueOf(data[0])

		// 字段名
		var fieldNameArr []string

		for i := 0; i < t.NumField(); i++ {
			// 判断字段值
			if !v.Field(i).CanInterface() {
				continue
			}
			// 解析tag
			sqlTag := t.Field(i).Tag.Get("sql")
			if sqlTag != "" {
				fieldNameArr = append(fieldNameArr, strings.Split(sqlTag, ",")[0]+"=?")
			} else {
				fieldNameArr = append(fieldNameArr, t.Field(i).Name+"=?")
			}
			// 更新值
			e.UpdateExec = append(e.UpdateExec, v.Field(i).Interface())
		}
		// 更新字段
		e.UpdateParam += strings.Join(fieldNameArr, ",")
	} else if dataType == 2 {
		// 直接等于的情况
		e.UpdateParam += data[0].(string) + "=?"
		e.UpdateExec = append(e.UpdateExec, data[1])
	}

	// 拼接sql
	e.Prepare = "update " + e.GetTable() + " set " + e.UpdateParam

	// where不为空
	if e.WhereParam != "" || e.OrWhereParam != "" {
		e.Prepare += " where " + e.WhereParam + e.OrWhereParam
	}

	// limit不为空
	if e.LimitParam != "" {
		e.Prepare += "limit " + e.LimitParam
	}

	var stmt *sql.Stmt
	var err error
	// 预处理
	stmt, err = e.Db.Prepare(e.Prepare)
	if err != nil {
		return 0, e.setError(err)
	}

	// 合并UpdateExec和WhereExec,即update table set params1 = ? where id = ?的问号参数值
	if e.WhereExec != nil {
		// 此处切记要加...,目的是把切片里的一个个参数都追加到UpdateExec后面,形成类似php中array_merge的效果
		e.AllExec = append(e.UpdateExec, e.WhereExec...)
	}

	// 生成sql
	e.generateSql()
	// 执行
	result, err := stmt.Exec(e.AllExec...)
	if err != nil {
		return 0, e.setError(err)
	}

	// 影响行数
	id, _ := result.RowsAffected()
	return id, nil
}

八、查询 Select

Go的原生代码中,查询的写法是没有Prepare和Exec的,而是通过QueryRow和Query方法来查询数据,然后再用Scan去绑定赋值数据,写法真的麻烦


// 单条查询,需要先把查询的字段定义好,然后再Scan()去绑定赋值
var username, jobname, createdAt string

err := db.QueryRow("select username, jobname, created_at from users where id=?", 1).Scan(&username, &jobname, &createdAt)
if err != nil {
    fmt.Println("QueryRow error :", err.Error())
}
fmt.Println(username, jobname, createdAt)


// 多条查询,先把需要查询的字段结构体定义好,再初始化切片结构体,for循环给这个切片赋值
rows, err := db.Query("select username, jobname, created_at from users where username=?", "admin")
if err != nil {
    fmt.Println("QueryRow error :", err.Error())
}

type Users struct {
    Username   string `json:"username"`
    Jobname    string `json:"jobname"`
    CreatedAt  string `json:"created_at"`
}

// 初始化切片结构体
var user []Users

for rows.Next() {
    var username1, jobname1, createdAt1 string

    if err := rows.Scan(&username1, &jobname1, &createdAt1); err != nil {
        fmt.Println("Query error :", err.Error())
    }
    user = append(user, Users{Username: username1, Jobname: jobname1, CreatedAt: createdAt1})
}

思路:
为了简化查询内部实现的复杂度,单条查询也用Query方法,然后使用limit 1限制条数,继而满足条件。这样处理的话多条查询和单条查询都能使用同样的方法处理,只是对条数做出限制。
要先定义结构体,然后再初始化一个切片结构体,这操作有点麻烦,我的想法是直接按照表里的字段名返回在切片map中。
那怎么拿到表里的字段有哪些呢?DB.Query提供了Columns的方法,能返回本次查询的表字段。
观察一下,rows.Scan()的数据绑定,它是需要先定义字段变量和类型,每次循环的时候把地址传给Scan方法,通过地址来动态引用赋值,所以这几个字段变量名字不重要,反正最后都是传他们的地址。

  • 多条查询之切片map

// Select 多条查询,返回值为切片map
func (e *DbEngine) Select() ([]map[string]string, error) {
	// 拼接sql
	e.Prepare = "select " + e.FieldParam + " from " + e.GetTable()

	// where不为空
	if e.WhereParam != "" || e.OrWhereParam != "" {
		e.Prepare += " where " + e.WhereParam + e.OrWhereParam
	}
	// group不为空
	if e.GroupParam != "" {
		e.Prepare += " group by " + e.GroupParam
	}
	// having不为空
	if e.HavingParam != "" {
		e.Prepare += " having " + e.HavingParam
	}
	// order by不为空
	if e.OrderParam != "" {
		e.Prepare += " order by " + e.OrderParam
	}
	// limit不为空
	if e.LimitParam != "" {
		e.Prepare += " limit " + e.LimitParam
	}
	// 查询条件对应值
	e.AllExec = e.WhereExec

	// 生成sql
	e.generateSql()
	// 查询数据
	rows, err := e.Db.Query(e.Prepare, e.AllExec...)
	if err != nil {
		return nil, e.setError(err)
	}

	// 查询字段名
	columns, err := rows.Columns()
	if err != nil {
		return nil, e.setError(err)
	}

	// 每个字段的值
	values := make([][]byte, len(columns))
	// 存放对应的每个字段值的地址
	scans := make([]interface{}, len(columns))

	// values中的值地址绑定到scan的每个元素上
	for i := range values {
		scans[i] = &values[i]
	}

	results := make([]map[string]string, 0)
	for rows.Next() {
		// 使用可变参数解决需要定义的字段数无法确定的问题
		if err := rows.Scan(scans...); err != nil {
			// Scan查询出来的值放到scans[i] = &values[i],也就是每行数据都放在values里
			return nil, e.setError(err)
		}

		// 每行数据
		row := make(map[string]string)

		// 循环values数据,通过相同的键,取columns里面对应的字段,生成1个map
		for k, v := range values {
			key := columns[k]
			row[key] = string(v)
		}
		results = append(results, row)
	}
	return results, nil
}

注意:这个方法有一个小问题,会把所有字段的类型都转为字符串,但是理论上影响不大。

Select方法调用,如下:


result, err := e.Table("users").Where("created_at", ">=", 1645597211).Select()

// fmt.Printf("%T\n", result)
[]map[string]string

// fmt.Println(result)
[map[id:1 username:admin jobname:dev] map[id:2 username:zhangsan jobname:dev]]
  • 单条查询之map

思路:单条查询就是在多条查询基础上增加limit 1的条数限制,然后取出数据返回在map中


// SelectOne 单条查询,返回值为map
func (e *DbEngine) SelectOne() (map[string]string, error) {
	// limit 1
	results, err := e.Limit(1).Select()
	if err != nil {
		return nil, e.setError(err)
	}

	// 判断结果是否为空
	if len(results) == 0 {
		return nil, nil
	} else {
		return results[0], nil
	}
}

SelectOne方法调用,如下:


result, err := e.Table("users").Where("id", "=", 1).SelectOne()

// fmt.Printf("%T\n", result)
map[string]string

// fmt.Println(result)
map[id:1 username:admin jobname:dev]

有的小伙伴可能还是习惯了类似GORM那种查询操作,先定义好结构体,然后进行引用赋值。

如下:


type Users struct {
    Username  string `sql:"username"`
    JobName   string `sql:"jobname"`
    CreatedAt int64  `sql:"created_at"`
}

var user []Users

// select * from users where created_at >= 1645597211;

err := e.Table("users").Where("created_at", ">=", 1645597211).Find(&user)

if err != nil {
	panic(err)
}

fmt.Println(user)

思路:
定义一个结构体,字段属性定义好对应的tag标签
初始化空切片结构体,然后通过&取地址符传到Find方法
Find方法内部先获取表的字段,再通过tag标签和各种反射操作,将数据绑定到传入的切片结构体进行赋值。这一步想想就觉得麻烦

话不多说,直接上代码:

  • 多条查询之切片结构体

// Find 多条查询,切片结构体
func (e *DbEngine) Find(result interface{}) error {
	// 判断参数是否为指针变量
	if reflect.ValueOf(result).Kind() != reflect.Ptr {
		return e.setError(errors.New("参数错误"))
	}
	// 判断是空指针
	if reflect.ValueOf(result).IsNil() {
		return e.setError(errors.New("参数错误,不能是空指针"))
	}

	// 拼接sql
	e.Prepare = "select " + e.FieldParam + " from " + e.GetTable()

	// where不为空
	if e.WhereParam != "" || e.OrWhereParam != "" {
		e.Prepare += " where " + e.WhereParam + e.OrWhereParam
	}
	// group不为空
	if e.GroupParam != "" {
		e.Prepare += " group by " + e.GroupParam
	}
	// having不为空
	if e.HavingParam != "" {
		e.Prepare += " having " + e.HavingParam
	}
	// order by不为空
	if e.OrderParam != "" {
		e.Prepare += " order by " + e.OrderParam
	}
	// limit不为空
	if e.LimitParam != "" {
		e.Prepare += " limit " + e.LimitParam
	}
	// 查询条件对应值
	e.AllExec = e.WhereExec

	// 生成sql
	e.generateSql()
	// 查询数据
	rows, err := e.Db.Query(e.Prepare, e.AllExec...)
	if err != nil {
		return e.setError(err)
	}

	// 查询字段名
	columns, err := rows.Columns()
	if err != nil {
		return e.setError(err)
	}

	// 每个字段的值
	values := make([][]byte, len(columns))
	// 存放对应的每个字段值的地址
	scans := make([]interface{}, len(columns))
	// values中的值地址绑定到scan的每个元素上
	for i := range values {
		scans[i] = &values[i]
	}

	// 获取指针切片结构体的值
	sliceStruct := reflect.ValueOf(result).Elem()
	// 获取单个结构体的类型
	structType := sliceStruct.Type().Elem()

	for rows.Next() {
		// 根据结构体类型,生成一个新的结构体
		newStruct := reflect.New(structType).Elem()

		if err := rows.Scan(scans...); err != nil {
			// Scan查询出来的值放到scans[i] = &values[i],也就是每行数据都放在values里
			return e.setError(err)
		}

		// 遍历每行数据各个字段
		for k, v := range values {
			key := columns[k]
			value := string(v)

			num := structType.NumField()
			// 遍历结构体
			for i := 0; i < num; i++ {
				// 解析tag标签
				sqlTag := structType.Field(i).Tag.Get("sql")

				var fieldName string
				if sqlTag != "" {
					fieldName = strings.Split(sqlTag, ",")[0]
				} else {
					fieldName = structType.Field(i).Name
				}

				// 结构体中没这个字段属性
				if key != fieldName {
					continue
				}
				// 反射判断类型,赋值结构体
				if err := e.reflectSetStruct(newStruct, i, value); err != nil {
					return err
				}
			}
		}
		// 赋值,reflect.Append将新结构体的值追加到切片结构体中;sliceStruct.Set将追加后的结果赋值到传入的指针切片结构体中
		sliceStruct.Set(reflect.Append(sliceStruct, newStruct))
	}

	return nil
}

因为结构体严格区分字段类型,会有一个结构体赋值问题,处理方法:将查询的结果集中各个字段值的类型进行遍历判断,然后转换成结构体对应的字段类型赋值回去。同样是基于反射处理。

代码如下:


// 通过反射判断结构体字段类型,然后将对应的结构体的字段值转化为相应类型,然后再赋值回结构体中
func (e *DbEngine) reflectSetStruct(target reflect.Value, i int, value string) error {
	switch target.Field(i).Kind() {
	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
		res, err := strconv.ParseInt(value, 10, 64)
		if err != nil {
			return e.setError(err)
		}
		target.Field(i).SetInt(res)
	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
		res, err := strconv.ParseUint(value, 10, 64)
		if err != nil {
			return e.setError(err)
		}
		target.Field(i).SetUint(res)
	case reflect.Float64:
		res, err := strconv.ParseFloat(value, 64)
		if err != nil {
			return e.setError(err)
		}
		target.Field(i).SetFloat(res)
	case reflect.String:
		target.Field(i).SetString(value)
	case reflect.Bool:
		res, err := strconv.ParseBool(value)
		if err != nil {
			return e.setError(err)
		}
		target.Field(i).SetBool(res)
	}
	return nil
}
  • 单条查询之结构体

同样的单条查询就是在多条查询基础上增加limit 1的条数限制,然后取出数据返回在结构体中。但是,有个问题:我们的Find方法的参数是一个切片结构体指针,而单条查询传入的应该是单个结构体指针,类型不匹配,还得想办法搞定他(同样还是查询了反射的文档资料得出答案)。

代码如下:


// FindOne 单条查询,结构体
func (e *DbEngine) FindOne(result interface{}) error {
	// 获取结构体原始值
	structValue := reflect.Indirect(reflect.ValueOf(result))
	// 反射出一个相同结构体类型的切片
	structSlice := reflect.New(reflect.SliceOf(structValue.Type())).Elem()

	// 查询数据,传入指针
	if err := e.Limit(1).Find(structSlice.Addr().Interface()); err != nil {
		return err
	}

	// 判断返回值长度
	if structSlice.Len() == 0 {
		return e.setError(errors.New("Not Found Error"))
	}

	// 取切片里的第0个数据,并赋值给原始值结构体指针
	structValue.Set(structSlice.Index(0))
	return nil
}

九、设置查询字段 Field

平时很多人查询数据的时候都喜欢用select* ,把表字段都查出来,在一些大表中执行查询的话是非差低效和浪费资源的。
Field()方法指定本次查询的字段,实现方式也比较简单,直接赋值给FieldParam属性即可

代码如下:


// Field 设置查询字段
func (e *DbEngine) Field(field string) *DbEngine {
	e.FieldParam = field
	return e
}

注意:FieldParam的初始值是“”,一开始在NewMysql里面已经初始化过。所以即使没进行字段设置,Prepare的值也是select,也是不影响sql的完整性。
我们是直接传入字段列名,并没有对传入字段名做检验。这个将在第二版中进行优化。

调用方式如下:


// 不可放置在前后两端,中间部分任意位置都可以
e.Table("users").Field("id,username").Where("created_at", ">=", 1645597211).Select()

十、设置条数 Limit

limit一般用来控制获取的数据量大小或者分页,如:单条查询 limit 1、分页查询(一页十条数据) limit 0,9。所以我们的Limit方法也实现这两种。

如:


// 查询一条
e.Table("users").Field("id,username").Where("created_at", ">=", 1645597211).Limit(1).Select()

// 查询十条
e.Table("users").Field("id,username").Where("created_at", ">=", 1645597211).Limit(0, 9).Select()

思路:有两种方式传参,那同样也是类似于Insert的传参方式,使用可变参数。同时对参数进行判断长度和限制长度。然后将对应的limt语句拼接好赋值到LimitParamb中

代码如下:


// Limit 分页
func (e *DbEngine) Limit(limit ...int64) *DbEngine {
	if len(limit) == 1 {
		e.LimitParam = strconv.Itoa(int(limit[0]))
	} else if len(limit) == 2 {
		e.LimitParam = strconv.Itoa(int(limit[0])) + "," + strconv.Itoa(int(limit[1]))
	} else {
		panic("参数错误")
	}
	return e
}

十一、排序 Order

排序的实现思路也是比较简单:根据字段,升序asc、降序desc,参数为字符串传入,然后赋值到OrderParam中即可

代码如下:


// Order 排序
func (e *DbEngine) Order(order string) *DbEngine {
	if order == "" {
		panic("参数错误")
	}

	// 多次调用判断
	if e.OrderParam != "" {
		e.OrderParam += ","
	}
	e.OrderParam += order
	return e
}

调用方式如下:


// 按ID降序排列
e.Table("users").Field("id,username").Where("created_at", ">=", 1645597211).Order("id desc").Select()
// 先按ID升序排列,再按分数降序排列
e.Table("users").Field("id,username").Where("created_at", ">=", 1645597211).Order("id asc,score desc").Select()

十二、聚合查询 Count/Sum/Max/Min/Avg


# 都是类似写法(函数(字段)),不一一列举了
select count(*) from users;
...

思路:
聚合函数在平时查询统计的时候用的还是比较多,它们的SQL拼接方式都类似,把函数和字段替换到select* 里面。
他们聚合查询后的数据都只有一条。所以我就不考虑用Select中的Db.Query+for next方法,直接用Db.QueryRow方法,此方法就是原生的查询单条的方式。
既然几个聚合函数写法类似,那就抽一个通用方法,然后具体暴漏出来的聚合方法去调用他即可。

  • 公共部分:聚合查询

// 公共部分:聚合查询
func (e *DbEngine) aggregateQuery(name, field string) (interface{}, error) {
	// 拼接sql
	e.Prepare = "select " + name + "(" + field + ") as aggre from " + e.GetTable()

	// where不为空
	if e.WhereParam != "" || e.OrWhereParam != "" {
		e.Prepare += " where " + e.WhereParam + e.OrWhereParam
	}
	// limit不为空
	if e.LimitParam != "" {
		e.Prepare += " limit " + e.LimitParam
	}
	e.AllExec = e.WhereExec

	// 生成sql
	e.generateSql()
	// 执行绑定,由于最终的聚合结果值不确定类型,可能是整数、小数、浮点数(平均值),所以定义了一个interface变量
	var aggre interface{}

	// 查询单条
	err := e.Db.QueryRow(e.Prepare, e.AllExec...).Scan(&aggre)
	if err != nil {
		return nil, e.setError(err)
	}

	return aggre, err
}
  • 总数 Count

// Count 总数
func (e *DbEngine) Count() (int64, error) {
	count, err := e.aggregateQuery("count", "*")
	if err != nil {
		return 0, e.setError(err)
	}
	return count.(int64), err
}
  • 总和 Sum

// Sum 总和
func (e *DbEngine) Sum(field string) (string, error) {
	sum, err := e.aggregateQuery("sum", field)
	if err != nil {
		return "0", e.setError(err)
	}
	return string(sum.([]byte)), nil
}
  • 最大值 Max

// Max 最大值
func (e *DbEngine) Max(field string) (string, error) {
	max, err := e.aggregateQuery("max", field)
	if err != nil {
		return "0", e.setError(err)
	}
	return string(max.([]byte)), nil
}
  • 最小值 Min

// Min 最小值
func (e *DbEngine) Min(field string) (string, error) {
	min, err := e.aggregateQuery("min", field)
	if err != nil {
		return "0", e.setError(err)
	}
	return string(min.([]byte)), nil
}
  • 平均值 Avg

// Avg 平均值
func (e *DbEngine) Avg(field string) (string, error) {
	avg, err := e.aggregateQuery("avg", field)
	if err != nil {
		return "0", e.setError(err)
	}
	return string(avg.([]byte)), nil
}

调用方式如下:



count, err := e.Table("users").Where("id", ">=", 1).Count()

sum, err := e.Table("users").Where("id", ">=", 1).Sum("id")

max, err := e.Table("users").Where("id", ">=", 1).Max("id")

min, err := e.Table("users").Where("id", ">=", 1).Min("id")

avg, err := e.Table("users").Where("id", ">=", 1).Avg("id")

十三、分组 Group

分组用于我们对某1个或者几个字段进行归类,然后查询归类后的数据。可以搭配上Field(count( * ) as num)来做更加具体的分组查询
思路:直接就是可变参数逗号拼接字符串

代码如下:


// Group 分组
func (e *DbEngine) Group(group ...string) *DbEngine {
	if len(group) != 0 {
		e.GroupParam = strings.Join(group, ",")
	}
	return e
}

调用方式如下:


result,err := e.Table("users").Field("jobname, count(*) as num").Group("jobname").Select()

十四、分组后判断 Having

思路:Having的作用和Where是一样的,只不过Having是作用在Group后的。同样的,我们要实现和Where一样的参数支持:结构体、字符串

代码如下:


// Having 分组后判断
func (e *DbEngine) Having(having ...interface{}) *DbEngine {
	// 判断是结构体
	var dataType int
	if len(having) == 1 {
		dataType = 1
	} else if len(having) == 2 {
		dataType = 2
	} else if len(having) == 3 {
		dataType = 3
	} else {
		panic("参数错误")
	}

	// 多次调用判断
	if e.HavingParam != "" {
		e.HavingParam += "and ("
	} else {
		e.HavingParam += "("
	}

	// 若是结构体
	if dataType == 1 {
		// 反射type和value
		t := reflect.TypeOf(having[0])
		v := reflect.ValueOf(having[0])

		// 字段名
		var fieldNameArray []string
		// 字段总数
		num := t.NumField()

		for i := 0; i < num; i++ {
			// 判断字段值
			if !v.Field(i).CanInterface() {
				continue
			}

			// 解析tag
			sqlTag := t.Field(i).Tag.Get("sql")
			if sqlTag != "" {
				fieldNameArray = append(fieldNameArray, strings.Split(sqlTag, ",")[0]+"=?")
			} else {
				fieldNameArray = append(fieldNameArray, t.Field(i).Name+"=?")
			}
			// 组装值
			e.WhereExec = append(e.WhereExec, v.Field(i).Interface())
		}
		// 拼接
		e.HavingParam += strings.Join(fieldNameArray, " and ") + ") "
	} else if dataType == 2 {
		// 直接等于"="
		e.HavingParam += having[0].(string) + "=?) "
		e.WhereExec = append(e.WhereExec, having[1])
	} else if dataType == 3 {
		// 字符串参数
		e.HavingParam += having[0].(string) + " " + having[1].(string) + " ?) "
		e.WhereExec = append(e.WhereExec, having[2])
	}
	return e
}

调用方式如下:


result,err := e.Table("users").Field("jobname, count(*) as num").Group("jobname").Having("jobname", "dev").Select()
if err != nil {
    panic(err)
}

十五、获取执行SQL GetLastSql

基本上所有的ORM方法本质上都是组装好原生Sql。我们排查数据库错误时多数都是需要通过最终的sql进行分析。所以这个ORM也需要提供这个方法。
思路:我们所有的ORM方法都已经在预处理的部分把sql组装好了,占位符对应的参数值也按占位符顺序一一对应排列好了。所以我们就只需要吧问号?替换成对应的参数值即可。

代码如下:


// 生成SQL,在链式调用最后一步执行的方法中去调用这个方法。就可以生成执行的sql语句。调用GetLastSql就可以打印出来了。
func (e *DbEngine) generateSql() {
	e.Sql = e.Prepare
	for _, i2 := range e.AllExec {
		// 还可以继续根据需要补充类型
		switch i2.(type) {
		case int:
			e.Sql = strings.Replace(e.Sql, "?", strconv.Itoa(i2.(int)), 1)
		case int64:
			e.Sql = strings.Replace(e.Sql, "?", strconv.FormatInt(i2.(int64), 10), 1)
		case bool:
			e.Sql = strings.Replace(e.Sql, "?", strconv.FormatBool(i2.(bool)), 1)
		default:
			e.Sql = strings.Replace(e.Sql, "?", "'"+i2.(string)+"'", 1)
		}
	}
}

// GetLastSql 获取执行的sql
func (e *DbEngine) GetLastSql() string {
	return e.Sql
}

调用方式如下:


result, err := e.Table("users").Where("id", ">=", 1).Order("id desc").Select()

fmt.Println(e.GetLastSql()) //select * from users where (id >= 1) order by id desc

十六、执行原生sql Exec/Query

虽然不推荐直接执行sql语句,但在一些场景中有一些比较复杂的sql,这时就需要用到。

  • sql执行原生增删改操作 Exec

思路:直接使用Exec方法暴力执行sql


// Exec 直接执行增删改sql
func (e *DbEngine) Exec(sql string) (id int64, err error) {
	result, err := e.Db.Exec(sql)
	e.Sql = sql
	if err != nil {
		return 0, e.setError(err)
	}

	// insert返回ID,其他返回影响行数
	if strings.Contains(sql, "insert") {
		lastInsertId, _ := result.LastInsertId()
		return lastInsertId, nil
	} else {
		rowsAffected, _ := result.RowsAffected()
		return rowsAffected, nil
	}
}

调用如下:


result, err:= e.Exec("insert into users(username,jobname,created_at) values('root', 'leader', 1645597211)");

result, err := e.Exec("delete from users where id=1")
...
  • sql执行原生查询操作 Query

思路:把Select方法拿来改下,删减一些东西就行了


// Query 直接执行查询sql
func (e *DbEngine) Query(sql string) ([]map[string]string, error) {
	rows, err := e.Db.Query(sql)
	// 直接写入sql
	e.Sql = sql
	if err != nil {
		return nil, e.setError(err)
	}

	// 查询字段名
	columns, err := rows.Columns()
	if err != nil {
		return nil, e.setError(err)
	}

	// 每个字段的值
	values := make([][]byte, len(columns))
	// 存放对应的每个字段值的地址
	scans := make([]interface{}, len(columns))
	// values中的值地址绑定到scan的每个元素上
	for i := range values {
		scans[i] = &values[i]
	}

	results := make([]map[string]string, 0)
	for rows.Next() {
		// 使用可变参数解决需要定义的字段数无法确定的问题
		if err := rows.Scan(scans...); err != nil {
			// Scan查询出来的值放到scans[i] = &values[i],也就是每行数据都放在values里
			return nil, e.setError(err)
		}

		// 每行数据
		row := make(map[string]string)
		for k, v := range values {
			key := columns[k]
			row[key] = string(v)
		}
		results = append(results, row)
	}
	return results, nil
}

调用如下:


result, err := e.Query("SELECT * FROM users limit 1")

十七、事务 Begin/Commit/Rollback

事务操作平时在业务中用的也是很多,它用于在多次执行增删改的操作的时候,如果其中一个出现问题,可以一起回滚数据,确保了数据的一致性。

  • 开启事务 Begin

思路:开启事务只需要调用原生库的方法,然后设置事务状态即可


// Begin 开启事务
func (e *DbEngine) Begin() error {
	// 调用原生方法
	tx, err := e.Db.Begin()
	if err != nil {
		return e.setError(err)
	}

	// 事务资源
	e.Tx = tx
	// 1代表开启了事务
	e.TransStatus = 1
	return nil
}

关键的一步:在具体的增删改方法中,通过事务状态去判断是否执行事务操作


// 如Inert方法中

func (e *DbEngine) insertData(data interface{}) (int64, error) {
	...

	// 原先这里已经有初始化过这个字段,用于判断是否是事务
	var stmt *sql.Stmt
	var err error

	// 只需要在原先这个地方增加这一个事务状态判断即可,后续其他操作不需改变。
	if e.TransStatus == 1 {
	  stmt, err = e.Tx.Prepare(e.Prepare)
	} else {
	  stmt, err = e.Db.Prepare(e.Prepare)
	}

	...
}

  • 回滚 Rollback

回滚是当我们的事务执行出现问题后,发送回滚指令给mysql服务器,请求将执行的结果还原到事务开始之前。
思路:实现比较简单,只需把事务状态归0,调用原生回滚即可


// Rollback 事务回滚
func (e *DbEngine) Rollback() error {
	e.TransStatus = 0
	return e.Tx.Rollback()
}
  • 提交事务 Commit

提交事务是代表事务中所有的操作都执行成功,然后发送确认提交指令给mysql服务器,如果不执行这个确认提交指令的话,数据操作是不会生效的。
思路:与实现回滚操作一样


// Commit 提交事务
func (e *DbEngine) Commit() error {
	e.TransStatus = 0
	return e.Tx.Commit()
}

调用例子如下:


err := e.Begin()

flag := true
if err != nil {
	panic(err)
}

result1, err1 := e.Table("users").Where("id", "=", 1).Update("username", "lisi")
if err1 != nil {
    flag = false
    panic(err)
}

// 更新失败/找不到数据
if result1 <= 0 {
    flag = false
    fmt.Println("update failed")
}

fmt.Println(result1)
fmt.Println(e.GetLastSql())

result2, err2 := e.Table("users").Where("id", "=", 2).Delete()
if err2 != nil {
    flag = false
    fmt.Println(err2.Error())
}

if result2 <= 0 {
    flag = false
    fmt.Println("delete failed")
}

fmt.Println(result2)
fmt.Println(e.GetLastSql())

// 每一步的执行如果失败都将flag设置为false,最后通过判断flag来回滚和提交
if flag {
    _ = e.Commit()
    fmt.Println("success")
} else {
    _ = e.Rollback()
    fmt.Println("error")
}

基准测试

自己写完之后,还是很想跟GORM做下对比,看看造的轮子性能怎么样,能不能用。先搞个测试的200W数据测下。


package mysql

import (
	"gorm.io/driver/mysql"
	"gorm.io/gorm"
	"testing"
)

func BenchmarkSelect(b *testing.B) {
	e, _ := NewMysql("root", "123456", "127.0.0.1:13306", "test")

	type Author struct {
		Id   int    `sql:"id"`
		Name string `sql:"name"`
		Age  int    `sql:"age"`
	}
	var author []Author

	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		_ = e.Table("author").Where("id", ">=", 0).Limit(100).Find(&author)
	}
	b.StopTimer()
}

func BenchmarkGormSelect(b *testing.B) {
	dsn := "root:123456@tcp(127.0.0.1:13306)/test?charset=utf8mb4&parseTime=True&loc=Local"
	db, _ := gorm.Open(mysql.Open(dsn), &gorm.Config{})

	type Author struct {
		Id   int    `gorm:"id"`
		Name string `gorm:"name"`
		Age  int    `gorm:"age"`
	}
	var author []Author

	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		db.Table("author").Where("id >= ?", 50).Limit(50).Find(&author)
	}
	b.StopTimer()
}

func BenchmarkUpdate(b *testing.B) {
	e, _ := NewMysql("root", "123456", "127.0.0.1:13306", "test")

	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		_, _ = e.Table("author").Where("id", "=", 27).Update("age", 20)
	}
	b.StopTimer()
}

func BenchmarkGormUpdate(b *testing.B) {
	dsn := "root:123456@tcp(127.0.0.1:13306)/test?charset=utf8mb4&parseTime=True&loc=Local"
	db, _ := gorm.Open(mysql.Open(dsn), &gorm.Config{})

	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		db.Table("author").Where("id = ?", 27).Update("age", 18)
	}
	b.StopTimer()
}

基准测试如下:


➜  mysql go test -bench=. -benchmem
goos: darwin
goarch: amd64
pkg: demo/mysql
cpu: Intel(R) Core(TM) i7-1068NG7 CPU @ 2.30GHz
BenchmarkSmallormSelect-8  1296   843769 ns/op    911 B/op    25 allocs/op
BenchmarkGormSelect-8      598    1998827 ns/op   29250 B/op  1058 allocs/op
BenchmarkSmallormUpdate-8  1197   864404 ns/op    727 B/op    21 allocs/op
BenchmarkGormUpdate-8      314    4216470 ns/op   6246 B/op   76 allocs/op
PASS
ok      demo/mysql     6.890s


总结:到这里ORM的基本功能大致实现了90%,性能也还不错。这里面难点主要在Insert、Select两个方法的实现。后续的版本可以做下优化,还可以继续往下迭代的点:联表查询、日志、安全以及性能优化。

ps:建议别自己造轮子,有现成好用并且长期维护的ORM用起来不香吗,以上更多的是锻炼逻辑思维能力,主要是学习下实现思路。

我是六涛sheliutao,文章编写总结不易,转载注明出处,喜欢本篇文章的小伙伴欢迎点赞、关注,有问题可以评论区留言或者私信我,相互交流!!!

参考链接