store.go 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. package auth
  2. import (
  3. "context"
  4. "fmt"
  5. "github.com/google/uuid"
  6. jsonIterator "github.com/json-iterator/go"
  7. redis "github.com/redis/go-redis/v9"
  8. "github.com/tidwall/buntdb"
  9. "time"
  10. )
  11. var (
  12. jsonMarshal = jsonIterator.Marshal
  13. jsonUnmarshal = jsonIterator.Unmarshal
  14. )
  15. type TokenStore interface {
  16. Create(info TokenInfo) error
  17. RemoveByAccess(access string) error
  18. RemoveByRefresh(refresh string) error
  19. GetByAccess(access string) (TokenInfo, error)
  20. GetByRefresh(refresh string) (TokenInfo, error)
  21. }
  22. // NewMemoryTokenStore create a token buntStore instance based on memory
  23. func NewMemoryTokenStore() (TokenStore, error) {
  24. return NewFileTokenStore(":memory:")
  25. }
  26. // NewFileTokenStore create a token buntStore instance based on file
  27. func NewFileTokenStore(filename string) (TokenStore, error) {
  28. db, err := buntdb.Open(filename)
  29. if err != nil {
  30. return nil, err
  31. }
  32. return &buntStore{db: db}, nil
  33. }
  34. // buntStore token storage based on buntdb(https://github.com/tidwall/buntdb)
  35. type buntStore struct {
  36. db *buntdb.DB
  37. }
  38. func (ts *buntStore) remove(key string) error {
  39. err := ts.db.Update(func(tx *buntdb.Tx) error {
  40. _, err := tx.Delete(key)
  41. return err
  42. })
  43. if err == buntdb.ErrNotFound {
  44. return nil
  45. }
  46. return err
  47. }
  48. func (ts *buntStore) getData(key string) (TokenInfo, error) {
  49. var ti TokenInfo
  50. err := ts.db.View(func(tx *buntdb.Tx) error {
  51. jv, err := tx.Get(key)
  52. if err != nil {
  53. return err
  54. }
  55. var tm Token
  56. err = jsonUnmarshal([]byte(jv), &tm)
  57. if err != nil {
  58. return err
  59. }
  60. ti = &tm
  61. return nil
  62. })
  63. if err != nil {
  64. if err == buntdb.ErrNotFound {
  65. return nil, nil
  66. }
  67. return nil, err
  68. }
  69. return ti, nil
  70. }
  71. func (ts *buntStore) getBasicID(key string) (string, error) {
  72. var basicID string
  73. err := ts.db.View(func(tx *buntdb.Tx) error {
  74. v, err := tx.Get(key)
  75. if err != nil {
  76. return err
  77. }
  78. basicID = v
  79. return nil
  80. })
  81. if err != nil {
  82. if err == buntdb.ErrNotFound {
  83. return "", nil
  84. }
  85. return "", err
  86. }
  87. return basicID, nil
  88. }
  89. // Create and buntStore the new token information
  90. func (ts *buntStore) Create(info TokenInfo) error {
  91. ct := time.Now()
  92. jv, err := jsonMarshal(info)
  93. if err != nil {
  94. return err
  95. }
  96. return ts.db.Update(func(tx *buntdb.Tx) error {
  97. basicID := uuid.Must(uuid.NewRandom()).String()
  98. aexp := info.GetAccessExpiresIn()
  99. rexp := aexp
  100. expires := true
  101. if refresh := info.GetRefresh(); refresh != "" {
  102. rexp = info.GetRefreshCreateAt().Add(info.GetRefreshExpiresIn()).Sub(ct)
  103. if aexp.Seconds() > rexp.Seconds() {
  104. aexp = rexp
  105. }
  106. expires = info.GetRefreshExpiresIn() != 0
  107. _, _, err = tx.Set(refresh, basicID, &buntdb.SetOptions{Expires: expires, TTL: rexp})
  108. if err != nil {
  109. return err
  110. }
  111. }
  112. _, _, err = tx.Set(basicID, string(jv), &buntdb.SetOptions{Expires: expires, TTL: rexp})
  113. if err != nil {
  114. return err
  115. }
  116. _, _, err = tx.Set(info.GetAccess(), basicID, &buntdb.SetOptions{Expires: expires, TTL: aexp})
  117. return err
  118. })
  119. }
  120. // RemoveByAccess use the access token to delete the token information
  121. func (ts *buntStore) RemoveByAccess(access string) error {
  122. return ts.remove(access)
  123. }
  124. // RemoveByRefresh use the refresh token to delete the token information
  125. func (ts *buntStore) RemoveByRefresh(refresh string) error {
  126. return ts.remove(refresh)
  127. }
  128. // GetByAccess use the access token for token information data
  129. func (ts *buntStore) GetByAccess(access string) (TokenInfo, error) {
  130. basicID, err := ts.getBasicID(access)
  131. if err != nil {
  132. return nil, err
  133. }
  134. return ts.getData(basicID)
  135. }
  136. // GetByRefresh use the refresh token for token information data
  137. func (ts *buntStore) GetByRefresh(refresh string) (TokenInfo, error) {
  138. basicID, err := ts.getBasicID(refresh)
  139. if err != nil {
  140. return nil, err
  141. }
  142. return ts.getData(basicID)
  143. }
  144. /*------------------------------------------------------------------------------------*/
  145. // NewRedisStoreWithCli create an instance of a redis store
  146. func NewRedisStoreWithCli(cli *redis.Client, keyNamespace string) TokenStore {
  147. store := &redisStore{
  148. cli: cli,
  149. ctx: context.TODO(),
  150. ns: keyNamespace,
  151. }
  152. return store
  153. }
  154. // TokenStore redis token store
  155. type redisStore struct {
  156. cli *redis.Client
  157. ctx context.Context
  158. ns string
  159. }
  160. func (s *redisStore) wrapperKey(key string) string {
  161. return fmt.Sprintf("%s%s", s.ns, key)
  162. }
  163. func (s *redisStore) checkError(result redis.Cmder) (bool, error) {
  164. if err := result.Err(); err != nil {
  165. if err == redis.Nil {
  166. return true, nil
  167. }
  168. return false, err
  169. }
  170. return false, nil
  171. }
  172. func (s *redisStore) remove(key string) error {
  173. result := s.cli.Del(s.ctx, s.wrapperKey(key))
  174. _, err := s.checkError(result)
  175. return err
  176. }
  177. func (s *redisStore) removeToken(tokenString string, isRefresh bool) error {
  178. basicID, err := s.getBasicID(tokenString)
  179. if err != nil {
  180. return err
  181. } else if basicID == "" {
  182. return nil
  183. }
  184. err = s.remove(tokenString)
  185. if err != nil {
  186. return err
  187. }
  188. token, err := s.getToken(basicID)
  189. if err != nil {
  190. return err
  191. } else if token == nil {
  192. return nil
  193. }
  194. checkToken := token.GetRefresh()
  195. if isRefresh {
  196. checkToken = token.GetAccess()
  197. }
  198. result := s.cli.Exists(s.ctx, s.wrapperKey(checkToken))
  199. if err = result.Err(); err != nil && err != redis.Nil {
  200. return err
  201. } else if result.Val() == 0 {
  202. return s.remove(basicID)
  203. }
  204. return nil
  205. }
  206. func (s *redisStore) parseToken(result *redis.StringCmd) (TokenInfo, error) {
  207. if ok, err := s.checkError(result); err != nil {
  208. return nil, err
  209. } else if ok {
  210. return nil, nil
  211. }
  212. buf, err := result.Bytes()
  213. if err != nil {
  214. if err == redis.Nil {
  215. return nil, nil
  216. }
  217. return nil, err
  218. }
  219. var token Token
  220. if err = jsonUnmarshal(buf, &token); err != nil {
  221. return nil, err
  222. }
  223. return &token, nil
  224. }
  225. func (s *redisStore) getToken(key string) (TokenInfo, error) {
  226. result := s.cli.Get(s.ctx, s.wrapperKey(key))
  227. return s.parseToken(result)
  228. }
  229. func (s *redisStore) parseBasicID(result *redis.StringCmd) (string, error) {
  230. if ok, err := s.checkError(result); err != nil {
  231. return "", err
  232. } else if ok {
  233. return "", nil
  234. }
  235. return result.Val(), nil
  236. }
  237. func (s *redisStore) getBasicID(token string) (string, error) {
  238. result := s.cli.Get(s.ctx, s.wrapperKey(token))
  239. return s.parseBasicID(result)
  240. }
  241. // Create and store the new token information
  242. func (s *redisStore) Create(info TokenInfo) error {
  243. ct := time.Now()
  244. jv, err := jsonMarshal(info)
  245. if err != nil {
  246. return err
  247. }
  248. pipe := s.cli.TxPipeline()
  249. basicID := uuid.Must(uuid.NewRandom()).String()
  250. aexp := info.GetAccessExpiresIn()
  251. rexp := aexp
  252. if refresh := info.GetRefresh(); refresh != "" {
  253. rexp = info.GetRefreshCreateAt().Add(info.GetRefreshExpiresIn()).Sub(ct)
  254. if aexp.Seconds() > rexp.Seconds() {
  255. aexp = rexp
  256. }
  257. pipe.Set(s.ctx, s.wrapperKey(refresh), basicID, rexp)
  258. }
  259. pipe.Set(s.ctx, s.wrapperKey(info.GetAccess()), basicID, aexp)
  260. pipe.Set(s.ctx, s.wrapperKey(basicID), jv, rexp)
  261. if _, err = pipe.Exec(s.ctx); err != nil {
  262. return err
  263. }
  264. return nil
  265. }
  266. // RemoveByAccess Use the access token to delete the token information
  267. func (s *redisStore) RemoveByAccess(access string) error {
  268. return s.removeToken(access, false)
  269. }
  270. // RemoveByRefresh Use the refresh token to delete the token information
  271. func (s *redisStore) RemoveByRefresh(refresh string) error {
  272. return s.removeToken(refresh, true)
  273. }
  274. // GetByAccess Use the access token for token information data
  275. func (s *redisStore) GetByAccess(access string) (TokenInfo, error) {
  276. basicID, err := s.getBasicID(access)
  277. if err != nil || basicID == "" {
  278. return nil, err
  279. }
  280. return s.getToken(basicID)
  281. }
  282. // GetByRefresh Use the refresh token for token information data
  283. func (s *redisStore) GetByRefresh(refresh string) (TokenInfo, error) {
  284. basicID, err := s.getBasicID(refresh)
  285. if err != nil || basicID == "" {
  286. return nil, err
  287. }
  288. return s.getToken(basicID)
  289. }