使用 tidb-lite 可以在 Golang 代码中运行 mocktikv 模式的 TiDB。

tidb-lite 可以用于数据库相关代码的单元测试,如果你的应用使用到了 MySQL/TiDB,就会有大量的数据库相关的代码要进行单元测试。

另外,有的应用希望在本地持久化一些数据,并且以数据库的方式存储,方便对数据进行管理。这种场景下也可以使用 tidb-lite,开发者使用兼容 MySQL 协议的 SQL 对数据进行本地处理。

本文主要介绍如何使用 tidb-lite 进行数据库相关代码的单元测试。那么如何写数据库相关代码的单元测试呢?

简单的方法是在进行单元测试时启动一个数据库实例,但是这样做不太优雅,因为这样的话单元测试就对环境有了一定的要求。

另一个通用的方法是在测试中 mock SQL 服务,目前比较流行的方案是使用 go-sqlmock。

go-sqlmock 的问题

首先我们看一下如何使用 go-sqlmock 进行数据库相关代码的单元测试。

比如我们有下面这些代码:

package main

import (

"database/sql"

_ "github.com/go-sql-driver/mysql"

)

func recordStats(db *sql.DB, userID, productID int64) (err error) {

tx, err := db.Begin()

if err != nil {

return

}

defer func() {

switch err {

case nil:

err = tx.Commit()

default:

tx.Rollback()

}

}()

if _, err = tx.Exec("UPDATE products SET views = views + 1"); err != nil {

return

}

if _, err = tx.Exec("INSERT INTO product_viewers (user_id, product_id) VALUES (?, ?)", userID, productID); err != nil {

return

}

return

}

func main() {

// @NOTE: the real connection is not required for tests

db, err := sql.Open("mysql", "root@/blog")

if err != nil {

panic(err)

}

defer db.Close()

if err = recordStats(db, 1 /*some user id*/, 5 /*some product id*/); err != nil {

panic(err)

}

}

recordStats 函数会将商品的查看次数加一,并把用户加入到该商品的查看者列表中。使用 go-sqlmock 对该函数进行单元测试的代码如下:

package main

import (

"fmt"

"testing"

"github.com/DATA-DOG/go-sqlmock"

)

// a successful case

func TestShouldUpdateStats(t *testing.T) {

db, mock, err := sqlmock.New()

if err != nil {

t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)

}

defer db.Close()

mock.ExpectBegin()

mock.ExpectExec("UPDATE products").WillReturnResult(sqlmock.NewResult(1, 1))

mock.ExpectExec("INSERT INTO product_viewers").WithArgs(2, 3).WillReturnResult(sqlmock.NewResult(1, 1))

mock.ExpectCommit()

// now we execute our method

if err = recordStats(db, 2, 3); err != nil {

t.Errorf("error was not expected while updating stats: %s", err)

}

// we make sure that all expectations were met

if err := mock.ExpectationsWereMet(); err != nil {

t.Errorf("there were unfulfilled expectations: %s", err)

}

}

go-sqlmock 需要对数据库的每一步操作以及顺序都需要事先定义好(包括执行事务的 begin 和 commit),如果实际执行的操作或者步骤不一致就会报错。而且需要定义好预计返回的数据,如果表比较复杂那就非常麻烦。

而实际上我们的单元测试可能只需要关注这个函数的返回结果是否正确,而不是该函数各个操作的执行顺序。

tidb-lite 的优点

可以使用 tidb-lite 来代替 go-sqlmock。

简单

最重要的一个优点就是简单。在代码中直接运行一个 TiDB,而不是在运行单元测试前运行一个 MySQL/TiDB 实例,这样可以保证单元测试不依赖于外部环境;另外,我们也不需要像 go-sqlmock 那样写大量冗余枯燥的测试代码,而是把测试的重点关注在函数的正确性上。

兼容 MySQL 协议

TiDB 高度兼容 MySQL 协议,使用 tidb-lite 几乎可以完全模拟 MySQL 的环境。

tidb-lite 的用法

例如我们有下面这些代码:

package example

import (

"context"

"database/sql"

"fmt"

"github.com/pingcap/errors"

"github.com/pingcap/log"

"go.uber.org/zap"

)

// GetRowCount returns row count of the table.

// if not specify where condition, return total row count of the table.

func GetRowCount(ctx context.Context, db *sql.DB, schemaName string, tableName string, where string) (int64, error) {

/*

select count example result:

mysql> SELECT count(1) cnt from `test`.`itest` where id > 0;

+------+

| cnt |

+------+

| 100 |

+------+

*/

query := fmt.Sprintf("SELECT COUNT(1) cnt FROM `%s`.`%s`", schemaName, tableName)

if len(where) > 0 {

query += fmt.Sprintf(" WHERE %s", where)

}

log.Debug("get row count", zap.String("sql", query))

var cnt sql.NullInt64

err := db.QueryRowContext(ctx, query).Scan(&cnt)

if err != nil {

return 0, errors.Trace(err)

}

if !cnt.Valid {

return 0, errors.NotFoundf("table `%s`.`%s`", schemaName, tableName)

}

return cnt.Int64, nil

}

GetRowCount 用于获取表中符合条件的行的数量,使用 tidb-lite 写该函数的单元测试的代码如下:

package example

import (

"context"

"testing"

"time"

tidblite "github.com/WangXiangUSTC/tidb-lite"

. "github.com/pingcap/check"

)

func TestClient(t *testing.T) {

TestingT(t)

}

var _ = Suite(&testExampleSuite{})

type testExampleSuite struct{}

func (t *testExampleSuite) TestGetRowCount(c *C) {

tidbServer, err := tidblite.NewTiDBServer(tidblite.NewOptions(c.MkDir()))

c.Assert(err, IsNil)

dbConn, err := tidbServer.CreateConn()

c.Assert(err, IsNil)

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)

defer cancel()

_, err = dbConn.ExecContext(ctx, "create database example_test")

c.Assert(err, IsNil)

_, err = dbConn.ExecContext(ctx, "create table example_test.t(id int primary key, name varchar(24))")

c.Assert(err, IsNil)

_, err = dbConn.ExecContext(ctx, "insert into example_test.t values(1, 'a'),(2, 'b'),(3, 'c')")

c.Assert(err, IsNil)

count, err := GetRowCount(ctx, dbConn, "example_test", "t", "id > 2")

c.Assert(err, IsNil)

c.Assert(count, Equals, int64(1))

count, err = GetRowCount(ctx, dbConn, "example_test", "t", "")

c.Assert(err, IsNil)

c.Assert(count, Equals, int64(3))

tidbServer.Close()

tidbServer2, err := tidblite.NewTiDBServer(tidblite.NewOptions(c.MkDir()))

c.Assert(err, IsNil)

defer tidbServer2.Close()

dbConn2, err := tidbServer2.CreateConn()

c.Assert(err, IsNil)

_, err = dbConn2.ExecContext(ctx, "create database example_test")

c.Assert(err, IsNil)

}

首先我们使用 NewTiDBServer 创建 TiDB 实例,然后使用 CreateConn 获取数据库链接,然后就可以使用这个链接访问数据库,生成测试数据,验证该函数的正确性。可以看 README 了解详细使用方法。