package auth import ( "context" "fmt" "github.com/google/uuid" jsonIterator "github.com/json-iterator/go" redis "github.com/redis/go-redis/v9" "github.com/tidwall/buntdb" "time" ) var ( jsonMarshal = jsonIterator.Marshal jsonUnmarshal = jsonIterator.Unmarshal ) type TokenStore interface { Create(info TokenInfo) error RemoveByAccess(access string) error RemoveByRefresh(refresh string) error GetByAccess(access string) (TokenInfo, error) GetByRefresh(refresh string) (TokenInfo, error) } // NewMemoryTokenStore create a token buntStore instance based on memory func NewMemoryTokenStore() (TokenStore, error) { return NewFileTokenStore(":memory:") } // NewFileTokenStore create a token buntStore instance based on file func NewFileTokenStore(filename string) (TokenStore, error) { db, err := buntdb.Open(filename) if err != nil { return nil, err } return &buntStore{db: db}, nil } // buntStore token storage based on buntdb(https://github.com/tidwall/buntdb) type buntStore struct { db *buntdb.DB } func (ts *buntStore) remove(key string) error { err := ts.db.Update(func(tx *buntdb.Tx) error { _, err := tx.Delete(key) return err }) if err == buntdb.ErrNotFound { return nil } return err } func (ts *buntStore) getData(key string) (TokenInfo, error) { var ti TokenInfo err := ts.db.View(func(tx *buntdb.Tx) error { jv, err := tx.Get(key) if err != nil { return err } var tm Token err = jsonUnmarshal([]byte(jv), &tm) if err != nil { return err } ti = &tm return nil }) if err != nil { if err == buntdb.ErrNotFound { return nil, nil } return nil, err } return ti, nil } func (ts *buntStore) getBasicID(key string) (string, error) { var basicID string err := ts.db.View(func(tx *buntdb.Tx) error { v, err := tx.Get(key) if err != nil { return err } basicID = v return nil }) if err != nil { if err == buntdb.ErrNotFound { return "", nil } return "", err } return basicID, nil } // Create and buntStore the new token information func (ts *buntStore) Create(info TokenInfo) error { ct := time.Now() jv, err := jsonMarshal(info) if err != nil { return err } return ts.db.Update(func(tx *buntdb.Tx) error { basicID := uuid.Must(uuid.NewRandom()).String() aexp := info.GetAccessExpiresIn() rexp := aexp expires := true if refresh := info.GetRefresh(); refresh != "" { rexp = info.GetRefreshCreateAt().Add(info.GetRefreshExpiresIn()).Sub(ct) if aexp.Seconds() > rexp.Seconds() { aexp = rexp } expires = info.GetRefreshExpiresIn() != 0 _, _, err = tx.Set(refresh, basicID, &buntdb.SetOptions{Expires: expires, TTL: rexp}) if err != nil { return err } } _, _, err = tx.Set(basicID, string(jv), &buntdb.SetOptions{Expires: expires, TTL: rexp}) if err != nil { return err } _, _, err = tx.Set(info.GetAccess(), basicID, &buntdb.SetOptions{Expires: expires, TTL: aexp}) return err }) } // RemoveByAccess use the access token to delete the token information func (ts *buntStore) RemoveByAccess(access string) error { return ts.remove(access) } // RemoveByRefresh use the refresh token to delete the token information func (ts *buntStore) RemoveByRefresh(refresh string) error { return ts.remove(refresh) } // GetByAccess use the access token for token information data func (ts *buntStore) GetByAccess(access string) (TokenInfo, error) { basicID, err := ts.getBasicID(access) if err != nil { return nil, err } return ts.getData(basicID) } // GetByRefresh use the refresh token for token information data func (ts *buntStore) GetByRefresh(refresh string) (TokenInfo, error) { basicID, err := ts.getBasicID(refresh) if err != nil { return nil, err } return ts.getData(basicID) } /*------------------------------------------------------------------------------------*/ // NewRedisStoreWithCli create an instance of a redis store func NewRedisStoreWithCli(cli *redis.Client, keyNamespace string) TokenStore { store := &redisStore{ cli: cli, ctx: context.TODO(), ns: keyNamespace, } return store } // TokenStore redis token store type redisStore struct { cli *redis.Client ctx context.Context ns string } func (s *redisStore) wrapperKey(key string) string { return fmt.Sprintf("%s%s", s.ns, key) } func (s *redisStore) checkError(result redis.Cmder) (bool, error) { if err := result.Err(); err != nil { if err == redis.Nil { return true, nil } return false, err } return false, nil } func (s *redisStore) remove(key string) error { result := s.cli.Del(s.ctx, s.wrapperKey(key)) _, err := s.checkError(result) return err } func (s *redisStore) removeToken(tokenString string, isRefresh bool) error { basicID, err := s.getBasicID(tokenString) if err != nil { return err } else if basicID == "" { return nil } err = s.remove(tokenString) if err != nil { return err } token, err := s.getToken(basicID) if err != nil { return err } else if token == nil { return nil } checkToken := token.GetRefresh() if isRefresh { checkToken = token.GetAccess() } result := s.cli.Exists(s.ctx, s.wrapperKey(checkToken)) if err = result.Err(); err != nil && err != redis.Nil { return err } else if result.Val() == 0 { return s.remove(basicID) } return nil } func (s *redisStore) parseToken(result *redis.StringCmd) (TokenInfo, error) { if ok, err := s.checkError(result); err != nil { return nil, err } else if ok { return nil, nil } buf, err := result.Bytes() if err != nil { if err == redis.Nil { return nil, nil } return nil, err } var token Token if err = jsonUnmarshal(buf, &token); err != nil { return nil, err } return &token, nil } func (s *redisStore) getToken(key string) (TokenInfo, error) { result := s.cli.Get(s.ctx, s.wrapperKey(key)) return s.parseToken(result) } func (s *redisStore) parseBasicID(result *redis.StringCmd) (string, error) { if ok, err := s.checkError(result); err != nil { return "", err } else if ok { return "", nil } return result.Val(), nil } func (s *redisStore) getBasicID(token string) (string, error) { result := s.cli.Get(s.ctx, s.wrapperKey(token)) return s.parseBasicID(result) } // Create and store the new token information func (s *redisStore) Create(info TokenInfo) error { ct := time.Now() jv, err := jsonMarshal(info) if err != nil { return err } pipe := s.cli.TxPipeline() basicID := uuid.Must(uuid.NewRandom()).String() aexp := info.GetAccessExpiresIn() rexp := aexp if refresh := info.GetRefresh(); refresh != "" { rexp = info.GetRefreshCreateAt().Add(info.GetRefreshExpiresIn()).Sub(ct) if aexp.Seconds() > rexp.Seconds() { aexp = rexp } pipe.Set(s.ctx, s.wrapperKey(refresh), basicID, rexp) } pipe.Set(s.ctx, s.wrapperKey(info.GetAccess()), basicID, aexp) pipe.Set(s.ctx, s.wrapperKey(basicID), jv, rexp) if _, err = pipe.Exec(s.ctx); err != nil { return err } return nil } // RemoveByAccess Use the access token to delete the token information func (s *redisStore) RemoveByAccess(access string) error { return s.removeToken(access, false) } // RemoveByRefresh Use the refresh token to delete the token information func (s *redisStore) RemoveByRefresh(refresh string) error { return s.removeToken(refresh, true) } // GetByAccess Use the access token for token information data func (s *redisStore) GetByAccess(access string) (TokenInfo, error) { basicID, err := s.getBasicID(access) if err != nil || basicID == "" { return nil, err } return s.getToken(basicID) } // GetByRefresh Use the refresh token for token information data func (s *redisStore) GetByRefresh(refresh string) (TokenInfo, error) { basicID, err := s.getBasicID(refresh) if err != nil || basicID == "" { return nil, err } return s.getToken(basicID) }