package auth import ( "encoding/json" "time" "github.com/google/uuid" "github.com/tidwall/buntdb" ) type TokenStore interface { Create(info TokenInfo) error RemoveByAccess(access string) error RemoveByRefresh(refresh string) error GetByAccess(access string) (TokenInfo, error) GetByRefresh(refresh string) (TokenInfo, error) } // NewMemoryTokenStore create a token Store instance based on memory func NewMemoryTokenStore() (TokenStore, error) { return NewFileTokenStore(":memory:") } // NewFileTokenStore create a token Store instance based on file func NewFileTokenStore(filename string) (TokenStore, error) { db, err := buntdb.Open(filename) if err != nil { return nil, err } return &Store{db: db}, nil } // Store token storage based on buntdb(https://github.com/tidwall/buntdb) type Store struct { db *buntdb.DB } func (ts *Store) remove(key string) error { err := ts.db.Update(func(tx *buntdb.Tx) error { _, err := tx.Delete(key) return err }) if err == buntdb.ErrNotFound { return nil } return err } func (ts *Store) getData(key string) (TokenInfo, error) { var ti TokenInfo err := ts.db.View(func(tx *buntdb.Tx) error { jv, err := tx.Get(key) if err != nil { return err } var tm Token err = json.Unmarshal([]byte(jv), &tm) if err != nil { return err } ti = &tm return nil }) if err != nil { if err == buntdb.ErrNotFound { return nil, nil } return nil, err } return ti, nil } func (ts *Store) getBasicID(key string) (string, error) { var basicID string err := ts.db.View(func(tx *buntdb.Tx) error { v, err := tx.Get(key) if err != nil { return err } basicID = v return nil }) if err != nil { if err == buntdb.ErrNotFound { return "", nil } return "", err } return basicID, nil } // Create and Store the new token information func (ts *Store) Create(info TokenInfo) error { ct := time.Now() jv, err := json.Marshal(info) if err != nil { return err } return ts.db.Update(func(tx *buntdb.Tx) error { basicID := uuid.Must(uuid.NewRandom()).String() aexp := info.GetAccessExpiresIn() rexp := aexp expires := true if refresh := info.GetRefresh(); refresh != "" { rexp = info.GetRefreshCreateAt().Add(info.GetRefreshExpiresIn()).Sub(ct) if aexp.Seconds() > rexp.Seconds() { aexp = rexp } expires = info.GetRefreshExpiresIn() != 0 _, _, err = tx.Set(refresh, basicID, &buntdb.SetOptions{Expires: expires, TTL: rexp}) if err != nil { return err } } _, _, err = tx.Set(basicID, string(jv), &buntdb.SetOptions{Expires: expires, TTL: rexp}) if err != nil { return err } _, _, err = tx.Set(info.GetAccess(), basicID, &buntdb.SetOptions{Expires: expires, TTL: aexp}) return err }) } // RemoveByAccess use the access token to delete the token information func (ts *Store) RemoveByAccess(access string) error { return ts.remove(access) } // RemoveByRefresh use the refresh token to delete the token information func (ts *Store) RemoveByRefresh(refresh string) error { return ts.remove(refresh) } // GetByAccess use the access token for token information data func (ts *Store) GetByAccess(access string) (TokenInfo, error) { basicID, err := ts.getBasicID(access) if err != nil { return nil, err } return ts.getData(basicID) } // GetByRefresh use the refresh token for token information data func (ts *Store) GetByRefresh(refresh string) (TokenInfo, error) { basicID, err := ts.getBasicID(refresh) if err != nil { return nil, err } return ts.getData(basicID) }