gomock 是 Google 开源的 Golang 测试框架。

GoMock is a mocking framework for the Go programming language.
https://github.com/golang/mock

快速开始

安装 mockgen

To get the latest released version use:
Go version < 1.16

GO111MODULE=on go get github.com/golang/mock/mockgen@v1.6.0

Go 1.16+

go install github.com/golang/mock/mockgen@v1.6.0

定义好被测接口

// mockgen -source=./driver/navigator_driver.go -destination ./driver/navigator_driver_mock.go -package driver

type INavigatorDriver interface {
    Query(Ctx context.Context,
        SqlClient *sqlclient.SQLClient,
        sqlKey,
        sql string,
        searchOptions ...*engine.Option,
    ) ([]map[string]interface{}, error)

    BatchGetProductInfoMap(Ctx context.Context,
        SqlClient *sqlclient.SQLClient,
        date string,
        ids []int64,
        entityFields []string,
    ) (map[int64]interface{}, error)

    BatchGetBrandInfoMap(Ctx context.Context,
        SqlClient *sqlclient.SQLClient,
        date string,
        ids []int64,
        entityFields []string,
    ) (map[int64]interface{}, error)
}

type NavigatorDriver struct {
}

使用 mockgen 命令行自动生成 gomock代码

gomock通过mockgen命令生成包含mock对象的.go文件,其生成的mock对象具备mock+stub的强大功能.

mockgen -source=./driver/navigator_driver.go -destination ./driver/navigator_driver_mock.go -package driver

其中, navigator_driver_mock.go 是生成的 mock 代码.

代码目录:


类型关系:

生成的Mock Stub代码如下:

// Code generated by MockGen. DO NOT EDIT.
// Source: ./driver/navigator_driver.go

// Package driver is a generated GoMock package.
package driver

import (
    context "context"
    reflect "reflect"

    ...
    gomock "github.com/golang/mock/gomock"
)

// MockINavigatorDriver is a mock of INavigatorDriver interface.
type MockINavigatorDriver struct {
    ctrl     *gomock.Controller
    recorder *MockINavigatorDriverMockRecorder
}

// MockINavigatorDriverMockRecorder is the mock recorder for MockINavigatorDriver.
type MockINavigatorDriverMockRecorder struct {
    mock *MockINavigatorDriver
}

// NewMockINavigatorDriver creates a new mock instance.
func NewMockINavigatorDriver(ctrl *gomock.Controller) *MockINavigatorDriver {
    mock := &MockINavigatorDriver{ctrl: ctrl}
    mock.recorder = &MockINavigatorDriverMockRecorder{mock}
    return mock
}

// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockINavigatorDriver) EXPECT() *MockINavigatorDriverMockRecorder {
    return m.recorder
}

// BatchGetBrandInfoList mocks base method.
func (m *MockINavigatorDriver) BatchGetBrandInfoMap(Ctx context.Context, SqlClient *sqlclient.SQLClient, date string, ids []int64, entityFields []string) (map[int64]interface{}, error) {
    m.ctrl.T.Helper()
    ret := m.ctrl.Call(m, "BatchGetBrandInfoMap", Ctx, SqlClient, date, ids, entityFields)
    ret0, _ := ret[0].(map[int64]interface{})
    ret1, _ := ret[1].(error)
    return ret0, ret1
}

// BatchGetBrandInfoList indicates an expected call of BatchGetBrandInfoList.
func (mr *MockINavigatorDriverMockRecorder) BatchGetBrandInfoList(Ctx, SqlClient, date, ids, entityFields interface{}) *gomock.Call {
    mr.mock.ctrl.T.Helper()
    return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchGetBrandInfoMap", reflect.TypeOf((*MockINavigatorDriver)(nil).BatchGetBrandInfoMap), Ctx, SqlClient, date, ids, entityFields)
}

// BatchGetProductInfoList mocks base method.
func (m *MockINavigatorDriver) BatchGetProductInfoMap(Ctx context.Context, SqlClient *sqlclient.SQLClient, date string, ids []int64, entityFields []string) (map[int64]interface{}, error) {
    m.ctrl.T.Helper()
    ret := m.ctrl.Call(m, "BatchGetProductInfoMap", Ctx, SqlClient, date, ids, entityFields)
    ret0, _ := ret[0].(map[int64]interface{})
    ret1, _ := ret[1].(error)
    return ret0, ret1
}

// BatchGetProductInfoList indicates an expected call of BatchGetProductInfoList.
func (mr *MockINavigatorDriverMockRecorder) BatchGetProductInfoList(Ctx, SqlClient, date, ids, entityFields interface{}) *gomock.Call {
    mr.mock.ctrl.T.Helper()
    return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchGetProductInfoMap", reflect.TypeOf((*MockINavigatorDriver)(nil).BatchGetProductInfoMap), Ctx, SqlClient, date, ids, entityFields)
}

