package mux import ( "errors" "fmt" "net/http" "net/url" "runtime/debug" "time" "git.bvbej.com/bvbej/base-golang/pkg/color" "git.bvbej.com/bvbej/base-golang/pkg/env" "git.bvbej.com/bvbej/base-golang/pkg/errno" "git.bvbej.com/bvbej/base-golang/pkg/limiter" "git.bvbej.com/bvbej/base-golang/pkg/trace" "git.bvbej.com/bvbej/base-golang/pkg/validator" "github.com/gin-contrib/pprof" "github.com/gin-gonic/gin" "github.com/gin-gonic/gin/binding" "github.com/prometheus/client_golang/prometheus/promhttp" cors "github.com/rs/cors/wrapper/gin" "go.uber.org/multierr" "go.uber.org/zap" "golang.org/x/time/rate" ) type Option func(*option) type option struct { enableCors bool enablePProf bool enablePrometheus bool enableOpenBrowser string staticDirs []string panicNotify OnPanicNotify recordMetrics RecordMetrics rateLimiter limiter.RateLimiter } const SuccessCode = 0 type Failure struct { ResultCode int `json:"result_code"` // 业务码 ResultInfo string `json:"result_info"` // 描述信息 } type Success struct { ResultCode int `json:"result_code"` // 业务码 ResultData any `json:"result_data"` //返回数据 } /******************************************************************************/ // OnPanicNotify 发生panic时通知用 type OnPanicNotify func(ctx Context, err any, stackInfo string) // RecordMetrics 记录prometheus指标用 // 如果使用AliasForRecordMetrics配置了别名,uri将被替换为别名。 type RecordMetrics func(method, uri string, success bool, costSeconds float64) // DisableTrace 禁用追踪链 func DisableTrace(ctx Context) { ctx.disableTrace() } // WithPanicNotify 设置panic时的通知回调 func WithPanicNotify(notify OnPanicNotify) Option { return func(opt *option) { opt.panicNotify = notify fmt.Println(color.Green("* [register panic notify]")) } } // WithRecordMetrics 设置记录prometheus记录指标回调 func WithRecordMetrics(record RecordMetrics) Option { return func(opt *option) { opt.recordMetrics = record } } // WithEnableCors 开启CORS func WithEnableCors() Option { return func(opt *option) { opt.enableCors = true fmt.Println(color.Green("* [register cors]")) } } // WithEnableRate 开启限流 func WithEnableRate(limit rate.Limit, burst int) Option { return func(opt *option) { opt.rateLimiter = limiter.NewRateLimiter(limit, burst) fmt.Println(color.Green("* [register rate]")) } } // WithStaticDir 设置静态文件目录 func WithStaticDir(dirs []string) Option { return func(opt *option) { opt.staticDirs = dirs fmt.Println(color.Green("* [register rate]")) } } // AliasForRecordMetrics 对请求uri起个别名,用于prometheus记录指标。 // 如:Get /user/:username 这样的uri,因为username会有非常多的情况,这样记录prometheus指标会非常的不有好。 func AliasForRecordMetrics(path string) HandlerFunc { return func(ctx Context) { ctx.setAlias(path) } } /******************************************************************************/ // RouterGroup 包装gin的RouterGroup type RouterGroup interface { Group(string, ...HandlerFunc) RouterGroup IRoutes } var _ IRoutes = (*router)(nil) // IRoutes 包装gin的IRoutes type IRoutes interface { Any(string, ...HandlerFunc) GET(string, ...HandlerFunc) POST(string, ...HandlerFunc) DELETE(string, ...HandlerFunc) PATCH(string, ...HandlerFunc) PUT(string, ...HandlerFunc) OPTIONS(string, ...HandlerFunc) HEAD(string, ...HandlerFunc) } type router struct { group *gin.RouterGroup } func (r *router) Group(relativePath string, handlers ...HandlerFunc) RouterGroup { group := r.group.Group(relativePath, wrapHandlers(handlers...)...) return &router{group: group} } func (r *router) Any(relativePath string, handlers ...HandlerFunc) { r.group.Any(relativePath, wrapHandlers(handlers...)...) } func (r *router) GET(relativePath string, handlers ...HandlerFunc) { r.group.GET(relativePath, wrapHandlers(handlers...)...) } func (r *router) POST(relativePath string, handlers ...HandlerFunc) { r.group.POST(relativePath, wrapHandlers(handlers...)...) } func (r *router) DELETE(relativePath string, handlers ...HandlerFunc) { r.group.DELETE(relativePath, wrapHandlers(handlers...)...) } func (r *router) PATCH(relativePath string, handlers ...HandlerFunc) { r.group.PATCH(relativePath, wrapHandlers(handlers...)...) } func (r *router) PUT(relativePath string, handlers ...HandlerFunc) { r.group.PUT(relativePath, wrapHandlers(handlers...)...) } func (r *router) OPTIONS(relativePath string, handlers ...HandlerFunc) { r.group.OPTIONS(relativePath, wrapHandlers(handlers...)...) } func (r *router) HEAD(relativePath string, handlers ...HandlerFunc) { r.group.HEAD(relativePath, wrapHandlers(handlers...)...) } func wrapHandlers(handlers ...HandlerFunc) []gin.HandlerFunc { list := make([]gin.HandlerFunc, len(handlers)) for i, handler := range handlers { fn := handler list[i] = func(c *gin.Context) { ctx := newContext(c) defer releaseContext(ctx) fn(ctx) } } return list } /******************************************************************************/ var _ Mux = (*mux)(nil) type Mux interface { http.Handler Group(relativePath string, handlers ...HandlerFunc) RouterGroup Routes() gin.RoutesInfo HandlerFunc(relativePath string, handlerFunc gin.HandlerFunc) } type mux struct { engine *gin.Engine } func (m *mux) ServeHTTP(w http.ResponseWriter, req *http.Request) { m.engine.ServeHTTP(w, req) } func (m *mux) Group(relativePath string, handlers ...HandlerFunc) RouterGroup { return &router{ group: m.engine.Group(relativePath, wrapHandlers(handlers...)...), } } func (m *mux) Routes() gin.RoutesInfo { return m.engine.Routes() } func (m *mux) HandlerFunc(relativePath string, handlerFunc gin.HandlerFunc) { m.engine.GET(relativePath, handlerFunc) } func New(logger *zap.Logger, options ...Option) (Mux, error) { if logger == nil { return nil, errors.New("logger required") } gin.SetMode(gin.ReleaseMode) binding.Validator = validator.Validator newMux := &mux{ engine: gin.New(), } fmt.Println(color.Green(fmt.Sprintf("* [register env %s]", env.Active().Value()))) // withoutLogPaths 这些请求,默认不记录日志 withoutTracePaths := map[string]bool{ "/metrics": true, "/favicon.ico": true, "/system/health": true, } opt := new(option) for _, f := range options { f(opt) } if opt.enablePProf { pprof.Register(newMux.engine) fmt.Println(color.Green("* [register pprof]")) } if opt.enablePrometheus { newMux.engine.GET("/metrics", gin.WrapH(promhttp.Handler())) fmt.Println(color.Green("* [register prometheus]")) } if opt.enableCors { newMux.engine.Use(cors.AllowAll()) } if opt.staticDirs != nil { for _, dir := range opt.staticDirs { newMux.engine.StaticFS(dir, gin.Dir(dir, false)) } } // recover两次,防止处理时发生panic,尤其是在OnPanicNotify中。 newMux.engine.Use(func(ctx *gin.Context) { defer func() { if err := recover(); err != nil { logger.Error("got panic", zap.String("panic", fmt.Sprintf("%+v", err)), zap.String("stack", string(debug.Stack()))) } }() ctx.Next() }) newMux.engine.Use(func(ctx *gin.Context) { ts := time.Now() newCtx := newContext(ctx) defer releaseContext(newCtx) newCtx.init() newCtx.setLogger(logger) if !withoutTracePaths[ctx.Request.URL.Path] { if traceId := newCtx.GetHeader(trace.Header); traceId != "" { newCtx.setTrace(trace.New(traceId)) } else { newCtx.setTrace(trace.New("")) } } defer func() { if err := recover(); err != nil { stackInfo := string(debug.Stack()) logger.Error("got panic", zap.String("panic", fmt.Sprintf("%+v", err)), zap.String("stack", stackInfo)) newCtx.AbortWithError(errno.NewError( http.StatusInternalServerError, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)), ) if notify := opt.panicNotify; notify != nil { notify(newCtx, err, stackInfo) } } if ctx.Writer.Status() == http.StatusNotFound { return } var ( response any businessCode int businessCodeMsg string abortErr error graphResponse any ) if ctx.IsAborted() { for i := range ctx.Errors { // gin error multierr.AppendInto(&abortErr, ctx.Errors[i]) } if err := newCtx.abortError(); err != nil { // customer err multierr.AppendInto(&abortErr, err.GetErr()) response = err businessCode = err.GetBusinessCode() businessCodeMsg = err.GetMsg() if x := newCtx.Trace(); x != nil { newCtx.SetHeader(trace.Header, x.ID()) } ctx.JSON(err.GetHttpCode(), &Failure{ ResultCode: businessCode, ResultInfo: businessCodeMsg, }) } } else { response = newCtx.getPayload() if response != nil { if x := newCtx.Trace(); x != nil { newCtx.SetHeader(trace.Header, x.ID()) } ctx.JSON(http.StatusOK, response) } } graphResponse = newCtx.getGraphPayload() if opt.recordMetrics != nil { uri := newCtx.Path() if alias := newCtx.Alias(); alias != "" { uri = alias } opt.recordMetrics( newCtx.Method(), uri, !ctx.IsAborted() && ctx.Writer.Status() == http.StatusOK, time.Since(ts).Seconds(), ) } var t *trace.Trace if x := newCtx.Trace(); x != nil { t = x.(*trace.Trace) } else { return } decodedURL, _ := url.QueryUnescape(ctx.Request.URL.RequestURI()) t.WithRequest(&trace.Request{ TTL: "un-limit", Method: ctx.Request.Method, DecodedURL: decodedURL, Header: ctx.Request.Header, Body: string(newCtx.RawData()), }) var responseBody any if response != nil { responseBody = response } if graphResponse != nil { responseBody = graphResponse } t.WithResponse(&trace.Response{ Header: ctx.Writer.Header(), HttpCode: ctx.Writer.Status(), HttpCodeMsg: http.StatusText(ctx.Writer.Status()), BusinessCode: businessCode, BusinessCodeMsg: businessCodeMsg, Body: responseBody, CostSeconds: time.Since(ts).Seconds(), }) t.Success = !ctx.IsAborted() && ctx.Writer.Status() == http.StatusOK t.CostSeconds = time.Since(ts).Seconds() logger.Info("core-interceptor", zap.Any("method", ctx.Request.Method), zap.Any("path", decodedURL), zap.Any("http_code", ctx.Writer.Status()), zap.Any("business_code", businessCode), zap.Any("success", t.Success), zap.Any("cost_seconds", t.CostSeconds), zap.Any("trace_id", t.Identifier), zap.Any("trace_info", t), zap.Error(abortErr), ) }() ctx.Next() }) if opt.rateLimiter != nil { newMux.engine.Use(func(ctx *gin.Context) { newCtx := newContext(ctx) defer releaseContext(newCtx) if !opt.rateLimiter.Allow(ctx.ClientIP()) { newCtx.AbortWithError(errno.NewError( http.StatusTooManyRequests, http.StatusTooManyRequests, http.StatusText(http.StatusTooManyRequests)), ) return } ctx.Next() }) } newMux.engine.NoMethod(wrapHandlers(DisableTrace)...) newMux.engine.NoRoute(wrapHandlers(DisableTrace)...) system := newMux.Group("/system") { // 健康检查 system.GET("/health", func(ctx Context) { resp := &struct { Timestamp time.Time `json:"timestamp"` Environment string `json:"environment"` Host string `json:"host"` Status string `json:"status"` }{ Timestamp: time.Now(), Environment: env.Active().Value(), Host: ctx.Host(), Status: "ok", } ctx.Payload(resp) }) } return newMux, nil }