store.go 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. package auth
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "github.com/google/uuid"
  7. jsonIterator "github.com/json-iterator/go"
  8. "github.com/redis/go-redis/v9"
  9. "github.com/tidwall/buntdb"
  10. "time"
  11. )
  12. var (
  13. jsonMarshal = jsonIterator.Marshal
  14. jsonUnmarshal = jsonIterator.Unmarshal
  15. )
  16. type TokenStore interface {
  17. Create(info TokenInfo) error
  18. RemoveByAccess(access string) error
  19. RemoveByRefresh(refresh string) error
  20. GetByAccess(access string) (TokenInfo, error)
  21. GetByRefresh(refresh string) (TokenInfo, error)
  22. }
  23. // NewMemoryTokenStore create a token buntStore instance based on memory
  24. func NewMemoryTokenStore() (TokenStore, error) {
  25. return NewFileTokenStore(":memory:")
  26. }
  27. // NewFileTokenStore create a token buntStore instance based on file
  28. func NewFileTokenStore(filename string) (TokenStore, error) {
  29. db, err := buntdb.Open(filename)
  30. if err != nil {
  31. return nil, err
  32. }
  33. return &buntStore{db: db}, nil
  34. }
  35. // buntStore token storage based on buntdb(https://github.com/tidwall/buntdb)
  36. type buntStore struct {
  37. db *buntdb.DB
  38. }
  39. func (ts *buntStore) remove(key string) error {
  40. err := ts.db.Update(func(tx *buntdb.Tx) error {
  41. _, err := tx.Delete(key)
  42. return err
  43. })
  44. if errors.Is(err, buntdb.ErrNotFound) {
  45. return nil
  46. }
  47. return err
  48. }
  49. func (ts *buntStore) getData(key string) (TokenInfo, error) {
  50. var ti TokenInfo
  51. err := ts.db.View(func(tx *buntdb.Tx) error {
  52. jv, err := tx.Get(key)
  53. if err != nil {
  54. return err
  55. }
  56. var tm Token
  57. err = jsonUnmarshal([]byte(jv), &tm)
  58. if err != nil {
  59. return err
  60. }
  61. ti = &tm
  62. return nil
  63. })
  64. if err != nil {
  65. if err == buntdb.ErrNotFound {
  66. return nil, nil
  67. }
  68. return nil, err
  69. }
  70. return ti, nil
  71. }
  72. func (ts *buntStore) getBasicID(key string) (string, error) {
  73. var basicID string
  74. err := ts.db.View(func(tx *buntdb.Tx) error {
  75. v, err := tx.Get(key)
  76. if err != nil {
  77. return err
  78. }
  79. basicID = v
  80. return nil
  81. })
  82. if err != nil {
  83. if err == buntdb.ErrNotFound {
  84. return "", nil
  85. }
  86. return "", err
  87. }
  88. return basicID, nil
  89. }
  90. // Create and buntStore the new token information
  91. func (ts *buntStore) Create(info TokenInfo) error {
  92. ct := time.Now()
  93. jv, err := jsonMarshal(info)
  94. if err != nil {
  95. return err
  96. }
  97. return ts.db.Update(func(tx *buntdb.Tx) error {
  98. basicID := uuid.Must(uuid.NewRandom()).String()
  99. aexp := info.GetAccessExpiresIn()
  100. rexp := aexp
  101. expires := true
  102. if refresh := info.GetRefresh(); refresh != "" {
  103. rexp = info.GetRefreshCreateAt().Add(info.GetRefreshExpiresIn()).Sub(ct)
  104. if aexp.Seconds() > rexp.Seconds() {
  105. aexp = rexp
  106. }
  107. expires = info.GetRefreshExpiresIn() != 0
  108. _, _, err = tx.Set(refresh, basicID, &buntdb.SetOptions{Expires: expires, TTL: rexp})
  109. if err != nil {
  110. return err
  111. }
  112. }
  113. _, _, err = tx.Set(basicID, string(jv), &buntdb.SetOptions{Expires: expires, TTL: rexp})
  114. if err != nil {
  115. return err
  116. }
  117. _, _, err = tx.Set(info.GetAccess(), basicID, &buntdb.SetOptions{Expires: expires, TTL: aexp})
  118. return err
  119. })
  120. }
  121. // RemoveByAccess use the access token to delete the token information
  122. func (ts *buntStore) RemoveByAccess(access string) error {
  123. return ts.remove(access)
  124. }
  125. // RemoveByRefresh use the refresh token to delete the token information
  126. func (ts *buntStore) RemoveByRefresh(refresh string) error {
  127. return ts.remove(refresh)
  128. }
  129. // GetByAccess use the access token for token information data
  130. func (ts *buntStore) GetByAccess(access string) (TokenInfo, error) {
  131. basicID, err := ts.getBasicID(access)
  132. if err != nil {
  133. return nil, err
  134. }
  135. return ts.getData(basicID)
  136. }
  137. // GetByRefresh use the refresh token for token information data
  138. func (ts *buntStore) GetByRefresh(refresh string) (TokenInfo, error) {
  139. basicID, err := ts.getBasicID(refresh)
  140. if err != nil {
  141. return nil, err
  142. }
  143. return ts.getData(basicID)
  144. }
  145. /*------------------------------------------------------------------------------------*/
  146. // NewRedisStoreWithCli create an instance of a redis store
  147. func NewRedisStoreWithCli(cli *redis.Client, keyNamespace string) TokenStore {
  148. store := &redisStore{
  149. cli: cli,
  150. ctx: context.TODO(),
  151. ns: keyNamespace,
  152. }
  153. return store
  154. }
  155. // TokenStore redis token store
  156. type redisStore struct {
  157. cli *redis.Client
  158. ctx context.Context
  159. ns string
  160. }
  161. func (s *redisStore) wrapperKey(key string) string {
  162. return fmt.Sprintf("%s%s", s.ns, key)
  163. }
  164. func (s *redisStore) checkError(result redis.Cmder) (bool, error) {
  165. if err := result.Err(); err != nil {
  166. if err == redis.Nil {
  167. return true, nil
  168. }
  169. return false, err
  170. }
  171. return false, nil
  172. }
  173. func (s *redisStore) remove(key string) error {
  174. result := s.cli.Del(s.ctx, s.wrapperKey(key))
  175. _, err := s.checkError(result)
  176. return err
  177. }
  178. func (s *redisStore) removeToken(tokenString string, isRefresh bool) error {
  179. basicID, err := s.getBasicID(tokenString)
  180. if err != nil {
  181. return err
  182. } else if basicID == "" {
  183. return nil
  184. }
  185. err = s.remove(tokenString)
  186. if err != nil {
  187. return err
  188. }
  189. token, err := s.getToken(basicID)
  190. if err != nil {
  191. return err
  192. } else if token == nil {
  193. return nil
  194. }
  195. checkToken := token.GetRefresh()
  196. if isRefresh {
  197. checkToken = token.GetAccess()
  198. }
  199. result := s.cli.Exists(s.ctx, s.wrapperKey(checkToken))
  200. if err = result.Err(); err != nil && err != redis.Nil {
  201. return err
  202. } else if result.Val() == 0 {
  203. return s.remove(basicID)
  204. }
  205. return nil
  206. }
  207. func (s *redisStore) parseToken(result *redis.StringCmd) (TokenInfo, error) {
  208. if ok, err := s.checkError(result); err != nil {
  209. return nil, err
  210. } else if ok {
  211. return nil, nil
  212. }
  213. buf, err := result.Bytes()
  214. if err != nil {
  215. if err == redis.Nil {
  216. return nil, nil
  217. }
  218. return nil, err
  219. }
  220. var token Token
  221. if err = jsonUnmarshal(buf, &token); err != nil {
  222. return nil, err
  223. }
  224. return &token, nil
  225. }
  226. func (s *redisStore) getToken(key string) (TokenInfo, error) {
  227. result := s.cli.Get(s.ctx, s.wrapperKey(key))
  228. return s.parseToken(result)
  229. }
  230. func (s *redisStore) parseBasicID(result *redis.StringCmd) (string, error) {
  231. if ok, err := s.checkError(result); err != nil {
  232. return "", err
  233. } else if ok {
  234. return "", nil
  235. }
  236. return result.Val(), nil
  237. }
  238. func (s *redisStore) getBasicID(token string) (string, error) {
  239. result := s.cli.Get(s.ctx, s.wrapperKey(token))
  240. return s.parseBasicID(result)
  241. }
  242. // Create and store the new token information
  243. func (s *redisStore) Create(info TokenInfo) error {
  244. ct := time.Now()
  245. jv, err := jsonMarshal(info)
  246. if err != nil {
  247. return err
  248. }
  249. pipe := s.cli.TxPipeline()
  250. basicID := uuid.Must(uuid.NewRandom()).String()
  251. aexp := info.GetAccessExpiresIn()
  252. rexp := aexp
  253. if refresh := info.GetRefresh(); refresh != "" {
  254. rexp = info.GetRefreshCreateAt().Add(info.GetRefreshExpiresIn()).Sub(ct)
  255. if aexp.Seconds() > rexp.Seconds() {
  256. aexp = rexp
  257. }
  258. pipe.Set(s.ctx, s.wrapperKey(refresh), basicID, rexp)
  259. }
  260. pipe.Set(s.ctx, s.wrapperKey(info.GetAccess()), basicID, aexp)
  261. pipe.Set(s.ctx, s.wrapperKey(basicID), jv, rexp)
  262. if _, err = pipe.Exec(s.ctx); err != nil {
  263. return err
  264. }
  265. return nil
  266. }
  267. // RemoveByAccess Use the access token to delete the token information
  268. func (s *redisStore) RemoveByAccess(access string) error {
  269. return s.removeToken(access, false)
  270. }
  271. // RemoveByRefresh Use the refresh token to delete the token information
  272. func (s *redisStore) RemoveByRefresh(refresh string) error {
  273. return s.removeToken(refresh, true)
  274. }
  275. // GetByAccess Use the access token for token information data
  276. func (s *redisStore) GetByAccess(access string) (TokenInfo, error) {
  277. basicID, err := s.getBasicID(access)
  278. if err != nil || basicID == "" {
  279. return nil, err
  280. }
  281. return s.getToken(basicID)
  282. }
  283. // GetByRefresh Use the refresh token for token information data
  284. func (s *redisStore) GetByRefresh(refresh string) (TokenInfo, error) {
  285. basicID, err := s.getBasicID(refresh)
  286. if err != nil || basicID == "" {
  287. return nil, err
  288. }
  289. return s.getToken(basicID)
  290. }