// Query mocks base method.
func (m *MockINavigatorDriver) Query(Ctx context.Context, SqlClient *sqlclient.SQLClient, sqlKey, sql string, searchOptions ...*engine.Option) ([]map[string]interface{}, error) {
    m.ctrl.T.Helper()
    varargs := []interface{}{Ctx, SqlClient, sqlKey, sql}
    for _, a := range searchOptions {
        varargs = append(varargs, a)
    }
    ret := m.ctrl.Call(m, "Query", varargs...)
    ret0, _ := ret[0].([]map[string]interface{})
    ret1, _ := ret[1].(error)
    return ret0, ret1
}

// Query indicates an expected call of Query.
func (mr *MockINavigatorDriverMockRecorder) Query(Ctx, SqlClient, sqlKey, sql interface{}, searchOptions ...interface{}) *gomock.Call {
    mr.mock.ctrl.T.Helper()
    varargs := append([]interface{}{Ctx, SqlClient, sqlKey, sql}, searchOptions...)
    return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockINavigatorDriver)(nil).Query), varargs...)
}

测试代码实例

我们来 Mock 如下代码中的这个接口调用的返回值:

datasourceData, _ := navigatorDriver.Query(u.Ctx, u.Datasource.SqlClient, u.Datasource.SqlKey, navigatorSQL)
func (u *UIComponent) RenderDataTable() ([]map[string]interface{}, error) {
    bu, _ := json.Marshal(u)
    fmt.Println(string(bu))
    endTime := u.DateRangeFilter.EndDate
    dateType := u.DateRangeFilter.DaysType
    // 本周期时间
    dateStr := indexu.GetDateStr(endTime)

    // 1.fetch 数据源
    // 1.1 select 数据源字段 select columns
    columnNames := make([]string, 0)
    for _, e := range u.Datasource.Columns {
        columnNames = append(columnNames, e.Name)
    }

    // 1.2 数据源表
    datasourceTable := u.Datasource.TableName

    // 1.3 过滤条件 DateRangeFilter 是固定的
    whereExpr := buildWhereExpr(dateStr, dateType, u)
    datasourceCQL := alpha.NewCQL().SELECT(columnNames...).FROM(datasourceTable).WHERE(whereExpr).ORDERBY2(u.Datasource.DSOrderType, u.Datasource.DSOrderColumn).LIMIT2(u.Datasource.DSLimit)

    // 1.4 从数据源获取数据(领航者)
    // TODO MOCK, 线上用真实的 NavigatorQueryList 查询
    navigatorDriver := u.INavigatorDriver
    navigatorSQL := datasourceCQL.Compile()
    logu.CtxInfo(u.Ctx, "RenderDataTable", "navigatorSQL: %v", navigatorSQL)
    datasourceData, _ := navigatorDriver.Query(u.Ctx, u.Datasource.SqlClient, u.Datasource.SqlKey, navigatorSQL)
    // 1.5 RSD 数据中添加排名字段信息 RankKey
    if u.NeedRank {
        datasourceData = rocket.NewRSD2(datasourceData, u.Ctx, u.Datasource.SqlClient, u.INavigatorDriver).WithRank(u.RankKey, u.Datasource.Columns).Records
    }

    // 2.内存指标计算
    sqlite, _ := driver.InitSqlite(u.Ctx, map[string]interface{}{})

    // 2.1 从数据源返回数据中解析出列的元数据信息
    columns := ParseColumnsMeta(datasourceData)
    fmt.Println(columns)

    // 2.2 表名生成
    sqliteTableName := driver.GenerateUniqSQLiteTableName()

    // 2.3 建内存表
    driver.CreateTable(u.Ctx, sqliteTableName, columns, sqlite)
    // 2.4 同步数据到内存
    driver.InsertData(u.Ctx, sqliteTableName, datasourceData, columns, sqlite)

    // 2.5 内存数据条数校验
    count, _, _ := driver.Query(nil, fmt.Sprintf("select count(1) as count from %s", sqliteTableName), sqlite)

    if nums, err := convert.ToInt64E(count[0]["count"]); err == nil && nums <= 0 {
        return nil, fmt.Errorf("datasource empty")
    }

    // 2.6 内存计算非指标列
    var selectItem = []string{}
    for _, c := range u.Datasource.Columns { // 指标计算规则元数据信息
        if !c.IsDataIndex {
            cname := c.Name
            selectItem = append(selectItem, cname)
        }
    }
    // 2.7 内存计算指标列
    indexColumns := getIndexColumns(u.Datasource.Columns)
    // add incr select items
    for _, column := range indexColumns {
        columnName := column.Name
        var exp = fmt.Sprintf("IndexInfo(%s) as %s", columnName, columnName)
        selectItem = append(selectItem, exp)
    }

    // 2.8 CQL中添加排名信息 UDF
    if u.NeedRank {
        // Rank Key 是单独指定的,不是数据列的概念
        rankKey := u.RankKey
        selectItem = append(selectItem, fmt.Sprintf("RankInfo(%s) as %s", rankKey, rankKey))
    }

    memCQL := alpha.NewCQL().
        SELECT(selectItem...).
        FROM(sqliteTableName)

    if u.DFLimit != nil && u.DFOffset != nil {
        memCQL = memCQL.LIMIT3(*u.DFLimit, *u.DFOffset)
    } else if u.DFLimit != nil && u.DFOffset == nil {
        memCQL = memCQL.LIMIT2(*u.DFLimit)
    }

    incrSQL := memCQL.Compile()
    fmt.Println(incrSQL)

    result, _, _ := driver.Query(u.Ctx, incrSQL, sqlite)

    rsd := rocket.NewRSD2(result, u.Ctx, u.Datasource.SqlClient, u.INavigatorDriver).
        UnmarshalIndexInfo(u.Datasource.Columns).
        UnmarshalRankInfo(u.NeedRank, u.RankKey).
        FillEntityInfoColumn(dateStr, u.Datasource.Columns)

    return rsd.Records, nil
}

