8`updated_at`timestampNOTNULLDEFAULTCURRENT_TIMESTAMPONUPDATECURRENT_TIMESTAMPCOMMENT'更新时间',

9PRIMARY KEY( `id`),

10KEY`idx_email`( `email`)

11) ENGINE= InnoDBDEFAULTCHARSET=utf8 COMMENT= '用户表';

同时,golang代码里定义一个与之对应的struct :

1typeUser struct{

2ID int64`json:"id"`// 自增主键

3Age int64`json:"age"`// 年龄

4FirstName string`json:"first_name"`// 姓

5LastName string`json:"last_name"`// 名

6Email string`json:"email"`// 邮箱地址

7CreatedAt time.Time `json:"created_at"`// 创建时间

8UpdatedAt time.Time `json:"updated_at"`// 更新时间

9}

与mysql交互需要用到一个go标准包和一个驱动 ,代码import 如下:

1packageorm

2

3import(

4"database/sql"

5

6//register driver

7_ "github.com/go-sql-driver/mysql"

8)

首先按照database 维度建立连接,写一个可以返回mysql连接的函数:

1//Connect db by dsn e.g. "user:password@tcp(127.0.0.1:3306)/dbname"

2funcConnect(dsn string) (*sql.DB, error){

3conn, err := sql.Open( "mysql", dsn)

4iferr != nil{

5returnnil, err

6}

7//设置连接池

8conn.SetMaxOpenConns( 100)

9conn.SetMaxIdleConns( 10)

10conn.SetConnMaxLifetime( 10* time.Minute)

11returnconn, conn.Ping

12}

设计一个struct 用于实现orm(go不是面向对象的语言,没有class ):

1//Query will build a sql

2type Query struct{

3db *sql.DB

4table string

5}

最后将通过 Query 拼接出sql语句与mysql交互,所以写一个绑定函数:

1//Table bind db and table

2funcTable(db *sql.DB, tableName string) func* Query{

3returnfunc* Query{

4return&Query{

5db: db,

6table: tableName,

7}

8}

9}

返回值是一个闭包函数,这样使用时直接调用这个闭包函数就可以获取一个绑定好的database和table的Query ,比如现在有数据库orm_db 和user 表:

1//全局变量ormDB和users

2ormDB, _ := Connect( "user:password@tcp(127.0.0.1:3306)/orm_db")

3users := Table(ormDB, "user")

4//调用

5users.Insert(...)

准备工作到此完成,下面进入正题。

Insert方法

首先分析一下标准 insert 语句:

1insertintouser(first_name, last_name) values( 'Tom', 'Cat'), ( 'Tom', 'Cruise')

把sql语句中变化的部分抽象出来,其实就是key (字段)和value (值),那么orm里的Insert 方法原型就有了,如下,参数是struct或者map,因为它们都能提供键值对:

1//Insert in can be *User, []*User, map[string]interface{}

2func(q *Query)Insert(in interface{}) ( int64, error) {

3varkeys, values [] string

4v := reflect.ValueOf(in)

5//剥离指针

6forv.Kind == reflect.Ptr {

7v = v.Elem

8}

9switchv.Kind {

10casereflect.Struct:

11keys, values = sKV(v)

12casereflect.Map:

13keys, values = mKV(v)

14casereflect.Slice:

15fori := 0; i < v.Len; i++ {

16//Kind是切片时,可以用Index方法遍历

17sv := v.Index(i)

18forsv.Kind == reflect.Ptr || sv.Kind == reflect.Interface {

19sv = sv.Elem

20}

21//切片元素不是struct或者指针,报错

22ifsv.Kind != reflect.Struct {

23return0, errors.New( "method Insert error: in slice is not structs")

24}

25//keys只保存一次就行,因为后面的都一样了

26iflen(keys) == 0{

27keys, values = sKV(sv)

28continue

29}

30_, val := sKV(sv)

31values = append(values, val...)

32}

33default:

34return0, errors.New( "method Insert error: type error")

35}

36//todo

37//...

38}

