mysql.go 6.1 KB


  1. package database
  2. import (
  3. "fmt"
  4. "git.bvbej.com/bvbej/base-golang/pkg/errors"
  5. "git.bvbej.com/bvbej/base-golang/pkg/time_parse"
  6. "git.bvbej.com/bvbej/base-golang/pkg/trace"
  7. "gorm.io/driver/mysql"
  8. "gorm.io/gorm"
  9. "gorm.io/gorm/logger"
  10. "gorm.io/gorm/schema"
  11. "gorm.io/gorm/utils"
  12. "log"
  13. "os"
  14. "time"
  15. )
  16. const (
  17. callBackBeforeName = "core:before"
  18. callBackAfterName = "core:after"
  19. startTime = "_start_time"
  20. traceCtxName = "_trace_ctx_name"
  21. )
  22. var _ MysqlRepo = (*mysqlRepo)(nil)
  23. type MysqlRepo interface {
  24. i()
  25. GetRead(options ...Option) *gorm.DB
  26. GetWrite(options ...Option) *gorm.DB
  27. Close() error
  28. }
  29. type MySQLConfig struct {
  30. Read struct {
  31. Addr string `yaml:"addr"`
  32. User string `yaml:"user"`
  33. Pass string `yaml:"pass"`
  34. Name string `yaml:"name"`
  35. } `yaml:"read"`
  36. Write struct {
  37. Addr string `yaml:"addr"`
  38. User string `yaml:"user"`
  39. Pass string `yaml:"pass"`
  40. Name string `yaml:"name"`
  41. } `yaml:"write"`
  42. Base struct {
  43. MaxOpenConn int `yaml:"maxOpenConn"` //最大连接数
  44. MaxIdleConn int `yaml:"maxIdleConn"` //最大空闲连接数
  45. ConnMaxLifeTime time.Duration `yaml:"connMaxLifeTime"` //最大连接超时(分钟)
  46. } `yaml:"base"`
  47. }
  48. type mysqlRepo struct {
  49. read *gorm.DB
  50. write *gorm.DB
  51. }
  52. func NewMysql(cfg MySQLConfig) (MysqlRepo, error) {
  53. dbr, err := dbConnect(cfg.Read.User, cfg.Read.Pass, cfg.Read.Addr, cfg.Read.Name,
  54. cfg.Base.MaxOpenConn, cfg.Base.MaxIdleConn, cfg.Base.ConnMaxLifeTime)
  55. if err != nil {
  56. return nil, err
  57. }
  58. dbw, err := dbConnect(cfg.Write.User, cfg.Write.Pass, cfg.Write.Addr, cfg.Write.Name,
  59. cfg.Base.MaxOpenConn, cfg.Base.MaxIdleConn, cfg.Base.ConnMaxLifeTime)
  60. if err != nil {
  61. return nil, err
  62. }
  63. return &mysqlRepo{
  64. read: dbr,
  65. write: dbw,
  66. }, nil
  67. }
  68. func (d *mysqlRepo) i() {}
  69. func (d *mysqlRepo) GetRead(options ...Option) *gorm.DB {
  70. opt := newOption()
  71. for _, f := range options {
  72. f(opt)
  73. }
  74. db := d.read
  75. if opt.Trace != nil {
  76. db.InstanceSet(traceCtxName, opt.Trace)
  77. }
  78. return db
  79. }
  80. func (d *mysqlRepo) GetWrite(options ...Option) *gorm.DB {
  81. opt := newOption()
  82. for _, f := range options {
  83. f(opt)
  84. }
  85. db := d.write
  86. if opt.Trace != nil {
  87. db.InstanceSet(traceCtxName, opt.Trace)
  88. }
  89. return db
  90. }
  91. func (d *mysqlRepo) Close() (err error) {
  92. rdb, err1 := d.read.DB()
  93. if err1 != nil {
  94. err = errors.WithStack(err1)
  95. }
  96. err2 := rdb.Close()
  97. if err2 != nil {
  98. err = errors.WithStack(err2)
  99. }
  100. wdb, err3 := d.write.DB()
  101. if err3 != nil {
  102. err = errors.WithStack(err3)
  103. }
  104. err4 := wdb.Close()
  105. if err4 != nil {
  106. err = errors.WithStack(err4)
  107. }
  108. return err
  109. }
  110. func dbConnect(user, pass, addr, dbName string, maxOpenConn, maxIdleConn int, connMaxLifeTime time.Duration) (*gorm.DB, error) {
  111. dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=%t&loc=%s",
  112. user,
  113. pass,
  114. addr,
  115. dbName,
  116. true,
  117. "Local")
  118. // 日志配置
  119. newLogger := logger.New(
  120. log.New(os.Stdout, "\r\n", log.LstdFlags),
  121. logger.Config{
  122. SlowThreshold: time.Second, // 慢SQL阈值
  123. Colorful: true, // 彩色打印
  124. IgnoreRecordNotFoundError: true, // 忽略记录未找到错误
  125. LogLevel: logger.Error, // 日志级别
  126. },
  127. )
  128. db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
  129. NamingStrategy: schema.NamingStrategy{
  130. SingularTable: true,
  131. },
  132. Logger: newLogger,
  133. })
  134. if err != nil {
  135. return nil, errors.Wrap(err, fmt.Sprintf("[db connection failed] Database name: %s", dbName))
  136. }
  137. db.Set("gorm:table_options", "CHARSET=utf8mb4")
  138. sqlDB, err := db.DB()
  139. if err != nil {
  140. return nil, err
  141. }
  142. // 设置连接池 用于设置最大打开的连接数,默认值为0表示不限制.设置最大的连接数,可以避免并发太高导致连接mysql出现too many connections的错误。
  143. sqlDB.SetMaxOpenConns(maxOpenConn)
  144. // 设置最大连接数 用于设置闲置的连接数.设置闲置的连接数则当开启的一个连接使用完成后可以放在池里等候下一次使用。
  145. sqlDB.SetMaxIdleConns(maxIdleConn)
  146. // 设置最大连接超时
  147. sqlDB.SetConnMaxLifetime(time.Minute * connMaxLifeTime)
  148. // 使用插件
  149. err = db.Use(&TracePlugin{})
  150. if err != nil {
  151. return nil, err
  152. }
  153. return db, nil
  154. }
  155. /***************************************************************/
  156. type TracePlugin struct{}
  157. func (op *TracePlugin) Name() string {
  158. return "TracePlugin"
  159. }
  160. func (op *TracePlugin) Initialize(db *gorm.DB) (err error) {
  161. // 开始前
  162. _ = db.Callback().Create().Before("gorm:before_create").Register(callBackBeforeName, before)
  163. _ = db.Callback().Query().Before("gorm:query").Register(callBackBeforeName, before)
  164. _ = db.Callback().Delete().Before("gorm:before_delete").Register(callBackBeforeName, before)
  165. _ = db.Callback().Update().Before("gorm:setup_reflect_value").Register(callBackBeforeName, before)
  166. _ = db.Callback().Row().Before("gorm:row").Register(callBackBeforeName, before)
  167. _ = db.Callback().Raw().Before("gorm:raw").Register(callBackBeforeName, before)
  168. // 结束后
  169. _ = db.Callback().Create().After("gorm:after_create").Register(callBackAfterName, after)
  170. _ = db.Callback().Query().After("gorm:after_query").Register(callBackAfterName, after)
  171. _ = db.Callback().Delete().After("gorm:after_delete").Register(callBackAfterName, after)
  172. _ = db.Callback().Update().After("gorm:after_update").Register(callBackAfterName, after)
  173. _ = db.Callback().Row().After("gorm:row").Register(callBackAfterName, after)
  174. _ = db.Callback().Raw().After("gorm:raw").Register(callBackAfterName, after)
  175. return
  176. }
  177. func before(db *gorm.DB) {
  178. db.InstanceSet(startTime, time.Now())
  179. }
  180. func after(db *gorm.DB) {
  181. _traceCtx, isExist := db.InstanceGet(traceCtxName)
  182. if !isExist {
  183. return
  184. }
  185. _trace, ok := _traceCtx.(trace.T)
  186. if !ok {
  187. return
  188. }
  189. _ts, isExist := db.InstanceGet(startTime)
  190. if !isExist {
  191. return
  192. }
  193. ts, ok := _ts.(time.Time)
  194. if !ok {
  195. return
  196. }
  197. sql := db.Dialector.Explain(db.Statement.SQL.String(), db.Statement.Vars...)
  198. sqlInfo := new(trace.SQL)
  199. sqlInfo.Timestamp = time_parse.CSTLayoutString()
  200. sqlInfo.SQL = sql
  201. sqlInfo.Stack = utils.FileWithLineNum()
  202. sqlInfo.Rows = db.Statement.RowsAffected
  203. sqlInfo.CostSeconds = time.Since(ts).Seconds()
  204. _trace.AppendSQL(sqlInfo)
  205. }