接口定义

// mockgen -source=./driver/navigator_driver.go -destination ./driver/navigator_driver_mock.go -package driver

type INavigatorDriver interface {
    Query(Ctx context.Context,
        SqlClient *sqlclient.SQLClient,
        sqlKey,
        sql string,
        searchOptions ...*engine.Option,
    ) ([]map[string]interface{}, error)

    BatchGetProductInfoMap(Ctx context.Context,
        SqlClient *sqlclient.SQLClient,
        date string,
        ids []int64,
        entityFields []string,
    ) (map[int64]interface{}, error)

    BatchGetBrandInfoMap(Ctx context.Context,
        SqlClient *sqlclient.SQLClient,
        date string,
        ids []int64,
        entityFields []string,
    ) (map[int64]interface{}, error)
}

type NavigatorDriver struct {
}

mock 测试代码

关键代码行:

    ctrl := gomock.NewController(t)
    defer ctrl.Finish()

    mockDriver := driver.NewMockINavigatorDriver(ctrl)
    // NavigatorQueryList 期望返回
    mockDriver.
        EXPECT().
        Query(ctx, SqlClient, "compass_strategy_chance_property_product_stats_di", gomock.Any(), gomock.Any()).
        Return(driver.MockNavigatorQueryListProductStats())

完整代码:


var (
    ctx       = context.Background()
    SqlClient = gomock.Any()
)

func TestDataTableUIComponent(t *testing.T) {
    ctrl := gomock.NewController(t)
    defer ctrl.Finish()

    mockDriver := driver.NewMockINavigatorDriver(ctrl)
    // NavigatorQueryList 期望返回
    mockDriver.
        EXPECT().
        Query(ctx, SqlClient, "compass_strategy_chance_property_product_stats_di", gomock.Any(), gomock.Any()).
        Return(driver.MockNavigatorQueryListProductStats())

    mockDriver.
        EXPECT().
        BatchGetProductInfoList(ctx, SqlClient, gomock.Any(), gomock.Any(), gomock.Any()).
        Return(driver.MockNavigatorQueryListProudctMap())

    //mockDriver.
    //  EXPECT().
    //  BatchGetBrandInfoList(ctx, SqlClient, gomock.Any(), gomock.Any(), gomock.Any()).
    //  Return(driver.MockNavigatorQueryListBrandMap())

    // 初始化数据源
    columns := []datasource.Column{
        {Name: "date"},
        {Name: "days_type"},
        {Name: "stats_date"},
        {Name: "cate_id"},
        {Name: "cate_name"},
        {Name: "property_name"},
        {Name: "market_name"},
        {Name: "product_property_value"},
        {Name: "product_id", IsRowKey: true, NeedFillEntityInfo: true, EntityType: datasource.Product, EntityInfoColumnKey: "product_info"},
        {Name: "pay_amt", IsDataIndex: true},
        {Name: "pay_combo_cnt", IsDataIndex: true},
    }

    datasoure := &datasource.DataSource{
        TableName:     "compass_strategy_chance_property_product_stats_di",
        Columns:       columns,
        SqlKey:        "compass_strategy_chance_property_product_stats_di",
        SearchOptions: []*engine.Option{},
        DSOrderColumn: "pay_combo_cnt",
        DSOrderType:   alpha.DESC,
        DSLimit:       50,
    }

    // 创建组件
    UIComponent := NewUIComponent(
        ctx,
        mockDriver,
        DataTable,
        datasoure,
        &DateRangeFilter{
            DaysType:  constu.DateType_LAST_SEVEN_DAYS,
            StartDate: 0,
            EndDate:   1653177600,
        },
        &DimFilter{
            DimCondition: map[string]string{"cate_id": "123",
                "market_name":            "碎花",
                "product_property_value": "长款裙子",
            },
        },
    )

    // 内存分页
    PageNo := int64(2)
    PageSize := int64(5)

    dflimit := (PageNo - 1) * PageSize
    dfoffset := PageSize

    UIComponent.DFOrderColumn = "pay_combo_cnt"
    UIComponent.DFOrderType = alpha.DESC
    UIComponent.DFLimit = &dflimit
    UIComponent.DFOffset = &dfoffset
    UIComponent.NeedRank = true
    UIComponent.RankKey = "rank"

    // UIComponent 唯一 Render() 数据函数
    result, _ := UIComponent.Render()

    fmt.Println("size:", len(result))

    fmt.Println("====================================================================================")
    b, _ := json.Marshal(result)
    fmt.Println(string(b))
}
gomock
testing