参数 in 可以是一个 User (前文定义好的结构体)实例的指针(或者指针集合),也可以是一个map,这两个结构都可以提供键值对,我们通过反射来分析它的 类型 ,然后根据类型执行相应的逻辑。

reflect包里的有两个重要结构 Type 和 Value ,Type是一个接口,定义了所有类型相关的api,reflect里的 *rtype 实现了这个接口,通过reflect.TypeOf函数可以获取任何传入值的 *rtype 。Value是一个struct,通过reflect.ValueOf函数获取,它在 *rtype 的基础上又封装了传入值的unsafe.Pointer类型的 地址 以及这个值的 元数据 。

在Type和Value之上还有一个 Kind ,它代表传入值的 原始类型 ,比如:

1typemyInt int

2vari myInt

3t := reflect.TypeOf(i)

4k := t.Kind

t是myInt,而k是int,Type和Kind是不同的,这一点要注意区分。

如果Type的Kind是指针、接口、切片、map等复合类型,可以调用Elem方法获取基类型。

如果Value的Kind是指针、接口,可以调用Elem方法获取实际值。

Value上还定义了一个 Interface 方法,它是ValueOf方法的反操作。

有了上面这些反射方法,我们可以封装一个 sKV 函数,它专门处理struct类型的值,获取key(取json tag)和value:

1funcsKV(v reflect.Value)([] string, [] string) {

2varkeys, values [] string

3t := v.Type

4forn := 0; n < t.NumField; n++ {

5tf := t.Field(n)

6vf := v.Field(n)

7//忽略非导出字段

8iftf.Anonymous {

9continue

10}

11//忽略无效、零值字段

12if!vf.IsValid || reflect.DeepEqual(vf.Interface, reflect.Zero(vf.Type).Interface) {

13continue

14}

15forvf.Type.Kind == reflect.Ptr {

16vf = vf.Elem

17}

18//有时候根据需求会组合struct,这里处理下,支持获取嵌套的struct tag和value

19//如果字段值是time类型之外的struct,递归获取keys和values

20ifvf.Kind == reflect.Struct && tf.Type.Name != "Time"{

21cKeys, cValues := sKV(vf)

22keys = append(keys, cKeys...)

23values = append(values, cValues...)

24continue

25}

26//根据字段的json tag获取key,忽略无tag字段

27key := strings.Split(tf.Tag.Get( "json"), ",")[ 0]

28ifkey == ""{

29continue

30}

31value := format(vf)

32ifvalue != ""{

33keys = append(keys, key)

34values = append(values, value)

35}

36}

37returnkeys, values

38}

sKV 函数里需要格式化字符串,那么定义一个format 函数。

time.Time 类型怎么转化成各种数据库的时间类型我有点拿不准,所以需要对比时间类型的值时,一律用unxi时间戳,感觉比较省事不会出错:

1funcformat(v reflect.Value)string{

2//断言出time类型直接转unix时间戳

3ift, ok := v.Interface.(time.Time); ok {

4returnfmt.Sprintf( "FROM_UNIXTIME(%d)", t.Unix)

5}

6switchv.Kind {

7casereflect.String:

8returnfmt.Sprintf( `'%s'`, v.Interface)

9casereflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:

10returnfmt.Sprintf( `%d`, v.Interface)

11casereflect.Float32, reflect.Float64:

12returnfmt.Sprintf( `%f`, v.Interface)

13//如果是切片类型,遍历元素,递归格式化成"(, , , )"形式

14casereflect.Slice:

15varvalues [] string

16fori := 0; i < v.Len; i++ {

17values = append(values, format(v.Index(i)))

18}

19returnfmt.Sprintf( `(%s)`, strings.Join(values, ","))

20//接口类型剥一层递归

21casereflect.Interface:

22returnformat(v.Elem)

23}

24return""

25}

map类型处理起来和struct不同,所以我们再定义一个mKV 函数,目的和sKV一样,都是获取键值对:

1funcmKV(v reflect.Value)([] string, [] string) {

2varkeys, values [] string

3//获取map的key组成的切片

4mapKeys := v.MapKeys

5for_, key := rangemapKeys {

6value := format(v.MapIndex(key))

7ifvalue != ""{

8values = append(values, value)

9keys = append(keys, key.Interface.( string))

10}

11}

12returnkeys, values

13}

利用sKV和mKV函数取到键值对后,就得到了insert语句中的变化部分,补全Insert方法的todo 部分:

1// Insertincan be User, * User, [] User, []* User, map[ string] interface{}

2func (q * Query) Insert( ininterface{}) (int64, error) {

3//already done

4kl := len( keys)

5vl := len( values)

6ifkl == 0|| vl == 0{

7return0, errors.New( "method Insert error: no data")

8}

9varinsertValue string

10//插入多条记录时需要用 ","拼接一下 values

11ifkl < vl {

12vartmpValues [] string

13forkl <= vl {

14ifkl%( len( keys)) == 0{

15tmpValues = append(tmpValues, fmt.Sprintf( "(%s)", strings.Join( values[kl- len( keys):kl], ",")))

16}

17kl++

18}

19insertValue = strings.Join(tmpValues, ",")

20} else{

21insertValue = fmt.Sprintf( "(%s)", strings.Join( values, ","))

22}

23query:= fmt.Sprintf( `insert into %s (%s) values %s`, q.table, strings.Join( keys, ","), insertValue)

24log.Printf( "insert sql: %s", query)

25st, err := q.DB.Prepare( query)

26iferr != nil {

27return0, err

28}

29result, err := st.Exec

30iferr != nil {

31return0, err

32}

33returnresult.LastInsertId

34}

原理很简单,利用反射分析参数,取键值对,然后拼接sql语句,再通过mysql驱动入库。

调用示例:

1user1 := &User{

2Age: 30,

3FirstName: "Tom",

4LastName: "Cat",

5}

6user2 := User{

7Age: 30,

8FirstName: "Tom",

9LastName: "Curise",

10}

11user3 := User{

12Age: 30,

13FirstName: "Tom",

14LastName: "Hanks",

15}

16user4 := map[string]interface{}{

17"age": 30,

18"first_name": "Tom",

19"last_name": "Zzy",

20}

21users.Insert([]interface{}{user1, user2})

22users.Insert(user3)

23users.Insert(user4)

增删改查的 增 部分到此完成,因为查询语句非常复杂多变,所以有了数据后,先进行 查 。

Select方法

先分析一下标准 select 语句

1selectid, age fromuserwherefirst_name = 'Tom'andlast_name = 'Cat'

可见sql语句的变量部分是select 后面的字段和where 后面的键值对,所以我们需要一个Where 来方法构造查询条件,并且需要一个Select 方法最后执行查询,最终形成一个链式调用效果:

1varuser[]User

2users.Where(?) .WhereNot(?) .Limit(100) .Offset(100) .Order(" iddesc") .Only(" id", " age") .Select(& user)

所以需要改造Query如下,增加属性用于暂存链式调用中添加的值:

1//Query will build a sql

2type Query struct{

3db *sql.DB

4table string

5wheres [] string

6only [] string

7limit string

8offset string

9order string

10errs [] string

11}

为Query添加Where方法,支持struct和map参数,同时支持传如同"age > 10" 形式的字符串:

1//Where args can be string, User, *User, map[string]interface{}

2func(q *Query)Where(wheres ... interface{}) * Query{

3for_, w := rangewheres {

4v := reflect.ValueOf(w)

5forv.Kind == reflect.Ptr {

6v = v.Elem

7}

8switchv.Kind {

9casereflect.String:

10q.wheres = append(q.wheres, w.( string))

11casereflect.Struct:

12//todo

13casereflect.Map:

14//todo

15default:

16q.errs = append(q.errs, "method Where error: type error")

17}

18}

19returnq

20}

但是考虑到后面还会实现一个WhereNot 方法,所以把公共逻辑抽到一个where 函数里,并且直接复用之前的sKV、mKv函数获取键值对:

1funcwhere(eq bool, w interface{}) ( string, error) {

2varkeys, values [] string

3v := reflect.ValueOf(w)

4forv.Kind == reflect.Ptr {

5v = v.Elem

6}

7switchv.Kind {

8casereflect.String:

9returnw.( string), nil

10casereflect.Struct:

11keys, values = sKV(v)

12casereflect.Map:

13keys, values = mKV(v)

14default:

15return"", errors.New( "method Where error: type error")

16}

17iflen(keys) != len(values) {

18return"", errors.New( "method Where error: len(keys) not equal len(values))")

19}

20varwheres [] string

21//之前的format函数里,已经将切片类型值处理成"( , , ,)“形式

22foridx, key := rangekeys {

23ifeq {

24ifstrings.HasPrefix(values[idx], "(") && strings.HasSuffix(values[idx], ")") {

25wheres = append(wheres, fmt.Sprintf( "%s in %s", key, values[idx]))

26continue

27}

28wheres = append(wheres, fmt.Sprintf( "%s = %s", key, values[idx]))

29continue

30}

31ifstrings.HasPrefix(values[idx], "(") && strings.HasSuffix(values[idx], ")") {

32wheres = append(wheres, fmt.Sprintf( "%s not in %s", key, values[idx]))

33continue

34}

35wheres = append(wheres, fmt.Sprintf( "%s != %s", key, values[idx]))

36}

37returnstrings.Join(wheres, " and "), nil

38}

Where方法最终变成:

1//Where args can be string, User, *User, map[string]interface{}

2func(q *Query)Where(wheres ... interface{}) * Query{

3for_, w := rangewheres {

4str, err := where( true, w)

5q.wheres = append(q.wheres, str)

6iferr != nil{

7//因为需要达到链式调用的效果,所以把错误都搜集起来,最后再处理

8q.errs = append(q.errs, err.Error)

9}

10}

11returnq

12}

WhereNot把调用where的第一个参数改成false就行了,不贴代码了。

Limit 、Offset 、Order 、Only 这几个方法也很简单:

1//Limit .

2func(q *Query)Limit(limit uint) * Query{

3q.limit = fmt.Sprintf( "limit %d", limit)

4returnq

5}

6

7//Offset .

8func(q *Query)Offset(offset uint) * Query{

9q.offset = fmt.Sprintf( "offset %d", offset)

10returnq

11}

12

13//Order .

14func(q *Query)Order(ord string) * Query{

15q.order = fmt.Sprintf( "order by %s", ord)

16returnq

17}

18

19//Only 指定需要查询的字段

20func(q *Query)Only(columns ... string) * Query{

21q.only = append(q.only, columns...)

22returnq

23}

有了上面这些条件之后,我们可以写一个 toSQL 方法,把Query的属性组装成一条sql语句:

1func (q *Query) toSQL string{

2varwherestring

3iflen( q.wheres) > 0 {

4where= fmt.Sprintf(` where%s`, strings.Join(q.wheres, " and "))

5}

6sqlStr := fmt.Sprintf(` select%s from%s %s %s %s %s`, strings.Join(q.only, ","), q.table, where, q.order, q.limit, q.offset)

7log.Printf( "select sql: %s", sqlStr)

8returnsqlStr

9}

有了sql语句我们就可以查询数据了,但是想查一个表的全部字段时,为了方便,只需要传入对应的 struct ,比如 user 表对应的 User ,我们就直接分析这个struct,取它的tag作为查询字段,而不需要再调用Only方法指定字段。

另外,因为golang中的参数传递全都是值传递,要修改传入值,必须传值的指针,这里要注意一点:

1varuser User

2users.Select(&user)

3varuserPtr *User

4users.Select(user)

这两种声明方式是不同的,后者只声明了一个指针类型,是错误的。

综上,我们首先为Select方法做一下的参数检查,确保传入值是一个正确的指针,并确保only属性有值:

1// Selectdest must be a ptr, e.g. * user, *[] user, *[]* user, * map, *[] map, * int, *[] int

