jwt_access.go 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. package auth
  2. import (
  3. "encoding/base64"
  4. "errors"
  5. "strings"
  6. "time"
  7. "github.com/golang-jwt/jwt/v4"
  8. "github.com/google/uuid"
  9. )
  10. // JWTAccessClaims jwt claims
  11. type JWTAccessClaims struct {
  12. jwt.RegisteredClaims
  13. }
  14. // Valid claims verification
  15. func (a *JWTAccessClaims) Valid() error {
  16. if a.ExpiresAt.Before(time.Now()) {
  17. return ErrInvalidAccessToken
  18. }
  19. return nil
  20. }
  21. // NewJWTAccessGenerate create to generate the jwt access token instance
  22. func NewJWTAccessGenerate(key []byte, method jwt.SigningMethod) *JWTAccessGenerate {
  23. return &JWTAccessGenerate{
  24. SignedKey: key,
  25. SignedMethod: method,
  26. }
  27. }
  28. // JWTAccessGenerate generate the jwt access token
  29. type JWTAccessGenerate struct {
  30. SignedKey []byte
  31. SignedMethod jwt.SigningMethod
  32. }
  33. // Token based on the UUID generated token
  34. func (a *JWTAccessGenerate) Token(data *GenerateBasic, isGenRefresh bool) (string, string, error) {
  35. claims := &JWTAccessClaims{
  36. RegisteredClaims: jwt.RegisteredClaims{
  37. Issuer: "BvBeJ",
  38. Subject: data.UserID,
  39. ExpiresAt: jwt.NewNumericDate(data.TokenInfo.GetAccessCreateAt().Add(data.TokenInfo.GetAccessExpiresIn())),
  40. },
  41. }
  42. token := jwt.NewWithClaims(a.SignedMethod, claims)
  43. var key any
  44. if a.isEs() {
  45. v, err := jwt.ParseECPrivateKeyFromPEM(a.SignedKey)
  46. if err != nil {
  47. return "", "", err
  48. }
  49. key = v
  50. } else if a.isRsOrPS() {
  51. v, err := jwt.ParseRSAPrivateKeyFromPEM(a.SignedKey)
  52. if err != nil {
  53. return "", "", err
  54. }
  55. key = v
  56. } else if a.isHs() {
  57. key = a.SignedKey
  58. } else {
  59. return "", "", errors.New("unsupported sign method")
  60. }
  61. access, err := token.SignedString(key)
  62. if err != nil {
  63. return "", "", err
  64. }
  65. refresh := ""
  66. if isGenRefresh {
  67. t := uuid.NewSHA1(uuid.Must(uuid.NewRandom()), []byte(access)).String()
  68. refresh = base64.URLEncoding.EncodeToString([]byte(t))
  69. refresh = strings.ToUpper(strings.TrimRight(refresh, "="))
  70. }
  71. return access, refresh, nil
  72. }
  73. func (a *JWTAccessGenerate) isEs() bool {
  74. return strings.HasPrefix(a.SignedMethod.Alg(), "ES")
  75. }
  76. func (a *JWTAccessGenerate) isRsOrPS() bool {
  77. isRs := strings.HasPrefix(a.SignedMethod.Alg(), "RS")
  78. isPs := strings.HasPrefix(a.SignedMethod.Alg(), "PS")
  79. return isRs || isPs
  80. }
  81. func (a *JWTAccessGenerate) isHs() bool {
  82. return strings.HasPrefix(a.SignedMethod.Alg(), "HS")
  83. }