go get github.com/go-sql-driver/mysql
基础类:
package mysql
import (
"context"
"database/sql"
"fmt"
"math/rand"
"time"
"why/config"
"why/log"
"github.com/opentracing/opentracing-go"
_ "github.com/go-sql-driver/mysql"
)
type (
DB struct {
masterDB *sql.DB
slaveDB []*sql.DB
Config *Config
}
)
func GetDSN(conn string) string {
cfg := config.GetConfigEntrance("mysql", conn)
dsn := cfg["user"] + ":" + cfg["password"] + "@tcp(" + cfg["host"] + ":" + cfg["port"] + ")/" + cfg["db"] + "?charset=" + cfg["charset"]
return dsn
}
func New(c *Config) (db *DB, err error) {
db = new(DB)
db.Config = c
db.masterDB, err = sql.Open("mysql", c.Master.DSN)
if err != nil {
err = errorsWrap(err, "init master db error")
return
}
db.masterDB.SetMaxOpenConns(c.Master.MaxOpen)
db.masterDB.SetMaxIdleConns(c.Master.MaxIdle)
if err = db.masterDB.Ping(); err != nil {
err = errorsWrap(err, "master db ping error")
return
}
for i := 0; i < len(c.Slave); i++ {
var mysqlDB *sql.DB
mysqlDB, err = sql.Open("mysql", c.Slave[i].DSN)
if err != nil {
err = errorsWrap(err, "init slave db error")
return
}
mysqlDB.SetMaxOpenConns(c.Slave[i].MaxOpen)
mysqlDB.SetMaxIdleConns(c.Slave[i].MaxIdle)
if err = mysqlDB.Ping(); err != nil {
err = errorsWrap(err, "slave db ping error")
return
}
db.slaveDB = append(db.slaveDB, mysqlDB)
}
return
}
func (db *DB) MasterDB() *sql.DB {
return db.masterDB
}
func (db *DB) SlaveDB() *sql.DB {
if len(db.slaveDB) == 0 {
return db.masterDB
}
n := rand.Intn(len(db.slaveDB))
return db.slaveDB[n]
}
// MasterDBClose 释放主库的资源
func (db *DB) MasterDBClose() error {
if db.masterDB != nil {
return db.masterDB.Close()
}
return nil
}
// SlaveDBClose 释放从库的资源
func (db *DB) SlaveDBClose() (err error) {
for i := 0; i < len(db.slaveDB); i++ {
err = db.slaveDB[i].Close()
if err != nil {
return err
}
}
return nil
}
type operate int64
const (
operateMasterExec operate = iota
operateMasterQuery
operateMasterQueryRow
operateSlaveQuery
operateSlaveQueryRow
)
var operationNames = map[operate]string{
operateMasterExec: "masterDBExec",
operateMasterQuery: "masterDBQuery",
operateMasterQueryRow: "masterDBQueryRow",
operateSlaveQuery: "slaveDBQuery",
operateSlaveQueryRow: "slaveDBQueryRow",
}
func (db *DB) operate(ctx context.Context, op operate, query string, args ...interface{}) (i interface{}, err error) {
var (
parent = opentracing.SpanFromContext(ctx)
operationName = operationNames[op]
span = func() opentracing.Span {
if parent == nil {
return opentracing.StartSpan(operationName)
}
return opentracing.StartSpan(operationName, opentracing.ChildOf(parent.Context()))
}()
logFormat = log.LogHeaderFromContext(ctx)
startAt = time.Now()
endAt time.Time
)
lastModule := logFormat.Module
defer func() {logFormat.Module = lastModule}()
defer span.Finish()
defer func() {
endAt = time.Now()
logFormat.StartTime = startAt
logFormat.EndTime = endAt
latencyTime := logFormat.EndTime.Sub(logFormat.StartTime).Microseconds()// 执行时间
logFormat.LatencyTime = latencyTime
span.SetTag("error", err != nil)
span.SetTag("db.type", "sql")
span.SetTag("db.statement", query)
logFormat.Module = "databus/mysql"
if endAt.Sub(startAt) > db.Config.ExecTimeout.Duration {
log.Warnf(logFormat, "%s:[%s], params:%s, used: %d milliseconds", operationName, query,
args, endAt.Sub(startAt).Milliseconds())
}
if err != nil {
log.Errorf(logFormat, "%s:[%s], params:%s, error: %s", operationName, query,
args, err)
}
}()
switch op {
case operateMasterQuery:
i, err = db.MasterDB().QueryContext(ctx, query, args...)
case operateMasterQueryRow:
i = db.MasterDB().QueryRowContext(ctx, query, args...)
case operateMasterExec:
i, err = db.MasterDB().ExecContext(ctx, query, args...)
case operateSlaveQuery:
i, err = db.SlaveDB().QueryContext(ctx, query, args...)
case operateSlaveQueryRow:
i = db.SlaveDB().QueryRowContext(ctx, query, args...)
}
return
}
func (db *DB) MasterDBExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
r, err := db.operate(ctx, operateMasterExec, query, args...)
if err != nil {
return nil, err
}
return r.(sql.Result), err
}
func (db *DB) MasterDBQueryContext(ctx context.Context, query string, args ...interface{}) (result *sql.Rows, err error) {
r, err := db.operate(ctx, operateMasterQuery, query, args...)
if err != nil {
return nil, err
}
return r.(*sql.Rows), err
}
func (db *DB) MasterDBQueryRowContext(ctx context.Context, query string, args ...interface{}) (result *sql.Row) {
r, _ := db.operate(ctx, operateMasterQueryRow, query, args...)
return r.(*sql.Row)
}
func (db *DB) SlaveDBQueryContext(ctx context.Context, query string, args ...interface{}) (result *sql.Rows, err error) {
r, err := db.operate(ctx, operateMasterQuery, query, args...)
if err != nil {
return nil, err
}
return r.(*sql.Rows), err
}
func (db *DB) SlaveDBQueryRowContext(ctx context.Context, query string, args ...interface{}) (result *sql.Row) {
r, _ := db.operate(ctx, operateSlaveQueryRow, query, args...)
return r.(*sql.Row)
}
func errorsWrap(err error, msg string) error {
return fmt.Errorf("%s: %w", msg, err)
}
/* example
*/
common数据库连接方法:
package models
import (
"why/util"
"why/mysql"
)
var priceInstance *mysql.DB
var ymtInstance *mysql.DB
func GetConn(conn string) *mysql.DB{
db := &mysql.DB{}
if conn == "hangqing" {
if priceInstance == nil {
db = getPriceConn()
}else {
db = priceInstance
}
}else if conn == "ymt360" {
if ymtInstance == nil {
db = getYmtConn()
}else {
db = ymtInstance
}
}else{
panic("err conn string")
}
return db
}
func getYmtConn() *mysql.DB{
write := mysql.GetDSN("ymt360_write")
read := mysql.GetDSN("ymt360_read")
writeDSN := mysql.Conn{
DSN: write,
MaxOpen: 5,
MaxIdle: 5,
}
readDSN := mysql.Conn{
DSN: read,
MaxOpen: 5,
MaxIdle: 5,
}
arrDSN := []mysql.Conn{}
arrDSN = append(arrDSN, readDSN)
cfg := &mysql.Config{
Master: writeDSN,
Slave: arrDSN,
}
db, err := mysql.New(cfg)
util.Must(err)
return db
}
func getPriceConn() *mysql.DB{
write := mysql.GetDSN("hangqing_write")
read := mysql.GetDSN("hangqing_read")
writeDSN := mysql.Conn{
DSN: write,
MaxOpen: 5,
MaxIdle: 5,
}
readDSN := mysql.Conn{
DSN: read,
MaxOpen: 5,
MaxIdle: 5,
}
arrDSN := []mysql.Conn{}
arrDSN = append(arrDSN, readDSN)
cfg := &mysql.Config{
Master: writeDSN,
Slave: arrDSN,
}
db, err := mysql.New(cfg)
util.Must(err)
return db
}
model层:
package hangqing
import (
"context"
"github.com/jmoiron/sqlx"
"why/util"
"models"
)
type HqCustomer struct {
Province_id int
City_id int
County_id int
Location_id int
Market_info_id int
Point_key string
Point_key2 string
Product_id int
Breed_id int
Customer_id int
}
func GetCustomerBreedsByCid(ctx context.Context, cid int) ( data []map[string]interface{} ) {
query := "select province_id,city_id,county_id,location_id,market_info_id,point_key,point_key2,product_id,breed_id,customer_id from hq_customer where customer_id = ?"
db := models.GetConn("hangqing")
rows, err := db.SlaveDBQueryContext(ctx, query, cid)
util.Must(err)
var list = []*HqCustomer{}
err = sqlx.StructScan(rows, &list)
util.Must(err)
if len(list) == 0 {
return
}
for _, v := range list {
tmp := util.StructToMap(*v)
data = append(data, tmp)
}
return
}