2func (q * Query) Select(dest interface{}) error{

3iflen(q.errs) != 0{

4returnerrors.New(strings.Join(q.errs, "

5" ))

6}

7t := reflect.TypeOf(dest)

8v := reflect.ValueOf(dest)

9typeErr := errors.New( "method Select error: type error")

10ift.Kind != reflect.Ptr {

11returntypeErr

12}

13//如果是用 varuserPtr * User方式声明的变量,则不可取址

14if!v.Elem.CanAddr {

15returntypeErr

16}

17t = t.Elem

18v = v.Elem

19//如果 only此时仍然为空,说明 Only方法未被调用,我们从 struct上取tag填充

20iflen(q.only) == 0{

21switcht.Kind {

22casereflect.Struct:

23ift.Name != "Time"{

24q.only = sK(v)

25}

26casereflect.Slice:

27//获取切片的基本类型给一个局部变量

28t := t.Elem

29ift.Kind == reflect.Ptr {

30t = t.Elem

31}

32ift.Kind == reflect.Struct {

33ift.Name != "Time"{

34q.only = sK(reflect.Zero(t))

35}

36}

37}

38}

39iflen(q.only) == 0{

40returnerrors.New( "method Select error: type error, no columns to select")

41}

42ift.Kind != reflect.Slice {

43q.limit = "limit 1"

44}

45//todo

46}

这里只取struct的tag,不取value,我们定义一个新的sK函数:

1funcsK(v reflect.Value)[] string{

2varkeys [] string

3t := v.Type

4forn := 0; n < t.NumField; n++ {

5tf := t.Field(n)

6vf := v.Field(n)

7//忽略非导出字段

8iftf.Anonymous {

9continue

10}

11forvf.Type.Kind == reflect.Ptr {

12vf = vf.Elem

13}

14//如果字段值是time类型之外的struct,递归获取keys

15ifvf.Kind == reflect.Struct && tf.Type.Name != "Time"{

16keys = append(keys, sK(vf)...)

17continue

18}

19//根据字段的json tag获取key,忽略无tag字段

20key := strings.Split(tf.Tag.Get( "json"), ",")[ 0]

21ifkey == ""{

22continue

23}

24keys = append(keys, key)

25}

26returnkeys

27}

现在sql语句已经完备了,可以执行最后的取值步骤了。

我们根据传入Select的指针的基类型生成实际数据,对其取址后交给sql包的Scan 方法填充,然后Set 回去,所以这里需要一个address 函数用于取址:

1funcaddress(dest reflect.Value, columns [] string) [] interface{} {

2dest = dest.Elem

3t := dest.Type

4addrs := make([] interface{}, 0)

5switcht.Kind {

6casereflect.Struct:

7forn := 0; n < t.NumField; n++ {

8tf := t.Field(n)

9vf := dest.Field(n)

10iftf.Anonymous {

11continue

12}

13forvf.Type.Kind == reflect.Ptr {

14vf = vf.Elem

15}

16//如果字段值是time类型之外的struct,递归取址

17ifvf.Kind == reflect.Struct && tf.Type.Name != "Time"{

18nVf := reflect.New(vf.Type)

19vf.Set(nVf.Elem)

20addrs = append(addrs, address(nVf, columns)...)

21continue

22}

23column := strings.Split(tf.Tag.Get( "json"), ",")[ 0]

24ifcolumn == ""{

25continue

26}

27//只取选定的字段的地址

28for_, col := rangecolumns {

29ifcol == column {

30addrs = append(addrs, vf.Addr.Interface)

31break

32}

33}

34}

35default:

36addrs = append(addrs, dest.Addr.Interface)

37}

38returnaddrs

39}

Value.Addr 函数可用于取址,前提是 Value.CanAddr 返回true。

relfect.New 可以根据 Type 来 new 出一个 Value ,这个Value是一个 指针 ,它的基值是可以取址的,把它的基值 Set 到目标值上,就达到了根据Type从无到有生成对应值的目的。

因为map不能用new函数生成,所以需要写一个用于生成map的函数 setMap :

1//map的value类型必须是interface{},因为无类型信息,所以mysql驱动会返回一个字节切片,需要自行用[]byte断言

2func(q *Query)setMap(rows *sql.Rows, t reflect.Type)(reflect.Value, error){

3ift.Elem.Kind != reflect.Interface {

4returnreflect.ValueOf( nil), errors.New( "method setMap error: type error, must be map[string]interface{}")

5}

6m := reflect.MakeMap(t)

7addrs := make([] interface{}, len(q.only))

8foridx := rangeq.only {

9addrs[idx] = new( interface{})

10}

11iferr := rows.Scan(addrs...); err != nil{

12returnreflect.ValueOf( nil), err

13}

14foridx, column := rangeq.only {

15//从指针剥出interface{},再剥出实际值

16m.SetMapIndex(reflect.ValueOf(column), reflect.ValueOf(addrs[idx]).Elem.Elem)

17}

18returnm, nil

19}

reflect.MakeMap 跟make 作用差不多,它接受一个Kind 是reflect.Map 的Type 作为参数,生成一个对应类型的map。

对于其它适用于new 的类型,写一个通用的函数setElem 处理:

1//适用于基类型和struct

2func(q *Query)setElem(rows *sql.Rows, t reflect.Type)(reflect.Value, error){

3addrsErr := errors.New( "method setElem error: columns not match addresses")

4dest := reflect.New(t)

5addrs := address(dest, q.only)

6iflen(q.only) != len(addrs) {

7returnreflect.ValueOf( nil), addrsErr

8}

9iferr := rows.Scan(addrs...); err != nil{

10returnreflect.ValueOf( nil), err

11}

12returndest, nil

13}

这些函数完成后,就可以着手完善Select里的todo部分了:

1//already done

2rows, err := q.DB.Query(q.toSQL)

3iferr != nil{

4returnerr

5}

6switcht.Kind {

7casereflect.Slice:

8dt := t.Elem

9fordt.Kind == reflect.Ptr {

10dt = dt.Elem

11}

12sl := reflect.MakeSlice(t, 0, 0)

13forrows.Next {

14vardestination reflect.Value

15ifdt.Kind == reflect.Map {

16destination, err = q.setMap(rows, dt)

17} else{

18destination, err = q.setElem(rows, dt)

19}

20iferr != nil{

21returnerr

22}

23//区分切片元素是否指针

24switcht.Elem.Kind {

25casereflect.Ptr, reflect.Map:

26sl = reflect.Append(sl, destination)

27default:

28sl = reflect.Append(sl, destination.Elem)

29}

30}

31v.Set(sl)

32returnnil

33casereflect.Map:

34forrows.Next {

35m, err := q.setMap(rows, t)

36iferr != nil{

37returnerr

38}

39v.Set(m)

40}

41returnnil

42default:

43forrows.Next {

44destination, err := q.setElem(rows, t)

45iferr != nil{

46returnerr

47}

48v.Set(destination.Elem)

49}

50}

51returnnil

至此,Select方法就大功告成了,部分调用方式示例:

1varuser User

2users

3.Where( "first_name = 'Tom'", map[ string] interface{}{

4"id": [] int{ 1, 2, 3, 4},

5})

6.WhereNot(&User{LastName: "Cat"})

7.Only( "last_name")

8.Select(&user)

9

10varuserMore []User

11users.Where( "first_name = 'Tom'").Order( "id desc").Select(&userMore)

12varuserMoreP []*User

13users.Where( "first_name = 'Tom'").Select(&userMoreP)

14varlastName string

15users.Where(&User{FirstName: "Tom"}).Only( "last_name").Select(&lastName)

16varlastNames [] string

17users.Where( map[ string] interface{}{

18"first_name": "Tom",

19}).Only( "last_name").Select(&lastNames)

20varuserM map[ string] interface{}

21users.Where(&User{FirstName: "Tom"}).Only( "last_name").Select(&userM)

22varuserMS [] map[ string] interface{}

23users.Where( "age > 10").Only( "last_name", "age").Limit( 100).Select(&userMS)

篇幅有限,下半部分在次篇。

www.bytedancing.com