package upload import ( "context" "crypto/sha256" "errors" "fmt" "git.bvbej.com/bvbej/base-golang/pkg/color" "git.bvbej.com/bvbej/base-golang/pkg/ticker" "git.bvbej.com/bvbej/base-golang/pkg/token" "github.com/rs/cors" "github.com/tus/tusd/pkg/filestore" tus "github.com/tus/tusd/pkg/handler" "go.uber.org/zap" "net/http" "os" "strings" "sync" "time" ) var _ Server = (*server)(nil) type Server interface { GetUploadToken(string, string, time.Duration) string GetFileInfo(string) (*tus.FileInfo, error) Start(func(string, string, tus.FileInfo)) error Stop() error } type server struct { headerTokenKey string uploading sync.Map config Config token token.Token store filestore.FileStore logger *zap.Logger httpServer *http.Server ctx context.Context done context.CancelFunc checker ticker.Ticker completedEvent func(sha256, param string, info tus.FileInfo) } type Config struct { ListenAddr string Path string Dir string Secret string DisableDownload bool Debug bool } func New(conf Config, logger *zap.Logger) Server { ctx, cancelFunc := context.WithCancel(context.Background()) return &server{ config: conf, uploading: sync.Map{}, headerTokenKey: "Authorization", logger: logger, token: token.New(conf.Secret), ctx: ctx, done: cancelFunc, checker: ticker.New(time.Minute), } } func (s *server) GetUploadToken(sha256, param string, ttl time.Duration) string { sign, _ := s.token.JwtSign(sha256, param, ttl) return sign } func (s *server) GetFileInfo(id string) (*tus.FileInfo, error) { upload, err := s.store.GetUpload(context.Background(), id) if err != nil { return nil, err } info, err := upload.GetInfo(context.Background()) if err != nil { return nil, err } return &info, nil } func (s *server) Start(completedEvent func(sha256, param string, info tus.FileInfo)) error { s.completedEvent = completedEvent composer := tus.NewStoreComposer() if err := os.MkdirAll(s.config.Dir, os.ModePerm); err != nil { return err } s.store = filestore.New(s.config.Dir) s.store.UseIn(composer) handler, err := tus.NewHandler(tus.Config{ StoreComposer: composer, BasePath: s.config.Path, Logger: zap.NewStdLog(s.logger), NotifyCompleteUploads: true, NotifyTerminatedUploads: true, DisableTermination: true, DisableDownload: s.config.DisableDownload, RespectForwardedHeaders: strings.Contains(s.config.ListenAddr, "127.0.0.1"), PreUploadCreateCallback: func(hook tus.HookEvent) error { authStr := hook.HTTPRequest.Header.Get(s.headerTokenKey) jwtClaims, err := s.token.JwtParse(authStr) if err == nil { _, ok := s.uploading.Load(authStr) if !ok { s.uploading.Store(authStr, jwtClaims.ExpiresAt.Time) return nil } return errors.New("repeated") } return errors.New("unauthorized") }, PreFinishResponseCallback: func(hook tus.HookEvent) error { authStr := hook.HTTPRequest.Header.Get(s.headerTokenKey) jwtParse, err := s.token.JwtParse(authStr) if err != nil { return errors.New("token expired") } _, ok := s.uploading.Load(authStr) if ok { s.uploading.Delete(authStr) } upload, err := s.store.GetUpload(context.Background(), hook.Upload.ID) if err != nil { return err } info, err := upload.GetInfo(context.Background()) path, exist := info.Storage["Path"] if err != nil || !exist { return errors.New("file not found") } content, err := os.ReadFile(path) if err != nil { return err } hash := sha256.New() hash.Write(content) sha256Byte := hash.Sum(nil) sha256String := fmt.Sprintf("%x", sha256Byte) if !s.config.Debug && sha256String != strings.ToLower(jwtParse.ID) { _ = os.Remove(path) _ = os.Remove(path + ".info") return errors.New("file check error") } return nil }, }) if err != nil { return err } go func() { for { select { case event := <-handler.CompleteUploads: authStr := event.HTTPRequest.Header.Get(s.headerTokenKey) jwtParse, _ := s.token.JwtParse(authStr) if s.completedEvent != nil { go func() { s.completedEvent(jwtParse.ID, jwtParse.Subject, event.Upload) }() } case <-s.ctx.Done(): return } } }() go func() { for { select { case event := <-handler.TerminatedUploads: upload, _ := s.store.GetUpload(context.Background(), event.Upload.ID) if upload != nil { info, _ := upload.GetInfo(context.Background()) path, exist := info.Storage["Path"] if exist { _ = os.Remove(path) _ = os.Remove(path + ".info") } } case <-s.ctx.Done(): return } } }() s.checker.Process(func() { s.uploading.Range(func(key, value any) bool { t := value.(time.Time) if t.Before(time.Now()) { s.uploading.Delete(key) } return true }) }) //监听服务 addr := s.config.ListenAddr mux := http.NewServeMux() mux.Handle(s.config.Path, http.StripPrefix(s.config.Path, handler)) s.httpServer = &http.Server{ Addr: addr, Handler: cors.AllowAll().Handler(mux), } go func() { if err = s.httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { s.logger.Sugar().Fatal("upload server startup err", zap.Error(err)) } }() fmt.Println(color.Green(fmt.Sprintf("* [register tusd listen %s]", addr))) return nil } func (s *server) Stop() error { s.done() return s.httpServer.Close() }