package mux import ( "bytes" stdCtx "context" "io" "net/http" "net/url" "strings" "sync" "git.bvbej.com/bvbej/base-golang/pkg/errno" "git.bvbej.com/bvbej/base-golang/pkg/trace" "github.com/gin-gonic/gin" "github.com/gin-gonic/gin/binding" "go.uber.org/zap" ) type HandlerFunc func(c Context) type Trace = trace.T const ( _Alias = "_alias_" _TraceName = "_trace_" _LoggerName = "_logger_" _BodyName = "_body_" _PayloadName = "_payload_" _GraphPayloadName = "_graph_payload_" _AbortErrorName = "_abort_error_" _UserID = "_user_id_" ) var contextPool = &sync.Pool{ New: func() any { return new(context) }, } func newContext(ctx *gin.Context) Context { getContext := contextPool.Get().(*context) getContext.ctx = ctx return getContext } func releaseContext(ctx Context) { c := ctx.(*context) c.ctx = nil contextPool.Put(c) } var _ Context = (*context)(nil) type Context interface { init() Context() *gin.Context // ShouldBindQuery 反序列化 query // tag: `form:"xxx"` (注:不要写成 query) ShouldBindQuery(obj any) error // ShouldBindPostForm 反序列化 x-www-from-urlencoded // tag: `form:"xxx"` ShouldBindPostForm(obj any) error // ShouldBindForm 同时反序列化 form-data; // tag: `form:"xxx"` ShouldBindForm(obj any) error // ShouldBindJSON 反序列化 post-json // tag: `json:"xxx"` ShouldBindJSON(obj any) error // ShouldBindURI 反序列化 path 参数(如路由路径为 /user/:name) // tag: `uri:"xxx"` ShouldBindURI(obj any) error // Redirect 重定向 Redirect(code int, location string) // Trace 获取 Trace 对象 Trace() Trace setTrace(trace Trace) disableTrace() // Logger 获取 Logger 对象 Logger() *zap.Logger setLogger(logger *zap.Logger) // Payload 正确返回 Payload(payload any) getPayload() any // GraphPayload GraphQL返回值 与 api 返回结构不同 GraphPayload(payload any) getGraphPayload() any // HTML 返回界面 HTML(name string, obj any) // AbortWithError 错误返回 AbortWithError(err errno.Error) abortError() errno.Error // Header 获取 Header 对象 Header() http.Header // GetHeader 获取 Header GetHeader(key string) string // SetHeader 设置 Header SetHeader(key, value string) // UserID 获取 UserID UserID() uint64 SetUserID(userID uint64) // Authorization 获取请求认证信息 Authorization() string // Alias 设置路由别名 for metrics uri Alias() string setAlias(path string) // RequestInputParams 获取所有参数 RequestInputParams() url.Values // RequestQueryParams 获取 Query 参数 RequestQueryParams() url.Values // RequestPostFormParams 获取 PostForm 参数 RequestPostFormParams() url.Values // Request 获取 Request 对象 Request() *http.Request // RawData 获取 Request.Body RawData() []byte // Method 获取 Request.Method Method() string // Host 获取 Request.Host Host() string // Path 获取 请求的路径 Request.URL.Path (不附带 querystring) Path() string // URI 获取 unescape 后的 Request.URL.RequestURI() URI() string // RequestContext 获取请求的 context (当 client 关闭后,会自动 canceled) RequestContext() StdContext // ResponseWriter 获取 ResponseWriter 对象 ResponseWriter() gin.ResponseWriter } type context struct { ctx *gin.Context } type StdContext struct { stdCtx.Context Trace *zap.Logger } func (c *context) init() { body, err := c.ctx.GetRawData() if err != nil { panic(err) } c.ctx.Set(_BodyName, body) // cache body是为了trace使用 c.ctx.Request.Body = io.NopCloser(bytes.NewBuffer(body)) // re-construct req body } func (c *context) Context() *gin.Context { return c.ctx } // ShouldBindQuery 反序列化querystring // tag: `form:"xxx"` (注:不要写成query) func (c *context) ShouldBindQuery(obj any) error { return c.ctx.ShouldBindWith(obj, binding.Query) } // ShouldBindPostForm 反序列化 postform (querystring 会被忽略) // tag: `form:"xxx"` func (c *context) ShouldBindPostForm(obj any) error { return c.ctx.ShouldBindWith(obj, binding.FormPost) } // ShouldBindForm 同时反序列化querystring和postform; // 当querystring和postform存在相同字段时,postform优先使用。 // tag: `form:"xxx"` func (c *context) ShouldBindForm(obj any) error { return c.ctx.ShouldBindWith(obj, binding.Form) } // ShouldBindJSON 反序列化postjson // tag: `json:"xxx"` func (c *context) ShouldBindJSON(obj any) error { return c.ctx.ShouldBindWith(obj, binding.JSON) } // ShouldBindURI 反序列化path参数(如路由路径为 /user/:name) // tag: `uri:"xxx"` func (c *context) ShouldBindURI(obj any) error { return c.ctx.ShouldBindUri(obj) } // Redirect 重定向 func (c *context) Redirect(code int, location string) { c.ctx.Redirect(code, location) } func (c *context) Trace() Trace { t, ok := c.ctx.Get(_TraceName) if !ok || t == nil { return nil } return t.(Trace) } func (c *context) setTrace(trace Trace) { c.ctx.Set(_TraceName, trace) } func (c *context) disableTrace() { c.setTrace(nil) } func (c *context) Logger() *zap.Logger { logger, ok := c.ctx.Get(_LoggerName) if !ok { return nil } return logger.(*zap.Logger) } func (c *context) setLogger(logger *zap.Logger) { c.ctx.Set(_LoggerName, logger) } func (c *context) getPayload() any { if payload, ok := c.ctx.Get(_PayloadName); ok != false { return payload } return nil } func (c *context) Payload(payload any) { c.ctx.Set(_PayloadName, payload) } func (c *context) getGraphPayload() any { if payload, ok := c.ctx.Get(_GraphPayloadName); ok != false { return payload } return nil } func (c *context) GraphPayload(payload any) { c.ctx.Set(_GraphPayloadName, payload) } func (c *context) HTML(name string, obj any) { c.ctx.HTML(200, name+".html", obj) } func (c *context) Header() http.Header { header := c.ctx.Request.Header clone := make(http.Header, len(header)) for k, v := range header { value := make([]string, len(v)) copy(value, v) clone[k] = value } return clone } func (c *context) GetHeader(key string) string { return c.ctx.GetHeader(key) } func (c *context) SetHeader(key, value string) { c.ctx.Header(key, value) } func (c *context) UserID() uint64 { val, ok := c.ctx.Get(_UserID) if !ok { return 0 } return val.(uint64) } func (c *context) SetUserID(userID uint64) { c.ctx.Set(_UserID, userID) } func (c *context) Authorization() string { return c.ctx.GetHeader("Authorization") } func (c *context) AbortWithError(err errno.Error) { if err != nil { httpCode := err.GetHttpCode() if httpCode == 0 { httpCode = http.StatusInternalServerError } c.ctx.AbortWithStatus(httpCode) c.ctx.Set(_AbortErrorName, err) } } func (c *context) abortError() errno.Error { err, _ := c.ctx.Get(_AbortErrorName) return err.(errno.Error) } func (c *context) Alias() string { path, ok := c.ctx.Get(_Alias) if !ok { return "" } return path.(string) } func (c *context) setAlias(path string) { if path = strings.TrimSpace(path); path != "" { c.ctx.Set(_Alias, path) } } // RequestInputParams 获取所有参数 func (c *context) RequestInputParams() url.Values { _ = c.ctx.Request.ParseForm() return c.ctx.Request.Form } // RequestQueryParams 获取Query参数 func (c *context) RequestQueryParams() url.Values { query, _ := url.ParseQuery(c.ctx.Request.URL.RawQuery) return query } // RequestPostFormParams 获取 PostForm 参数 func (c *context) RequestPostFormParams() url.Values { _ = c.ctx.Request.ParseForm() return c.ctx.Request.PostForm } // Request 获取 Request func (c *context) Request() *http.Request { return c.ctx.Request } func (c *context) RawData() []byte { body, ok := c.ctx.Get(_BodyName) if !ok { return nil } return body.([]byte) } // Method 请求的method func (c *context) Method() string { return c.ctx.Request.Method } // Host 请求的host func (c *context) Host() string { return c.ctx.Request.Host } // Path 请求的路径(不附带querystring) func (c *context) Path() string { return c.ctx.Request.URL.Path } // URI unescape后的uri func (c *context) URI() string { uri, _ := url.QueryUnescape(c.ctx.Request.URL.RequestURI()) return uri } // RequestContext (包装 Trace + Logger) 获取请求的 context (当client关闭后,会自动canceled) func (c *context) RequestContext() StdContext { return StdContext{ //c.ctx.Request.Context(), stdCtx.Background(), c.Trace(), c.Logger(), } } // ResponseWriter 获取 ResponseWriter func (c *context) ResponseWriter() gin.ResponseWriter { return c.ctx.Writer }