package main

import (
	"database/sql"
	"fmt"
	//"log"
	"strconv"
	"sync"

	_ "github.com/go-sql-driver/mysql"
)

var (
	dbConn *sql.DB
	err    error
)

type SimpleSession struct {
	Username string
	TTL      int64
}

func dbInit() {
	dbConn, err = sql.Open("mysql", "root:root@/video_server?charset=utf8") // "mysql","账号","密码"@数据库名称?charset=字符编码
	if err != nil {
		panic(err.Error())
	}
}

func InserSession(sid string, ttl int64, uname string) error {
	ttlstr := strconv.FormatInt(ttl, 10) // 将int类型按10进制转化为string
	stmtIns, err := dbConn.Prepare("INSERT INTO sessions (session_id, TTL, login_name) VALUES (?, ?, ?)")
	if err != nil {
		return err
	}

	_, err = stmtIns.Exec(sid, ttlstr, uname)
	if err != nil {
		return err
	}

	defer stmtIns.Close()
	return nil
}

func DeleteSession(sid string) error {
	stmtOut, err := dbConn.Prepare("DELETE FROM sessions WHERE session_id = ?")
	if err != nil {
		return err
	}

	if _, err := stmtOut.Query(sid); err != nil {
		return err
	}

	return nil
}

func UpdateSession(sid string, ttl int64) error {
	stmtOut, err := dbConn.Prepare("UPDATE sessions SET TTL = ? WHERE session_id = ?")
	if err != nil {
		return err
	}

	if _, err := stmtOut.Query(ttl, sid); err != nil {
		return err
	}

	return nil
}

func SelectOne(sid string) (*SimpleSession, error) {
	ss := &SimpleSession{}
	stmtOut, err := dbConn.Prepare("SELECT TTL, login_name FROM sessions WHERE session_id=?")
	if err != nil {
		return nil, err
	}

	var ttl string
	var uname string
	stmtOut.QueryRow(sid).Scan(&ttl, &uname)
	if err != nil && err != sql.ErrNoRows {
		return nil, err
	}

	//将string转为int64
	if res, err := strconv.ParseInt(ttl, 10, 64); err == nil {
		ss.TTL = res
		ss.Username = uname
	} else {
		return nil, err
	}

	defer stmtOut.Close()
	return ss, nil
}

func SelectAll() (*sync.Map, error) {
	m := &sync.Map{}
	stmtOut, err := dbConn.Prepare("SELECT * FROM sessions")
	if err != nil {
		return nil, err
	}

	rows, err := stmtOut.Query()
	if err != nil {
		return nil, err
	}

	for rows.Next() {
		var id string
		var ttlstr string
		var login_name string
		if er := rows.Scan(&id, &ttlstr, &login_name); er != nil {
			//log.Printf("retrive sessions error: %s", er)
			break
		}
		//出一组数据就打包好往map里运
		if ttl, err1 := strconv.ParseInt(ttlstr, 10, 64); err1 == nil {
			ss := &SimpleSession{Username: login_name, TTL: ttl}
			m.Store(id, ss)
			//log.Printf(" session id: %s, ttl: %d", id, ss.TTL)
		}
	}
	return m, nil
}

// 新增数据
func InsertData() {
	err := InserSession("8bed7b4f-d201-4ca1-a816-8adce84a9164", 1598602333474, "xiaoming")
	if err != nil {
		fmt.Println(err)
	}
}

// 删除数据
func DeleteData() {
	err := DeleteSession("8bed7b4f-d201-4ca1-a816-8adce84a9164")
	if err != nil {
		fmt.Println(err)
	}
}

// 修改数据
func UpdateData() {
	err := UpdateSession("8bed7b4f-d201-4ca1-a816-8adce84a9164", 1598602333475)
	if err != nil {
		fmt.Println(err)
	}
}

// 查询一条数据
func SelectDataOne() {
	ss, err := SelectOne("8bed7b4f-d201-4ca1-a816-8adce84a9163")
	if err != nil {
		fmt.Println(err)
	} else {
		fmt.Println(ss) // your code
	}
}

// 查询多条数据 func SelectDataAll() { ss, err := SelectAll() if err != nil { fmt.Println(err) } else { fmt.Println(ss) // your code } } func main() { // 连接数据库 dbInit() // 增 //InsertData() // 删 //DeleteData() // 改 //UpdateData() // 查 一条数据 //SelectDataOne() // 查 多条数据 SelectDataAll() }