初始提交: Gitea 项目代码
This commit is contained in:
@@ -0,0 +1,196 @@
|
||||
// Copyright 2023 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package web
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"slices"
|
||||
|
||||
"gitea.dev/modules/log"
|
||||
"gitea.dev/modules/web/routing"
|
||||
"gitea.dev/modules/web/types"
|
||||
)
|
||||
|
||||
var responseStatusProviders = map[reflect.Type]func(req *http.Request) types.ResponseStatusProvider{}
|
||||
|
||||
func RegisterResponseStatusProvider[T any](fn func(req *http.Request) types.ResponseStatusProvider) {
|
||||
responseStatusProviders[reflect.TypeFor[T]()] = fn
|
||||
}
|
||||
|
||||
// responseWriter is a wrapper of http.ResponseWriter, to check whether the response has been written
|
||||
type responseWriter struct {
|
||||
respWriter http.ResponseWriter
|
||||
status int
|
||||
}
|
||||
|
||||
var _ types.ResponseStatusProvider = (*responseWriter)(nil)
|
||||
|
||||
func (r *responseWriter) WrittenStatus() int {
|
||||
return r.status
|
||||
}
|
||||
|
||||
func (r *responseWriter) Header() http.Header {
|
||||
return r.respWriter.Header()
|
||||
}
|
||||
|
||||
func (r *responseWriter) Write(bytes []byte) (int, error) {
|
||||
if r.status == 0 {
|
||||
r.status = http.StatusOK
|
||||
}
|
||||
return r.respWriter.Write(bytes)
|
||||
}
|
||||
|
||||
func (r *responseWriter) WriteHeader(statusCode int) {
|
||||
r.status = statusCode
|
||||
r.respWriter.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
func (r *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if hj, ok := r.respWriter.(http.Hijacker); ok {
|
||||
return hj.Hijack()
|
||||
}
|
||||
return nil, nil, http.ErrNotSupported
|
||||
}
|
||||
|
||||
var (
|
||||
httpReqType = reflect.TypeFor[*http.Request]()
|
||||
respWriterType = reflect.TypeFor[http.ResponseWriter]()
|
||||
)
|
||||
|
||||
// preCheckHandler checks whether the handler is valid, developers could get first-time feedback, all mistakes could be found at startup
|
||||
func preCheckHandler(fn reflect.Value, argsIn []reflect.Value) {
|
||||
hasStatusProvider := false
|
||||
for _, argIn := range argsIn {
|
||||
if _, hasStatusProvider = argIn.Interface().(types.ResponseStatusProvider); hasStatusProvider {
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasStatusProvider {
|
||||
panic(fmt.Sprintf("handler should have at least one ResponseStatusProvider argument, but got %s", fn.Type()))
|
||||
}
|
||||
if fn.Type().NumOut() != 0 {
|
||||
panic(fmt.Sprintf("handler should have no return value other than registered ones, but got %s", fn.Type()))
|
||||
}
|
||||
}
|
||||
|
||||
func prepareHandleArgsIn(resp http.ResponseWriter, req *http.Request, fn reflect.Value, fnInfo *routing.FuncInfo) []reflect.Value {
|
||||
defer func() {
|
||||
if recovered := recover(); recovered != nil {
|
||||
err := fmt.Errorf("%v\n%s", recovered, log.Stack(2))
|
||||
log.Error("unable to prepare handler arguments for %s: %v", fnInfo.String(), err)
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
isPreCheck := req == nil
|
||||
|
||||
argsIn := make([]reflect.Value, fn.Type().NumIn())
|
||||
for i := 0; i < fn.Type().NumIn(); i++ {
|
||||
argTyp := fn.Type().In(i)
|
||||
switch argTyp {
|
||||
case respWriterType:
|
||||
argsIn[i] = reflect.ValueOf(resp)
|
||||
case httpReqType:
|
||||
argsIn[i] = reflect.ValueOf(req)
|
||||
default:
|
||||
if argFn, ok := responseStatusProviders[argTyp]; ok {
|
||||
if isPreCheck {
|
||||
argsIn[i] = reflect.ValueOf(&responseWriter{})
|
||||
} else {
|
||||
argsIn[i] = reflect.ValueOf(argFn(req))
|
||||
}
|
||||
} else {
|
||||
panic(fmt.Sprintf("unsupported argument type: %s", argTyp))
|
||||
}
|
||||
}
|
||||
}
|
||||
return argsIn
|
||||
}
|
||||
|
||||
func handleResponse(fn reflect.Value, ret []reflect.Value) {
|
||||
if len(ret) != 0 {
|
||||
panic(fmt.Sprintf("unsupported return values: %s", fn.Type()))
|
||||
}
|
||||
}
|
||||
|
||||
func hasResponseBeenWritten(argsIn []reflect.Value) bool {
|
||||
for _, argIn := range argsIn {
|
||||
if statusProvider, ok := argIn.Interface().(types.ResponseStatusProvider); ok {
|
||||
if statusProvider.WrittenStatus() != 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type middlewareProvider = func(next http.Handler) http.Handler
|
||||
|
||||
func executeMiddlewaresHandler(w http.ResponseWriter, r *http.Request, middlewares []middlewareProvider, endpoint http.HandlerFunc) {
|
||||
handler := endpoint
|
||||
for _, middleware := range slices.Backward(middlewares) {
|
||||
handler = middleware(handler).ServeHTTP
|
||||
}
|
||||
handler(w, r)
|
||||
}
|
||||
|
||||
func wrapHandlerProvider[T http.Handler](hp func(next http.Handler) T, funcInfo *routing.FuncInfo) middlewareProvider {
|
||||
return func(next http.Handler) http.Handler {
|
||||
h := hp(next) // this handle could be dynamically generated, so we can't use it for debug info
|
||||
return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
|
||||
defer routing.RecordFuncInfo(req.Context(), funcInfo)()
|
||||
h.ServeHTTP(resp, req)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// toHandlerProvider converts a handler to a handler provider
|
||||
// A handler provider is a function that takes a "next" http.Handler, it can be used as a middleware
|
||||
func toHandlerProvider(handler any) middlewareProvider {
|
||||
funcInfo := routing.GetFuncInfo(handler)
|
||||
fn := reflect.ValueOf(handler)
|
||||
if fn.Type().Kind() != reflect.Func {
|
||||
panic(fmt.Sprintf("handler must be a function, but got %s", fn.Type()))
|
||||
}
|
||||
|
||||
if hp, ok := handler.(middlewareProvider); ok {
|
||||
return wrapHandlerProvider(hp, funcInfo)
|
||||
} else if hp, ok := handler.(func(http.Handler) http.HandlerFunc); ok {
|
||||
return wrapHandlerProvider(hp, funcInfo)
|
||||
}
|
||||
|
||||
provider := func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(respOrig http.ResponseWriter, req *http.Request) {
|
||||
// wrap the response writer to check whether the response has been written
|
||||
resp := respOrig
|
||||
if _, ok := resp.(types.ResponseStatusProvider); !ok {
|
||||
resp = &responseWriter{respWriter: resp}
|
||||
}
|
||||
|
||||
// prepare the arguments for the handler and do pre-check
|
||||
argsIn := prepareHandleArgsIn(resp, req, fn, funcInfo)
|
||||
if req == nil {
|
||||
preCheckHandler(fn, argsIn)
|
||||
return // it's doing pre-check, just return
|
||||
}
|
||||
|
||||
defer routing.RecordFuncInfo(req.Context(), funcInfo)()
|
||||
ret := fn.Call(argsIn)
|
||||
|
||||
// handle the return value (no-op at the moment)
|
||||
handleResponse(fn, ret)
|
||||
|
||||
// if the response has not been written, call the next handler
|
||||
if next != nil && !hasResponseBeenWritten(argsIn) {
|
||||
next.ServeHTTP(resp, req)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
provider(nil).ServeHTTP(nil, nil) // do a pre-check to make sure all arguments and return values are supported
|
||||
return provider
|
||||
}
|
||||
@@ -0,0 +1,174 @@
|
||||
// Copyright 2014 The Gogs Authors. All rights reserved.
|
||||
// Copyright 2019 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"gitea.dev/modules/translation"
|
||||
"gitea.dev/modules/util"
|
||||
"gitea.dev/modules/validation"
|
||||
|
||||
"gitea.com/go-chi/binding"
|
||||
)
|
||||
|
||||
// Form form binding interface
|
||||
type Form interface {
|
||||
binding.Validator
|
||||
}
|
||||
|
||||
func init() {
|
||||
binding.SetNameMapper(util.ToSnakeCase)
|
||||
}
|
||||
|
||||
// AssignForm assign form values back to the template data.
|
||||
func AssignForm(form any, data map[string]any) {
|
||||
typ := reflect.TypeOf(form)
|
||||
val := reflect.ValueOf(form)
|
||||
|
||||
for typ.Kind() == reflect.Pointer {
|
||||
typ = typ.Elem()
|
||||
val = val.Elem()
|
||||
}
|
||||
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
|
||||
fieldName := field.Tag.Get("form")
|
||||
// Allow ignored fields in the struct
|
||||
if fieldName == "-" {
|
||||
continue
|
||||
} else if len(fieldName) == 0 {
|
||||
fieldName = util.ToSnakeCase(field.Name)
|
||||
}
|
||||
|
||||
data[fieldName] = val.Field(i).Interface()
|
||||
}
|
||||
}
|
||||
|
||||
func getRuleBody(field reflect.StructField, prefix string) string {
|
||||
for rule := range strings.SplitSeq(field.Tag.Get("binding"), ";") {
|
||||
if strings.HasPrefix(rule, prefix) {
|
||||
return rule[len(prefix) : len(rule)-1]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetSize get size int form tag
|
||||
func GetSize(field reflect.StructField) string {
|
||||
return getRuleBody(field, "Size(")
|
||||
}
|
||||
|
||||
// GetMinSize get minimal size in form tag
|
||||
func GetMinSize(field reflect.StructField) string {
|
||||
return getRuleBody(field, "MinSize(")
|
||||
}
|
||||
|
||||
// GetMaxSize get max size in form tag
|
||||
func GetMaxSize(field reflect.StructField) string {
|
||||
return getRuleBody(field, "MaxSize(")
|
||||
}
|
||||
|
||||
// GetInclude get include in form tag
|
||||
func GetInclude(field reflect.StructField) string {
|
||||
return getRuleBody(field, "Include(")
|
||||
}
|
||||
|
||||
func ReportValidationError(errs binding.Errors, data map[string]any, fieldName, classification, errorMsg string) binding.Errors {
|
||||
errs.Add([]string{fieldName}, classification, errorMsg)
|
||||
|
||||
data["HasError"] = true
|
||||
data["ErrorMsg"] = fieldName + ": " + errorMsg
|
||||
data["Err_"+fieldName] = true
|
||||
// there is already a reported validation error, so no need to generate default error messages in Validate()
|
||||
data["HasErrorFormValidation"] = true
|
||||
return errs
|
||||
}
|
||||
|
||||
func Validate(errs binding.Errors, data map[string]any, f Form, l translation.Locale) binding.Errors {
|
||||
// try to restore the form's values as much as possible,
|
||||
// especially for RenderWithErrDeprecated to re-render the form with errors
|
||||
AssignForm(f, data)
|
||||
|
||||
if errs.Len() == 0 || data["HasErrorFormValidation"] == true {
|
||||
return errs
|
||||
}
|
||||
|
||||
// if HasError=true, then must set default error message
|
||||
// because still a lot of places use `ctx.Data["ErrorMsg"].(string)` even if the error fields can't be found
|
||||
data["HasError"] = true
|
||||
data["ErrorMsg"] = l.TrString("form.unknown_error")
|
||||
|
||||
typ := reflect.TypeOf(f)
|
||||
if typ.Kind() == reflect.Pointer {
|
||||
typ = typ.Elem()
|
||||
}
|
||||
|
||||
field, fieldExists := typ.FieldByName(errs[0].FieldNames[0])
|
||||
if !fieldExists {
|
||||
return errs
|
||||
}
|
||||
|
||||
if field.Tag.Get("form") == "-" {
|
||||
return errs
|
||||
}
|
||||
|
||||
data["Err_"+field.Name] = true
|
||||
|
||||
trName := field.Tag.Get("locale")
|
||||
if len(trName) == 0 {
|
||||
trName = l.TrString("form." + field.Name)
|
||||
} else {
|
||||
trName = l.TrString(trName)
|
||||
}
|
||||
|
||||
switch errs[0].Classification {
|
||||
case binding.ERR_REQUIRED:
|
||||
data["ErrorMsg"] = trName + l.TrString("form.require_error")
|
||||
case binding.ERR_ALPHA_DASH:
|
||||
data["ErrorMsg"] = trName + l.TrString("form.alpha_dash_error")
|
||||
case binding.ERR_ALPHA_DASH_DOT:
|
||||
data["ErrorMsg"] = trName + l.TrString("form.alpha_dash_dot_error")
|
||||
case validation.ErrGitRefName:
|
||||
data["ErrorMsg"] = trName + l.TrString("form.git_ref_name_error")
|
||||
case binding.ERR_SIZE:
|
||||
data["ErrorMsg"] = trName + l.TrString("form.size_error", GetSize(field))
|
||||
case binding.ERR_MIN_SIZE:
|
||||
data["ErrorMsg"] = trName + l.TrString("form.min_size_error", GetMinSize(field))
|
||||
case binding.ERR_MAX_SIZE:
|
||||
data["ErrorMsg"] = trName + l.TrString("form.max_size_error", GetMaxSize(field))
|
||||
case binding.ERR_EMAIL:
|
||||
data["ErrorMsg"] = trName + l.TrString("form.email_error")
|
||||
case binding.ERR_URL:
|
||||
data["ErrorMsg"] = trName + l.TrString("form.url_error", errs[0].Message)
|
||||
case binding.ERR_INCLUDE:
|
||||
data["ErrorMsg"] = trName + l.TrString("form.include_error", GetInclude(field))
|
||||
case validation.ErrGlobPattern:
|
||||
data["ErrorMsg"] = trName + l.TrString("form.glob_pattern_error", errs[0].Message)
|
||||
case validation.ErrRegexPattern:
|
||||
data["ErrorMsg"] = trName + l.TrString("form.regex_pattern_error", errs[0].Message)
|
||||
case validation.ErrUsername:
|
||||
data["ErrorMsg"] = trName + l.TrString("form.username_error")
|
||||
case validation.ErrInvalidGroupTeamMap:
|
||||
data["ErrorMsg"] = trName + l.TrString("form.invalid_group_team_map_error", errs[0].Message)
|
||||
case validation.ErrInvalidBadgeSlug:
|
||||
data["ErrorMsg"] = trName + l.TrString("form.invalid_slug_error")
|
||||
default:
|
||||
msg := errs[0].Classification
|
||||
if msg != "" && errs[0].Message != "" {
|
||||
msg += ": "
|
||||
}
|
||||
|
||||
msg += errs[0].Message
|
||||
if msg == "" {
|
||||
msg = l.TrString("form.unknown_error")
|
||||
}
|
||||
data["ErrorMsg"] = trName + ": " + msg
|
||||
}
|
||||
|
||||
return errs
|
||||
}
|
||||
@@ -0,0 +1,107 @@
|
||||
// Copyright 2020 The Macaron Authors
|
||||
// Copyright 2020 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"gitea.dev/modules/session"
|
||||
"gitea.dev/modules/setting"
|
||||
"gitea.dev/modules/util"
|
||||
)
|
||||
|
||||
const (
|
||||
CookieWebBannerDismissed = "gitea_disbnr"
|
||||
CookieTheme = "gitea_theme"
|
||||
cookieRedirectTo = "redirect_to"
|
||||
)
|
||||
|
||||
func GetRedirectToCookie(req *http.Request) string {
|
||||
return GetSiteCookie(req, cookieRedirectTo)
|
||||
}
|
||||
|
||||
// SetRedirectToCookie convenience function to set the RedirectTo cookie consistently
|
||||
func SetRedirectToCookie(resp http.ResponseWriter, value string) {
|
||||
SetSiteCookie(resp, cookieRedirectTo, value, 0)
|
||||
}
|
||||
|
||||
// DeleteRedirectToCookie convenience function to delete most cookies consistently
|
||||
func DeleteRedirectToCookie(resp http.ResponseWriter) {
|
||||
SetSiteCookie(resp, cookieRedirectTo, "", -1)
|
||||
}
|
||||
|
||||
func RedirectLinkUserLogin(req *http.Request) string {
|
||||
if req.Header.Get("X-Gitea-Fetch-Action") != "" {
|
||||
// when building the redirect link for a fetch request, the current link might be a partial page,
|
||||
// so we only redirect to the login page without redirect_to parameter
|
||||
return setting.AppSubURL + "/user/login"
|
||||
}
|
||||
return setting.AppSubURL + "/user/login?redirect_to=" + url.QueryEscape(setting.AppSubURL+req.URL.RequestURI())
|
||||
}
|
||||
|
||||
// GetSiteCookie returns given cookie value from request header.
|
||||
func GetSiteCookie(req *http.Request, name string) string {
|
||||
cookie, err := req.Cookie(name)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
val, _ := url.QueryUnescape(cookie.Value)
|
||||
return val
|
||||
}
|
||||
|
||||
// SetSiteCookie returns given cookie value from request header.
|
||||
func SetSiteCookie(resp http.ResponseWriter, name, value string, maxAge int) {
|
||||
// Previous versions would use a cookie path with a trailing /.
|
||||
// These are more specific than cookies without a trailing /, so
|
||||
// we need to delete these if they exist.
|
||||
deleteLegacySiteCookie(resp, name)
|
||||
|
||||
// HINT: INSTALL-PAGE-COOKIE-INIT: the cookie system is not properly initialized on the Install page, so there is no CookiePath
|
||||
cookie := &http.Cookie{
|
||||
Name: name,
|
||||
Value: url.QueryEscape(value),
|
||||
MaxAge: maxAge,
|
||||
Path: util.IfZero(setting.SessionConfig.CookiePath, "/"),
|
||||
Domain: setting.SessionConfig.Domain,
|
||||
Secure: setting.SessionConfig.Secure,
|
||||
HttpOnly: true,
|
||||
SameSite: setting.SessionConfig.SameSite,
|
||||
}
|
||||
resp.Header().Add("Set-Cookie", cookie.String())
|
||||
}
|
||||
|
||||
// deleteLegacySiteCookie deletes the cookie with the given name at the cookie
|
||||
// path with a trailing /, which would unintentionally override the cookie.
|
||||
func deleteLegacySiteCookie(resp http.ResponseWriter, name string) {
|
||||
if setting.SessionConfig.CookiePath == "" || strings.HasSuffix(setting.SessionConfig.CookiePath, "/") {
|
||||
// If the cookie path ends with /, no legacy cookies will take
|
||||
// precedence, so do nothing. The exception is that cookies with no
|
||||
// path could override other cookies, but it's complicated and we don't
|
||||
// currently handle that.
|
||||
return
|
||||
}
|
||||
|
||||
cookie := &http.Cookie{
|
||||
Name: name,
|
||||
Value: "",
|
||||
MaxAge: -1,
|
||||
Path: setting.SessionConfig.CookiePath + "/",
|
||||
Domain: setting.SessionConfig.Domain,
|
||||
Secure: setting.SessionConfig.Secure,
|
||||
HttpOnly: true,
|
||||
SameSite: setting.SessionConfig.SameSite,
|
||||
}
|
||||
resp.Header().Add("Set-Cookie", cookie.String())
|
||||
}
|
||||
|
||||
func init() {
|
||||
session.BeforeRegenerateSession = append(session.BeforeRegenerateSession, func(resp http.ResponseWriter, _ *http.Request) {
|
||||
// Ensure that a cookie with a trailing slash does not take precedence over
|
||||
// the cookie written by the middleware.
|
||||
deleteLegacySiteCookie(resp, setting.SessionConfig.CookieName)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
// Copyright 2020 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"gitea.dev/modules/public"
|
||||
"gitea.dev/modules/reqctx"
|
||||
"gitea.dev/modules/setting"
|
||||
)
|
||||
|
||||
const ContextDataKeySignedUser = "SignedUser"
|
||||
|
||||
func GetContextData(c context.Context) reqctx.ContextData {
|
||||
if rc := reqctx.GetRequestDataStore(c); rc != nil {
|
||||
return rc.GetData()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func CommonTemplateContextData() reqctx.ContextData {
|
||||
return reqctx.ContextData{
|
||||
"PageTitleCommon": setting.AppName,
|
||||
|
||||
"IsLandingPageOrganizations": setting.LandingPageURL == setting.LandingPageOrganizations,
|
||||
|
||||
"ShowRegistrationButton": setting.Service.ShowRegistrationButton,
|
||||
"ShowMilestonesDashboardPage": setting.Service.ShowMilestonesDashboardPage,
|
||||
"ShowFooterVersion": setting.Other.ShowFooterVersion,
|
||||
"DisableDownloadSourceArchives": setting.Repository.DisableDownloadSourceArchives,
|
||||
|
||||
"EnableSwagger": setting.API.EnableSwagger,
|
||||
"EnableOpenIDSignIn": setting.Service.EnableOpenIDSignIn,
|
||||
"PageStartTime": time.Now(),
|
||||
|
||||
"RunModeIsProd": setting.IsProd,
|
||||
"ViteModeIsDev": public.IsViteDevMode(),
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,92 @@
|
||||
// Copyright 2020 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"gitea.dev/modules/reqctx"
|
||||
)
|
||||
|
||||
// Flash represents a one time data transfer between two requests.
|
||||
type Flash struct {
|
||||
DataStore reqctx.RequestDataStore
|
||||
url.Values
|
||||
ErrorMsg, WarningMsg, InfoMsg, SuccessMsg string
|
||||
}
|
||||
|
||||
func (f *Flash) set(name, msg string, current ...bool) {
|
||||
if f.Values == nil {
|
||||
f.Values = make(map[string][]string)
|
||||
}
|
||||
showInCurrentPage := len(current) > 0 && current[0]
|
||||
if showInCurrentPage {
|
||||
// assign it to the context data, then the template can use ".Flash.XxxMsg" to render the message
|
||||
f.DataStore.GetData()["Flash"] = f
|
||||
} else {
|
||||
// the message map will be saved into the cookie and be shown in next response (a new page response which decodes the cookie)
|
||||
f.Set(name, msg)
|
||||
}
|
||||
}
|
||||
|
||||
func flashMsgStringOrHTML(msg any) string {
|
||||
switch v := msg.(type) {
|
||||
case string:
|
||||
return v
|
||||
case template.HTML:
|
||||
return string(v)
|
||||
}
|
||||
panic(fmt.Sprintf("unknown type: %T", msg))
|
||||
}
|
||||
|
||||
// Error sets error message
|
||||
func (f *Flash) Error(msg any, current ...bool) {
|
||||
f.ErrorMsg = flashMsgStringOrHTML(msg)
|
||||
f.set("error", f.ErrorMsg, current...)
|
||||
}
|
||||
|
||||
// Warning sets warning message
|
||||
func (f *Flash) Warning(msg any, current ...bool) {
|
||||
f.WarningMsg = flashMsgStringOrHTML(msg)
|
||||
f.set("warning", f.WarningMsg, current...)
|
||||
}
|
||||
|
||||
// Info sets info message
|
||||
func (f *Flash) Info(msg any, current ...bool) {
|
||||
f.InfoMsg = flashMsgStringOrHTML(msg)
|
||||
f.set("info", f.InfoMsg, current...)
|
||||
}
|
||||
|
||||
// Success sets success message
|
||||
func (f *Flash) Success(msg any, current ...bool) {
|
||||
f.SuccessMsg = flashMsgStringOrHTML(msg)
|
||||
f.set("success", f.SuccessMsg, current...)
|
||||
}
|
||||
|
||||
func ParseCookieFlashMessage(val string) *Flash {
|
||||
if vals, _ := url.ParseQuery(val); len(vals) > 0 {
|
||||
return &Flash{
|
||||
Values: vals,
|
||||
ErrorMsg: vals.Get("error"),
|
||||
SuccessMsg: vals.Get("success"),
|
||||
InfoMsg: vals.Get("info"),
|
||||
WarningMsg: vals.Get("warning"),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetSiteCookieFlashMessage(dataStore reqctx.RequestDataStore, req *http.Request, cookieName string) (string, *Flash) {
|
||||
// Get the last flash message from cookie
|
||||
lastFlashCookie := GetSiteCookie(req, cookieName)
|
||||
lastFlashMsg := ParseCookieFlashMessage(lastFlashCookie)
|
||||
if lastFlashMsg != nil {
|
||||
lastFlashMsg.DataStore = dataStore
|
||||
return lastFlashCookie, lastFlashMsg
|
||||
}
|
||||
return lastFlashCookie, nil
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
// Copyright 2020 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"gitea.dev/modules/translation"
|
||||
"gitea.dev/modules/translation/i18n"
|
||||
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
// Locale handle locale
|
||||
func Locale(resp http.ResponseWriter, req *http.Request) translation.Locale {
|
||||
// 1. Check URL arguments.
|
||||
lang := req.URL.Query().Get("lang")
|
||||
changeLang := lang != ""
|
||||
|
||||
// 2. Get language information from cookies.
|
||||
if len(lang) == 0 {
|
||||
ck, _ := req.Cookie("lang")
|
||||
if ck != nil {
|
||||
lang = ck.Value
|
||||
}
|
||||
}
|
||||
|
||||
// Check again in case someone changes the supported language list.
|
||||
if lang != "" && !i18n.DefaultLocales.HasLang(lang) {
|
||||
lang = ""
|
||||
changeLang = false
|
||||
}
|
||||
|
||||
// 3. Get language information from 'Accept-Language'.
|
||||
// The first element in the list is chosen to be the default language automatically.
|
||||
if len(lang) == 0 {
|
||||
tags, _, _ := language.ParseAcceptLanguage(req.Header.Get("Accept-Language"))
|
||||
tag := translation.Match(tags...)
|
||||
lang = tag.String()
|
||||
}
|
||||
|
||||
if changeLang {
|
||||
SetLocaleCookie(resp, lang, 1<<31-1)
|
||||
}
|
||||
|
||||
return translation.NewLocale(lang)
|
||||
}
|
||||
|
||||
// SetLocaleCookie convenience function to set the locale cookie consistently
|
||||
func SetLocaleCookie(resp http.ResponseWriter, lang string, maxAge int) {
|
||||
SetSiteCookie(resp, "lang", lang, maxAge)
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
// Copyright 2023 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package web
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"gitea.dev/modules/setting"
|
||||
)
|
||||
|
||||
// MockAfterMiddlewares is a general mock point, it's between middlewares and the handler
|
||||
const MockAfterMiddlewares = "MockAfterMiddlewares"
|
||||
|
||||
var routeMockPoints = map[string]func(next http.Handler) http.Handler{}
|
||||
|
||||
// RouterMockPoint registers a mock point as a middleware for testing, example:
|
||||
//
|
||||
// r.Use(web.RouterMockPoint("my-mock-point-1"))
|
||||
// r.Get("/foo", middleware2, web.RouterMockPoint("my-mock-point-2"), middleware2, handler)
|
||||
//
|
||||
// Then use web.RouteMock to mock the route execution.
|
||||
// It only takes effect in testing mode (setting.IsInTesting == true).
|
||||
func RouterMockPoint(pointName string) func(next http.Handler) http.Handler {
|
||||
if !setting.IsInTesting {
|
||||
return nil
|
||||
}
|
||||
routeMockPoints[pointName] = nil
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if h := routeMockPoints[pointName]; h != nil {
|
||||
h(next).ServeHTTP(w, r)
|
||||
} else {
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RouteMock uses the registered mock point to mock the route execution, example:
|
||||
//
|
||||
// defer web.RouteMockReset()
|
||||
// web.RouteMock(web.MockAfterMiddlewares, func(ctx *context.Context) {
|
||||
// ctx.WriteResponse(...)
|
||||
// }
|
||||
//
|
||||
// Then the mock function will be executed as a middleware at the mock point.
|
||||
// It only takes effect in testing mode (setting.IsInTesting == true).
|
||||
func RouteMock(pointName string, h any) func() {
|
||||
if _, ok := routeMockPoints[pointName]; !ok {
|
||||
panic("route mock point not found: " + pointName)
|
||||
}
|
||||
old := routeMockPoints[pointName]
|
||||
routeMockPoints[pointName] = toHandlerProvider(h)
|
||||
return func() {
|
||||
routeMockPoints[pointName] = old
|
||||
}
|
||||
}
|
||||
|
||||
// RouteMockReset resets all mock points (no mock anymore)
|
||||
func RouteMockReset() {
|
||||
for k := range routeMockPoints {
|
||||
routeMockPoints[k] = nil // keep the keys because RouteMock will check the keys to make sure no misspelling
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
// Copyright 2023 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package web
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"gitea.dev/modules/setting"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRouteMock(t *testing.T) {
|
||||
setting.IsInTesting = true
|
||||
|
||||
r := NewRouter()
|
||||
middleware1 := func(resp http.ResponseWriter, req *http.Request) {
|
||||
resp.Header().Set("X-Test-Middleware1", "m1")
|
||||
}
|
||||
middleware2 := func(resp http.ResponseWriter, req *http.Request) {
|
||||
resp.Header().Set("X-Test-Middleware2", "m2")
|
||||
}
|
||||
handler := func(resp http.ResponseWriter, req *http.Request) {
|
||||
resp.Header().Set("X-Test-Handler", "h")
|
||||
}
|
||||
r.Get("/foo", middleware1, RouterMockPoint("mock-point"), middleware2, handler)
|
||||
|
||||
// normal request
|
||||
recorder := httptest.NewRecorder()
|
||||
req, err := http.NewRequest(http.MethodGet, "http://localhost:8000/foo", nil)
|
||||
assert.NoError(t, err)
|
||||
r.ServeHTTP(recorder, req)
|
||||
assert.Len(t, recorder.Header(), 3)
|
||||
assert.Equal(t, "m1", recorder.Header().Get("X-Test-Middleware1"))
|
||||
assert.Equal(t, "m2", recorder.Header().Get("X-Test-Middleware2"))
|
||||
assert.Equal(t, "h", recorder.Header().Get("X-Test-Handler"))
|
||||
RouteMockReset()
|
||||
|
||||
// mock at "mock-point"
|
||||
RouteMock("mock-point", func(resp http.ResponseWriter, req *http.Request) {
|
||||
resp.Header().Set("X-Test-MockPoint", "a")
|
||||
resp.WriteHeader(http.StatusOK)
|
||||
})
|
||||
recorder = httptest.NewRecorder()
|
||||
req, err = http.NewRequest(http.MethodGet, "http://localhost:8000/foo", nil)
|
||||
assert.NoError(t, err)
|
||||
r.ServeHTTP(recorder, req)
|
||||
assert.Len(t, recorder.Header(), 2)
|
||||
assert.Equal(t, "m1", recorder.Header().Get("X-Test-Middleware1"))
|
||||
assert.Equal(t, "a", recorder.Header().Get("X-Test-MockPoint"))
|
||||
RouteMockReset()
|
||||
|
||||
// mock at MockAfterMiddlewares
|
||||
RouteMock(MockAfterMiddlewares, func(resp http.ResponseWriter, req *http.Request) {
|
||||
resp.Header().Set("X-Test-MockPoint", "b")
|
||||
resp.WriteHeader(http.StatusOK)
|
||||
})
|
||||
recorder = httptest.NewRecorder()
|
||||
req, err = http.NewRequest(http.MethodGet, "http://localhost:8000/foo", nil)
|
||||
assert.NoError(t, err)
|
||||
r.ServeHTTP(recorder, req)
|
||||
assert.Len(t, recorder.Header(), 3)
|
||||
assert.Equal(t, "m1", recorder.Header().Get("X-Test-Middleware1"))
|
||||
assert.Equal(t, "m2", recorder.Header().Get("X-Test-Middleware2"))
|
||||
assert.Equal(t, "b", recorder.Header().Get("X-Test-MockPoint"))
|
||||
RouteMockReset()
|
||||
}
|
||||
@@ -0,0 +1,295 @@
|
||||
// Copyright 2020 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package web
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"gitea.dev/modules/htmlutil"
|
||||
"gitea.dev/modules/reqctx"
|
||||
"gitea.dev/modules/setting"
|
||||
"gitea.dev/modules/web/middleware"
|
||||
"gitea.dev/modules/web/types"
|
||||
|
||||
"gitea.com/go-chi/binding"
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
// Bind binding an obj to a handler's context data
|
||||
func Bind[T any](_ T) http.HandlerFunc {
|
||||
return func(resp http.ResponseWriter, req *http.Request) {
|
||||
theObj := new(T) // create a new form obj for every request but not use obj directly
|
||||
data := middleware.GetContextData(req.Context())
|
||||
binding.Bind(req, theObj)
|
||||
SetForm(data, theObj)
|
||||
middleware.AssignForm(theObj, data)
|
||||
}
|
||||
}
|
||||
|
||||
// SetForm set the form object
|
||||
func SetForm(dataStore reqctx.ContextDataProvider, obj any) {
|
||||
dataStore.GetData()["__form"] = obj
|
||||
}
|
||||
|
||||
// GetForm returns the validate form information
|
||||
func GetForm(dataStore reqctx.RequestDataStore) any {
|
||||
return dataStore.GetData()["__form"]
|
||||
}
|
||||
|
||||
// Router defines a route based on chi's router
|
||||
type Router struct {
|
||||
chiRouter *chi.Mux
|
||||
|
||||
afterRouting []any
|
||||
|
||||
curGroupPrefix string
|
||||
curMiddlewares []any
|
||||
}
|
||||
|
||||
// NewRouter creates a new route
|
||||
func NewRouter() *Router {
|
||||
r := chi.NewRouter()
|
||||
return &Router{chiRouter: r}
|
||||
}
|
||||
|
||||
// BeforeRouting adds middlewares which will be executed before the request path gets routed
|
||||
// It should only be used for framework-level global middlewares when it needs to change request method & path.
|
||||
func (r *Router) BeforeRouting(middlewares ...any) {
|
||||
for _, m := range middlewares {
|
||||
if !isNilOrFuncNil(m) {
|
||||
r.chiRouter.Use(toHandlerProvider(m))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AfterRouting adds middlewares which will be executed after the request path gets routed
|
||||
// It can see the routed path and resolved path parameters
|
||||
func (r *Router) AfterRouting(middlewares ...any) {
|
||||
r.afterRouting = append(r.afterRouting, middlewares...)
|
||||
}
|
||||
|
||||
// Group mounts a sub-router along a "pattern" string.
|
||||
func (r *Router) Group(pattern string, fn func(), middlewares ...any) {
|
||||
previousGroupPrefix := r.curGroupPrefix
|
||||
previousMiddlewares := r.curMiddlewares
|
||||
r.curGroupPrefix += pattern
|
||||
r.curMiddlewares = append(r.curMiddlewares, middlewares...)
|
||||
|
||||
fn()
|
||||
|
||||
r.curGroupPrefix = previousGroupPrefix
|
||||
r.curMiddlewares = previousMiddlewares
|
||||
}
|
||||
|
||||
func (r *Router) getPattern(pattern string) string {
|
||||
newPattern := r.curGroupPrefix + pattern
|
||||
if !strings.HasPrefix(newPattern, "/") {
|
||||
newPattern = "/" + newPattern
|
||||
}
|
||||
if newPattern == "/" {
|
||||
return newPattern
|
||||
}
|
||||
return strings.TrimSuffix(newPattern, "/")
|
||||
}
|
||||
|
||||
func isNilOrFuncNil(v any) bool {
|
||||
if v == nil {
|
||||
return true
|
||||
}
|
||||
r := reflect.ValueOf(v)
|
||||
return r.Kind() == reflect.Func && r.IsNil()
|
||||
}
|
||||
|
||||
func wrapMiddlewareAppendPre(all []middlewareProvider, middlewares []any) []middlewareProvider {
|
||||
for _, m := range middlewares {
|
||||
if h, ok := m.(types.PreMiddlewareProvider); ok && h != nil {
|
||||
all = append(all, toHandlerProvider(middlewareProvider(h)))
|
||||
}
|
||||
}
|
||||
return all
|
||||
}
|
||||
|
||||
func wrapMiddlewareAppendNormal(all []middlewareProvider, middlewares []any) []middlewareProvider {
|
||||
for _, m := range middlewares {
|
||||
if _, ok := m.(types.PreMiddlewareProvider); !ok && !isNilOrFuncNil(m) {
|
||||
all = append(all, toHandlerProvider(m))
|
||||
}
|
||||
}
|
||||
return all
|
||||
}
|
||||
|
||||
func wrapMiddlewareAndHandler(useMiddlewares, curMiddlewares, h []any) (_ []middlewareProvider, _ http.HandlerFunc, hasPreMiddlewares bool) {
|
||||
if len(h) == 0 {
|
||||
panic("no endpoint handler provided")
|
||||
}
|
||||
if isNilOrFuncNil(h[len(h)-1]) {
|
||||
panic("endpoint handler can't be nil")
|
||||
}
|
||||
|
||||
handlerProviders := make([]middlewareProvider, 0, len(useMiddlewares)+len(curMiddlewares)+len(h)+1)
|
||||
handlerProviders = wrapMiddlewareAppendPre(handlerProviders, useMiddlewares)
|
||||
handlerProviders = wrapMiddlewareAppendPre(handlerProviders, curMiddlewares)
|
||||
handlerProviders = wrapMiddlewareAppendPre(handlerProviders, h)
|
||||
hasPreMiddlewares = len(handlerProviders) > 0
|
||||
handlerProviders = wrapMiddlewareAppendNormal(handlerProviders, useMiddlewares)
|
||||
handlerProviders = wrapMiddlewareAppendNormal(handlerProviders, curMiddlewares)
|
||||
handlerProviders = wrapMiddlewareAppendNormal(handlerProviders, h)
|
||||
|
||||
middlewares := handlerProviders[:len(handlerProviders)-1]
|
||||
handlerFunc := handlerProviders[len(handlerProviders)-1](nil).ServeHTTP
|
||||
mockPoint := RouterMockPoint(MockAfterMiddlewares)
|
||||
if mockPoint != nil {
|
||||
middlewares = append(middlewares, mockPoint)
|
||||
}
|
||||
return middlewares, handlerFunc, hasPreMiddlewares
|
||||
}
|
||||
|
||||
// Methods adds the same handlers for multiple http "methods" (separated by ",").
|
||||
// If any method is invalid, the lower level router will panic.
|
||||
func (r *Router) Methods(methods, pattern string, h ...any) {
|
||||
middlewares, handlerFunc, _ := wrapMiddlewareAndHandler(r.afterRouting, r.curMiddlewares, h)
|
||||
fullPattern := r.getPattern(pattern)
|
||||
if strings.Contains(methods, ",") {
|
||||
methods := strings.SplitSeq(methods, ",")
|
||||
for method := range methods {
|
||||
r.chiRouter.With(middlewares...).Method(strings.TrimSpace(method), fullPattern, handlerFunc)
|
||||
}
|
||||
} else {
|
||||
r.chiRouter.With(middlewares...).Method(methods, fullPattern, handlerFunc)
|
||||
}
|
||||
}
|
||||
|
||||
// Mount attaches another Router along "/pattern/*"
|
||||
func (r *Router) Mount(pattern string, subRouter *Router) {
|
||||
handlerProviders := make([]middlewareProvider, 0, len(r.afterRouting)+len(r.curMiddlewares))
|
||||
handlerProviders = wrapMiddlewareAppendPre(handlerProviders, r.afterRouting)
|
||||
handlerProviders = wrapMiddlewareAppendPre(handlerProviders, r.curMiddlewares)
|
||||
handlerProviders = wrapMiddlewareAppendNormal(handlerProviders, r.afterRouting)
|
||||
handlerProviders = wrapMiddlewareAppendNormal(handlerProviders, r.curMiddlewares)
|
||||
r.chiRouter.With(handlerProviders...).Mount(r.getPattern(pattern), subRouter.chiRouter)
|
||||
}
|
||||
|
||||
// Any delegate requests for all methods
|
||||
func (r *Router) Any(pattern string, h ...any) {
|
||||
middlewares, handlerFunc, _ := wrapMiddlewareAndHandler(r.afterRouting, r.curMiddlewares, h)
|
||||
r.chiRouter.With(middlewares...).HandleFunc(r.getPattern(pattern), handlerFunc)
|
||||
}
|
||||
|
||||
// Delete delegate delete method
|
||||
func (r *Router) Delete(pattern string, h ...any) {
|
||||
r.Methods("DELETE", pattern, h...)
|
||||
}
|
||||
|
||||
// Get delegate get method
|
||||
func (r *Router) Get(pattern string, h ...any) {
|
||||
r.Methods("GET", pattern, h...)
|
||||
}
|
||||
|
||||
// Head delegate head method
|
||||
func (r *Router) Head(pattern string, h ...any) {
|
||||
r.Methods("HEAD", pattern, h...)
|
||||
}
|
||||
|
||||
// Post delegate post method
|
||||
func (r *Router) Post(pattern string, h ...any) {
|
||||
r.Methods("POST", pattern, h...)
|
||||
}
|
||||
|
||||
// Put delegate put method
|
||||
func (r *Router) Put(pattern string, h ...any) {
|
||||
r.Methods("PUT", pattern, h...)
|
||||
}
|
||||
|
||||
// Patch delegate patch method
|
||||
func (r *Router) Patch(pattern string, h ...any) {
|
||||
r.Methods("PATCH", pattern, h...)
|
||||
}
|
||||
|
||||
// ServeHTTP implements http.Handler
|
||||
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
// TODO: need to move it to the top-level common middleware, otherwise each "Mount" will cause it to be executed multiple times, which is inefficient.
|
||||
r.normalizeRequestPath(w, req, r.chiRouter)
|
||||
}
|
||||
|
||||
// NotFound defines a handler to respond whenever a route could not be found.
|
||||
func (r *Router) NotFound(h http.HandlerFunc) {
|
||||
middlewares, handlerFunc, _ := wrapMiddlewareAndHandler(r.afterRouting, r.curMiddlewares, []any{h})
|
||||
r.chiRouter.NotFound(func(w http.ResponseWriter, r *http.Request) {
|
||||
executeMiddlewaresHandler(w, r, middlewares, handlerFunc)
|
||||
})
|
||||
}
|
||||
|
||||
func (r *Router) normalizeRequestPath(resp http.ResponseWriter, req *http.Request, next http.Handler) {
|
||||
normalized := false
|
||||
normalizedPath := req.URL.EscapedPath()
|
||||
if normalizedPath == "" {
|
||||
normalizedPath, normalized = "/", true
|
||||
} else if normalizedPath != "/" {
|
||||
normalized = strings.HasSuffix(normalizedPath, "/")
|
||||
normalizedPath = strings.TrimRight(normalizedPath, "/")
|
||||
}
|
||||
removeRepeatedSlashes := strings.Contains(normalizedPath, "//")
|
||||
normalized = normalized || removeRepeatedSlashes
|
||||
|
||||
// the following code block is a slow-path for replacing all repeated slashes "//" to one single "/"
|
||||
// if the path doesn't have repeated slashes, then no need to execute it
|
||||
if removeRepeatedSlashes {
|
||||
buf := &strings.Builder{}
|
||||
for i := 0; i < len(normalizedPath); i++ {
|
||||
if i == 0 || normalizedPath[i-1] != '/' || normalizedPath[i] != '/' {
|
||||
buf.WriteByte(normalizedPath[i])
|
||||
}
|
||||
}
|
||||
normalizedPath = buf.String()
|
||||
}
|
||||
|
||||
// If the config tells Gitea to use a sub-url path directly without reverse proxy,
|
||||
// then we need to remove the sub-url path from the request URL path.
|
||||
// But "/v2" is special for OCI container registry, it should always be in the root of the site.
|
||||
if setting.UseSubURLPath {
|
||||
remainingPath, ok := strings.CutPrefix(normalizedPath, setting.AppSubURL+"/")
|
||||
if ok {
|
||||
normalizedPath = "/" + remainingPath
|
||||
} else if normalizedPath == setting.AppSubURL {
|
||||
normalizedPath = "/"
|
||||
} else if !strings.HasPrefix(normalizedPath+"/", "/v2/") {
|
||||
// do not respond to other requests, to simulate a real sub-path environment
|
||||
resp.Header().Add("Content-Type", "text/html; charset=utf-8")
|
||||
resp.WriteHeader(http.StatusNotFound)
|
||||
_, _ = htmlutil.HTMLPrintf(resp, `404 page not found, sub-path is: <a href="%s">%s</a>`, setting.AppSubURL, setting.AppSubURL)
|
||||
return
|
||||
}
|
||||
normalized = true
|
||||
}
|
||||
|
||||
// if the path is normalized, then fill it back to the request
|
||||
if normalized {
|
||||
decodedPath, err := url.PathUnescape(normalizedPath)
|
||||
if err != nil {
|
||||
http.Error(resp, "400 Bad Request: unable to unescape path "+normalizedPath, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
req.URL.RawPath = normalizedPath
|
||||
req.URL.Path = decodedPath
|
||||
}
|
||||
|
||||
next.ServeHTTP(resp, req)
|
||||
}
|
||||
|
||||
// Combo delegates requests to Combo
|
||||
func (r *Router) Combo(pattern string, h ...any) *Combo {
|
||||
return &Combo{r, pattern, h}
|
||||
}
|
||||
|
||||
// PathGroup creates a group of paths which could be matched by regexp.
|
||||
// It is only designed to resolve some special cases which chi router can't handle.
|
||||
// For most cases, it shouldn't be used because it needs to iterate all rules to find the matched one (inefficient).
|
||||
func (r *Router) PathGroup(pattern string, fn func(g *RouterPathGroup), h ...any) {
|
||||
g := &RouterPathGroup{r: r, pathParam: "*"}
|
||||
fn(g)
|
||||
r.Any(pattern, append(h, g.ServeHTTP)...)
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
// Copyright 2024 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package web
|
||||
|
||||
// Combo represents a tiny group routes with same pattern
|
||||
type Combo struct {
|
||||
r *Router
|
||||
pattern string
|
||||
h []any
|
||||
}
|
||||
|
||||
// Get delegates Get method
|
||||
func (c *Combo) Get(h ...any) *Combo {
|
||||
c.r.Get(c.pattern, append(c.h, h...)...)
|
||||
return c
|
||||
}
|
||||
|
||||
// Post delegates Post method
|
||||
func (c *Combo) Post(h ...any) *Combo {
|
||||
c.r.Post(c.pattern, append(c.h, h...)...)
|
||||
return c
|
||||
}
|
||||
|
||||
// Delete delegates Delete method
|
||||
func (c *Combo) Delete(h ...any) *Combo {
|
||||
c.r.Delete(c.pattern, append(c.h, h...)...)
|
||||
return c
|
||||
}
|
||||
|
||||
// Put delegates Put method
|
||||
func (c *Combo) Put(h ...any) *Combo {
|
||||
c.r.Put(c.pattern, append(c.h, h...)...)
|
||||
return c
|
||||
}
|
||||
|
||||
// Patch delegates Patch method
|
||||
func (c *Combo) Patch(h ...any) *Combo {
|
||||
c.r.Patch(c.pattern, append(c.h, h...)...)
|
||||
return c
|
||||
}
|
||||
@@ -0,0 +1,185 @@
|
||||
// Copyright 2024 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package web
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"gitea.dev/modules/container"
|
||||
"gitea.dev/modules/util"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
type RouterPathGroup struct {
|
||||
r *Router
|
||||
pathParam string
|
||||
matchers []*routerPathMatcher
|
||||
}
|
||||
|
||||
func (g *RouterPathGroup) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
|
||||
chiCtx := chi.RouteContext(req.Context())
|
||||
path := chiCtx.URLParam(g.pathParam)
|
||||
for _, m := range g.matchers {
|
||||
if m.matchPath(chiCtx, path) {
|
||||
chiCtx.RoutePatterns = append(chiCtx.RoutePatterns, m.pattern)
|
||||
executeMiddlewaresHandler(resp, req, m.middlewares, m.handlerFunc)
|
||||
return
|
||||
}
|
||||
}
|
||||
g.r.chiRouter.NotFoundHandler().ServeHTTP(resp, req)
|
||||
}
|
||||
|
||||
type RouterPathGroupPattern struct {
|
||||
pattern string
|
||||
re *regexp.Regexp
|
||||
params []routerPathParam
|
||||
middlewares []any
|
||||
}
|
||||
|
||||
// MatchPath matches the request method, and uses regexp to match the path.
|
||||
// The pattern uses "<...>" to define path parameters, for example, "/<name>" (different from chi router)
|
||||
// It is only designed to resolve some special cases that chi router can't handle.
|
||||
// For most cases, it shouldn't be used because it needs to iterate all rules to find the matched one (inefficient).
|
||||
func (g *RouterPathGroup) MatchPath(methods, pattern string, h ...any) {
|
||||
g.MatchPattern(methods, g.PatternRegexp(pattern), h...)
|
||||
}
|
||||
|
||||
func (g *RouterPathGroup) MatchPattern(methods string, pattern *RouterPathGroupPattern, h ...any) {
|
||||
g.matchers = append(g.matchers, newRouterPathMatcher(methods, pattern, h...))
|
||||
}
|
||||
|
||||
type routerPathParam struct {
|
||||
name string
|
||||
pathSepEnd bool
|
||||
captureGroup int
|
||||
}
|
||||
|
||||
type routerPathMatcher struct {
|
||||
methods container.Set[string]
|
||||
pattern string
|
||||
re *regexp.Regexp
|
||||
params []routerPathParam
|
||||
middlewares []middlewareProvider
|
||||
handlerFunc http.HandlerFunc
|
||||
}
|
||||
|
||||
func (p *routerPathMatcher) matchPath(chiCtx *chi.Context, path string) bool {
|
||||
if !p.methods.Contains(chiCtx.RouteMethod) {
|
||||
return false
|
||||
}
|
||||
if !strings.HasPrefix(path, "/") {
|
||||
path = "/" + path
|
||||
}
|
||||
pathMatches := p.re.FindStringSubmatchIndex(path) // Golang regexp match pairs [start, end, start, end, ...]
|
||||
if pathMatches == nil {
|
||||
return false
|
||||
}
|
||||
var paramMatches [][]int
|
||||
for i := 2; i < len(pathMatches); {
|
||||
paramMatches = append(paramMatches, []int{pathMatches[i], pathMatches[i+1]})
|
||||
pmIdx := len(paramMatches) - 1
|
||||
end := pathMatches[i+1]
|
||||
i += 2
|
||||
for ; i < len(pathMatches); i += 2 {
|
||||
if pathMatches[i] >= end {
|
||||
break
|
||||
}
|
||||
paramMatches[pmIdx] = append(paramMatches[pmIdx], pathMatches[i], pathMatches[i+1])
|
||||
}
|
||||
}
|
||||
for i, pm := range paramMatches {
|
||||
groupIdx := p.params[i].captureGroup * 2
|
||||
if pm[groupIdx] == -1 || pm[groupIdx+1] == -1 {
|
||||
chiCtx.URLParams.Add(p.params[i].name, "")
|
||||
continue
|
||||
}
|
||||
val := path[pm[groupIdx]:pm[groupIdx+1]]
|
||||
if p.params[i].pathSepEnd {
|
||||
val = strings.TrimSuffix(val, "/")
|
||||
}
|
||||
chiCtx.URLParams.Add(p.params[i].name, val)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func isValidMethod(name string) bool {
|
||||
switch name {
|
||||
case http.MethodGet, http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete, http.MethodHead, http.MethodOptions, http.MethodConnect, http.MethodTrace:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func newRouterPathMatcher(methods string, patternRegexp *RouterPathGroupPattern, h ...any) *routerPathMatcher {
|
||||
middlewares, handlerFunc, hasPreMiddlewares := wrapMiddlewareAndHandler(nil, patternRegexp.middlewares, h)
|
||||
if hasPreMiddlewares {
|
||||
panic("pre-middlewares are not supported in router path matcher")
|
||||
}
|
||||
p := &routerPathMatcher{methods: make(container.Set[string]), middlewares: middlewares, handlerFunc: handlerFunc}
|
||||
for method := range strings.SplitSeq(methods, ",") {
|
||||
method = strings.TrimSpace(method)
|
||||
if !isValidMethod(method) {
|
||||
panic("invalid HTTP method: " + method)
|
||||
}
|
||||
p.methods.Add(method)
|
||||
}
|
||||
p.pattern, p.re, p.params = patternRegexp.pattern, patternRegexp.re, patternRegexp.params
|
||||
return p
|
||||
}
|
||||
|
||||
func patternRegexp(pattern string, h ...any) *RouterPathGroupPattern {
|
||||
p := &RouterPathGroupPattern{middlewares: slices.Clone(h)}
|
||||
re := []byte{'^'}
|
||||
lastEnd := 0
|
||||
for lastEnd < len(pattern) {
|
||||
start := strings.IndexByte(pattern[lastEnd:], '<')
|
||||
if start == -1 {
|
||||
re = append(re, regexp.QuoteMeta(pattern[lastEnd:])...)
|
||||
break
|
||||
}
|
||||
end := strings.IndexByte(pattern[lastEnd+start:], '>')
|
||||
if end == -1 {
|
||||
panic("invalid pattern: " + pattern)
|
||||
}
|
||||
re = append(re, regexp.QuoteMeta(pattern[lastEnd:lastEnd+start])...)
|
||||
partName, partExp, _ := strings.Cut(pattern[lastEnd+start+1:lastEnd+start+end], ":")
|
||||
lastEnd += start + end + 1
|
||||
|
||||
// TODO: it could support to specify a "capture group" for the name, for example: "/<name[2]:(\d)-(\d)>"
|
||||
// it is not used so no need to implement it now
|
||||
param := routerPathParam{}
|
||||
if partExp == "*" {
|
||||
// "<part:*>" is a shorthand for optionally matching any string (but not greedy)
|
||||
partExp = ".*?"
|
||||
if lastEnd < len(pattern) && pattern[lastEnd] == '/' {
|
||||
// if this param part ends with path separator "/", then consider it together: "(.*?/)"
|
||||
partExp += "/"
|
||||
param.pathSepEnd = true
|
||||
lastEnd++
|
||||
}
|
||||
re = append(re, '(')
|
||||
re = append(re, partExp...)
|
||||
re = append(re, ')', '?') // the wildcard matching is optional
|
||||
} else {
|
||||
// the pattern is user-provided regexp, defaults to a path part (separated by "/")
|
||||
partExp = util.IfZero(partExp, "[^/]+")
|
||||
re = append(re, '(')
|
||||
re = append(re, partExp...)
|
||||
re = append(re, ')')
|
||||
}
|
||||
param.name = partName
|
||||
p.params = append(p.params, param)
|
||||
}
|
||||
re = append(re, '$')
|
||||
p.pattern, p.re = pattern, regexp.MustCompile(string(re))
|
||||
return p
|
||||
}
|
||||
|
||||
func (g *RouterPathGroup) PatternRegexp(pattern string, h ...any) *RouterPathGroupPattern {
|
||||
return patternRegexp(pattern, h...)
|
||||
}
|
||||
@@ -0,0 +1,342 @@
|
||||
// Copyright 2021 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package web
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"gitea.dev/modules/setting"
|
||||
"gitea.dev/modules/test"
|
||||
"gitea.dev/modules/util"
|
||||
"gitea.dev/modules/web/types"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func chiURLParamsToMap(chiCtx *chi.Context) map[string]string {
|
||||
pathParams := chiCtx.URLParams
|
||||
m := make(map[string]string, len(pathParams.Keys))
|
||||
for i, key := range pathParams.Keys {
|
||||
if key == "*" && pathParams.Values[i] == "" {
|
||||
continue // chi router will add an empty "*" key if there is a "Mount"
|
||||
}
|
||||
m[key] = pathParams.Values[i]
|
||||
}
|
||||
return util.Iif(len(m) == 0, nil, m)
|
||||
}
|
||||
|
||||
type testResult struct {
|
||||
method string
|
||||
pathParams map[string]string
|
||||
handlerMarks []string
|
||||
chiRoutePattern *string
|
||||
}
|
||||
|
||||
type testRecorder struct {
|
||||
res testResult
|
||||
}
|
||||
|
||||
func (r *testRecorder) reset() {
|
||||
r.res = testResult{}
|
||||
}
|
||||
|
||||
func (r *testRecorder) handle(optMark ...string) func(resp http.ResponseWriter, req *http.Request) {
|
||||
mark := util.OptionalArg(optMark, "")
|
||||
return func(resp http.ResponseWriter, req *http.Request) {
|
||||
chiCtx := chi.RouteContext(req.Context())
|
||||
r.res.method = req.Method
|
||||
r.res.pathParams = chiURLParamsToMap(chiCtx)
|
||||
r.res.chiRoutePattern = new(chiCtx.RoutePattern())
|
||||
if mark != "" {
|
||||
r.res.handlerMarks = append(r.res.handlerMarks, mark)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *testRecorder) provider(optMark ...string) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
|
||||
r.handle(optMark...)(resp, req)
|
||||
next.ServeHTTP(resp, req)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (r *testRecorder) stop(optMark ...string) func(resp http.ResponseWriter, req *http.Request) {
|
||||
mark := util.OptionalArg(optMark, "")
|
||||
return func(resp http.ResponseWriter, req *http.Request) {
|
||||
if stop := req.FormValue("stop"); stop != "" && (mark == "" || mark == stop) {
|
||||
r.handle(stop)(resp, req)
|
||||
resp.WriteHeader(http.StatusOK)
|
||||
} else if mark != "" {
|
||||
r.res.handlerMarks = append(r.res.handlerMarks, mark)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *testRecorder) test(t *testing.T, rt *Router, methodPath string, expected testResult) {
|
||||
r.reset()
|
||||
methodPathFields := strings.Fields(methodPath)
|
||||
req, err := http.NewRequest(methodPathFields[0], methodPathFields[1], nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
buff := &bytes.Buffer{}
|
||||
httpRecorder := httptest.NewRecorder()
|
||||
httpRecorder.Body = buff
|
||||
rt.ServeHTTP(httpRecorder, req)
|
||||
if expected.chiRoutePattern == nil {
|
||||
r.res.chiRoutePattern = nil
|
||||
}
|
||||
assert.Equal(t, expected, r.res)
|
||||
}
|
||||
|
||||
func TestPathProcessor(t *testing.T) {
|
||||
testProcess := func(pattern, uri string, expectedPathParams map[string]string) {
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.RouteMethod = "GET"
|
||||
p := newRouterPathMatcher("GET", patternRegexp(pattern), http.NotFound)
|
||||
shouldProcess := expectedPathParams != nil
|
||||
assert.Equal(t, shouldProcess, p.matchPath(chiCtx, uri), "use pattern %s to process uri %s", pattern, uri)
|
||||
assert.Equal(t, expectedPathParams, chiURLParamsToMap(chiCtx), "use pattern %s to process uri %s", pattern, uri)
|
||||
}
|
||||
|
||||
// the "<...>" is intentionally designed to distinguish from chi's path parameters, because:
|
||||
// 1. their behaviors are totally different, we do not want to mislead developers
|
||||
// 2. we can write regexp in "<name:\w{3,4}>" easily and parse it easily
|
||||
testProcess("/<p1>/<p2>", "/a/b", map[string]string{"p1": "a", "p2": "b"})
|
||||
testProcess("/<p1:*>", "", map[string]string{"p1": ""}) // this is a special case, because chi router could use empty path
|
||||
testProcess("/<p1:*>", "/", map[string]string{"p1": ""})
|
||||
testProcess("/<p1:*>/<p2>", "/a", map[string]string{"p1": "", "p2": "a"})
|
||||
testProcess("/<p1:*>/<p2>", "/a/b", map[string]string{"p1": "a", "p2": "b"})
|
||||
testProcess("/<p1:*>/<p2>", "/a/b/c", map[string]string{"p1": "a/b", "p2": "c"})
|
||||
testProcess("/<p1:*>/part/<p2>", "/a/part/c", map[string]string{"p1": "a", "p2": "c"})
|
||||
testProcess("/<p1:*>/part/<p2>", "/part/c", map[string]string{"p1": "", "p2": "c"})
|
||||
testProcess("/<p1:*>/part/<p2>", "/a/other-part/c", nil)
|
||||
testProcess("/<p1:*>-part/<p2>", "/a-other-part/c", map[string]string{"p1": "a-other", "p2": "c"})
|
||||
}
|
||||
|
||||
func TestRouter(t *testing.T) {
|
||||
type resultStruct = testResult
|
||||
resRecorder := &testRecorder{}
|
||||
h := resRecorder.handle
|
||||
stopMark := resRecorder.stop
|
||||
|
||||
r := NewRouter()
|
||||
r.NotFound(h("not-found:/"))
|
||||
r.Get("/{username}/{reponame}/{type:issues|pulls}", h("list-issues-a")) // this one will never be called
|
||||
r.Group("/{username}/{reponame}", func() {
|
||||
r.Get("/{type:issues|pulls}", h("list-issues-b"))
|
||||
r.Group("", func() {
|
||||
r.Get("/{type:issues|pulls}/{index}", h("view-issue"))
|
||||
}, stopMark())
|
||||
r.Group("/issues/{index}", func() {
|
||||
r.Post("/update", h("update-issue"))
|
||||
})
|
||||
})
|
||||
|
||||
m := NewRouter()
|
||||
m.NotFound(h("not-found:/api/v1"))
|
||||
r.Mount("/api/v1", m)
|
||||
m.Group("/repos", func() {
|
||||
m.Group("/{username}/{reponame}", func() {
|
||||
m.Group("/branches", func() {
|
||||
m.Get("", h())
|
||||
m.Post("", h())
|
||||
m.Group("/{name}", func() {
|
||||
m.Get("", h())
|
||||
m.Patch("", h())
|
||||
m.Delete("", h())
|
||||
})
|
||||
m.PathGroup("/*", func(g *RouterPathGroup) {
|
||||
g.MatchPattern("GET", g.PatternRegexp(`/<dir:*>/<file:[a-z]{1,2}>`, stopMark("s2")), stopMark("s3"), h("match-path"))
|
||||
}, stopMark("s1"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
testRoute := func(t *testing.T, methodPath string, expected resultStruct) {
|
||||
t.Run(methodPath, func(t *testing.T) {
|
||||
resRecorder.test(t, r, methodPath, expected)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("RootRouter", func(t *testing.T) {
|
||||
testRoute(t, "GET /the-user/the-repo/other", resultStruct{
|
||||
method: "GET",
|
||||
handlerMarks: []string{"not-found:/"},
|
||||
chiRoutePattern: new(""),
|
||||
})
|
||||
testRoute(t, "GET /the-user/the-repo/pulls", resultStruct{
|
||||
method: "GET",
|
||||
pathParams: map[string]string{"username": "the-user", "reponame": "the-repo", "type": "pulls"},
|
||||
handlerMarks: []string{"list-issues-b"},
|
||||
})
|
||||
testRoute(t, "GET /the-user/the-repo/issues/123", resultStruct{
|
||||
method: "GET",
|
||||
pathParams: map[string]string{"username": "the-user", "reponame": "the-repo", "type": "issues", "index": "123"},
|
||||
handlerMarks: []string{"view-issue"},
|
||||
chiRoutePattern: new("/{username}/{reponame}/{type:issues|pulls}/{index}"),
|
||||
})
|
||||
testRoute(t, "GET /the-user/the-repo/issues/123?stop=hijack", resultStruct{
|
||||
method: "GET",
|
||||
pathParams: map[string]string{"username": "the-user", "reponame": "the-repo", "type": "issues", "index": "123"},
|
||||
handlerMarks: []string{"hijack"},
|
||||
})
|
||||
testRoute(t, "POST /the-user/the-repo/issues/123/update", resultStruct{
|
||||
method: "POST",
|
||||
pathParams: map[string]string{"username": "the-user", "reponame": "the-repo", "index": "123"},
|
||||
handlerMarks: []string{"update-issue"},
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Sub Router", func(t *testing.T) {
|
||||
testRoute(t, "GET /api/v1/other", resultStruct{
|
||||
method: "GET",
|
||||
handlerMarks: []string{"not-found:/api/v1"},
|
||||
})
|
||||
testRoute(t, "GET /api/v1/repos/the-user/the-repo/branches", resultStruct{
|
||||
method: "GET",
|
||||
pathParams: map[string]string{"username": "the-user", "reponame": "the-repo"},
|
||||
})
|
||||
|
||||
testRoute(t, "POST /api/v1/repos/the-user/the-repo/branches", resultStruct{
|
||||
method: "POST",
|
||||
pathParams: map[string]string{"username": "the-user", "reponame": "the-repo"},
|
||||
})
|
||||
|
||||
testRoute(t, "GET /api/v1/repos/the-user/the-repo/branches/master", resultStruct{
|
||||
method: "GET",
|
||||
pathParams: map[string]string{"username": "the-user", "reponame": "the-repo", "name": "master"},
|
||||
})
|
||||
|
||||
testRoute(t, "PATCH /api/v1/repos/the-user/the-repo/branches/master", resultStruct{
|
||||
method: "PATCH",
|
||||
pathParams: map[string]string{"username": "the-user", "reponame": "the-repo", "name": "master"},
|
||||
})
|
||||
|
||||
testRoute(t, "DELETE /api/v1/repos/the-user/the-repo/branches/master", resultStruct{
|
||||
method: "DELETE",
|
||||
pathParams: map[string]string{"username": "the-user", "reponame": "the-repo", "name": "master"},
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("MatchPath", func(t *testing.T) {
|
||||
testRoute(t, "GET /api/v1/repos/the-user/the-repo/branches/d1/d2/fn", resultStruct{
|
||||
method: "GET",
|
||||
pathParams: map[string]string{"username": "the-user", "reponame": "the-repo", "*": "d1/d2/fn", "dir": "d1/d2", "file": "fn"},
|
||||
handlerMarks: []string{"s1", "s2", "s3", "match-path"},
|
||||
})
|
||||
testRoute(t, "GET /api/v1/repos/the-user/the-repo/branches/d1%2fd2/fn", resultStruct{
|
||||
method: "GET",
|
||||
pathParams: map[string]string{"username": "the-user", "reponame": "the-repo", "*": "d1%2fd2/fn", "dir": "d1%2fd2", "file": "fn"},
|
||||
handlerMarks: []string{"s1", "s2", "s3", "match-path"},
|
||||
})
|
||||
testRoute(t, "GET /api/v1/repos/the-user/the-repo/branches/d1/d2/000", resultStruct{
|
||||
method: "GET",
|
||||
pathParams: map[string]string{"reponame": "the-repo", "username": "the-user", "*": "d1/d2/000"},
|
||||
handlerMarks: []string{"s1", "not-found:/api/v1"},
|
||||
})
|
||||
|
||||
testRoute(t, "GET /api/v1/repos/the-user/the-repo/branches/d1/d2/fn?stop=s1", resultStruct{
|
||||
method: "GET",
|
||||
pathParams: map[string]string{"username": "the-user", "reponame": "the-repo", "*": "d1/d2/fn"},
|
||||
handlerMarks: []string{"s1"},
|
||||
})
|
||||
|
||||
testRoute(t, "GET /api/v1/repos/the-user/the-repo/branches/d1/d2/fn?stop=s2", resultStruct{
|
||||
method: "GET",
|
||||
pathParams: map[string]string{"username": "the-user", "reponame": "the-repo", "*": "d1/d2/fn", "dir": "d1/d2", "file": "fn"},
|
||||
handlerMarks: []string{"s1", "s2"},
|
||||
})
|
||||
|
||||
testRoute(t, "GET /api/v1/repos/the-user/the-repo/branches/d1/d2/fn?stop=s3", resultStruct{
|
||||
method: "GET",
|
||||
pathParams: map[string]string{"username": "the-user", "reponame": "the-repo", "*": "d1/d2/fn", "dir": "d1/d2", "file": "fn"},
|
||||
handlerMarks: []string{"s1", "s2", "s3"},
|
||||
chiRoutePattern: new("/api/v1/repos/{username}/{reponame}/branches/<dir:*>/<file:[a-z]{1,2}>"),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestRouteNormalizePath(t *testing.T) {
|
||||
type paths struct {
|
||||
EscapedPath, RawPath, Path string
|
||||
}
|
||||
testPath := func(reqPath string, expectedPaths paths) {
|
||||
recorder := httptest.NewRecorder()
|
||||
recorder.Body = bytes.NewBuffer(nil)
|
||||
|
||||
actualPaths := paths{EscapedPath: "(none)", RawPath: "(none)", Path: "(none)"}
|
||||
r := NewRouter()
|
||||
r.Get("/*", func(resp http.ResponseWriter, req *http.Request) {
|
||||
actualPaths.EscapedPath = req.URL.EscapedPath()
|
||||
actualPaths.RawPath = req.URL.RawPath
|
||||
actualPaths.Path = req.URL.Path
|
||||
})
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, reqPath, nil)
|
||||
assert.NoError(t, err)
|
||||
r.ServeHTTP(recorder, req)
|
||||
assert.Equal(t, expectedPaths, actualPaths, "req path = %q", reqPath)
|
||||
}
|
||||
|
||||
// RawPath could be empty if the EscapedPath is the same as escape(Path) and it is already normalized
|
||||
testPath("/", paths{EscapedPath: "/", RawPath: "", Path: "/"})
|
||||
testPath("//", paths{EscapedPath: "/", RawPath: "/", Path: "/"})
|
||||
testPath("/%2f", paths{EscapedPath: "/%2f", RawPath: "/%2f", Path: "//"})
|
||||
testPath("///a//b/", paths{EscapedPath: "/a/b", RawPath: "/a/b", Path: "/a/b"})
|
||||
|
||||
defer test.MockVariableValue(&setting.UseSubURLPath, true)()
|
||||
defer test.MockVariableValue(&setting.AppSubURL, "/sub-path")()
|
||||
testPath("/", paths{EscapedPath: "(none)", RawPath: "(none)", Path: "(none)"}) // 404
|
||||
testPath("/sub-path", paths{EscapedPath: "/", RawPath: "/", Path: "/"})
|
||||
testPath("/sub-path/", paths{EscapedPath: "/", RawPath: "/", Path: "/"})
|
||||
testPath("/sub-path//a/b///", paths{EscapedPath: "/a/b", RawPath: "/a/b", Path: "/a/b"})
|
||||
testPath("/sub-path/%2f/", paths{EscapedPath: "/%2f", RawPath: "/%2f", Path: "//"})
|
||||
// "/v2" is special for OCI container registry, it should always be in the root of the site
|
||||
testPath("/v2", paths{EscapedPath: "/v2", RawPath: "/v2", Path: "/v2"})
|
||||
testPath("/v2/", paths{EscapedPath: "/v2", RawPath: "/v2", Path: "/v2"})
|
||||
testPath("/v2/%2f", paths{EscapedPath: "/v2/%2f", RawPath: "/v2/%2f", Path: "/v2//"})
|
||||
}
|
||||
|
||||
func TestPreMiddlewareProvider(t *testing.T) {
|
||||
resRecorder := &testRecorder{}
|
||||
h := resRecorder.handle
|
||||
p := resRecorder.provider
|
||||
|
||||
root := NewRouter()
|
||||
root.BeforeRouting(h("before-root"))
|
||||
root.AfterRouting(h("root"))
|
||||
root.Get("/a/1", h("mid"), types.PreMiddlewareProvider(p("pre-root")), h("end1"))
|
||||
|
||||
sub := NewRouter()
|
||||
sub.BeforeRouting(h("before-sub"))
|
||||
sub.AfterRouting(h("sub"))
|
||||
sub.Get("/2", h("mid"), types.PreMiddlewareProvider(p("pre-sub")), h("end2"))
|
||||
sub.NotFound(h("not-found"))
|
||||
|
||||
root.Mount("/a", sub)
|
||||
|
||||
resRecorder.test(t, root, "GET /a/1", testResult{
|
||||
method: "GET",
|
||||
handlerMarks: []string{"before-root", "pre-root", "root", "mid", "end1"},
|
||||
})
|
||||
resRecorder.test(t, root, "GET /a/2", testResult{
|
||||
method: "GET",
|
||||
handlerMarks: []string{"before-root", "root", "before-sub", "pre-sub", "sub", "mid", "end2"},
|
||||
})
|
||||
resRecorder.test(t, root, "GET /no-such", testResult{
|
||||
method: "GET",
|
||||
handlerMarks: []string{"before-root"},
|
||||
})
|
||||
resRecorder.test(t, root, "GET /a/no-such", testResult{
|
||||
method: "GET",
|
||||
handlerMarks: []string{"before-root", "root", "before-sub", "sub", "not-found"},
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,92 @@
|
||||
// Copyright 2021 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package routing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"gitea.dev/modules/gtprof"
|
||||
"gitea.dev/modules/log"
|
||||
"gitea.dev/modules/reqctx"
|
||||
"gitea.dev/modules/web/types"
|
||||
)
|
||||
|
||||
type contextKeyType struct{}
|
||||
|
||||
var contextKey contextKeyType
|
||||
|
||||
func getRequestRecord(ctx context.Context) *requestRecord {
|
||||
record, _ := ctx.Value(contextKey).(*requestRecord)
|
||||
return record
|
||||
}
|
||||
|
||||
// RecordFuncInfo records a func info into context
|
||||
func RecordFuncInfo(ctx context.Context, funcInfo *FuncInfo) (end func()) {
|
||||
end = func() {}
|
||||
if reqCtx := reqctx.FromContext(ctx); reqCtx != nil {
|
||||
var traceSpan *gtprof.TraceSpan
|
||||
traceSpan, end = gtprof.GetTracer().StartInContext(reqCtx, "http.func")
|
||||
traceSpan.SetAttributeString("func", funcInfo.shortName)
|
||||
}
|
||||
if record := getRequestRecord(ctx); record != nil {
|
||||
record.lock.Lock()
|
||||
record.funcInfo = funcInfo
|
||||
record.lock.Unlock()
|
||||
}
|
||||
return end
|
||||
}
|
||||
|
||||
func GetRequestRecordInfo(reqCtx context.Context) (ret struct {
|
||||
HasRecord bool
|
||||
IsLongPolling bool
|
||||
},
|
||||
) {
|
||||
record := getRequestRecord(reqCtx)
|
||||
if record == nil {
|
||||
return ret
|
||||
}
|
||||
ret.HasRecord = true
|
||||
record.lock.RLock()
|
||||
ret.IsLongPolling = record.isLongPolling
|
||||
record.lock.RUnlock()
|
||||
return ret
|
||||
}
|
||||
|
||||
// MarkLongPolling marks the request is a long-polling request, and the logger may output different message for it
|
||||
func MarkLongPolling() types.PreMiddlewareProvider {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
record := getRequestRecord(req.Context()) // it must exist
|
||||
record.lock.Lock()
|
||||
record.isLongPolling = true
|
||||
record.logLevel = log.TRACE
|
||||
record.lock.Unlock()
|
||||
next.ServeHTTP(w, req)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func MarkLogLevelTrace(resp http.ResponseWriter, req *http.Request) {
|
||||
record := getRequestRecord(req.Context())
|
||||
if record == nil {
|
||||
return
|
||||
}
|
||||
|
||||
record.lock.Lock()
|
||||
record.logLevel = log.TRACE
|
||||
record.lock.Unlock()
|
||||
}
|
||||
|
||||
// UpdatePanicError updates a context's error info, a panic may be recovered by other middlewares, but we still need to know that.
|
||||
func UpdatePanicError(ctx context.Context, err error) {
|
||||
record := getRequestRecord(ctx)
|
||||
if record == nil {
|
||||
return
|
||||
}
|
||||
|
||||
record.lock.Lock()
|
||||
record.panicError = err
|
||||
record.lock.Unlock()
|
||||
}
|
||||
@@ -0,0 +1,172 @@
|
||||
// Copyright 2021 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package routing
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
funcInfoMap = map[uintptr]*FuncInfo{}
|
||||
funcInfoNameMap = map[string]*FuncInfo{}
|
||||
funcInfoMapMu sync.RWMutex
|
||||
)
|
||||
|
||||
// FuncInfo contains information about the function to be logged by the router log
|
||||
type FuncInfo struct {
|
||||
file string
|
||||
shortFile string
|
||||
line int
|
||||
name string
|
||||
shortName string
|
||||
}
|
||||
|
||||
// String returns a string form of the FuncInfo for logging
|
||||
func (info *FuncInfo) String() string {
|
||||
if info == nil {
|
||||
return "unknown-handler"
|
||||
}
|
||||
return fmt.Sprintf("%s:%d(%s)", info.shortFile, info.line, info.shortName)
|
||||
}
|
||||
|
||||
// GetFuncInfo returns the FuncInfo for a provided function and friendlyname
|
||||
func GetFuncInfo(fn any, friendlyName ...string) *FuncInfo {
|
||||
// ptr represents the memory position of the function passed in as v.
|
||||
// This will be used as program counter in FuncForPC below
|
||||
ptr := reflect.ValueOf(fn).Pointer()
|
||||
|
||||
// if we have been provided with a friendlyName look for the named funcs
|
||||
if len(friendlyName) == 1 {
|
||||
name := friendlyName[0]
|
||||
funcInfoMapMu.RLock()
|
||||
info, ok := funcInfoNameMap[name]
|
||||
funcInfoMapMu.RUnlock()
|
||||
if ok {
|
||||
return info
|
||||
}
|
||||
}
|
||||
|
||||
// Otherwise attempt to get pre-cached information for this function pointer
|
||||
funcInfoMapMu.RLock()
|
||||
info, ok := funcInfoMap[ptr]
|
||||
funcInfoMapMu.RUnlock()
|
||||
|
||||
if ok {
|
||||
if len(friendlyName) == 1 {
|
||||
name := friendlyName[0]
|
||||
info = copyFuncInfo(info)
|
||||
info.shortName = name
|
||||
|
||||
funcInfoNameMap[name] = info
|
||||
funcInfoMapMu.Lock()
|
||||
funcInfoNameMap[name] = info
|
||||
funcInfoMapMu.Unlock()
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
// This is likely the first time we have seen this function
|
||||
//
|
||||
// Get the runtime.func for this function (if we can)
|
||||
f := runtime.FuncForPC(ptr)
|
||||
if f != nil {
|
||||
info = convertToFuncInfo(f)
|
||||
|
||||
// cache this info globally
|
||||
funcInfoMapMu.Lock()
|
||||
funcInfoMap[ptr] = info
|
||||
|
||||
// if we have been provided with a friendlyName override the short name we've generated
|
||||
if len(friendlyName) == 1 {
|
||||
name := friendlyName[0]
|
||||
info = copyFuncInfo(info)
|
||||
info.shortName = name
|
||||
funcInfoNameMap[name] = info
|
||||
}
|
||||
funcInfoMapMu.Unlock()
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
// convertToFuncInfo take a runtime.Func and convert it to a logFuncInfo, fill in shorten filename, etc
|
||||
func convertToFuncInfo(f *runtime.Func) *FuncInfo {
|
||||
file, line := f.FileLine(f.Entry())
|
||||
|
||||
info := &FuncInfo{
|
||||
file: strings.ReplaceAll(file, "\\", "/"),
|
||||
line: line,
|
||||
name: f.Name(),
|
||||
}
|
||||
|
||||
// only keep last 2 names in path, fall back to funcName if not
|
||||
info.shortFile = shortenFilename(info.file, info.name)
|
||||
|
||||
// remove package prefix. eg: "xxx.com/pkg1/pkg2.foo" => "pkg2.foo"
|
||||
pos := strings.LastIndexByte(info.name, '/')
|
||||
if pos >= 0 {
|
||||
info.shortName = info.name[pos+1:]
|
||||
} else {
|
||||
info.shortName = info.name
|
||||
}
|
||||
|
||||
// remove ".func[0-9]*" suffix for anonymous func
|
||||
info.shortName = trimAnonymousFunctionSuffix(info.shortName)
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
func copyFuncInfo(l *FuncInfo) *FuncInfo {
|
||||
return &FuncInfo{
|
||||
file: l.file,
|
||||
shortFile: l.shortFile,
|
||||
line: l.line,
|
||||
name: l.name,
|
||||
shortName: l.shortName,
|
||||
}
|
||||
}
|
||||
|
||||
// shortenFilename generates a short source code filename from a full package path, eg: "gitea.dev/routers/common/logger_context.go" => "common/logger_context.go"
|
||||
func shortenFilename(filename, fallback string) string {
|
||||
if filename == "" {
|
||||
return fallback
|
||||
}
|
||||
if lastIndex := strings.LastIndexByte(filename, '/'); lastIndex >= 0 {
|
||||
if secondLastIndex := strings.LastIndexByte(filename[:lastIndex], '/'); secondLastIndex >= 0 {
|
||||
return filename[secondLastIndex+1:]
|
||||
}
|
||||
}
|
||||
return filename
|
||||
}
|
||||
|
||||
// trimAnonymousFunctionSuffix trims ".func[0-9]*" from the end of anonymous function names, we only want to see the main function names in logs
|
||||
func trimAnonymousFunctionSuffix(name string) string {
|
||||
// if the name is an anonymous name, it should be like "{main-function}.func1", so the length can not be less than 7
|
||||
if len(name) < 7 {
|
||||
return name
|
||||
}
|
||||
|
||||
funcSuffixIndex := strings.LastIndex(name, ".func")
|
||||
if funcSuffixIndex < 0 {
|
||||
return name
|
||||
}
|
||||
|
||||
hasFuncSuffix := true
|
||||
|
||||
// len(".func") = 5
|
||||
for i := funcSuffixIndex + 5; i < len(name); i++ {
|
||||
if name[i] < '0' || name[i] > '9' {
|
||||
hasFuncSuffix = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if hasFuncSuffix {
|
||||
return name[:funcSuffixIndex]
|
||||
}
|
||||
return name
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
// Copyright 2021 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package routing
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_shortenFilename(t *testing.T) {
|
||||
tests := []struct {
|
||||
filename string
|
||||
fallback string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
"gitea.dev/routers/common/logger_context.go",
|
||||
"NO_FALLBACK",
|
||||
"common/logger_context.go",
|
||||
},
|
||||
{
|
||||
"common/logger_context.go",
|
||||
"NO_FALLBACK",
|
||||
"common/logger_context.go",
|
||||
},
|
||||
{
|
||||
"logger_context.go",
|
||||
"NO_FALLBACK",
|
||||
"logger_context.go",
|
||||
},
|
||||
{
|
||||
"",
|
||||
"USE_FALLBACK",
|
||||
"USE_FALLBACK",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(fmt.Sprintf("shortenFilename('%s')", tt.filename), func(t *testing.T) {
|
||||
gotShort := shortenFilename(tt.filename, tt.fallback)
|
||||
assert.Equal(t, tt.expected, gotShort)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_trimAnonymousFunctionSuffix(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
"notAnonymous",
|
||||
"notAnonymous",
|
||||
},
|
||||
{
|
||||
"anonymous.func1",
|
||||
"anonymous",
|
||||
},
|
||||
{
|
||||
"notAnonymous.funca",
|
||||
"notAnonymous.funca",
|
||||
},
|
||||
{
|
||||
"anonymous.func100",
|
||||
"anonymous",
|
||||
},
|
||||
{
|
||||
"anonymous.func100.func6",
|
||||
"anonymous.func100",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := trimAnonymousFunctionSuffix(tt.name)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
// Copyright 2021 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package routing
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"gitea.dev/modules/log"
|
||||
"gitea.dev/modules/web/types"
|
||||
)
|
||||
|
||||
var (
|
||||
startMessage = log.NewColoredValue("started ", log.DEBUG.ColorAttributes()...)
|
||||
slowMessage = log.NewColoredValue("slow ", log.WARN.ColorAttributes()...)
|
||||
pollingMessage = log.NewColoredValue("polling ", log.INFO.ColorAttributes()...)
|
||||
failedMessage = log.NewColoredValue("failed ", log.WARN.ColorAttributes()...)
|
||||
completedMessage = log.NewColoredValue("completed", log.INFO.ColorAttributes()...)
|
||||
unknownHandlerMessage = log.NewColoredValue("completed", log.ERROR.ColorAttributes()...)
|
||||
)
|
||||
|
||||
func logPrinter(logger log.Logger) func(trigger Event, record *requestRecord) {
|
||||
const callerName = "HTTPRequest"
|
||||
logRequest := func(level log.Level, fmt string, args ...any) {
|
||||
logger.Log(2, &log.Event{Level: level, Caller: callerName}, fmt, args...)
|
||||
}
|
||||
return func(trigger Event, record *requestRecord) {
|
||||
if trigger == StartEvent {
|
||||
if !logger.LevelEnabled(log.TRACE) {
|
||||
// for performance, if the "started" message shouldn't be logged, we just return as early as possible
|
||||
// developers can set the router log level to TRACE to get the "started" request messages.
|
||||
return
|
||||
}
|
||||
// when a request starts, we have no information about the handler function information, we only have the request path
|
||||
req := record.request
|
||||
logRequest(log.TRACE, "router: %s %v %s for %s", startMessage, log.ColoredMethod(req.Method), req.RequestURI, req.RemoteAddr)
|
||||
return
|
||||
}
|
||||
|
||||
req := record.request
|
||||
|
||||
// Get data from the record
|
||||
record.lock.Lock()
|
||||
handlerFuncInfo := record.funcInfo.String()
|
||||
isLongPolling := record.isLongPolling
|
||||
isUnknownHandler := record.funcInfo == nil
|
||||
panicErr := record.panicError
|
||||
record.lock.Unlock()
|
||||
|
||||
if trigger == StillExecutingEvent {
|
||||
message := slowMessage
|
||||
logLevel := log.WARN
|
||||
if isLongPolling {
|
||||
logLevel = log.INFO
|
||||
message = pollingMessage
|
||||
}
|
||||
logRequest(logLevel, "router: %s %v %s for %s, elapsed %v @ %s",
|
||||
message,
|
||||
log.ColoredMethod(req.Method), req.RequestURI, req.RemoteAddr,
|
||||
log.ColoredTime(time.Since(record.startTime)),
|
||||
handlerFuncInfo,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if panicErr != nil {
|
||||
logRequest(log.WARN, "router: %s %v %s for %s, panic in %v @ %s, err=%v",
|
||||
failedMessage,
|
||||
log.ColoredMethod(req.Method), req.RequestURI, req.RemoteAddr,
|
||||
log.ColoredTime(time.Since(record.startTime)),
|
||||
handlerFuncInfo,
|
||||
panicErr,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
var status int
|
||||
if v, ok := record.respWriter.(types.ResponseStatusProvider); ok {
|
||||
status = v.WrittenStatus()
|
||||
}
|
||||
logLevel := record.logLevel
|
||||
if logLevel == log.UNDEFINED {
|
||||
logLevel = log.INFO
|
||||
}
|
||||
// lower the log level for some specific requests, in most cases these logs are not useful
|
||||
if status > 0 && status < 400 &&
|
||||
req.RequestURI == "/api/actions/runner.v1.RunnerService/FetchTask" /* Actions Runner polling */ {
|
||||
logLevel = log.TRACE
|
||||
}
|
||||
message := completedMessage
|
||||
if isUnknownHandler {
|
||||
logLevel = log.ERROR
|
||||
message = unknownHandlerMessage
|
||||
}
|
||||
|
||||
logRequest(logLevel, "router: %s %v %s for %s, %v %v in %v @ %s",
|
||||
message,
|
||||
log.ColoredMethod(req.Method), req.RequestURI, req.RemoteAddr,
|
||||
log.ColoredStatus(status), log.ColoredStatus(status, http.StatusText(status)), log.ColoredTime(time.Since(record.startTime)),
|
||||
handlerFuncInfo,
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,90 @@
|
||||
// Copyright 2021 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package routing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gitea.dev/modules/graceful"
|
||||
"gitea.dev/modules/log"
|
||||
"gitea.dev/modules/process"
|
||||
)
|
||||
|
||||
// Event indicates when the printer is triggered
|
||||
type Event int
|
||||
|
||||
const (
|
||||
// StartEvent at the beginning of a request
|
||||
StartEvent Event = iota
|
||||
|
||||
// StillExecutingEvent the request is still executing
|
||||
StillExecutingEvent
|
||||
|
||||
// EndEvent the request has ended (either completed or failed)
|
||||
EndEvent
|
||||
)
|
||||
|
||||
// logPrinterFunc is used to output the log for a request
|
||||
type logPrinterFunc func(trigger Event, record *requestRecord)
|
||||
|
||||
type loggerRequestManager struct {
|
||||
logPrint logPrinterFunc
|
||||
reqRecords sync.Map // it only contains the active requests which haven't been detected as "slow"
|
||||
}
|
||||
|
||||
func (manager *loggerRequestManager) startSlowQueryDetector(threshold time.Duration) {
|
||||
go graceful.GetManager().RunWithShutdownContext(func(ctx context.Context) {
|
||||
ctx, _, finished := process.GetManager().AddTypedContext(ctx, "Service: SlowQueryDetector", process.SystemProcessType, true)
|
||||
defer finished()
|
||||
// This go-routine checks all active requests every second.
|
||||
// If a request has been running for a long time (eg: /user/events), we also print a log with "still-executing" message
|
||||
// After the "still-executing" log is printed, the record will be removed from the map to prevent from duplicated logs in future
|
||||
// We do not care about accurate duration here. It just does the check periodically, 0.5s or 1.5s are all OK.
|
||||
t := time.NewTicker(time.Second)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-t.C:
|
||||
now := time.Now()
|
||||
|
||||
// print logs for slow requests
|
||||
manager.reqRecords.Range(func(key, value any) bool {
|
||||
index, record := key.(uint64), value.(*requestRecord)
|
||||
if now.Sub(record.startTime) >= threshold {
|
||||
manager.logPrint(StillExecutingEvent, record)
|
||||
manager.reqRecords.Delete(index)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (manager *loggerRequestManager) handleRequestRecord(record *requestRecord) func() {
|
||||
manager.reqRecords.Store(record.index, record)
|
||||
manager.logPrint(StartEvent, record)
|
||||
|
||||
return func() {
|
||||
// just in case there is a panic. now the panics are all recovered in middleware.go
|
||||
localPanicErr := recover()
|
||||
if localPanicErr != nil {
|
||||
record.lock.Lock()
|
||||
record.panicError = fmt.Errorf("%v\n%s", localPanicErr, log.Stack(2))
|
||||
record.lock.Unlock()
|
||||
}
|
||||
|
||||
manager.reqRecords.Delete(record.index)
|
||||
manager.logPrint(EndEvent, record)
|
||||
|
||||
if localPanicErr != nil {
|
||||
// the panic wasn't recovered before us, so we should pass it up, and let the framework handle the panic error
|
||||
panic(localPanicErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
// Copyright 2026 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package routing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"gitea.dev/modules/log"
|
||||
"gitea.dev/modules/setting"
|
||||
)
|
||||
|
||||
// NewRequestInfoHandler is a handler that saves request info into request context.
|
||||
// If router logger is enabled, it will also print request logs and detect slow requests.
|
||||
func NewRequestInfoHandler() func(next http.Handler) http.Handler {
|
||||
var reqLogger *loggerRequestManager
|
||||
if setting.IsRouteLogEnabled() {
|
||||
reqLogger = &loggerRequestManager{
|
||||
logPrint: logPrinter(log.GetLogger("router")),
|
||||
}
|
||||
reqLogger.startSlowQueryDetector(3 * time.Second)
|
||||
}
|
||||
var requestCounter atomic.Uint64
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
record := &requestRecord{
|
||||
index: requestCounter.Add(1),
|
||||
startTime: time.Now(),
|
||||
respWriter: w,
|
||||
}
|
||||
req = req.WithContext(context.WithValue(req.Context(), contextKey, record))
|
||||
record.request = req
|
||||
if reqLogger != nil {
|
||||
end := reqLogger.handleRequestRecord(record)
|
||||
defer end()
|
||||
}
|
||||
next.ServeHTTP(w, req)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
// Copyright 2021 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package routing
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gitea.dev/modules/log"
|
||||
)
|
||||
|
||||
type requestRecord struct {
|
||||
// immutable fields
|
||||
index uint64 // unique number (per process) for the request
|
||||
startTime time.Time
|
||||
request *http.Request
|
||||
respWriter http.ResponseWriter
|
||||
|
||||
// mutex
|
||||
lock sync.RWMutex
|
||||
|
||||
// below are mutable fields
|
||||
funcInfo *FuncInfo
|
||||
// * for "mark as long polling"
|
||||
isLongPolling bool
|
||||
// * for router logger
|
||||
logLevel log.Level
|
||||
panicError error
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
// Copyright 2026 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package types
|
||||
|
||||
import "net/http"
|
||||
|
||||
// PreMiddlewareProvider is a special middleware provider which will be executed
|
||||
// before other middlewares on the same "routing" level (AfterRouting/Group/Methods/Any, but not BeforeRouting).
|
||||
// A route can do something (e.g.: set middleware options) at the place where it is declared,
|
||||
// and the code will be executed before other middlewares which are added before the declaration.
|
||||
// Use cases: mark a route with some meta info, set some options for middlewares, etc.
|
||||
type PreMiddlewareProvider func(next http.Handler) http.Handler
|
||||
@@ -0,0 +1,10 @@
|
||||
// Copyright 2023 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package types
|
||||
|
||||
// ResponseStatusProvider is an interface to get the written status in the response
|
||||
// Many packages need this interface, so put it in the separate package to avoid import cycle
|
||||
type ResponseStatusProvider interface {
|
||||
WrittenStatus() int
|
||||
}
|
||||
Reference in New Issue
Block a user