8`updated_at`timestampNOTNULLDEFAULTCURRENT_TIMESTAMPONUPDATECURRENT_TIMESTAMPCOMMENT'更新时间',
9PRIMARY KEY( `id`),
10KEY`idx_email`( `email`)
11) ENGINE= InnoDBDEFAULTCHARSET=utf8 COMMENT= '用户表';
同时,golang代码里定义一个与之对应的struct :
1typeUser struct{
2ID int64`json:"id"`// 自增主键
3Age int64`json:"age"`// 年龄
4FirstName string`json:"first_name"`// 姓
5LastName string`json:"last_name"`// 名
6Email string`json:"email"`// 邮箱地址
7CreatedAt time.Time `json:"created_at"`// 创建时间
8UpdatedAt time.Time `json:"updated_at"`// 更新时间
9}
与mysql交互需要用到一个go标准包和一个驱动 ,代码import 如下:
1packageorm
2
3import(
4"database/sql"
5
6//register driver
7_ "github.com/go-sql-driver/mysql"
8)
首先按照database 维度建立连接,写一个可以返回mysql连接的函数:
1//Connect db by dsn e.g. "user:password@tcp(127.0.0.1:3306)/dbname"
2funcConnect(dsn string) (*sql.DB, error){
3conn, err := sql.Open( "mysql", dsn)
4iferr != nil{
5returnnil, err
6}
7//设置连接池
8conn.SetMaxOpenConns( 100)
9conn.SetMaxIdleConns( 10)
10conn.SetConnMaxLifetime( 10* time.Minute)
11returnconn, conn.Ping
12}
设计一个struct 用于实现orm(go不是面向对象的语言,没有class ):
1//Query will build a sql
2type Query struct{
3db *sql.DB
4table string
5}
最后将通过 Query 拼接出sql语句与mysql交互,所以写一个绑定函数:
1//Table bind db and table
2funcTable(db *sql.DB, tableName string) func* Query{
3returnfunc* Query{
4return&Query{
5db: db,
6table: tableName,
7}
8}
9}
返回值是一个闭包函数,这样使用时直接调用这个闭包函数就可以获取一个绑定好的database和table的Query ,比如现在有数据库orm_db 和user 表:
1//全局变量ormDB和users
2ormDB, _ := Connect( "user:password@tcp(127.0.0.1:3306)/orm_db")
3users := Table(ormDB, "user")
4//调用
5users.Insert(...)
准备工作到此完成,下面进入正题。
Insert方法
首先分析一下标准 insert 语句:
1insertintouser(first_name, last_name) values( 'Tom', 'Cat'), ( 'Tom', 'Cruise')
把sql语句中变化的部分抽象出来,其实就是key (字段)和value (值),那么orm里的Insert 方法原型就有了,如下,参数是struct或者map,因为它们都能提供键值对:
1//Insert in can be *User, []*User, map[string]interface{}
2func(q *Query)Insert(in interface{}) ( int64, error) {
3varkeys, values [] string
4v := reflect.ValueOf(in)
5//剥离指针
6forv.Kind == reflect.Ptr {
7v = v.Elem
8}
9switchv.Kind {
10casereflect.Struct:
11keys, values = sKV(v)
12casereflect.Map:
13keys, values = mKV(v)
14casereflect.Slice:
15fori := 0; i < v.Len; i++ {
16//Kind是切片时,可以用Index方法遍历
17sv := v.Index(i)
18forsv.Kind == reflect.Ptr || sv.Kind == reflect.Interface {
19sv = sv.Elem
20}
21//切片元素不是struct或者指针,报错
22ifsv.Kind != reflect.Struct {
23return0, errors.New( "method Insert error: in slice is not structs")
24}
25//keys只保存一次就行,因为后面的都一样了
26iflen(keys) == 0{
27keys, values = sKV(sv)
28continue
29}
30_, val := sKV(sv)
31values = append(values, val...)
32}
33default:
34return0, errors.New( "method Insert error: type error")
35}
36//todo
37//...
38}
参数 in 可以是一个 User (前文定义好的结构体)实例的指针(或者指针集合),也可以是一个map,这两个结构都可以提供键值对,我们通过反射来分析它的 类型 ,然后根据类型执行相应的逻辑。
reflect包里的有两个重要结构 Type 和 Value ,Type是一个接口,定义了所有类型相关的api,reflect里的 *rtype 实现了这个接口,通过reflect.TypeOf函数可以获取任何传入值的 *rtype 。Value是一个struct,通过reflect.ValueOf函数获取,它在 *rtype 的基础上又封装了传入值的unsafe.Pointer类型的 地址 以及这个值的 元数据 。
在Type和Value之上还有一个 Kind ,它代表传入值的 原始类型 ,比如:
1typemyInt int
2vari myInt
3t := reflect.TypeOf(i)
4k := t.Kind
t是myInt,而k是int,Type和Kind是不同的,这一点要注意区分。
如果Type的Kind是指针、接口、切片、map等复合类型,可以调用Elem方法获取基类型。
如果Value的Kind是指针、接口,可以调用Elem方法获取实际值。
Value上还定义了一个 Interface 方法,它是ValueOf方法的反操作。
有了上面这些反射方法,我们可以封装一个 sKV 函数,它专门处理struct类型的值,获取key(取json tag)和value:
1funcsKV(v reflect.Value)([] string, [] string) {
2varkeys, values [] string
3t := v.Type
4forn := 0; n < t.NumField; n++ {
5tf := t.Field(n)
6vf := v.Field(n)
7//忽略非导出字段
8iftf.Anonymous {
9continue
10}
11//忽略无效、零值字段
12if!vf.IsValid || reflect.DeepEqual(vf.Interface, reflect.Zero(vf.Type).Interface) {
13continue
14}
15forvf.Type.Kind == reflect.Ptr {
16vf = vf.Elem
17}
18//有时候根据需求会组合struct,这里处理下,支持获取嵌套的struct tag和value
19//如果字段值是time类型之外的struct,递归获取keys和values
20ifvf.Kind == reflect.Struct && tf.Type.Name != "Time"{
21cKeys, cValues := sKV(vf)
22keys = append(keys, cKeys...)
23values = append(values, cValues...)
24continue
25}
26//根据字段的json tag获取key,忽略无tag字段
27key := strings.Split(tf.Tag.Get( "json"), ",")[ 0]
28ifkey == ""{
29continue
30}
31value := format(vf)
32ifvalue != ""{
33keys = append(keys, key)
34values = append(values, value)
35}
36}
37returnkeys, values
38}
sKV 函数里需要格式化字符串,那么定义一个format 函数。
time.Time 类型怎么转化成各种数据库的时间类型我有点拿不准,所以需要对比时间类型的值时,一律用unxi时间戳,感觉比较省事不会出错:
1funcformat(v reflect.Value)string{
2//断言出time类型直接转unix时间戳
3ift, ok := v.Interface.(time.Time); ok {
4returnfmt.Sprintf( "FROM_UNIXTIME(%d)", t.Unix)
5}
6switchv.Kind {
7casereflect.String:
8returnfmt.Sprintf( `'%s'`, v.Interface)
9casereflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
10returnfmt.Sprintf( `%d`, v.Interface)
11casereflect.Float32, reflect.Float64:
12returnfmt.Sprintf( `%f`, v.Interface)
13//如果是切片类型,遍历元素,递归格式化成"(, , , )"形式
14casereflect.Slice:
15varvalues [] string
16fori := 0; i < v.Len; i++ {
17values = append(values, format(v.Index(i)))
18}
19returnfmt.Sprintf( `(%s)`, strings.Join(values, ","))
20//接口类型剥一层递归
21casereflect.Interface:
22returnformat(v.Elem)
23}
24return""
25}
map类型处理起来和struct不同,所以我们再定义一个mKV 函数,目的和sKV一样,都是获取键值对:
1funcmKV(v reflect.Value)([] string, [] string) {
2varkeys, values [] string
3//获取map的key组成的切片
4mapKeys := v.MapKeys
5for_, key := rangemapKeys {
6value := format(v.MapIndex(key))
7ifvalue != ""{
8values = append(values, value)
9keys = append(keys, key.Interface.( string))
10}
11}
12returnkeys, values
13}
利用sKV和mKV函数取到键值对后,就得到了insert语句中的变化部分,补全Insert方法的todo 部分:
1// Insertincan be User, * User, [] User, []* User, map[ string] interface{}
2func (q * Query) Insert( ininterface{}) (int64, error) {
3//already done
4kl := len( keys)
5vl := len( values)
6ifkl == 0|| vl == 0{
7return0, errors.New( "method Insert error: no data")
8}
9varinsertValue string
10//插入多条记录时需要用 ","拼接一下 values
11ifkl < vl {
12vartmpValues [] string
13forkl <= vl {
14ifkl%( len( keys)) == 0{
15tmpValues = append(tmpValues, fmt.Sprintf( "(%s)", strings.Join( values[kl- len( keys):kl], ",")))
16}
17kl++
18}
19insertValue = strings.Join(tmpValues, ",")
20} else{
21insertValue = fmt.Sprintf( "(%s)", strings.Join( values, ","))
22}
23query:= fmt.Sprintf( `insert into %s (%s) values %s`, q.table, strings.Join( keys, ","), insertValue)
24log.Printf( "insert sql: %s", query)
25st, err := q.DB.Prepare( query)
26iferr != nil {
27return0, err
28}
29result, err := st.Exec
30iferr != nil {
31return0, err
32}
33returnresult.LastInsertId
34}
原理很简单,利用反射分析参数,取键值对,然后拼接sql语句,再通过mysql驱动入库。
调用示例:
1user1 := &User{
2Age: 30,
3FirstName: "Tom",
4LastName: "Cat",
5}
6user2 := User{
7Age: 30,
8FirstName: "Tom",
9LastName: "Curise",
10}
11user3 := User{
12Age: 30,
13FirstName: "Tom",
14LastName: "Hanks",
15}
16user4 := map[string]interface{}{
17"age": 30,
18"first_name": "Tom",
19"last_name": "Zzy",
20}
21users.Insert([]interface{}{user1, user2})
22users.Insert(user3)
23users.Insert(user4)
增删改查的 增 部分到此完成,因为查询语句非常复杂多变,所以有了数据后,先进行 查 。
Select方法
先分析一下标准 select 语句
1selectid, age fromuserwherefirst_name = 'Tom'andlast_name = 'Cat'
可见sql语句的变量部分是select 后面的字段和where 后面的键值对,所以我们需要一个Where 来方法构造查询条件,并且需要一个Select 方法最后执行查询,最终形成一个链式调用效果:
1varuser[]User
2users.Where(?) .WhereNot(?) .Limit(100) .Offset(100) .Order(" iddesc") .Only(" id", " age") .Select(& user)
所以需要改造Query如下,增加属性用于暂存链式调用中添加的值:
1//Query will build a sql
2type Query struct{
3db *sql.DB
4table string
5wheres [] string
6only [] string
7limit string
8offset string
9order string
10errs [] string
11}
为Query添加Where方法,支持struct和map参数,同时支持传如同"age > 10" 形式的字符串:
1//Where args can be string, User, *User, map[string]interface{}
2func(q *Query)Where(wheres ... interface{}) * Query{
3for_, w := rangewheres {
4v := reflect.ValueOf(w)
5forv.Kind == reflect.Ptr {
6v = v.Elem
7}
8switchv.Kind {
9casereflect.String:
10q.wheres = append(q.wheres, w.( string))
11casereflect.Struct:
12//todo
13casereflect.Map:
14//todo
15default:
16q.errs = append(q.errs, "method Where error: type error")
17}
18}
19returnq
20}
但是考虑到后面还会实现一个WhereNot 方法,所以把公共逻辑抽到一个where 函数里,并且直接复用之前的sKV、mKv函数获取键值对:
1funcwhere(eq bool, w interface{}) ( string, error) {
2varkeys, values [] string
3v := reflect.ValueOf(w)
4forv.Kind == reflect.Ptr {
5v = v.Elem
6}
7switchv.Kind {
8casereflect.String:
9returnw.( string), nil
10casereflect.Struct:
11keys, values = sKV(v)
12casereflect.Map:
13keys, values = mKV(v)
14default:
15return"", errors.New( "method Where error: type error")
16}
17iflen(keys) != len(values) {
18return"", errors.New( "method Where error: len(keys) not equal len(values))")
19}
20varwheres [] string
21//之前的format函数里,已经将切片类型值处理成"( , , ,)“形式
22foridx, key := rangekeys {
23ifeq {
24ifstrings.HasPrefix(values[idx], "(") && strings.HasSuffix(values[idx], ")") {
25wheres = append(wheres, fmt.Sprintf( "%s in %s", key, values[idx]))
26continue
27}
28wheres = append(wheres, fmt.Sprintf( "%s = %s", key, values[idx]))
29continue
30}
31ifstrings.HasPrefix(values[idx], "(") && strings.HasSuffix(values[idx], ")") {
32wheres = append(wheres, fmt.Sprintf( "%s not in %s", key, values[idx]))
33continue
34}
35wheres = append(wheres, fmt.Sprintf( "%s != %s", key, values[idx]))
36}
37returnstrings.Join(wheres, " and "), nil
38}
Where方法最终变成:
1//Where args can be string, User, *User, map[string]interface{}
2func(q *Query)Where(wheres ... interface{}) * Query{
3for_, w := rangewheres {
4str, err := where( true, w)
5q.wheres = append(q.wheres, str)
6iferr != nil{
7//因为需要达到链式调用的效果,所以把错误都搜集起来,最后再处理
8q.errs = append(q.errs, err.Error)
9}
10}
11returnq
12}
WhereNot把调用where的第一个参数改成false就行了,不贴代码了。
Limit 、Offset 、Order 、Only 这几个方法也很简单:
1//Limit .
2func(q *Query)Limit(limit uint) * Query{
3q.limit = fmt.Sprintf( "limit %d", limit)
4returnq
5}
6
7//Offset .
8func(q *Query)Offset(offset uint) * Query{
9q.offset = fmt.Sprintf( "offset %d", offset)
10returnq
11}
12
13//Order .
14func(q *Query)Order(ord string) * Query{
15q.order = fmt.Sprintf( "order by %s", ord)
16returnq
17}
18
19//Only 指定需要查询的字段
20func(q *Query)Only(columns ... string) * Query{
21q.only = append(q.only, columns...)
22returnq
23}
有了上面这些条件之后,我们可以写一个 toSQL 方法,把Query的属性组装成一条sql语句:
1func (q *Query) toSQL string{
2varwherestring
3iflen( q.wheres) > 0 {
4where= fmt.Sprintf(` where%s`, strings.Join(q.wheres, " and "))
5}
6sqlStr := fmt.Sprintf(` select%s from%s %s %s %s %s`, strings.Join(q.only, ","), q.table, where, q.order, q.limit, q.offset)
7log.Printf( "select sql: %s", sqlStr)
8returnsqlStr
9}
有了sql语句我们就可以查询数据了,但是想查一个表的全部字段时,为了方便,只需要传入对应的 struct ,比如 user 表对应的 User ,我们就直接分析这个struct,取它的tag作为查询字段,而不需要再调用Only方法指定字段。
另外,因为golang中的参数传递全都是值传递,要修改传入值,必须传值的指针,这里要注意一点:
1varuser User
2users.Select(&user)
3varuserPtr *User
4users.Select(user)
这两种声明方式是不同的,后者只声明了一个指针类型,是错误的。
综上,我们首先为Select方法做一下的参数检查,确保传入值是一个正确的指针,并确保only属性有值:
1// Selectdest must be a ptr, e.g. * user, *[] user, *[]* user, * map, *[] map, * int, *[] int
2func (q * Query) Select(dest interface{}) error{
3iflen(q.errs) != 0{
4returnerrors.New(strings.Join(q.errs, "
5" ))
6}
7t := reflect.TypeOf(dest)
8v := reflect.ValueOf(dest)
9typeErr := errors.New( "method Select error: type error")
10ift.Kind != reflect.Ptr {
11returntypeErr
12}
13//如果是用 varuserPtr * User方式声明的变量,则不可取址
14if!v.Elem.CanAddr {
15returntypeErr
16}
17t = t.Elem
18v = v.Elem
19//如果 only此时仍然为空,说明 Only方法未被调用,我们从 struct上取tag填充
20iflen(q.only) == 0{
21switcht.Kind {
22casereflect.Struct:
23ift.Name != "Time"{
24q.only = sK(v)
25}
26casereflect.Slice:
27//获取切片的基本类型给一个局部变量
28t := t.Elem
29ift.Kind == reflect.Ptr {
30t = t.Elem
31}
32ift.Kind == reflect.Struct {
33ift.Name != "Time"{
34q.only = sK(reflect.Zero(t))
35}
36}
37}
38}
39iflen(q.only) == 0{
40returnerrors.New( "method Select error: type error, no columns to select")
41}
42ift.Kind != reflect.Slice {
43q.limit = "limit 1"
44}
45//todo
46}
这里只取struct的tag,不取value,我们定义一个新的sK函数:
1funcsK(v reflect.Value)[] string{
2varkeys [] string
3t := v.Type
4forn := 0; n < t.NumField; n++ {
5tf := t.Field(n)
6vf := v.Field(n)
7//忽略非导出字段
8iftf.Anonymous {
9continue
10}
11forvf.Type.Kind == reflect.Ptr {
12vf = vf.Elem
13}
14//如果字段值是time类型之外的struct,递归获取keys
15ifvf.Kind == reflect.Struct && tf.Type.Name != "Time"{
16keys = append(keys, sK(vf)...)
17continue
18}
19//根据字段的json tag获取key,忽略无tag字段
20key := strings.Split(tf.Tag.Get( "json"), ",")[ 0]
21ifkey == ""{
22continue
23}
24keys = append(keys, key)
25}
26returnkeys
27}
现在sql语句已经完备了,可以执行最后的取值步骤了。
我们根据传入Select的指针的基类型生成实际数据,对其取址后交给sql包的Scan 方法填充,然后Set 回去,所以这里需要一个address 函数用于取址:
1funcaddress(dest reflect.Value, columns [] string) [] interface{} {
2dest = dest.Elem
3t := dest.Type
4addrs := make([] interface{}, 0)
5switcht.Kind {
6casereflect.Struct:
7forn := 0; n < t.NumField; n++ {
8tf := t.Field(n)
9vf := dest.Field(n)
10iftf.Anonymous {
11continue
12}
13forvf.Type.Kind == reflect.Ptr {
14vf = vf.Elem
15}
16//如果字段值是time类型之外的struct,递归取址
17ifvf.Kind == reflect.Struct && tf.Type.Name != "Time"{
18nVf := reflect.New(vf.Type)
19vf.Set(nVf.Elem)
20addrs = append(addrs, address(nVf, columns)...)
21continue
22}
23column := strings.Split(tf.Tag.Get( "json"), ",")[ 0]
24ifcolumn == ""{
25continue
26}
27//只取选定的字段的地址
28for_, col := rangecolumns {
29ifcol == column {
30addrs = append(addrs, vf.Addr.Interface)
31break
32}
33}
34}
35default:
36addrs = append(addrs, dest.Addr.Interface)
37}
38returnaddrs
39}
Value.Addr 函数可用于取址,前提是 Value.CanAddr 返回true。
relfect.New 可以根据 Type 来 new 出一个 Value ,这个Value是一个 指针 ,它的基值是可以取址的,把它的基值 Set 到目标值上,就达到了根据Type从无到有生成对应值的目的。
因为map不能用new函数生成,所以需要写一个用于生成map的函数 setMap :
1//map的value类型必须是interface{},因为无类型信息,所以mysql驱动会返回一个字节切片,需要自行用[]byte断言
2func(q *Query)setMap(rows *sql.Rows, t reflect.Type)(reflect.Value, error){
3ift.Elem.Kind != reflect.Interface {
4returnreflect.ValueOf( nil), errors.New( "method setMap error: type error, must be map[string]interface{}")
5}
6m := reflect.MakeMap(t)
7addrs := make([] interface{}, len(q.only))
8foridx := rangeq.only {
9addrs[idx] = new( interface{})
10}
11iferr := rows.Scan(addrs...); err != nil{
12returnreflect.ValueOf( nil), err
13}
14foridx, column := rangeq.only {
15//从指针剥出interface{},再剥出实际值
16m.SetMapIndex(reflect.ValueOf(column), reflect.ValueOf(addrs[idx]).Elem.Elem)
17}
18returnm, nil
19}
reflect.MakeMap 跟make 作用差不多,它接受一个Kind 是reflect.Map 的Type 作为参数,生成一个对应类型的map。
对于其它适用于new 的类型,写一个通用的函数setElem 处理:
1//适用于基类型和struct
2func(q *Query)setElem(rows *sql.Rows, t reflect.Type)(reflect.Value, error){
3addrsErr := errors.New( "method setElem error: columns not match addresses")
4dest := reflect.New(t)
5addrs := address(dest, q.only)
6iflen(q.only) != len(addrs) {
7returnreflect.ValueOf( nil), addrsErr
8}
9iferr := rows.Scan(addrs...); err != nil{
10returnreflect.ValueOf( nil), err
11}
12returndest, nil
13}
这些函数完成后,就可以着手完善Select里的todo部分了:
1//already done
2rows, err := q.DB.Query(q.toSQL)
3iferr != nil{
4returnerr
5}
6switcht.Kind {
7casereflect.Slice:
8dt := t.Elem
9fordt.Kind == reflect.Ptr {
10dt = dt.Elem
11}
12sl := reflect.MakeSlice(t, 0, 0)
13forrows.Next {
14vardestination reflect.Value
15ifdt.Kind == reflect.Map {
16destination, err = q.setMap(rows, dt)
17} else{
18destination, err = q.setElem(rows, dt)
19}
20iferr != nil{
21returnerr
22}
23//区分切片元素是否指针
24switcht.Elem.Kind {
25casereflect.Ptr, reflect.Map:
26sl = reflect.Append(sl, destination)
27default:
28sl = reflect.Append(sl, destination.Elem)
29}
30}
31v.Set(sl)
32returnnil
33casereflect.Map:
34forrows.Next {
35m, err := q.setMap(rows, t)
36iferr != nil{
37returnerr
38}
39v.Set(m)
40}
41returnnil
42default:
43forrows.Next {
44destination, err := q.setElem(rows, t)
45iferr != nil{
46returnerr
47}
48v.Set(destination.Elem)
49}
50}
51returnnil
至此,Select方法就大功告成了,部分调用方式示例:
1varuser User
2users
3.Where( "first_name = 'Tom'", map[ string] interface{}{
4"id": [] int{ 1, 2, 3, 4},
5})
6.WhereNot(&User{LastName: "Cat"})
7.Only( "last_name")
8.Select(&user)
9
10varuserMore []User
11users.Where( "first_name = 'Tom'").Order( "id desc").Select(&userMore)
12varuserMoreP []*User
13users.Where( "first_name = 'Tom'").Select(&userMoreP)
14varlastName string
15users.Where(&User{FirstName: "Tom"}).Only( "last_name").Select(&lastName)
16varlastNames [] string
17users.Where( map[ string] interface{}{
18"first_name": "Tom",
19}).Only( "last_name").Select(&lastNames)
20varuserM map[ string] interface{}
21users.Where(&User{FirstName: "Tom"}).Only( "last_name").Select(&userM)
22varuserMS [] map[ string] interface{}
23users.Where( "age > 10").Only( "last_name", "age").Limit( 100).Select(&userMS)
篇幅有限,下半部分在次篇。
www.bytedancing.com