Installation

mockgen
$GOPATH/binPATH

To get the latest released version use:

Go version < 1.16

GO111MODULE=on go get github.com/golang/mock/mockgen@v1.6.0

Go 1.16+

go install github.com/golang/mock/mockgen@v1.6.0
mockgen

Running mockgen

mockgen

Source mode

Source mode generates mock interfaces from a source file.
It is enabled by using the -source flag. Other flags that
may be useful in this mode are -imports and -aux_files.

Example:

mockgen -source=foo.go [other options]

Reflect mode

Reflect mode generates mock interfaces by building a program
that uses reflection to understand interfaces. It is enabled
by passing two non-flag arguments: an import path, and a
comma-separated list of symbols.

You can use "." to refer to the current path's package.

Example:

mockgen database/sql/driver Conn,Driver

# Convenient for `go:generate`.
mockgen . Conn,Driver

Flags

mockgen
-source-destination-packagemock_-importsfoo=bar/bazbar/bazfoo-aux_filesfoo=bar/baz.gobar/baz.gofoo-build_flagsgo build-mock_namesRepository=MockSensorRepository,Endpoint=MockSensorEndpointRepositoryMockSensorRepository-self_package-copyright_file-debug_parser-exec_only-prog_only-write_package_comment
mockgensample/-source

Building Mocks

type Foo interface {
  Bar(x int) int
}

func SUT(f Foo) {
 // ...
}

func TestFoo(t *testing.T) {
  ctrl := gomock.NewController(t)

  // Assert that Bar() is invoked.
  defer ctrl.Finish()

  m := NewMockFoo(ctrl)

  // Asserts that the first and only call to Bar() is passed 99.
  // Anything else will fail.
  m.
    EXPECT().
    Bar(gomock.Eq(99)).
    Return(101)

  SUT(m)
}
gomock.NewController(t)ctrl.Finish()

Building Stubs

type Foo interface {
  Bar(x int) int
}

func SUT(f Foo) {
 // ...
}

func TestFoo(t *testing.T) {
  ctrl := gomock.NewController(t)
  defer ctrl.Finish()

  m := NewMockFoo(ctrl)

  // Does not make any assertions. Executes the anonymous functions and returns
  // its result when Bar is invoked with 99.
  m.
    EXPECT().
    Bar(gomock.Eq(99)).
    DoAndReturn(func(_ int) int {
      time.Sleep(1*time.Second)
      return 101
    }).
    AnyTimes()

  // Does not make any assertions. Returns 103 when Bar is invoked with 101.
  m.
    EXPECT().
    Bar(gomock.Eq(101)).
    Return(103).
    AnyTimes()

  SUT(m)
}

Modifying Failure Messages

GotWant
Got: [3]
Want: is equal to 2
Expected call at user_test.go:33 doesn't match the argument at index 1.
Got: [0 1 1 2 3]
Want: is equal to 1
Want
WantString()
gomock.WantFormatter(
  gomock.StringerFunc(func() string { return "is equal to fifteen" }),
  gomock.Eq(15),
)
gomock.Eq(15)Want:is equal to 15is equal to fifteen
Got
GotString()[]byteGot
gomock.GotFormatterAdapter(
  gomock.GotFormatterFunc(func(i interface{}) string {
    // Leading 0s
    return fmt.Sprintf("%02d", i)
  }),
  gomock.Eq(15),
)
303

Debugging Errors

reflect vendoring error

cannot find package "."
... github.com/golang/mock/mockgen/model

If you come across this error while using reflect mode and vendoring
dependencies there are three workarounds you can choose from:

import _ "github.com/golang/mock/mockgen/model"--build_flags=--mod=mod
go