Golang之搭建ORM框架
背景
用过GORM后,我个人觉得不够简洁有些笨重,有较大的学习成本。本着学习和探索的目的,自己实现了个较为简洁优雅的go版本ORM。本文主要从基础原理开始将,再到实现步骤,最终完成这个简洁版的ORM(MySQL版),目的不是说一定要实现这个ORM来用,而是重点在于学习实现的思路,以及可改进方案。
ORM是什么?为什么要用ORM?
不管是用PHP、Go等语言,我们在做需求的时候,应该都使用过ORM。特别在新项目中,我们会用ORM框架去连接数据库,而不是直接用原生去写SQL连接,原因有很多,比如有安全、性能等的考虑,但我觉得更多的还是因为开发效率和懒,因为有些SQL写起来很复杂很累,特别是查询的时候,又是分页,又是结果集,还需要自己去判断和遍历,开发效率真的低。如果有个ORM,数据库config配置完,几个链式调用,结果就出来了,方便多了。
通俗来讲就是:ORM是提供与数据库交互的中间件。ORM提供更加方便的curd方法与数据库产生交互。
Go原生如何去连接MySQL
了解了ORM后,再来看看如何原生连接MySQL,只有清楚他们之间交互的原理,才能更好的搭建ORM(造轮子)。
第一步:导入数据库包和MySQL驱动包
import (
"database/sql"
_ "github.com/go-sql-driver/mysql"
)
第二步:连接MySQL
db, err := sql.Open("mysql", "root:123456@tcp(127.0.0.1:3306)/test?charset=utf8mb4")
if err != nil {
panic(err)
}
第三步:快速过一遍增删改查
func main() {
db, err := sql.Open("mysql", "root:123456@tcp(127.0.0.1:3306)/test?charset=utf8mb4")
if err != nil {
panic(err)
}
// 增
// 方式一:
result, err := db.Exec("INSERT INTO users (username, jobname, created_at) VALUES (?, ?, ?)", "zhangsan", "dev", "2022-02-23")
// 方式二:
stmt, err := db.Prepare("INSERT INTO users (username, jobname, created_at) VALUES (?, ?, ?)")
result2, err := stmt.Exec("zhangsan", "dev", time.Now().Format("2006-01-02"))
// 删
// 方式一:
result3, err := db.Exec("delete from users where id=?", 1)
// 方式二:
stmt, err := db.Prepare("delete from users where id=?")
result4, err := stmt.Exec(1)
// 改
// 方式一:
result, err := db.Exec("update users set username=? where id=?", "lisi", 2)
// 方式二:
stmt, err := db.Prepare("update users set username=? where id=?")
result, err := stmt.Exec("lisi", 2)
// 查
// 单条
var username, jobname, createdAt string
err := db.QueryRow("select username, jobname, created_at from users where id=?", 4).Scan(&username, &jobname, &createdAt)
if err != nil {
fmt.Println("QueryRow error :", err.Error())
}
fmt.Println("username: ", username, "jobname: ", jobname, "created_at: ", createdAt)
// 多条
rows, err := db.Query("select username, jobname, created_at from users where username=?", "wangwu")
if err != nil {
fmt.Println("QueryRow error :", err.Error())
}
// 结构体
type Users struct {
Username string `json:"username"`
JobName string `json:"jobname"`
CreatedAt string `json:"created_at"`
}
// 初始化
var user []Users
// 判断及遍历
for rows.Next() {
var username1, jobname1, createdAt1 string
if err := rows.Scan(&username1, &jobname1, &createdAt1); err != nil {
fmt.Println("Query error :", err.Error())
}
user = append(user, Users{Username: username1, JobName: jobname1, CreatedAt: createdAt1})
}
}
总结下:Go里面原生连接MySQL的方法,就是直接写sql,简单粗暴点(方式一)就直接Exec执行,想性能高一点就可以麻烦点(方式二)就先Prepare再Exec(此处原因自行进行Benchmark实验可以得出,不做过多赘述)。学习成本是较低,但问题还是麻烦和开发效率低。
所以我在想基于这个原生代码的优势,撸一个ORM:1、能提供各种方法提高开发效率;2、底层直接转换拼接成完整SQL,然后调用原生的组件和MySQL进行交互。这样就可以降低学习其他ORM的成本,也可以提高开发效率,也用得顺手方便。
搭建ORM的构思
原理:简单的SQL拼接,暴露出各种CURD方法,并在底层逻辑拼接成Prepare和Eexc占位符的部分,然后调用原生mysql驱动包的方法来实现和数据库交互。
调用过程采取链式调用的方法。比如:
db.Where().Order().Limit().Select()
注意:暴露的CURD方法,名字要清晰,无歧义,不要搞一大堆复杂的间接调用,使用起来要简单。
搭建前先整理列举出mysql常用的curd的方法,按照列表一个个实现:
- 连接 Connect
- 设置表名 Table
- 新增 Insert
- 条件 Where
- 删除 Delete
- 更新 Update
- 查询 Select
- 设置查询字段 Field
- 设置条数 Limit
- 排序 Order
- 聚合查询 Count/Sum/Max/Min/Avg
- 分组 Group
- 分组后判断 Having
- 获取执行的SQL GetLastSql
- 执行原生sql Exec/Query
- 事务 Begin/Commit/Rollback
注意:Insert/Delete/Select/Update是整个链式调用的最后一步。是真正的和MySQL交互的方法,后面不能再接其他方法。
在开始搭建之前先设想下,自己开发的这个ORM是什么样的,如何去使用它:
// 注意:Users结构体的每一个元素后面都有一个sql:"xxx",叫Tag标签。因为go里面首字母大写表示是可见的属性,所以如果是可见的属性都是大写字母开头,而数据表里面的字段首字母名一般是小写,所以为了照顾这个特殊关系而进行转换和匹配,才用了这个标签。如果你的表的字段类型也是大小字母开头,那可以不需要这个标签,下面会具体说到如何转换匹配。
type Users struct {
Username string `sql:"username"`
JobName string `sql:"jobname"`
CreatedAt int64 `sql:"created_at"`
}
// 增
user1 := Users{
Username: "admin",
JobName: "dev",
CreatedAt: time.Now().Unix(),
}
// insert into users (username,jobname,created_at) values ('admin', 'dev', 1645597211)
id, err := db.Table("users").Insert(user1)
// 删
// delete from users where id = 1
result1, err := db.Table("users").Where("id", "=", 1).Delete()
// 改
// update users set jobname = 'test' where id = 1)
result2, err := db.Table("users").Where("id", "=", 1).Update("jobname", "test")
// 查
// select id,username,jobname from users where (username like '%san') or (jobname = 'dev') order by id desc limit 1
result3, err := db.Table("users").Where("username", "like", "%san").OrWhere("jobname", "=", "dev").Field("id, username, jobname").Order("id", "desc").Limit(1).Select()
user2 := Users{
Username: "admin",
JobName: "dev",
}
// select * from users where username='admin' and jobname='dev' limit 1
result4, err := db.Table("users").Where(user2).SelectOne()
开发阶段
一、连接 Connect
连接MySQL直接把原生的sql.Open方法套上函数即可。第一个版本以实现功能为主,就暂时先不考虑协程和长连接的保持。
1、先构建一个结构体,用来存储各种数据。因为这个ORM的底层本质是SQL拼接,所以我们需要把各种操作方法生成的数据,都保存到这个结构体的各个属性上,方便最后一步生成SQL。
type DbEngine struct {
Db *sql.DB // 用于直接进行CURD操作
TableName string
WhereParam string
OrWhereParam string
WhereExec []interface{}
UpdateParam string
UpdateExec []interface{}
FieldParam string
LimitParam string
OrderParam string
GroupParam string
HavingParam string
Prepare string
AllExec []interface{}
Sql string
TransStatus int
Tx *sql.Tx // 数据库的事务操作,用于回滚和提交
}
2、编写连接方法
// 新建Mysql连接
func NewMysql(user, pwd, address, dbname string) (*DbEngine, error) {
dsn := user + ":" + pwd + "@tcp(" + address + ")/" + dbname + "?charset=utf8mb4&timeout=5s&readTimeout=5s"
db, err := sql.Open("mysql", dsn)
if err != nil {
return nil, err
}
// 配置,先占个位
db.SetMaxOpenConns(100)
db.SetMaxIdleConns(100)
return &DbEngine{
Db: db,
FieldParam: "*",
}, nil
}
二、设置/读取表名 Table/GetTable
每次调用Table()方法,就给本次执行设置一个表名。并且清空DbEngine的其他数据。
// 设置表名
func (e *DbEngine) Table(tableName string) *DbEngine {
e.TableName = tableName
e.resetDbEngine()
return e
}
// 获取表名
func (e *DbEngine) GetTable() string {
return e.TableName
}
// 重置引擎
func (e *DbEngine) resetDbEngine() {
*e = DbEngine{
Db: e.Db,
TableName: e.TableName,
}
}
三、新增 Insert
这里涉及到的反射知识点会比较多,可自行了解
- 单条插入
回顾下原生的单条插入,采用Prepare再Exec的方式,高效且安全:
stmt, err := db.Prepare("INSERT INTO users (username, jobname, created_at) VALUES (?, ?, ?)")
result, err := stmt.Exec("zhangsan", "dev", time.Now().Unix())
先在Prepare里把插入的数据的value值用占位符代替,有几个value就用几个占位符,在Exec里面,把value值一一对应上。
思路:为了方便,我们插入数据的参数要传kv键值对集合,比如[field1:val1,field2:val2,field3:val3]。在go语言中可以选择Map或者Struct,由于数据表不同字段可能是不同类型,所以选择结构体。
type Users struct {
Username string `sql:"username"`
JobName string `sql:"jobname"`
CreatedAt int64 `sql:"created_at"`
}
// 增
user1 := Users{
Username: "admin",
JobName: "dev",
CreatedAt: time.Now().Unix(),
}
// insert into users (username,jobname,created_at) values ('admin', 'dev', 1645597211)
id, err := e.Table("users").Insert(user1)
第一步:将结构体中sql:"xxx"标签进行解析,解析成(username, jobname, created_at)。
第二步:将user1的元素的值都取出来,与字段一一对应的放入到Exec中。
如何从结构体中解析出对应的kv键值呢?可以通过反射来推导出结构体属性,属性、属性值、类型、tag标签。
// 反射获取结构体变量的类型
t := reflect.TypeOf(user1)
fmt.Printf("%+v\n", t) // main.User
// 反射获取结构体属性对应值
v := reflect.ValueOf(user1)
fmt.Printf("%+v\n", v) // {Username:admin JobName:dev CreatedAt:1645597211}
// 其他反射方法自行了解
t.NumField() // 获取结构体字段总数
t.Field(i) // 获取结构体属性,属性名、包括类型、tag标签等的值
v.Field(i) // 获取结构体属性值
...
单条插入,代码如下:
func (e *DbEngine) insertData(data interface{}) (int64, error) {
// 反射type和value
t := reflect.TypeOf(data)
v := reflect.ValueOf(data)
// 字段名
var fieldName []string
// 占位符
var placeholder []string
// 字段总数
num := t.NumField()
for i := 0; i < num; i++ {
// 判断字段值
if !v.Field(i).CanInterface() {
continue
}
// 解析tag
sqlTag := t.Field(i).Tag.Get("sql")
if sqlTag != "" {
// 跳过自增字段
if strings.Contains(strings.ToLower(sqlTag), "auto_increment") {
continue
} else {
fieldName = append(fieldName, strings.Split(sqlTag, ",")[0])
placeholder = append(placeholder, "?")
}
} else {
fieldName = append(fieldName, t.Field(i).Name)
placeholder = append(placeholder, "?")
}
// 组装字段值
e.AllExec = append(e.AllExec, v.Field(i).Interface())
}
// sql拼接
e.Prepare = "insert into " + e.GetTable() + " (" + strings.Join(fieldName, ",") + ") values(" + strings.Join(placeholder, ",") + ")"
var stmt *sql.Stmt
var err error
// 预处理
stmt, err = e.Db.Prepare(e.Prepare)
if err != nil {
return 0, e.setError(err)
}
// 生成sql
e.generateSql()
// 执行exec
result, err := stmt.Exec(e.AllExec...)
if err != nil {
return 0, e.setError(err)
}
id, _ := result.LastInsertId()
return id, nil
}
// 设置错误内容
func (e *DbEngine) setError(err error) error {
_, file, line, _ := runtime.Caller(1)
return errors.New("File: " + file + ":" + strconv.Itoa(line) + ", " + err.Error())
}
- 批量插入
不描述那么多了,直接说下思路:
1、传入参数为一个切片数组结构体,[]struct
2、通过反射获取切片数据,计算该数组元素个数,进行循环处理
3、2个for循环,最外层for循环获取切片数组中每个子元素的值和类型,也即结构体的值和结构体类型,然后进行判断
4、里面的for循环则和单条插入的操作一样,反射获取字段名,计算字段总数,解析tag,组装占位符、字段值
5、最终拼接sql进行操作
批量插入,代码如下:
func (e *DbEngine) batchInsertData(batchData interface{}) (int64, error) {
// 反射解析
v := reflect.ValueOf(batchData)
// 切片长度
l := v.Len()
// 字段名
var fieldName []string
// 总占位符
var placeholderString []string
for i := 0; i < l; i++ {
value := v.Index(i) // 子元素值,即结构体
typed := value.Type() // 子元素值类型
if typed.Kind() != reflect.Struct {
panic("参数必须是切片结构体类型")
}
num := value.NumField()
// 子占位符
var placeholder []string
for j := 0; j < num; j++ {
// 判断字段值
if !value.Field(j).CanInterface() {
continue
}
// 解析tag
sqlTag := typed.Field(j).Tag.Get("sql")
if sqlTag != "" {
// 跳过自增字段
if strings.Contains(strings.ToLower(sqlTag), "auto_increment") {
continue
} else {
// 字段名只记录一个循环的即可
if i == 1 {
fieldName = append(fieldName, strings.Split(sqlTag, ",")[0])
}
placeholder = append(placeholder, "?")
}
} else {
if i == 1 {
fieldName = append(fieldName, typed.Field(j).Name)
}
placeholder = append(placeholder, "?")
}
// 组装字段值
e.AllExec = append(e.AllExec, value.Field(j).Interface())
}
// 组装成多个()括号的字段值
placeholderString = append(placeholderString, "("+strings.Join(placeholder, ",")+")")
}
// 拼接sql
e.Prepare = "insert into " + e.GetTable() + " (" + strings.Join(fieldName, ",") + ") values " + strings.Join(placeholderString, ",")
var stmt *sql.Stmt
var err error
// 预处理
stmt, err = e.Db.Prepare(e.Prepare)
if err != nil {
return 0, e.setError(err)
}
// 生成sql
e.generateSql()
// 执行exec
result, err := stmt.Exec(e.AllExec...)
if err != nil {
return 0, e.setError(err)
}
// 获取自增ID
id, _ := result.LastInsertId()
return id, nil
}
- 合二为一,整合单条和批量为一个方法
为了让ORM简洁和优雅一点,把单条插入和批量插入整合一起,只对外暴露一个方法。通过反射判断传入的参数是单结构体还是切片结构体。
func (e *DbEngine) Insert(data interface{}) (int64, error) {
// 参数类型
typed := reflect.ValueOf(data).Kind()
// 判断是批量还是单个插入
if typed == reflect.Struct {
return e.insertData(data)
} else if typed == reflect.Slice || typed == reflect.Array {
return e.batchInsertData(data)
} else {
return 0, errors.New("参数有误")
}
}
四、条件 Where
Where方法主要是为了替换sql中where后面的条件语句,结合各种比较符号:=、!=、>、<、like、in或者用and、or去隔开多个条件。
思路:同样是先占位符进行预处理,然后再执行exec替换值的方式操作,所以Where方法的逻辑也是和Insert一样分两步走
- 方式一:结构体参数
type Users struct {
Username string `sql:"username"`
JobName string `sql:"jobname"`
CreatedAt int64 `sql:"created_at"`
}
user1 := Users{
Username: "admin",
JobName: "dev",
}
// select * from users where username = 'admin' and jobname = 'dev';
result, err := e.Table("users").Where(user1).Select()
where部分是中间层,不会去执行结果,只是将数据拆分出来,存入到DB引擎结构体的WhereParam和WhereExec属性中,由最后的增删改查操作使用。
结构体参数Where方法,代码如下:
func (e *DbEngine) Where(data interface{}) *DbEngine {
// 反射type和value
t := reflect.TypeOf(data)
v := reflect.ValueOf(data)
// 字段名
var fieldNameArray []string
// 字段总数
num := t.NumField()
for i := 0; i < num; i++ {
// 判断字段值
if !v.Field(i).CanInterface() {
continue
}
// 解析tag
sqlTag := t.Field(i).Tag.Get("sql")
if sqlTag != "" {
fieldNameArray = append(fieldNameArray, strings.Split(sqlTag, ",")[0]+"=?")
} else {
fieldNameArray = append(fieldNameArray, t.Field(i).Name+"=?")
}
// 组装值
e.WhereExec = append(e.WhereExec, v.Field(i).Interface())
}
// 多次调用判断
if e.WhereParam != "" {
// 如果不为空,则说明这是第二次调用了
e.WhereParam += " and ("
} else {
e.WhereParam += "("
}
// 拼接条件语句
e.WhereParam += strings.Join(fieldNameArray, " and ") + ") "
return e
}
打印下WhereParam和WhereExec的值
// fmt.Println
WhereParam = "(username=? and jobname=?) "
WhereExec = []interface{"admin", "dev"}
// 多次调用Where方法,每次都是一个and连接
result, err := e.Table("users").Where(user1).Where(user2).Select()
// fmt.Println
WhereParam = "(username=? and jobname=?) and (created_at=?)"
WhereExec = []interface{"admin", "dev", 1645597211}
注意:结构体参数的方式每个条件都是等于“=”的关系,如果需要其他条件关系的话,可以使用下面的字符串参数的方式
- 方式二:字符串参数
结构体参数方式除了上面说的只能是使用“=”的条件外,他每次使用还得写个结构体然后赋值才能使用,有时只有一个参数的话使用结构体参数方式不够便捷。所以提供了字符串参数的方式,较为简单便捷。
此方式的Where方法有三个参数,(查询的字段,条件符号,值)如:
// 也可以用其他条件符号,如:=、!=、>、<、>=、<=、not in、in、like
Where("id", "=", 1)
Where("id", ">", 2)
Where("id", ">=", 3)
Where("id", "in", []int{1, 2, 3, 4})
Where("name", "like", "%tim")
字符串参数方法,代码如下:
func (e *DbEngine) Where(field string, opt string, value interface{}) *DbEngine {
// 获取符号
data2 := strings.Trim(strings.ToLower(opt), " ")
// 多次调用判断
if e.WhereParam != "" {
e.WhereParam += " and ("
} else {
e.WhereParam += "("
}
// 判断是否是in
if data2 == "in" || data2 == "not in" {
typed := reflect.TypeOf(value).Kind()
// 判断传入的是切片或数组
if typed != reflect.Slice && typed != reflect.Array {
panic("in类的操作参数必须是切片或数组")
}
// 反射值
v := reflect.ValueOf(value)
// 切片长度
l := v.Len()
// 占位符
ps := make([]string, l)
for i := 0; i < l; i++ {
ps[i] = "?"
e.WhereExec = append(e.WhereExec, v.Index(i).Interface())
}
// 拼接
e.WhereParam += field + " " + data2 + " (" + strings.Join(ps, ",") + ")) "
} else {
e.WhereParam += field + " " + opt + " ?) "
e.WhereExec = append(e.WhereExec, value)
}
return e
}
- 合二为一,整合为一个方法,并增加了"="等于号的情况下,可以省略等于号,只传字段和值两个参数
// Where and条件方法
func (e *DbEngine) Where(data ...interface{}) *DbEngine {
// 判断是结构体参数还是字符串参数
var dataType int
if len(data) == 1 {
dataType = 1 // 结构体参数
} else if len(data) == 2 {
dataType = 2 // 直接等于"="的情况,省去等于号,只传字段和值两个参数
} else if len(data) == 3 {
dataType = 3 // 字符串参数
} else {
panic("参数错误")
}
// 多次调用判断
if e.WhereParam != "" {
e.WhereParam += " and ("
} else {
e.WhereParam += "("
}
// 结构体
if dataType == 1 {
t := reflect.TypeOf(data[0])
v := reflect.ValueOf(data[0])
// 字段名
var fieldNameArr []string
// 字段总数
num := t.NumField()
for i := 0; i < num; i++ {
// 判断字段值
if !v.Field(i).CanInterface() {
continue
}
// 解析tag
sqlTag := t.Field(i).Tag.Get("sql")
if sqlTag != "" {
fieldNameArr = append(fieldNameArr, strings.Split(sqlTag, ",")[0]+"=?")
} else {
fieldNameArr = append(fieldNameArr, t.Field(i).Name+"=?")
}
e.WhereExec = append(e.WhereExec, v.Field(i).Interface())
}
// 拼接
e.WhereParam += strings.Join(fieldNameArr, " and ") + ") "
} else if dataType == 2 {
// 直接等于"="
e.WhereParam += data[0].(string) + "=?) "
e.WhereExec = append(e.WhereExec, data[1])
} else if dataType == 3 {
// 字符串参数
// 获取符号
data2 := strings.Trim(strings.ToLower(data[1].(string)), " ")
// 判断是否是in操作
if data2 == "in" || data2 == "not in" {
// 判断传入的是切片
reType := reflect.TypeOf(data[2]).Kind()
if reType != reflect.Slice && reType != reflect.Array {
panic("in类操作传入参数必须是切片或数组")
}
// 反射切片值
v := reflect.ValueOf(data[2])
// 数组/切片长度
l := v.Len()
// 占位符
ps := make([]string, l)
for i := 0; i < l; i++ {
ps[i] = "?"
e.WhereExec = append(e.WhereExec, v.Index(i).Interface())
}
// 拼接
e.WhereParam += data[0].(string) + " " + data2 + " (" + strings.Join(ps, ",") + ")) "
} else {
e.WhereParam += data[0].(string) + " " + data[1].(string) + " ?) "
e.WhereExec = append(e.WhereExec, data[2])
}
}
return e
}
五、条件 OrWhere
思路:与Where方法是一样,只需要将whereParam替换成OrWhereParam,WhereExec还是维持不变,然后把and拼接改为or即可
同样支持三种方式,直接上代码:
// OrWhere or条件方法
func (e *DbEngine) OrWhere(data ...interface{}) *DbEngine {
// 判断是结构体参数还是字符串参数
var dataType int
if len(data) == 1 {
dataType = 1 // 结构体参数
} else if len(data) == 2 {
dataType = 2 // 直接等于"="的情况,省去等于号,只传字段和值两个参数
} else if len(data) == 3 {
dataType = 3 // 字符串参数
} else {
panic("参数错误")
}
// 判断方法使用顺序
if e.WhereParam == "" {
panic("OrWhere方法必须在Where方式之后调用")
}
e.OrWhereParam += " or ("
// 结构体
if dataType == 1 {
t := reflect.TypeOf(data[0])
v := reflect.ValueOf(data[0])
// 字段名
var fieldNameArr []string
// 字段总数
num := t.NumField()
for i := 0; i < num; i++ {
// 判断字段值
if !v.Field(i).CanInterface() {
continue
}
// 解析tag
sqlTag := t.Field(i).Tag.Get("sql")
if sqlTag != "" {
fieldNameArr = append(fieldNameArr, strings.Split(sqlTag, ",")[0]+"=?")
} else {
fieldNameArr = append(fieldNameArr, t.Field(i).Name+"=?")
}
e.WhereExec = append(e.WhereExec, v.Field(i).Interface())
}
// 拼接
e.OrWhereParam += strings.Join(fieldNameArr, " and ") + ") "
} else if dataType == 2 {
// 直接等于"="
e.OrWhereParam += data[0].(string) + "=?) "
e.WhereExec = append(e.WhereExec, data[1])
} else if dataType == 3 {
// 字符串参数
// 获取符号
data2 := strings.Trim(strings.ToLower(data[1].(string)), " ")
// 判断是否是in操作
if data2 == "in" || data2 == "not in" {
// 判断传入的是切片
reType := reflect.TypeOf(data[2]).Kind()
if reType != reflect.Slice && reType != reflect.Array {
panic("in类操作传入参数必须是切片或数组")
}
// 反射切片值
v := reflect.ValueOf(data[2])
// 数组/切片长度
l := v.Len()
// 占位符
ps := make([]string, l)
for i := 0; i < l; i++ {
ps[i] = "?"
e.WhereExec = append(e.WhereExec, v.Index(i).Interface())
}
// 拼接
e.OrWhereParam += data[0].(string) + " " + data2 + " (" + strings.Join(ps, ",") + ")) "
} else {
e.OrWhereParam += data[0].(string) + " " + data[1].(string) + " ?) "
e.WhereExec = append(e.WhereExec, data[2])
}
}
return e
}
注意:必须先调用Where后才能调用OrWhere方法。因为一般用到了or,前面肯定也有前置where条件。
为了使这个方法更简单使用,不搞得复杂,这种方式的or实质上是针对多次调用where之间的去使用的,是不支持同一个where里面的数据是or关系的,一般也不会用到这种,不然性能也是极差。
六、删除 Delete
思路:Delete方法是直接与数据库交互的最后一步,不需要再去处理各种数据,只需要把Where或OrWhere方法中组装好的参数和值写入预处理和执行方法中即可
代码如下:
// Delete 删除方法
func (e *DbEngine) Delete() (int64, error) {
// 拼接sql
e.Prepare = "delete from " + e.GetTable()
// 若where参数不为空
if e.WhereParam != "" || e.OrWhereParam != "" {
e.Prepare += " where " + e.WhereParam + e.OrWhereParam
}
// limit不为空
if e.LimitParam != "" {
e.Prepare += "limit " + e.LimitParam
}
var stmt *sql.Stmt
var err error
// 预处理
stmt, err = e.Db.Prepare(e.Prepare)
if err != nil {
return 0, err
}
e.AllExec = e.WhereExec
// 生成sql
e.generateSql()
// 执行
result, err := stmt.Exec(e.AllExec...)
if err != nil {
return 0, e.setError(err)
}
// 影响行数
rowsAffected, err := result.RowsAffected()
if err != nil {
return 0, e.setError(err)
}
return rowsAffected, nil
}
调用方式
// delete from users where (id = 1) or (username = 'admin');
rowsAffected, err := e.Table("users").Where("id", 1).OrWhere("username", "admin").Delete()
7、更新 Update
update users set jobname = 'web' where (id = 1) or (username = 'admin');
思路:Update和Delete都是与数据库交互的最后一步,但是Delete不同的是他不仅需要Where或OrWhere方法中组装好的参数和值,还需要处理要更新的数据,也就是set jobname='web’这一步,我们可以参考下Where方法,搞一个字符串参数传值和结构体传值的形式,用起来灵活一点,具体实现方式跟实现Insert几乎一样。
例如:
// 字符串参数
e.Table("users").Where("id", 1).Update("jobname", "web")
// 结构体参数
e.Table("users").Where("id", 1).Update(user1)
代码如下:
// Update 更新
func (e *DbEngine) Update(data ...interface{}) (int64, error) {
// 判断结构体参数还是字符串参数
var dataType int
if len(data) == 1 {
dataType = 1 // 结构体参数
} else if len(data) == 2 {
dataType = 2 // 直接等于"="的情况,省去等于号,只传字段和值两个参数
} else {
return 0, errors.New("参数错误")
}
// 结构体参数
if dataType == 1 {
t := reflect.TypeOf(data[0])
v := reflect.ValueOf(data[0])
// 字段名
var fieldNameArr []string
for i := 0; i < t.NumField(); i++ {
// 判断字段值
if !v.Field(i).CanInterface() {
continue
}
// 解析tag
sqlTag := t.Field(i).Tag.Get("sql")
if sqlTag != "" {
fieldNameArr = append(fieldNameArr, strings.Split(sqlTag, ",")[0]+"=?")
} else {
fieldNameArr = append(fieldNameArr, t.Field(i).Name+"=?")
}
// 更新值
e.UpdateExec = append(e.UpdateExec, v.Field(i).Interface())
}
// 更新字段
e.UpdateParam += strings.Join(fieldNameArr, ",")
} else if dataType == 2 {
// 直接等于的情况
e.UpdateParam += data[0].(string) + "=?"
e.UpdateExec = append(e.UpdateExec, data[1])
}
// 拼接sql
e.Prepare = "update " + e.GetTable() + " set " + e.UpdateParam
// where不为空
if e.WhereParam != "" || e.OrWhereParam != "" {
e.Prepare += " where " + e.WhereParam + e.OrWhereParam
}
// limit不为空
if e.LimitParam != "" {
e.Prepare += "limit " + e.LimitParam
}
var stmt *sql.Stmt
var err error
// 预处理
stmt, err = e.Db.Prepare(e.Prepare)
if err != nil {
return 0, e.setError(err)
}
// 合并UpdateExec和WhereExec,即update table set params1 = ? where id = ?的问号参数值
if e.WhereExec != nil {
// 此处切记要加...,目的是把切片里的一个个参数都追加到UpdateExec后面,形成类似php中array_merge的效果
e.AllExec = append(e.UpdateExec, e.WhereExec...)
}
// 生成sql
e.generateSql()
// 执行
result, err := stmt.Exec(e.AllExec...)
if err != nil {
return 0, e.setError(err)
}
// 影响行数
id, _ := result.RowsAffected()
return id, nil
}
八、查询 Select
Go的原生代码中,查询的写法是没有Prepare和Exec的,而是通过QueryRow和Query方法来查询数据,然后再用Scan去绑定赋值数据,写法真的麻烦
// 单条查询,需要先把查询的字段定义好,然后再Scan()去绑定赋值
var username, jobname, createdAt string
err := db.QueryRow("select username, jobname, created_at from users where id=?", 1).Scan(&username, &jobname, &createdAt)
if err != nil {
fmt.Println("QueryRow error :", err.Error())
}
fmt.Println(username, jobname, createdAt)
// 多条查询,先把需要查询的字段结构体定义好,再初始化切片结构体,for循环给这个切片赋值
rows, err := db.Query("select username, jobname, created_at from users where username=?", "admin")
if err != nil {
fmt.Println("QueryRow error :", err.Error())
}
type Users struct {
Username string `json:"username"`
Jobname string `json:"jobname"`
CreatedAt string `json:"created_at"`
}
// 初始化切片结构体
var user []Users
for rows.Next() {
var username1, jobname1, createdAt1 string
if err := rows.Scan(&username1, &jobname1, &createdAt1); err != nil {
fmt.Println("Query error :", err.Error())
}
user = append(user, Users{Username: username1, Jobname: jobname1, CreatedAt: createdAt1})
}
思路:
为了简化查询内部实现的复杂度,单条查询也用Query方法,然后使用limit 1限制条数,继而满足条件。这样处理的话多条查询和单条查询都能使用同样的方法处理,只是对条数做出限制。
要先定义结构体,然后再初始化一个切片结构体,这操作有点麻烦,我的想法是直接按照表里的字段名返回在切片map中。
那怎么拿到表里的字段有哪些呢?DB.Query提供了Columns的方法,能返回本次查询的表字段。
观察一下,rows.Scan()的数据绑定,它是需要先定义字段变量和类型,每次循环的时候把地址传给Scan方法,通过地址来动态引用赋值,所以这几个字段变量名字不重要,反正最后都是传他们的地址。
- 多条查询之切片map
// Select 多条查询,返回值为切片map
func (e *DbEngine) Select() ([]map[string]string, error) {
// 拼接sql
e.Prepare = "select " + e.FieldParam + " from " + e.GetTable()
// where不为空
if e.WhereParam != "" || e.OrWhereParam != "" {
e.Prepare += " where " + e.WhereParam + e.OrWhereParam
}
// group不为空
if e.GroupParam != "" {
e.Prepare += " group by " + e.GroupParam
}
// having不为空
if e.HavingParam != "" {
e.Prepare += " having " + e.HavingParam
}
// order by不为空
if e.OrderParam != "" {
e.Prepare += " order by " + e.OrderParam
}
// limit不为空
if e.LimitParam != "" {
e.Prepare += " limit " + e.LimitParam
}
// 查询条件对应值
e.AllExec = e.WhereExec
// 生成sql
e.generateSql()
// 查询数据
rows, err := e.Db.Query(e.Prepare, e.AllExec...)
if err != nil {
return nil, e.setError(err)
}
// 查询字段名
columns, err := rows.Columns()
if err != nil {
return nil, e.setError(err)
}
// 每个字段的值
values := make([][]byte, len(columns))
// 存放对应的每个字段值的地址
scans := make([]interface{}, len(columns))
// values中的值地址绑定到scan的每个元素上
for i := range values {
scans[i] = &values[i]
}
results := make([]map[string]string, 0)
for rows.Next() {
// 使用可变参数解决需要定义的字段数无法确定的问题
if err := rows.Scan(scans...); err != nil {
// Scan查询出来的值放到scans[i] = &values[i],也就是每行数据都放在values里
return nil, e.setError(err)
}
// 每行数据
row := make(map[string]string)
// 循环values数据,通过相同的键,取columns里面对应的字段,生成1个map
for k, v := range values {
key := columns[k]
row[key] = string(v)
}
results = append(results, row)
}
return results, nil
}
注意:这个方法有一个小问题,会把所有字段的类型都转为字符串,但是理论上影响不大。
Select方法调用,如下:
result, err := e.Table("users").Where("created_at", ">=", 1645597211).Select()
// fmt.Printf("%T\n", result)
[]map[string]string
// fmt.Println(result)
[map[id:1 username:admin jobname:dev] map[id:2 username:zhangsan jobname:dev]]
- 单条查询之map
思路:单条查询就是在多条查询基础上增加limit 1的条数限制,然后取出数据返回在map中
// SelectOne 单条查询,返回值为map
func (e *DbEngine) SelectOne() (map[string]string, error) {
// limit 1
results, err := e.Limit(1).Select()
if err != nil {
return nil, e.setError(err)
}
// 判断结果是否为空
if len(results) == 0 {
return nil, nil
} else {
return results[0], nil
}
}
SelectOne方法调用,如下:
result, err := e.Table("users").Where("id", "=", 1).SelectOne()
// fmt.Printf("%T\n", result)
map[string]string
// fmt.Println(result)
map[id:1 username:admin jobname:dev]
有的小伙伴可能还是习惯了类似GORM那种查询操作,先定义好结构体,然后进行引用赋值。
如下:
type Users struct {
Username string `sql:"username"`
JobName string `sql:"jobname"`
CreatedAt int64 `sql:"created_at"`
}
var user []Users
// select * from users where created_at >= 1645597211;
err := e.Table("users").Where("created_at", ">=", 1645597211).Find(&user)
if err != nil {
panic(err)
}
fmt.Println(user)
思路:
定义一个结构体,字段属性定义好对应的tag标签
初始化空切片结构体,然后通过&取地址符传到Find方法
Find方法内部先获取表的字段,再通过tag标签和各种反射操作,将数据绑定到传入的切片结构体进行赋值。这一步想想就觉得麻烦
话不多说,直接上代码:
- 多条查询之切片结构体
// Find 多条查询,切片结构体
func (e *DbEngine) Find(result interface{}) error {
// 判断参数是否为指针变量
if reflect.ValueOf(result).Kind() != reflect.Ptr {
return e.setError(errors.New("参数错误"))
}
// 判断是空指针
if reflect.ValueOf(result).IsNil() {
return e.setError(errors.New("参数错误,不能是空指针"))
}
// 拼接sql
e.Prepare = "select " + e.FieldParam + " from " + e.GetTable()
// where不为空
if e.WhereParam != "" || e.OrWhereParam != "" {
e.Prepare += " where " + e.WhereParam + e.OrWhereParam
}
// group不为空
if e.GroupParam != "" {
e.Prepare += " group by " + e.GroupParam
}
// having不为空
if e.HavingParam != "" {
e.Prepare += " having " + e.HavingParam
}
// order by不为空
if e.OrderParam != "" {
e.Prepare += " order by " + e.OrderParam
}
// limit不为空
if e.LimitParam != "" {
e.Prepare += " limit " + e.LimitParam
}
// 查询条件对应值
e.AllExec = e.WhereExec
// 生成sql
e.generateSql()
// 查询数据
rows, err := e.Db.Query(e.Prepare, e.AllExec...)
if err != nil {
return e.setError(err)
}
// 查询字段名
columns, err := rows.Columns()
if err != nil {
return e.setError(err)
}
// 每个字段的值
values := make([][]byte, len(columns))
// 存放对应的每个字段值的地址
scans := make([]interface{}, len(columns))
// values中的值地址绑定到scan的每个元素上
for i := range values {
scans[i] = &values[i]
}
// 获取指针切片结构体的值
sliceStruct := reflect.ValueOf(result).Elem()
// 获取单个结构体的类型
structType := sliceStruct.Type().Elem()
for rows.Next() {
// 根据结构体类型,生成一个新的结构体
newStruct := reflect.New(structType).Elem()
if err := rows.Scan(scans...); err != nil {
// Scan查询出来的值放到scans[i] = &values[i],也就是每行数据都放在values里
return e.setError(err)
}
// 遍历每行数据各个字段
for k, v := range values {
key := columns[k]
value := string(v)
num := structType.NumField()
// 遍历结构体
for i := 0; i < num; i++ {
// 解析tag标签
sqlTag := structType.Field(i).Tag.Get("sql")
var fieldName string
if sqlTag != "" {
fieldName = strings.Split(sqlTag, ",")[0]
} else {
fieldName = structType.Field(i).Name
}
// 结构体中没这个字段属性
if key != fieldName {
continue
}
// 反射判断类型,赋值结构体
if err := e.reflectSetStruct(newStruct, i, value); err != nil {
return err
}
}
}
// 赋值,reflect.Append将新结构体的值追加到切片结构体中;sliceStruct.Set将追加后的结果赋值到传入的指针切片结构体中
sliceStruct.Set(reflect.Append(sliceStruct, newStruct))
}
return nil
}
因为结构体严格区分字段类型,会有一个结构体赋值问题,处理方法:将查询的结果集中各个字段值的类型进行遍历判断,然后转换成结构体对应的字段类型赋值回去。同样是基于反射处理。
代码如下:
// 通过反射判断结构体字段类型,然后将对应的结构体的字段值转化为相应类型,然后再赋值回结构体中
func (e *DbEngine) reflectSetStruct(target reflect.Value, i int, value string) error {
switch target.Field(i).Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
res, err := strconv.ParseInt(value, 10, 64)
if err != nil {
return e.setError(err)
}
target.Field(i).SetInt(res)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
res, err := strconv.ParseUint(value, 10, 64)
if err != nil {
return e.setError(err)
}
target.Field(i).SetUint(res)
case reflect.Float64:
res, err := strconv.ParseFloat(value, 64)
if err != nil {
return e.setError(err)
}
target.Field(i).SetFloat(res)
case reflect.String:
target.Field(i).SetString(value)
case reflect.Bool:
res, err := strconv.ParseBool(value)
if err != nil {
return e.setError(err)
}
target.Field(i).SetBool(res)
}
return nil
}
- 单条查询之结构体
同样的单条查询就是在多条查询基础上增加limit 1的条数限制,然后取出数据返回在结构体中。但是,有个问题:我们的Find方法的参数是一个切片结构体指针,而单条查询传入的应该是单个结构体指针,类型不匹配,还得想办法搞定他(同样还是查询了反射的文档资料得出答案)。
代码如下:
// FindOne 单条查询,结构体
func (e *DbEngine) FindOne(result interface{}) error {
// 获取结构体原始值
structValue := reflect.Indirect(reflect.ValueOf(result))
// 反射出一个相同结构体类型的切片
structSlice := reflect.New(reflect.SliceOf(structValue.Type())).Elem()
// 查询数据,传入指针
if err := e.Limit(1).Find(structSlice.Addr().Interface()); err != nil {
return err
}
// 判断返回值长度
if structSlice.Len() == 0 {
return e.setError(errors.New("Not Found Error"))
}
// 取切片里的第0个数据,并赋值给原始值结构体指针
structValue.Set(structSlice.Index(0))
return nil
}
九、设置查询字段 Field
平时很多人查询数据的时候都喜欢用select* ,把表字段都查出来,在一些大表中执行查询的话是非差低效和浪费资源的。
Field()方法指定本次查询的字段,实现方式也比较简单,直接赋值给FieldParam属性即可
代码如下:
// Field 设置查询字段
func (e *DbEngine) Field(field string) *DbEngine {
e.FieldParam = field
return e
}
注意:FieldParam的初始值是“”,一开始在NewMysql里面已经初始化过。所以即使没进行字段设置,Prepare的值也是select,也是不影响sql的完整性。
我们是直接传入字段列名,并没有对传入字段名做检验。这个将在第二版中进行优化。
调用方式如下:
// 不可放置在前后两端,中间部分任意位置都可以
e.Table("users").Field("id,username").Where("created_at", ">=", 1645597211).Select()
十、设置条数 Limit
limit一般用来控制获取的数据量大小或者分页,如:单条查询 limit 1、分页查询(一页十条数据) limit 0,9。所以我们的Limit方法也实现这两种。
如:
// 查询一条
e.Table("users").Field("id,username").Where("created_at", ">=", 1645597211).Limit(1).Select()
// 查询十条
e.Table("users").Field("id,username").Where("created_at", ">=", 1645597211).Limit(0, 9).Select()
思路:有两种方式传参,那同样也是类似于Insert的传参方式,使用可变参数。同时对参数进行判断长度和限制长度。然后将对应的limt语句拼接好赋值到LimitParamb中
代码如下:
// Limit 分页
func (e *DbEngine) Limit(limit ...int64) *DbEngine {
if len(limit) == 1 {
e.LimitParam = strconv.Itoa(int(limit[0]))
} else if len(limit) == 2 {
e.LimitParam = strconv.Itoa(int(limit[0])) + "," + strconv.Itoa(int(limit[1]))
} else {
panic("参数错误")
}
return e
}
十一、排序 Order
排序的实现思路也是比较简单:根据字段,升序asc、降序desc,参数为字符串传入,然后赋值到OrderParam中即可
代码如下:
// Order 排序
func (e *DbEngine) Order(order string) *DbEngine {
if order == "" {
panic("参数错误")
}
// 多次调用判断
if e.OrderParam != "" {
e.OrderParam += ","
}
e.OrderParam += order
return e
}
调用方式如下:
// 按ID降序排列
e.Table("users").Field("id,username").Where("created_at", ">=", 1645597211).Order("id desc").Select()
// 先按ID升序排列,再按分数降序排列
e.Table("users").Field("id,username").Where("created_at", ">=", 1645597211).Order("id asc,score desc").Select()
十二、聚合查询 Count/Sum/Max/Min/Avg
# 都是类似写法(函数(字段)),不一一列举了
select count(*) from users;
...
思路:
聚合函数在平时查询统计的时候用的还是比较多,它们的SQL拼接方式都类似,把函数和字段替换到select* 里面。
他们聚合查询后的数据都只有一条。所以我就不考虑用Select中的Db.Query+for next方法,直接用Db.QueryRow方法,此方法就是原生的查询单条的方式。
既然几个聚合函数写法类似,那就抽一个通用方法,然后具体暴漏出来的聚合方法去调用他即可。
- 公共部分:聚合查询
// 公共部分:聚合查询
func (e *DbEngine) aggregateQuery(name, field string) (interface{}, error) {
// 拼接sql
e.Prepare = "select " + name + "(" + field + ") as aggre from " + e.GetTable()
// where不为空
if e.WhereParam != "" || e.OrWhereParam != "" {
e.Prepare += " where " + e.WhereParam + e.OrWhereParam
}
// limit不为空
if e.LimitParam != "" {
e.Prepare += " limit " + e.LimitParam
}
e.AllExec = e.WhereExec
// 生成sql
e.generateSql()
// 执行绑定,由于最终的聚合结果值不确定类型,可能是整数、小数、浮点数(平均值),所以定义了一个interface变量
var aggre interface{}
// 查询单条
err := e.Db.QueryRow(e.Prepare, e.AllExec...).Scan(&aggre)
if err != nil {
return nil, e.setError(err)
}
return aggre, err
}
- 总数 Count
// Count 总数
func (e *DbEngine) Count() (int64, error) {
count, err := e.aggregateQuery("count", "*")
if err != nil {
return 0, e.setError(err)
}
return count.(int64), err
}
- 总和 Sum
// Sum 总和
func (e *DbEngine) Sum(field string) (string, error) {
sum, err := e.aggregateQuery("sum", field)
if err != nil {
return "0", e.setError(err)
}
return string(sum.([]byte)), nil
}
- 最大值 Max
// Max 最大值
func (e *DbEngine) Max(field string) (string, error) {
max, err := e.aggregateQuery("max", field)
if err != nil {
return "0", e.setError(err)
}
return string(max.([]byte)), nil
}
- 最小值 Min
// Min 最小值
func (e *DbEngine) Min(field string) (string, error) {
min, err := e.aggregateQuery("min", field)
if err != nil {
return "0", e.setError(err)
}
return string(min.([]byte)), nil
}
- 平均值 Avg
// Avg 平均值
func (e *DbEngine) Avg(field string) (string, error) {
avg, err := e.aggregateQuery("avg", field)
if err != nil {
return "0", e.setError(err)
}
return string(avg.([]byte)), nil
}
调用方式如下:
count, err := e.Table("users").Where("id", ">=", 1).Count()
sum, err := e.Table("users").Where("id", ">=", 1).Sum("id")
max, err := e.Table("users").Where("id", ">=", 1).Max("id")
min, err := e.Table("users").Where("id", ">=", 1).Min("id")
avg, err := e.Table("users").Where("id", ">=", 1).Avg("id")
十三、分组 Group
分组用于我们对某1个或者几个字段进行归类,然后查询归类后的数据。可以搭配上Field(count( * ) as num)来做更加具体的分组查询
思路:直接就是可变参数逗号拼接字符串
代码如下:
// Group 分组
func (e *DbEngine) Group(group ...string) *DbEngine {
if len(group) != 0 {
e.GroupParam = strings.Join(group, ",")
}
return e
}
调用方式如下:
result,err := e.Table("users").Field("jobname, count(*) as num").Group("jobname").Select()
十四、分组后判断 Having
思路:Having的作用和Where是一样的,只不过Having是作用在Group后的。同样的,我们要实现和Where一样的参数支持:结构体、字符串
代码如下:
// Having 分组后判断
func (e *DbEngine) Having(having ...interface{}) *DbEngine {
// 判断是结构体
var dataType int
if len(having) == 1 {
dataType = 1
} else if len(having) == 2 {
dataType = 2
} else if len(having) == 3 {
dataType = 3
} else {
panic("参数错误")
}
// 多次调用判断
if e.HavingParam != "" {
e.HavingParam += "and ("
} else {
e.HavingParam += "("
}
// 若是结构体
if dataType == 1 {
// 反射type和value
t := reflect.TypeOf(having[0])
v := reflect.ValueOf(having[0])
// 字段名
var fieldNameArray []string
// 字段总数
num := t.NumField()
for i := 0; i < num; i++ {
// 判断字段值
if !v.Field(i).CanInterface() {
continue
}
// 解析tag
sqlTag := t.Field(i).Tag.Get("sql")
if sqlTag != "" {
fieldNameArray = append(fieldNameArray, strings.Split(sqlTag, ",")[0]+"=?")
} else {
fieldNameArray = append(fieldNameArray, t.Field(i).Name+"=?")
}
// 组装值
e.WhereExec = append(e.WhereExec, v.Field(i).Interface())
}
// 拼接
e.HavingParam += strings.Join(fieldNameArray, " and ") + ") "
} else if dataType == 2 {
// 直接等于"="
e.HavingParam += having[0].(string) + "=?) "
e.WhereExec = append(e.WhereExec, having[1])
} else if dataType == 3 {
// 字符串参数
e.HavingParam += having[0].(string) + " " + having[1].(string) + " ?) "
e.WhereExec = append(e.WhereExec, having[2])
}
return e
}
调用方式如下:
result,err := e.Table("users").Field("jobname, count(*) as num").Group("jobname").Having("jobname", "dev").Select()
if err != nil {
panic(err)
}
十五、获取执行SQL GetLastSql
基本上所有的ORM方法本质上都是组装好原生Sql。我们排查数据库错误时多数都是需要通过最终的sql进行分析。所以这个ORM也需要提供这个方法。
思路:我们所有的ORM方法都已经在预处理的部分把sql组装好了,占位符对应的参数值也按占位符顺序一一对应排列好了。所以我们就只需要吧问号?替换成对应的参数值即可。
代码如下:
// 生成SQL,在链式调用最后一步执行的方法中去调用这个方法。就可以生成执行的sql语句。调用GetLastSql就可以打印出来了。
func (e *DbEngine) generateSql() {
e.Sql = e.Prepare
for _, i2 := range e.AllExec {
// 还可以继续根据需要补充类型
switch i2.(type) {
case int:
e.Sql = strings.Replace(e.Sql, "?", strconv.Itoa(i2.(int)), 1)
case int64:
e.Sql = strings.Replace(e.Sql, "?", strconv.FormatInt(i2.(int64), 10), 1)
case bool:
e.Sql = strings.Replace(e.Sql, "?", strconv.FormatBool(i2.(bool)), 1)
default:
e.Sql = strings.Replace(e.Sql, "?", "'"+i2.(string)+"'", 1)
}
}
}
// GetLastSql 获取执行的sql
func (e *DbEngine) GetLastSql() string {
return e.Sql
}
调用方式如下:
result, err := e.Table("users").Where("id", ">=", 1).Order("id desc").Select()
fmt.Println(e.GetLastSql()) //select * from users where (id >= 1) order by id desc
十六、执行原生sql Exec/Query
虽然不推荐直接执行sql语句,但在一些场景中有一些比较复杂的sql,这时就需要用到。
- sql执行原生增删改操作 Exec
思路:直接使用Exec方法暴力执行sql
// Exec 直接执行增删改sql
func (e *DbEngine) Exec(sql string) (id int64, err error) {
result, err := e.Db.Exec(sql)
e.Sql = sql
if err != nil {
return 0, e.setError(err)
}
// insert返回ID,其他返回影响行数
if strings.Contains(sql, "insert") {
lastInsertId, _ := result.LastInsertId()
return lastInsertId, nil
} else {
rowsAffected, _ := result.RowsAffected()
return rowsAffected, nil
}
}
调用如下:
result, err:= e.Exec("insert into users(username,jobname,created_at) values('root', 'leader', 1645597211)");
result, err := e.Exec("delete from users where id=1")
...
- sql执行原生查询操作 Query
思路:把Select方法拿来改下,删减一些东西就行了
// Query 直接执行查询sql
func (e *DbEngine) Query(sql string) ([]map[string]string, error) {
rows, err := e.Db.Query(sql)
// 直接写入sql
e.Sql = sql
if err != nil {
return nil, e.setError(err)
}
// 查询字段名
columns, err := rows.Columns()
if err != nil {
return nil, e.setError(err)
}
// 每个字段的值
values := make([][]byte, len(columns))
// 存放对应的每个字段值的地址
scans := make([]interface{}, len(columns))
// values中的值地址绑定到scan的每个元素上
for i := range values {
scans[i] = &values[i]
}
results := make([]map[string]string, 0)
for rows.Next() {
// 使用可变参数解决需要定义的字段数无法确定的问题
if err := rows.Scan(scans...); err != nil {
// Scan查询出来的值放到scans[i] = &values[i],也就是每行数据都放在values里
return nil, e.setError(err)
}
// 每行数据
row := make(map[string]string)
for k, v := range values {
key := columns[k]
row[key] = string(v)
}
results = append(results, row)
}
return results, nil
}
调用如下:
result, err := e.Query("SELECT * FROM users limit 1")
十七、事务 Begin/Commit/Rollback
事务操作平时在业务中用的也是很多,它用于在多次执行增删改的操作的时候,如果其中一个出现问题,可以一起回滚数据,确保了数据的一致性。
- 开启事务 Begin
思路:开启事务只需要调用原生库的方法,然后设置事务状态即可
// Begin 开启事务
func (e *DbEngine) Begin() error {
// 调用原生方法
tx, err := e.Db.Begin()
if err != nil {
return e.setError(err)
}
// 事务资源
e.Tx = tx
// 1代表开启了事务
e.TransStatus = 1
return nil
}
关键的一步:在具体的增删改方法中,通过事务状态去判断是否执行事务操作
// 如Inert方法中
func (e *DbEngine) insertData(data interface{}) (int64, error) {
...
// 原先这里已经有初始化过这个字段,用于判断是否是事务
var stmt *sql.Stmt
var err error
// 只需要在原先这个地方增加这一个事务状态判断即可,后续其他操作不需改变。
if e.TransStatus == 1 {
stmt, err = e.Tx.Prepare(e.Prepare)
} else {
stmt, err = e.Db.Prepare(e.Prepare)
}
...
}
- 回滚 Rollback
回滚是当我们的事务执行出现问题后,发送回滚指令给mysql服务器,请求将执行的结果还原到事务开始之前。
思路:实现比较简单,只需把事务状态归0,调用原生回滚即可
// Rollback 事务回滚
func (e *DbEngine) Rollback() error {
e.TransStatus = 0
return e.Tx.Rollback()
}
- 提交事务 Commit
提交事务是代表事务中所有的操作都执行成功,然后发送确认提交指令给mysql服务器,如果不执行这个确认提交指令的话,数据操作是不会生效的。
思路:与实现回滚操作一样
// Commit 提交事务
func (e *DbEngine) Commit() error {
e.TransStatus = 0
return e.Tx.Commit()
}
调用例子如下:
err := e.Begin()
flag := true
if err != nil {
panic(err)
}
result1, err1 := e.Table("users").Where("id", "=", 1).Update("username", "lisi")
if err1 != nil {
flag = false
panic(err)
}
// 更新失败/找不到数据
if result1 <= 0 {
flag = false
fmt.Println("update failed")
}
fmt.Println(result1)
fmt.Println(e.GetLastSql())
result2, err2 := e.Table("users").Where("id", "=", 2).Delete()
if err2 != nil {
flag = false
fmt.Println(err2.Error())
}
if result2 <= 0 {
flag = false
fmt.Println("delete failed")
}
fmt.Println(result2)
fmt.Println(e.GetLastSql())
// 每一步的执行如果失败都将flag设置为false,最后通过判断flag来回滚和提交
if flag {
_ = e.Commit()
fmt.Println("success")
} else {
_ = e.Rollback()
fmt.Println("error")
}
基准测试
自己写完之后,还是很想跟GORM做下对比,看看造的轮子性能怎么样,能不能用。先搞个测试的200W数据测下。
package mysql
import (
"gorm.io/driver/mysql"
"gorm.io/gorm"
"testing"
)
func BenchmarkSelect(b *testing.B) {
e, _ := NewMysql("root", "123456", "127.0.0.1:13306", "test")
type Author struct {
Id int `sql:"id"`
Name string `sql:"name"`
Age int `sql:"age"`
}
var author []Author
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = e.Table("author").Where("id", ">=", 0).Limit(100).Find(&author)
}
b.StopTimer()
}
func BenchmarkGormSelect(b *testing.B) {
dsn := "root:123456@tcp(127.0.0.1:13306)/test?charset=utf8mb4&parseTime=True&loc=Local"
db, _ := gorm.Open(mysql.Open(dsn), &gorm.Config{})
type Author struct {
Id int `gorm:"id"`
Name string `gorm:"name"`
Age int `gorm:"age"`
}
var author []Author
b.ResetTimer()
for i := 0; i < b.N; i++ {
db.Table("author").Where("id >= ?", 50).Limit(50).Find(&author)
}
b.StopTimer()
}
func BenchmarkUpdate(b *testing.B) {
e, _ := NewMysql("root", "123456", "127.0.0.1:13306", "test")
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = e.Table("author").Where("id", "=", 27).Update("age", 20)
}
b.StopTimer()
}
func BenchmarkGormUpdate(b *testing.B) {
dsn := "root:123456@tcp(127.0.0.1:13306)/test?charset=utf8mb4&parseTime=True&loc=Local"
db, _ := gorm.Open(mysql.Open(dsn), &gorm.Config{})
b.ResetTimer()
for i := 0; i < b.N; i++ {
db.Table("author").Where("id = ?", 27).Update("age", 18)
}
b.StopTimer()
}
基准测试如下:
➜ mysql go test -bench=. -benchmem
goos: darwin
goarch: amd64
pkg: demo/mysql
cpu: Intel(R) Core(TM) i7-1068NG7 CPU @ 2.30GHz
BenchmarkSmallormSelect-8 1296 843769 ns/op 911 B/op 25 allocs/op
BenchmarkGormSelect-8 598 1998827 ns/op 29250 B/op 1058 allocs/op
BenchmarkSmallormUpdate-8 1197 864404 ns/op 727 B/op 21 allocs/op
BenchmarkGormUpdate-8 314 4216470 ns/op 6246 B/op 76 allocs/op
PASS
ok demo/mysql 6.890s
总结:到这里ORM的基本功能大致实现了90%,性能也还不错。这里面难点主要在Insert、Select两个方法的实现。后续的版本可以做下优化,还可以继续往下迭代的点:联表查询、日志、安全以及性能优化。
ps:建议别自己造轮子,有现成好用并且长期维护的ORM用起来不香吗,以上更多的是锻炼逻辑思维能力,主要是学习下实现思路。
我是六涛sheliutao,文章编写总结不易,转载注明出处,喜欢本篇文章的小伙伴欢迎点赞、关注,有问题可以评论区留言或者私信我,相互交流!!!