mongo.go 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. package database
  2. import (
  3. "context"
  4. "fmt"
  5. "go.mongodb.org/mongo-driver/mongo"
  6. "go.mongodb.org/mongo-driver/mongo/options"
  7. "go.mongodb.org/mongo-driver/mongo/readpref"
  8. "time"
  9. )
  10. var _ MongoDB = (*mongoDB)(nil)
  11. type MongoDB interface {
  12. i()
  13. GetDB() *mongo.Database
  14. Close() error
  15. }
  16. type MongoDBConfig struct {
  17. Addr string `yaml:"addr"`
  18. User string `yaml:"user"`
  19. Pass string `yaml:"pass"`
  20. Name string `yaml:"name"`
  21. Timeout time.Duration `yaml:"timeout"`
  22. }
  23. type mongoDB struct {
  24. client *mongo.Client
  25. db *mongo.Database
  26. timeout time.Duration
  27. }
  28. func (m *mongoDB) i() {}
  29. func New(cfg MongoDBConfig) (MongoDB, error) {
  30. timeout := cfg.Timeout * time.Second
  31. connectCtx, connectCancelFunc := context.WithTimeout(context.Background(), timeout)
  32. defer connectCancelFunc()
  33. var auth string
  34. if len(cfg.User) > 0 && len(cfg.Pass) > 0 {
  35. auth = fmt.Sprintf("%s:%s@", cfg.User, cfg.Pass)
  36. }
  37. client, err := mongo.Connect(connectCtx, options.Client().ApplyURI(
  38. fmt.Sprintf("mongodb://%s%s", auth, cfg.Addr),
  39. ))
  40. if err != nil {
  41. return nil, err
  42. }
  43. pingCtx, pingCancelFunc := context.WithTimeout(context.Background(), timeout)
  44. defer pingCancelFunc()
  45. err = client.Ping(pingCtx, readpref.Primary())
  46. if err != nil {
  47. return nil, err
  48. }
  49. return &mongoDB{
  50. client: client,
  51. db: client.Database(cfg.Name),
  52. timeout: timeout,
  53. }, nil
  54. }
  55. func (m *mongoDB) GetDB() *mongo.Database {
  56. return m.db
  57. }
  58. func (m *mongoDB) Close() error {
  59. disconnectCtx, disconnectCancelFunc := context.WithTimeout(context.Background(), m.timeout)
  60. defer disconnectCancelFunc()
  61. err := m.client.Disconnect(disconnectCtx)
  62. if err != nil {
  63. return err
  64. }
  65. return nil
  66. }