初始提交: Gitea 项目代码

This commit is contained in:
root
2026-05-30 22:47:36 +08:00
commit f288f76350
6116 changed files with 776822 additions and 0 deletions
+190
View File
@@ -0,0 +1,190 @@
// Copyright 2023 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db
import (
"errors"
"fmt"
"strings"
"gitea.dev/modules/container"
"gitea.dev/modules/log"
"gitea.dev/modules/setting"
"xorm.io/xorm/schemas"
)
type CheckCollationsResult struct {
ExpectedCollation string
AvailableCollation container.Set[string]
DatabaseCollation string
IsCollationCaseSensitive func(s string) bool
CollationEquals func(a, b string) bool
ExistingTableNumber int
InconsistentCollationColumns []string
}
func findAvailableCollationsMySQL(x EngineMigration) (ret container.Set[string], err error) {
var res []struct {
Collation string
}
if err = x.SQL("SHOW COLLATION WHERE (Collation = 'utf8mb4_bin') OR (Collation LIKE '%\\_as\\_cs%')").Find(&res); err != nil {
return nil, err
}
ret = make(container.Set[string], len(res))
for _, r := range res {
ret.Add(r.Collation)
}
return ret, nil
}
func findAvailableCollationsMSSQL(x EngineMigration) (ret container.Set[string], err error) {
var res []struct {
Name string
}
if err = x.SQL("SELECT * FROM sys.fn_helpcollations() WHERE name LIKE '%[_]CS[_]AS%'").Find(&res); err != nil {
return nil, err
}
ret = make(container.Set[string], len(res))
for _, r := range res {
ret.Add(r.Name)
}
return ret, nil
}
func CheckCollations(x EngineMigration) (*CheckCollationsResult, error) {
dbTables, err := x.DBMetas()
if err != nil {
return nil, err
}
res := &CheckCollationsResult{
ExistingTableNumber: len(dbTables),
CollationEquals: func(a, b string) bool { return a == b },
}
var candidateCollations []string
if x.Dialect().URI().DBType == schemas.MYSQL {
_, err = x.SQL("SELECT DEFAULT_COLLATION_NAME FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = ?", setting.Database.Name).Get(&res.DatabaseCollation)
if err != nil {
return nil, err
}
res.IsCollationCaseSensitive = func(s string) bool {
return s == "utf8mb4_bin" || strings.HasSuffix(s, "_as_cs")
}
candidateCollations = []string{"utf8mb4_0900_as_cs", "uca1400_as_cs", "utf8mb4_bin"}
res.AvailableCollation, err = findAvailableCollationsMySQL(x)
if err != nil {
return nil, err
}
res.CollationEquals = func(a, b string) bool {
// MariaDB adds the "utf8mb4_" prefix, eg: "utf8mb4_uca1400_as_cs", but not the name "uca1400_as_cs" in "SHOW COLLATION"
// At the moment, it's safe to ignore the database difference, just trim the prefix and compare. It could be fixed easily if there is any problem in the future.
return a == b || strings.TrimPrefix(a, "utf8mb4_") == strings.TrimPrefix(b, "utf8mb4_")
}
} else if x.Dialect().URI().DBType == schemas.MSSQL {
if _, err = x.SQL("SELECT DATABASEPROPERTYEX(DB_NAME(), 'Collation')").Get(&res.DatabaseCollation); err != nil {
return nil, err
}
res.IsCollationCaseSensitive = func(s string) bool {
return strings.HasSuffix(s, "_CS_AS")
}
candidateCollations = []string{"Latin1_General_CS_AS"}
res.AvailableCollation, err = findAvailableCollationsMSSQL(x)
if err != nil {
return nil, err
}
} else {
return nil, nil //nolint:nilnil // return nil for unsupported database types
}
if res.DatabaseCollation == "" {
return nil, errors.New("unable to get collation for current database")
}
res.ExpectedCollation = setting.Database.CharsetCollation
if res.ExpectedCollation == "" {
for _, collation := range candidateCollations {
if res.AvailableCollation.Contains(collation) {
res.ExpectedCollation = collation
break
}
}
}
if res.ExpectedCollation == "" {
return nil, errors.New("unable to find a suitable collation for current database")
}
allColumnsMatchExpected := true
allColumnsMatchDatabase := true
for _, table := range dbTables {
for _, col := range table.Columns() {
if col.Collation != "" {
allColumnsMatchExpected = allColumnsMatchExpected && res.CollationEquals(col.Collation, res.ExpectedCollation)
allColumnsMatchDatabase = allColumnsMatchDatabase && res.CollationEquals(col.Collation, res.DatabaseCollation)
if !res.IsCollationCaseSensitive(col.Collation) || !res.CollationEquals(col.Collation, res.DatabaseCollation) {
res.InconsistentCollationColumns = append(res.InconsistentCollationColumns, fmt.Sprintf("%s.%s", table.Name, col.Name))
}
}
}
}
// if all columns match expected collation or all match database collation, then it could also be considered as "consistent"
if allColumnsMatchExpected || allColumnsMatchDatabase {
res.InconsistentCollationColumns = nil
}
return res, nil
}
func CheckCollationsDefaultEngine() (*CheckCollationsResult, error) {
return CheckCollations(xormEngine)
}
func alterDatabaseCollation(x EngineMigration, collation string) error {
if x.Dialect().URI().DBType == schemas.MYSQL {
_, err := x.Exec("ALTER DATABASE CHARACTER SET utf8mb4 COLLATE " + collation)
return err
} else if x.Dialect().URI().DBType == schemas.MSSQL {
// TODO: MSSQL has many limitations on changing database collation, it could fail in many cases.
_, err := x.Exec("ALTER DATABASE CURRENT COLLATE " + collation)
return err
}
return errors.New("unsupported database type")
}
// preprocessDatabaseCollation checks database & table column collation, and alter the database collation if needed
func preprocessDatabaseCollation(x EngineMigration) {
r, err := CheckCollations(x)
if err != nil {
log.Error("Failed to check database collation: %v", err)
}
if r == nil {
return // no check result means the database doesn't need to do such check/process (at the moment ....)
}
// try to alter database collation to expected if the database is empty, it might fail in some cases (and it isn't necessary to succeed)
// at the moment, there is no "altering" solution for MSSQL, site admin should manually change the database collation
if !r.CollationEquals(r.DatabaseCollation, r.ExpectedCollation) && r.ExistingTableNumber == 0 {
if err = alterDatabaseCollation(x, r.ExpectedCollation); err != nil {
log.Error("Failed to change database collation to %q: %v", r.ExpectedCollation, err)
} else {
_, _ = x.Exec("SELECT 1") // after "altering", MSSQL's session becomes invalid, so make a simple query to "refresh" the session
if r, err = CheckCollations(x); err != nil {
log.Error("Failed to check database collation again after altering: %v", err) // impossible case
return
}
log.Warn("Current database has been altered to use collation %q", r.DatabaseCollation)
}
}
// check column collation, and show warning/error to end users -- no need to fatal, do not block the startup
if !r.IsCollationCaseSensitive(r.DatabaseCollation) {
log.Warn("Current database is using a case-insensitive collation %q, although Gitea could work with it, there might be some rare cases which don't work as expected.", r.DatabaseCollation)
}
if len(r.InconsistentCollationColumns) > 0 {
log.Error("There are %d table columns using inconsistent collation, they should use %q. Please go to admin panel Self Check page", len(r.InconsistentCollationColumns), r.DatabaseCollation)
}
}
+55
View File
@@ -0,0 +1,55 @@
// Copyright 2022 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db
import (
"strings"
"gitea.dev/modules/setting"
"gitea.dev/modules/util"
"xorm.io/builder"
)
// BuildCaseInsensitiveLike returns a case-insensitive LIKE condition for the given key and value.
// Cast the search value and the database column value to the same case for case-insensitive matching.
// * SQLite: only cast ASCII chars because it doesn't handle complete Unicode case folding
// * Other databases: use database's string function, assuming that they are able to handle complete Unicode case folding correctly
func BuildCaseInsensitiveLike(key, value string) builder.Cond {
// ToLowerASCII is about 7% faster than ToUpperASCII (according to Golang's benchmark)
if setting.Database.Type.IsSQLite3() {
return builder.Like{"LOWER(" + key + ")", util.ToLowerASCII(value)}
}
return builder.Like{"LOWER(" + key + ")", strings.ToLower(value)}
}
// BuildCaseInsensitiveIn returns a condition to check if the given value is in the given values case-insensitively.
// See BuildCaseInsensitiveLike for more details
func BuildCaseInsensitiveIn(key string, values []string) builder.Cond {
incaseValues := make([]string, len(values))
caseCast := strings.ToLower
if setting.Database.Type.IsSQLite3() {
caseCast = util.ToLowerASCII
}
for i, value := range values {
incaseValues[i] = caseCast(value)
}
return builder.In("LOWER("+key+")", incaseValues)
}
// BuilderDialect returns the xorm.Builder dialect of the engine
func BuilderDialect() string {
switch {
case setting.Database.Type.IsMySQL():
return builder.MYSQL
case setting.Database.Type.IsSQLite3():
return builder.SQLITE
case setting.Database.Type.IsPostgreSQL():
return builder.POSTGRES
case setting.Database.Type.IsMSSQL():
return builder.MSSQL
default:
return ""
}
}
+187
View File
@@ -0,0 +1,187 @@
// Copyright 2026 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db
import (
"errors"
"fmt"
"net"
"net/url"
"os"
"path/filepath"
"slices"
"strings"
"gitea.dev/modules/setting"
"gitea.dev/modules/util"
)
type ConnOptions struct {
Type setting.DatabaseType
Host string
Database string
User string
Passwd string
Schema string
SSLMode string
SQLitePath string
SQLiteBusyTimeout int
SQLiteJournalMode string
}
type SQLiteConnStrOptions struct {
FilePath string
// how long a concurrent query can wait for others (milliseconds),
// if timeout is reached, the error is something like "database is locked (SQLITE_BUSY)"
BusyTimeout int
JournalMode string
}
func GlobalConnOptions() ConnOptions {
return ConnOptions{
Type: setting.Database.Type,
Host: setting.Database.Host,
Database: setting.Database.Name,
User: setting.Database.User,
Passwd: setting.Database.Passwd,
Schema: setting.Database.Schema,
SSLMode: setting.Database.SSLMode,
SQLitePath: setting.Database.Path,
SQLiteBusyTimeout: setting.Database.SQLiteBusyTimeout,
SQLiteJournalMode: setting.Database.SQLiteJournalMode,
}
}
const (
sqlDriverPostgresSchema = "postgresschema"
sqlDriverSQLite3 = "sqlite3" // although database type also has "sqlite3", they are different, for different purposes
)
var makeSQLiteConnStr = func(opts SQLiteConnStrOptions) (string, string, error) {
return "", "", errors.New(`this Gitea binary was not built with SQLite3 support, get an official release or rebuild with correct "-tags"`)
}
func registerSQLiteConnStrMaker(fn func(opts SQLiteConnStrOptions) (string, string, error)) {
if slices.Contains(setting.SupportedDatabaseTypes, setting.DatabaseTypeSQLite3) {
panic("another sqlite3 driver has been registered")
}
setting.SupportedDatabaseTypes = append(setting.SupportedDatabaseTypes, setting.DatabaseTypeSQLite3)
makeSQLiteConnStr = fn
}
func ConnStrDefaultDatabase(opts ConnOptions) (string, string, error) {
opts.Database, opts.Schema = "", ""
return ConnStr(opts)
}
func ConnStr(opts ConnOptions) (string, string, error) {
switch {
case opts.Type.IsMySQL():
// use unix socket or tcp socket
connType := util.Iif(strings.HasPrefix(opts.Host, "/"), "unix", "tcp")
// allow (Postgres-inspired) default value to work in MySQL
tls := util.Iif(opts.SSLMode == "disable", "false", opts.SSLMode)
// in case the database name is a partial connection string which contains "?" parameters
paramSep := util.Iif(strings.Contains(opts.Database, "?"), "&", "?")
connStr := fmt.Sprintf("%s:%s@%s(%s)/%s%sparseTime=true&tls=%s", opts.User, opts.Passwd, connType, opts.Host, opts.Database, paramSep, tls)
return "mysql", connStr, nil
case opts.Type.IsPostgreSQL():
connStr := makePgSQLConnStr(opts.Host, opts.User, opts.Passwd, opts.Database, opts.SSLMode)
driver := util.Iif(opts.Schema == "", "postgres", sqlDriverPostgresSchema)
registerPostgresSchemaDriver()
return driver, connStr, nil
case opts.Type.IsMSSQL():
host, port := parseMSSQLHostPort(opts.Host)
connStr := fmt.Sprintf("server=%s; port=%s; user id=%s; password=%s;", host, port, opts.User, opts.Passwd)
if opts.Database != "" {
connStr += "; database=" + opts.Database
}
return "mssql", connStr, nil
case opts.Type.IsSQLite3():
if opts.SQLitePath == "" {
return "", "", errors.New("sqlite3 database path cannot be empty")
}
if err := os.MkdirAll(filepath.Dir(opts.SQLitePath), os.ModePerm); err != nil {
return "", "", fmt.Errorf("failed to create directories: %w", err)
}
return makeSQLiteConnStr(SQLiteConnStrOptions{
FilePath: opts.SQLitePath,
JournalMode: opts.SQLiteJournalMode,
BusyTimeout: opts.SQLiteBusyTimeout,
})
}
return "", "", fmt.Errorf("unknown database type: %s", opts.Type)
}
// parsePgSQLHostPort parses given input in various forms defined in
// https://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING
// and returns proper host and port number.
func parsePgSQLHostPort(info string) (host, port string) {
if h, p, err := net.SplitHostPort(info); err == nil {
host, port = h, p
} else {
// treat the "info" as "host", if it's an IPv6 address, remove the wrapper
host = info
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
host = host[1 : len(host)-1]
}
}
// set fallback values
if host == "" {
host = "127.0.0.1"
}
if port == "" {
port = "5432"
}
return host, port
}
func makePgSQLConnStr(dbHost, dbUser, dbPasswd, dbName, dbsslMode string) (connStr string) {
dbName, dbParam, _ := strings.Cut(dbName, "?")
host, port := parsePgSQLHostPort(dbHost)
connURL := url.URL{
Scheme: "postgres",
User: url.UserPassword(dbUser, dbPasswd),
Host: net.JoinHostPort(host, port),
Path: dbName,
OmitHost: false,
RawQuery: dbParam,
}
query := connURL.Query()
if strings.HasPrefix(host, "/") { // looks like a unix socket
query.Add("host", host)
connURL.Host = ":" + port
}
query.Set("sslmode", dbsslMode)
connURL.RawQuery = query.Encode()
return connURL.String()
}
// parseMSSQLHostPort splits the host into host and port
func parseMSSQLHostPort(info string) (string, string) {
// the default port "0" might be related to MSSQL's dynamic port, maybe it should be double-confirmed in the future
host, port := "127.0.0.1", "0"
if strings.Contains(info, ":") {
host = strings.Split(info, ":")[0]
port = strings.Split(info, ":")[1]
} else if strings.Contains(info, ",") {
host = strings.Split(info, ",")[0]
port = strings.TrimSpace(strings.Split(info, ",")[1])
} else if len(info) > 0 {
host = info
}
if host == "" {
host = "127.0.0.1"
}
if port == "" {
port = "0"
}
return host, port
}
+109
View File
@@ -0,0 +1,109 @@
// Copyright 2019 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestParsePgSQLHostPort(t *testing.T) {
tests := map[string]struct {
HostPort string
Host string
Port string
}{
"host-port": {
HostPort: "127.0.0.1:1234",
Host: "127.0.0.1",
Port: "1234",
},
"no-port": {
HostPort: "127.0.0.1",
Host: "127.0.0.1",
Port: "5432",
},
"ipv6-port": {
HostPort: "[::1]:1234",
Host: "::1",
Port: "1234",
},
"ipv6-no-port": {
HostPort: "[::1]",
Host: "::1",
Port: "5432",
},
"unix-socket": {
HostPort: "/tmp/pg.sock:1234",
Host: "/tmp/pg.sock",
Port: "1234",
},
"unix-socket-no-port": {
HostPort: "/tmp/pg.sock",
Host: "/tmp/pg.sock",
Port: "5432",
},
}
for k, test := range tests {
t.Run(k, func(t *testing.T) {
t.Log(test.HostPort)
host, port := parsePgSQLHostPort(test.HostPort)
assert.Equal(t, test.Host, host)
assert.Equal(t, test.Port, port)
})
}
}
func TestMakePgSQLConnStr(t *testing.T) {
tests := []struct {
Host string
User string
Passwd string
Name string
SSLMode string
Output string
}{
{
Host: "", // empty means default
Output: "postgres://:@127.0.0.1:5432?sslmode=",
},
{
Host: "/tmp/pg.sock",
User: "testuser",
Passwd: "space space !#$%^^%^```-=?=",
Name: "gitea",
SSLMode: "false",
Output: "postgres://testuser:space%20space%20%21%23$%25%5E%5E%25%5E%60%60%60-=%3F=@:5432/gitea?host=%2Ftmp%2Fpg.sock&sslmode=false",
},
{
Host: "/tmp/pg.sock:6432",
User: "testuser",
Passwd: "pass",
Name: "gitea",
SSLMode: "false",
Output: "postgres://testuser:pass@:6432/gitea?host=%2Ftmp%2Fpg.sock&sslmode=false",
},
{
Host: "localhost",
User: "pgsqlusername",
Passwd: "I love Gitea!",
Name: "gitea",
SSLMode: "true",
Output: "postgres://pgsqlusername:I%20love%20Gitea%21@localhost:5432/gitea?sslmode=true",
},
{
Host: "localhost:1234",
User: "user",
Passwd: "pass",
Name: "gitea?param=1",
Output: "postgres://user:pass@localhost:1234/gitea?param=1&sslmode=",
},
}
for _, test := range tests {
connStr := makePgSQLConnStr(test.Host, test.User, test.Passwd, test.Name, test.SSLMode)
assert.Equal(t, test.Output, connStr)
}
}
+31
View File
@@ -0,0 +1,31 @@
// Copyright 2022 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db
import (
"context"
"xorm.io/builder"
)
// CountOrphanedObjects count subjects with have no existing refobject anymore
func CountOrphanedObjects(ctx context.Context, subject, refObject, joinCond string) (int64, error) {
return GetEngine(ctx).
Table("`"+subject+"`").
Join("LEFT", "`"+refObject+"`", joinCond).
Where(builder.IsNull{"`" + refObject + "`.id"}).
Select("COUNT(`" + subject + "`.`id`)").
Count()
}
// DeleteOrphanedObjects delete subjects with have no existing refobject anymore
func DeleteOrphanedObjects(ctx context.Context, subject, refObject, joinCond string) error {
subQuery := builder.Select("`"+subject+"`.id").
From("`"+subject+"`").
Join("LEFT", "`"+refObject+"`", joinCond).
Where(builder.IsNull{"`" + refObject + "`.id"})
b := builder.Delete(builder.In("id", subQuery)).From("`" + subject + "`")
_, err := GetEngine(ctx).Exec(b)
return err
}
+320
View File
@@ -0,0 +1,320 @@
// Copyright 2019 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db
import (
"context"
"database/sql"
"errors"
"runtime"
"slices"
"sync"
"gitea.dev/modules/setting"
"xorm.io/builder"
"xorm.io/xorm"
)
type contextKey struct{ key string }
var (
contextKeyEngine = contextKey{"engine"}
ContextKeyTestFixtures = contextKey{"test-fixtures"}
)
func withContextEngine(ctx context.Context, e Engine) context.Context {
return context.WithValue(ctx, contextKeyEngine, e)
}
var (
contextSafetyOnce sync.Once
contextSafetyDeniedFuncPCs []uintptr
)
func contextSafetyCheck(e Engine) {
if setting.IsProd && !setting.IsInTesting {
return
}
if e == nil {
return
}
// Only do this check for non-end-users. If the problem could be fixed in the future, this code could be removed.
contextSafetyOnce.Do(func() {
// try to figure out the bad functions to deny
type m struct{}
_ = e.SQL("SELECT 1").Iterate(&m{}, func(int, any) error {
callers := make([]uintptr, 32)
callerNum := runtime.Callers(1, callers)
for i := range callerNum {
if funcName := runtime.FuncForPC(callers[i]).Name(); funcName == "xorm.io/xorm.(*Session).Iterate" {
contextSafetyDeniedFuncPCs = append(contextSafetyDeniedFuncPCs, callers[i])
}
}
return nil
})
if len(contextSafetyDeniedFuncPCs) != 1 {
panic(errors.New("unable to determine the functions to deny"))
}
})
// it should be very fast: xxxx ns/op
callers := make([]uintptr, 32)
callerNum := runtime.Callers(3, callers) // skip 3: runtime.Callers, contextSafetyCheck, GetEngine
for i := range callerNum {
if slices.Contains(contextSafetyDeniedFuncPCs, callers[i]) {
panic(errors.New("using session context in an iterator would cause corrupted results"))
}
}
}
// GetEngine gets an existing db Engine/Statement or creates a new Session
func GetEngine(ctx context.Context) Engine {
if engine, ok := ctx.Value(contextKeyEngine).(Engine); ok {
// if reusing the existing session, need to do "contextSafetyCheck" because the Iterate creates a "autoResetStatement=false" session
contextSafetyCheck(engine)
return engine
}
// no need to do "contextSafetyCheck" because it's a new Session
return xormEngine.Context(ctx)
}
func GetXORMEngineForTesting() *xorm.Engine {
return xormEngine
}
// Committer represents an interface to Commit or Close the Context
type Committer interface {
Commit() error
Close() error
}
// halfCommitter is a wrapper of Committer.
// It can be closed early, but can't be committed early, it is useful for reusing a transaction.
type halfCommitter struct {
committer Committer
committed bool
}
func (c *halfCommitter) Commit() error {
c.committed = true
// should do nothing, and the parent committer will commit later
return nil
}
func (c *halfCommitter) Close() error {
if c.committed {
// it's "commit and close", should do nothing, and the parent committer will commit later
return nil
}
// it's "rollback and close", let the parent committer rollback right now
return c.committer.Close()
}
// TxContext represents a transaction Context,
// it will reuse the existing transaction in the parent context or create a new one.
// Some tips to use:
//
// 1 It's always recommended to use `WithTx` in new code instead of `TxContext`, since `WithTx` will handle the transaction automatically.
// 2. To maintain the old code which uses `TxContext`:
// a. Always call `Close()` before returning regardless of whether `Commit()` has been called.
// b. Always call `Commit()` before returning if there are no errors, even if the code did not change any data.
// c. Remember the `Committer` will be a halfCommitter when a transaction is being reused.
// So calling `Commit()` will do nothing, but calling `Close()` without calling `Commit()` will rollback the transaction.
// And all operations submitted by the caller stack will be rollbacked as well, not only the operations in the current function.
// d. It doesn't mean rollback is forbidden, but always do it only when there is an error, and you do want to rollback.
func TxContext(parentCtx context.Context) (context.Context, Committer, error) {
if sess := getTransactionSession(parentCtx); sess != nil {
return withContextEngine(parentCtx, sess), &halfCommitter{committer: sess}, nil
}
sess := xormEngine.NewSession()
if err := sess.Begin(); err != nil {
_ = sess.Close()
return nil, nil, err
}
return withContextEngine(parentCtx, sess), sess, nil
}
// WithTx represents executing database operations on a transaction, if the transaction exist,
// this function will reuse it otherwise will create a new one and close it when finished.
func WithTx(parentCtx context.Context, f func(ctx context.Context) error) error {
if sess := getTransactionSession(parentCtx); sess != nil {
err := f(withContextEngine(parentCtx, sess))
if err != nil {
// rollback immediately, in case the caller ignores returned error and tries to commit the transaction.
_ = sess.Close()
}
return err
}
return txWithNoCheck(parentCtx, f)
}
// WithTx2 is similar to WithTx, but it has two return values: result and error.
func WithTx2[T any](parentCtx context.Context, f func(ctx context.Context) (T, error)) (ret T, errRet error) {
errRet = WithTx(parentCtx, func(ctx context.Context) (errInner error) {
ret, errInner = f(ctx)
return errInner
})
return ret, errRet
}
func txWithNoCheck(parentCtx context.Context, f func(ctx context.Context) error) error {
sess := xormEngine.NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
}
if err := f(withContextEngine(parentCtx, sess)); err != nil {
return err
}
return sess.Commit()
}
// Insert inserts records into database
func Insert(ctx context.Context, beans ...any) error {
_, err := GetEngine(ctx).Insert(beans...)
return err
}
// Exec executes a sql with args
func Exec(ctx context.Context, sqlAndArgs ...any) (sql.Result, error) {
return GetEngine(ctx).Exec(sqlAndArgs...)
}
func Get[T any](ctx context.Context, cond builder.Cond) (object *T, exist bool, err error) {
if !cond.IsValid() {
panic("cond is invalid in db.Get(ctx, cond). This should not be possible.")
}
var bean T
has, err := GetEngine(ctx).Where(cond).NoAutoCondition().Get(&bean)
if err != nil {
return nil, false, err
} else if !has {
return nil, false, nil
}
return &bean, true, nil
}
func GetByID[T any](ctx context.Context, id int64) (object *T, exist bool, err error) {
var bean T
has, err := GetEngine(ctx).ID(id).NoAutoCondition().Get(&bean)
if err != nil {
return nil, false, err
} else if !has {
return nil, false, nil
}
return &bean, true, nil
}
func Exist[T any](ctx context.Context, cond builder.Cond) (bool, error) {
if !cond.IsValid() {
panic("cond is invalid in db.Exist(ctx, cond). This should not be possible.")
}
var bean T
return GetEngine(ctx).Where(cond).NoAutoCondition().Exist(&bean)
}
func ExistByID[T any](ctx context.Context, id int64) (bool, error) {
var bean T
return GetEngine(ctx).ID(id).NoAutoCondition().Exist(&bean)
}
// DeleteByID deletes the given bean with the given ID
func DeleteByID[T any](ctx context.Context, id int64) (int64, error) {
var bean T
return GetEngine(ctx).ID(id).NoAutoCondition().NoAutoTime().Delete(&bean)
}
func DeleteByIDs[T any](ctx context.Context, ids ...int64) error {
if len(ids) == 0 {
return nil
}
var bean T
_, err := GetEngine(ctx).In("id", ids).NoAutoCondition().NoAutoTime().Delete(&bean)
return err
}
func Delete[T any](ctx context.Context, opts FindOptions) (int64, error) {
if opts == nil || !opts.ToConds().IsValid() {
panic("opts are empty or invalid in db.Delete(ctx, opts). This should not be possible.")
}
var bean T
return GetEngine(ctx).Where(opts.ToConds()).NoAutoCondition().NoAutoTime().Delete(&bean)
}
// DeleteByBean deletes all records according non-empty fields of the bean as conditions.
func DeleteByBean(ctx context.Context, bean any) (int64, error) {
return GetEngine(ctx).Delete(bean)
}
// FindIDs finds the IDs for the given table name satisfying the given condition
// By passing a different value than "id" for "idCol", you can query for foreign IDs, i.e. the repo IDs which satisfy the condition
func FindIDs(ctx context.Context, tableName, idCol string, cond builder.Cond) ([]int64, error) {
ids := make([]int64, 0, 10)
if err := GetEngine(ctx).Table(tableName).
Cols(idCol).
Where(cond).
Find(&ids); err != nil {
return nil, err
}
return ids, nil
}
// DecrByIDs decreases the given column for entities of the "bean" type with one of the given ids by one
// Timestamps of the entities won't be updated
func DecrByIDs(ctx context.Context, ids []int64, decrCol string, bean any) error {
if len(ids) == 0 {
return nil
}
_, err := GetEngine(ctx).Decr(decrCol).In("id", ids).NoAutoCondition().NoAutoTime().Update(bean)
return err
}
// DeleteBeans deletes all given beans, beans must contain delete conditions.
func DeleteBeans(ctx context.Context, beans ...any) (err error) {
e := GetEngine(ctx)
for i := range beans {
if _, err = e.Delete(beans[i]); err != nil {
return err
}
}
return nil
}
// TruncateBeans deletes all given beans, beans may contain delete conditions.
func TruncateBeans(ctx context.Context, beans ...any) (err error) {
e := GetEngine(ctx)
for i := range beans {
if _, err = e.Truncate(beans[i]); err != nil {
return err
}
}
return nil
}
// CountByBean counts the number of database records according non-empty fields of the bean as conditions.
func CountByBean(ctx context.Context, bean any) (int64, error) {
return GetEngine(ctx).Count(bean)
}
// InTransaction returns true if the engine is in a transaction otherwise return false
func InTransaction(ctx context.Context) bool {
return getTransactionSession(ctx) != nil
}
func getTransactionSession(ctx context.Context) *xorm.Session {
e, _ := ctx.Value(contextKeyEngine).(Engine)
if sess, ok := e.(*xorm.Session); ok && sess.IsInTx() {
return sess
}
return nil
}
+102
View File
@@ -0,0 +1,102 @@
// Copyright 2022 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db // it's not db_test, because this file is for testing the private type halfCommitter
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
)
type MockCommitter struct {
wants []string
gots []string
}
func NewMockCommitter(wants ...string) *MockCommitter {
return &MockCommitter{
wants: wants,
}
}
func (c *MockCommitter) Commit() error {
c.gots = append(c.gots, "commit")
return nil
}
func (c *MockCommitter) Close() error {
c.gots = append(c.gots, "close")
return nil
}
func (c *MockCommitter) Assert(t *testing.T) {
assert.Equal(t, c.wants, c.gots, "want operations %v, but got %v", c.wants, c.gots)
}
func Test_halfCommitter(t *testing.T) {
/*
Do something like:
ctx, committer, err := db.TxContext(t.Context())
if err != nil {
return nil
}
defer committer.Close()
// ...
if err != nil {
return nil
}
// ...
return committer.Commit()
*/
testWithCommitter := func(committer Committer, f func(committer Committer) error) {
if err := f(&halfCommitter{committer: committer}); err == nil {
committer.Commit()
}
committer.Close()
}
t.Run("commit and close", func(t *testing.T) {
mockCommitter := NewMockCommitter("commit", "close")
testWithCommitter(mockCommitter, func(committer Committer) error {
defer committer.Close()
return committer.Commit()
})
mockCommitter.Assert(t)
})
t.Run("rollback and close", func(t *testing.T) {
mockCommitter := NewMockCommitter("close", "close")
testWithCommitter(mockCommitter, func(committer Committer) error {
defer committer.Close()
if true {
return errors.New("error")
}
return committer.Commit()
})
mockCommitter.Assert(t)
})
t.Run("close and commit", func(t *testing.T) {
mockCommitter := NewMockCommitter("close", "close")
testWithCommitter(mockCommitter, func(committer Committer) error {
committer.Close()
committer.Commit()
return errors.New("error")
})
mockCommitter.Assert(t)
})
}
+135
View File
@@ -0,0 +1,135 @@
// Copyright 2022 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db_test
import (
"context"
"testing"
"gitea.dev/models/db"
"gitea.dev/models/unittest"
"github.com/stretchr/testify/assert"
)
func TestInTransaction(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
assert.False(t, db.InTransaction(t.Context()))
assert.NoError(t, db.WithTx(t.Context(), func(ctx context.Context) error {
assert.True(t, db.InTransaction(ctx))
return nil
}))
ctx, committer, err := db.TxContext(t.Context())
assert.NoError(t, err)
defer committer.Close()
assert.True(t, db.InTransaction(ctx))
assert.NoError(t, db.WithTx(ctx, func(ctx context.Context) error {
assert.True(t, db.InTransaction(ctx))
return nil
}))
}
func TestTxContext(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
{ // create new transaction
ctx, committer, err := db.TxContext(t.Context())
assert.NoError(t, err)
assert.True(t, db.InTransaction(ctx))
assert.NoError(t, committer.Commit())
}
{ // reuse the transaction created by TxContext and commit it
ctx, committer, err := db.TxContext(t.Context())
engine := db.GetEngine(ctx)
assert.NoError(t, err)
assert.True(t, db.InTransaction(ctx))
{
ctx, committer, err := db.TxContext(ctx)
assert.NoError(t, err)
assert.True(t, db.InTransaction(ctx))
assert.Equal(t, engine, db.GetEngine(ctx))
assert.NoError(t, committer.Commit())
}
assert.NoError(t, committer.Commit())
}
{ // reuse the transaction created by TxContext and close it
ctx, committer, err := db.TxContext(t.Context())
engine := db.GetEngine(ctx)
assert.NoError(t, err)
assert.True(t, db.InTransaction(ctx))
{
ctx, committer, err := db.TxContext(ctx)
assert.NoError(t, err)
assert.True(t, db.InTransaction(ctx))
assert.Equal(t, engine, db.GetEngine(ctx))
assert.NoError(t, committer.Close())
}
assert.NoError(t, committer.Close())
}
{ // reuse the transaction created by WithTx
assert.NoError(t, db.WithTx(t.Context(), func(ctx context.Context) error {
assert.True(t, db.InTransaction(ctx))
{
ctx, committer, err := db.TxContext(ctx)
assert.NoError(t, err)
assert.True(t, db.InTransaction(ctx))
assert.NoError(t, committer.Commit())
}
return nil
}))
}
}
func TestContextSafety(t *testing.T) {
type TestModel1 struct {
ID int64
}
type TestModel2 struct {
ID int64
}
assert.NoError(t, unittest.GetXORMEngine().Sync(&TestModel1{}, &TestModel2{}))
assert.NoError(t, db.TruncateBeans(t.Context(), &TestModel1{}, &TestModel2{}))
testCount := 10
for i := 1; i <= testCount; i++ {
assert.NoError(t, db.Insert(t.Context(), &TestModel1{ID: int64(i)}))
assert.NoError(t, db.Insert(t.Context(), &TestModel2{ID: int64(-i)}))
}
t.Run("Show-XORM-Bug", func(t *testing.T) {
actualCount := 0
// here: db.GetEngine(t.Context()) is a new *Session created from *Engine
_ = db.WithTx(t.Context(), func(ctx context.Context) error {
_ = db.GetEngine(ctx).Iterate(&TestModel1{}, func(i int, bean any) error {
// here: db.GetEngine(ctx) is always the unclosed "Iterate" *Session with autoResetStatement=false,
// and the internal states (including "cond" and others) are always there and not be reset in this callback.
m1 := bean.(*TestModel1)
assert.EqualValues(t, i+1, m1.ID)
// here: XORM bug, it fails because the SQL becomes "WHERE id=-1", "WHERE id=-1 AND id=-2", "WHERE id=-1 AND id=-2 AND id=-3" ...
// and it conflicts with the "Iterate"'s internal states.
// has, err := db.GetEngine(ctx).Get(&TestModel2{ID: -m1.ID})
actualCount++
return nil
})
return nil
})
assert.Equal(t, testCount, actualCount)
})
t.Run("DenyBadUsage", func(t *testing.T) {
assert.PanicsWithError(t, "using session context in an iterator would cause corrupted results", func() {
_ = db.WithTx(t.Context(), func(ctx context.Context) error {
return db.GetEngine(ctx).Iterate(&TestModel1{}, func(i int, bean any) error {
_ = db.GetEngine(ctx)
return nil
})
})
})
})
}
+86
View File
@@ -0,0 +1,86 @@
// Copyright 2019 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db
import (
"fmt"
"gitea.dev/modules/setting"
"xorm.io/xorm"
"xorm.io/xorm/convert"
"xorm.io/xorm/schemas"
)
// ConvertDatabaseTable converts database and tables from utf8 to utf8mb4 if it's mysql and set ROW_FORMAT=dynamic
func ConvertDatabaseTable() error {
if xormEngine.Dialect().URI().DBType != schemas.MYSQL {
return nil
}
r, err := CheckCollations(xormEngine)
if err != nil {
return err
}
_, err = xormEngine.Exec(fmt.Sprintf("ALTER DATABASE `%s` CHARACTER SET utf8mb4 COLLATE %s", setting.Database.Name, r.ExpectedCollation))
if err != nil {
return err
}
tables, err := xormEngine.DBMetas()
if err != nil {
return err
}
for _, table := range tables {
if _, err := xormEngine.Exec(fmt.Sprintf("ALTER TABLE `%s` ROW_FORMAT=dynamic", table.Name)); err != nil {
return err
}
if _, err := xormEngine.Exec(fmt.Sprintf("ALTER TABLE `%s` CONVERT TO CHARACTER SET utf8mb4 COLLATE %s", table.Name, r.ExpectedCollation)); err != nil {
return err
}
}
return nil
}
// ConvertVarcharToNVarchar converts database and tables from varchar to nvarchar if it's mssql
func ConvertVarcharToNVarchar() error {
if xormEngine.Dialect().URI().DBType != schemas.MSSQL {
return nil
}
sess := xormEngine.NewSession()
defer sess.Close()
res, err := sess.QuerySliceString(`SELECT 'ALTER TABLE ' + OBJECT_NAME(SC.object_id) + ' MODIFY SC.name NVARCHAR(' + CONVERT(VARCHAR(5),SC.max_length) + ')'
FROM SYS.columns SC
JOIN SYS.types ST
ON SC.system_type_id = ST.system_type_id
AND SC.user_type_id = ST.user_type_id
WHERE ST.name ='varchar'`)
if err != nil {
return err
}
for _, row := range res {
if len(row) == 1 {
if _, err = sess.Exec(row[0]); err != nil {
return err
}
}
}
return err
}
// CellToInt converts a xorm.Cell field value to an int value
func CellToInt[T ~int | int64](cell xorm.Cell, def T) (ret T, has bool, err error) {
if *cell == nil {
return def, false, nil
}
val, err := convert.AsInt64(*cell)
if err != nil {
return def, false, err
}
return T(val), true, err
}
+74
View File
@@ -0,0 +1,74 @@
// Copyright 2020 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db
import (
"database/sql"
"database/sql/driver"
"sync"
"gitea.dev/modules/setting"
"github.com/lib/pq"
"xorm.io/xorm/dialects"
)
var registerOnce sync.Once
func registerPostgresSchemaDriver() {
registerOnce.Do(func() {
sql.Register(sqlDriverPostgresSchema, &postgresSchemaDriver{})
dialects.RegisterDriver(sqlDriverPostgresSchema, dialects.QueryDriver("postgres"))
})
}
type postgresSchemaDriver struct {
pq.Driver
}
// Open opens a new connection to the database. name is a connection string.
// This function opens the postgres connection in the default manner but immediately
// runs set_config to set the search_path appropriately
func (d *postgresSchemaDriver) Open(name string) (driver.Conn, error) {
conn, err := d.Driver.Open(name)
if err != nil {
return conn, err
}
schemaValue, _ := driver.String.ConvertValue(setting.Database.Schema)
// golangci lint is incorrect here - there is no benefit to using driver.ExecerContext here
// and in any case pq does not implement it
if execer, ok := conn.(driver.Execer); ok { //nolint:staticcheck // see above
_, err := execer.Exec(`SELECT set_config(
'search_path',
$1 || ',' || current_setting('search_path'),
false)`, []driver.Value{schemaValue})
if err != nil {
_ = conn.Close()
return nil, err
}
return conn, nil
}
stmt, err := conn.Prepare(`SELECT set_config(
'search_path',
$1 || ',' || current_setting('search_path'),
false)`)
if err != nil {
_ = conn.Close()
return nil, err
}
defer stmt.Close()
// driver.String.ConvertValue will never return err for string
// golangci lint is incorrect here - there is no benefit to using stmt.ExecWithContext here
_, err = stmt.Exec([]driver.Value{schemaValue}) //nolint:staticcheck // see above
if err != nil {
_ = conn.Close()
return nil, err
}
return conn, nil
}
+31
View File
@@ -0,0 +1,31 @@
// Copyright 2026 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
//go:build sqlite_mattn && sqlite_unlock_notify
package db
import (
"fmt"
"strconv"
"strings"
_ "github.com/mattn/go-sqlite3"
)
func init() {
registerSQLiteConnStrMaker(makeSQLiteConnStrMattnCGO)
}
func makeSQLiteConnStrMattnCGO(opts SQLiteConnStrOptions) (string, string, error) {
var params []string
params = append(params, "cache=shared")
params = append(params, "mode=rwc")
params = append(params, "_busy_timeout="+strconv.Itoa(opts.BusyTimeout))
params = append(params, "_txlock=immediate")
if opts.JournalMode != "" {
params = append(params, "_journal_mode="+opts.JournalMode)
}
connStr := fmt.Sprintf("file:%s?%s", opts.FilePath, strings.Join(params, "&"))
return sqlDriverSQLite3, connStr, nil
}
+41
View File
@@ -0,0 +1,41 @@
// Copyright 2026 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
//go:build !sqlite_mattn
// modernc driver is chosen as the default one (compared to mattn, ncruces)
// * mattn was used as default, but it requires CGO
// * the CI times are almost the same for these three (race detector must be disabled)
// * modernc increases the binary size about 2MB, ncruces increases about 7MB
// * compiling time: modernc is slightly slower than mattn, ncruces is the slowest
package db
import (
"database/sql"
"fmt"
"strings"
"modernc.org/sqlite"
)
func init() {
// this driver contains huge amount of Golang code, so it is much slower when "-race" check is enabled.
registerSQLiteConnStrMaker(makeSQLiteConnStrModerncCCGO)
sql.Register(sqlDriverSQLite3, &sqlite.Driver{})
}
func makeSQLiteConnStrModerncCCGO(opts SQLiteConnStrOptions) (string, string, error) {
var params []string
// TODO: there is a changed behavior from mattn driver:
// * mattn driver can wait for pretty long time for concurrent accesses (not limited by the busy timeout)
// * but other drivers will report something like "database is locked (5) (SQLITE_BUSY)" if the timeout is reached
// Maybe we need to relax the busy timeout to a reasonable long time in the future
params = append(params, fmt.Sprintf("_pragma=busy_timeout(%d)", opts.BusyTimeout))
params = append(params, "_txlock=immediate")
if opts.JournalMode != "" {
params = append(params, fmt.Sprintf("_pragma=journal_mode(%s)", opts.JournalMode))
}
connStr := fmt.Sprintf("file:%s?%s", opts.FilePath, strings.Join(params, "&"))
return sqlDriverSQLite3, connStr, nil
}
+184
View File
@@ -0,0 +1,184 @@
// Copyright 2014 The Gogs Authors. All rights reserved.
// Copyright 2018 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db
import (
"context"
"database/sql"
"fmt"
"reflect"
"strings"
_ "github.com/go-sql-driver/mysql" // Needed for the MySQL driver
_ "github.com/lib/pq" // Needed for the Postgresql driver
_ "github.com/microsoft/go-mssqldb" // Needed for the MSSQL driver
"xorm.io/xorm"
"xorm.io/xorm/core"
"xorm.io/xorm/dialects"
"xorm.io/xorm/names"
"xorm.io/xorm/schemas"
)
var (
xormEngine *xorm.Engine
registeredModels []any
registeredInitFuncs []func() error
)
// Engine represents a xorm engine or session.
type Engine interface {
Table(tableNameOrBean any) *xorm.Session
Count(...any) (int64, error)
Decr(column string, arg ...any) *xorm.Session
Delete(...any) (int64, error)
Truncate(...any) (int64, error)
Exec(...any) (sql.Result, error)
Find(any, ...any) error
FindAndCount(any, ...any) (int64, error)
Get(beans ...any) (bool, error)
ID(any) *xorm.Session
In(string, ...any) *xorm.Session
Incr(column string, arg ...any) *xorm.Session
Insert(...any) (int64, error)
Iterate(any, xorm.IterFunc) error
Join(joinOperator string, tablename, condition any, args ...any) *xorm.Session
SQL(any, ...any) *xorm.Session
Where(any, ...any) *xorm.Session
Asc(colNames ...string) *xorm.Session
Desc(colNames ...string) *xorm.Session
Limit(limit int, start ...int) *xorm.Session
NoAutoTime() *xorm.Session
SumInt(bean any, columnName string) (res int64, err error)
Sync(...any) error
Select(string) *xorm.Session
SetExpr(string, any) *xorm.Session
NotIn(string, ...any) *xorm.Session
OrderBy(any, ...any) *xorm.Session
Exist(...any) (bool, error)
Distinct(...string) *xorm.Session
Query(...any) ([]map[string][]byte, error)
Cols(...string) *xorm.Session
Context(ctx context.Context) *xorm.Session
Ping() error
IsTableExist(tableNameOrBean any) (bool, error)
}
// Session represents a xorm session interface, used as an abstraction over *xorm.Session.
type Session interface {
Engine
And(query any, args ...any) *xorm.Session
Begin() error
Close() error
Commit() error
IsInTx() bool
Rollback() error
Engine() *xorm.Engine
}
// EngineMigration is a xorm engine interface used for migrations.
// It extends Engine with additional methods that are only available on the engine (not on the session)
// and are needed by the migration packages.
type EngineMigration interface {
Engine
Close() error
DB() *core.DB
DBMetas() ([]*schemas.Table, error)
Dialect() dialects.Dialect
DropTables(beans ...any) error
NewSession() *xorm.Session
QueryInterface(sqlOrArgs ...any) ([]map[string]any, error)
SetMapper(mapper names.Mapper)
SyncWithOptions(opts xorm.SyncOptions, beans ...any) (*xorm.SyncResult, error)
TableInfo(bean any) (*schemas.Table, error)
TableName(bean any, includeSchema ...bool) string
}
var (
_ Engine = (*xorm.Engine)(nil)
_ Engine = (*xorm.Session)(nil)
_ Session = (*xorm.Session)(nil)
_ EngineMigration = (*xorm.Engine)(nil)
)
// RegisterModel registers model, if initFuncs provided, it will be invoked after data model sync
func RegisterModel(bean any, initFunc ...func() error) {
registeredModels = append(registeredModels, bean)
if len(registeredInitFuncs) > 0 && initFunc[0] != nil {
registeredInitFuncs = append(registeredInitFuncs, initFunc[0])
}
}
// SyncAllTables sync the schemas of all tables, is required by unit test code
func SyncAllTables() error {
_, err := xormEngine.StoreEngine("InnoDB").SyncWithOptions(xorm.SyncOptions{
WarnIfDatabaseColumnMissed: true,
}, registeredModels...)
return err
}
// NamesToBean return a list of beans or an error
func NamesToBean(names ...string) ([]any, error) {
beans := []any{}
if len(names) == 0 {
beans = append(beans, registeredModels...)
return beans, nil
}
// Need to map provided names to beans...
beanMap := make(map[string]any)
for _, bean := range registeredModels {
beanMap[strings.ToLower(reflect.Indirect(reflect.ValueOf(bean)).Type().Name())] = bean
beanMap[strings.ToLower(xormEngine.TableName(bean))] = bean
beanMap[strings.ToLower(xormEngine.TableName(bean, true))] = bean
}
gotBean := make(map[any]bool)
for _, name := range names {
bean, ok := beanMap[strings.ToLower(strings.TrimSpace(name))]
if !ok {
return nil, fmt.Errorf("no table found that matches: %s", name)
}
if !gotBean[bean] {
beans = append(beans, bean)
gotBean[bean] = true
}
}
return beans, nil
}
// MaxBatchInsertSize returns the table's max batch insert size
func MaxBatchInsertSize(bean any) int {
t, err := xormEngine.TableInfo(bean)
if err != nil {
return 50
}
return 999 / len(t.ColumnsSeq())
}
// IsTableNotEmpty returns true if table has at least one record
func IsTableNotEmpty(beanOrTableName any) (bool, error) {
return xormEngine.Table(beanOrTableName).Exist()
}
// DeleteAllRecords will delete all the records of this table
func DeleteAllRecords(tableName string) error {
_, err := xormEngine.Exec("DELETE FROM " + tableName)
return err
}
// GetMaxID will return max id of the table
func GetMaxID(beanOrTableName any) (maxID int64, err error) {
_, err = xormEngine.Select("MAX(id)").Table(beanOrTableName).Get(&maxID)
return maxID, err
}
func SetLogSQL(ctx context.Context, on bool) {
e := GetEngine(ctx)
if x, ok := e.(*xorm.Engine); ok {
x.ShowSQL(on)
} else if sess, ok := e.(*xorm.Session); ok {
sess.Engine().ShowSQL(on)
}
}
+37
View File
@@ -0,0 +1,37 @@
// Copyright 2024 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db
import (
"gitea.dev/modules/setting"
"xorm.io/xorm/schemas"
)
// DumpDatabase dumps all data from database according the special database SQL syntax to file system.
func DumpDatabase(filePath string, dbType setting.DatabaseType) error {
var tbs []*schemas.Table
for _, t := range registeredModels {
t, err := xormEngine.TableInfo(t)
if err != nil {
return err
}
tbs = append(tbs, t)
}
type Version struct {
ID int64 `xorm:"pk autoincr"`
Version int64
}
t, err := xormEngine.TableInfo(&Version{})
if err != nil {
return err
}
tbs = append(tbs, t)
if dbType != "" {
return xormEngine.DumpTablesToFile(tbs, filePath, schemas.DBType(dbType))
}
return xormEngine.DumpTablesToFile(tbs, filePath)
}
+53
View File
@@ -0,0 +1,53 @@
// Copyright 2024 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db
import (
"context"
"time"
"gitea.dev/modules/gtprof"
"gitea.dev/modules/log"
"gitea.dev/modules/setting"
"xorm.io/xorm/contexts"
)
type EngineHook struct {
Threshold time.Duration
Logger log.Logger
}
var _ contexts.Hook = (*EngineHook)(nil)
func (*EngineHook) BeforeProcess(c *contexts.ContextHook) (context.Context, error) {
if c.Ctx.Value(ContextKeyTestFixtures) != nil {
return c.Ctx, nil
}
ctx, _ := gtprof.GetTracer().Start(c.Ctx, gtprof.TraceSpanDatabase)
return ctx, nil
}
func (h *EngineHook) AfterProcess(c *contexts.ContextHook) error {
if c.Ctx.Value(ContextKeyTestFixtures) != nil {
return nil
}
span := gtprof.GetContextSpan(c.Ctx)
if span != nil {
// Do not record SQL parameters here:
// * It shouldn't expose the parameters because they contain sensitive information, end users need to report the trace details safely.
// * Some parameters contain quite long texts, waste memory and are difficult to display.
span.SetAttributeString(gtprof.TraceAttrDbSQL, c.SQL)
span.End()
} else {
setting.PanicInDevOrTesting("span in database engine hook is nil")
}
if c.ExecuteTime >= h.Threshold {
// 8 is the amount of skips passed to runtime.Caller, so that in the log the correct function
// is being displayed (the function that ultimately wants to execute the query in the code)
// instead of the function of the slow query hook being called.
h.Logger.Log(8, &log.Event{Level: log.WARN}, "[Slow SQL Query] %s %v - %v", c.SQL, c.Args, c.ExecuteTime)
}
return nil
}
+127
View File
@@ -0,0 +1,127 @@
// Copyright 2024 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db
import (
"context"
"fmt"
"gitea.dev/modules/log"
"gitea.dev/modules/setting"
"xorm.io/xorm"
"xorm.io/xorm/names"
)
func init() {
gonicNames := []string{"SSL", "UID"}
for _, name := range gonicNames {
names.LintGonicMapper[name] = true
}
}
// newXORMEngine returns a new XORM engine from the configuration
func newXORMEngine() (*xorm.Engine, error) {
connOpts := GlobalConnOptions()
driver, connStr, err := ConnStr(connOpts)
if err != nil {
return nil, err
}
engine, err := xorm.NewEngine(driver, connStr)
if err != nil {
return nil, err
}
switch {
case connOpts.Type.IsMySQL():
engine.Dialect().SetParams(map[string]string{"rowFormat": "DYNAMIC"})
case connOpts.Type.IsMSSQL():
engine.Dialect().SetParams(map[string]string{"DEFAULT_VARCHAR": "nvarchar"})
}
engine.SetSchema(connOpts.Schema)
return engine, nil
}
// InitEngine initializes the xorm.Engine and sets it as XORM's default context
func InitEngine(ctx context.Context) error {
xe, err := newXORMEngine()
if err != nil {
return fmt.Errorf("failed to init database engine: %w", err)
}
xe.SetMapper(names.GonicMapper{})
// WARNING: for serv command, MUST remove the output to os.stdout,
// so use log file to instead print to stdout.
xe.SetLogger(NewXORMLogger(setting.Database.LogSQL))
xe.ShowSQL(setting.Database.LogSQL)
xe.SetMaxOpenConns(setting.Database.MaxOpenConns)
xe.SetMaxIdleConns(setting.Database.MaxIdleConns)
xe.SetConnMaxLifetime(setting.Database.ConnMaxLifetime)
if setting.Database.SlowQueryThreshold > 0 {
xe.AddHook(&EngineHook{
Threshold: setting.Database.SlowQueryThreshold,
Logger: log.GetLogger("xorm"),
})
}
SetDefaultEngine(ctx, xe)
return nil
}
// SetDefaultEngine sets the default engine for db
func SetDefaultEngine(ctx context.Context, eng *xorm.Engine) {
xormEngine = eng
xormEngine.SetDefaultContext(ctx)
}
// UnsetDefaultEngine closes and unsets the default engine
// We hope the SetDefaultEngine and UnsetDefaultEngine can be paired, but it's impossible now,
// there are many calls to InitEngine -> SetDefaultEngine directly to overwrite the `xormEngine` and `xormContext` without close
// Global database engine related functions are all racy and there is no graceful close right now.
func UnsetDefaultEngine() {
if xormEngine != nil {
_ = xormEngine.Close()
xormEngine = nil
}
}
// InitEngineWithMigration initializes a new xorm.Engine and sets it as the XORM's default context
// This function must never call .Sync() if the provided migration function fails.
// When called from the "doctor" command, the migration function is a version check
// that prevents the doctor from fixing anything in the database if the migration level
// is different from the expected value.
func InitEngineWithMigration(ctx context.Context, migrateFunc func(context.Context, EngineMigration) error) (err error) {
if err = InitEngine(ctx); err != nil {
return err
}
if err = xormEngine.Ping(); err != nil {
return err
}
preprocessDatabaseCollation(xormEngine)
// We have to run migrateFunc here in case the user is re-running installation on a previously created DB.
// If we do not then table schemas will be changed and there will be conflicts when the migrations run properly.
//
// Installation should only be being re-run if users want to recover an old database.
// However, we should think carefully about should we support re-install on an installed instance,
// as there may be other problems due to secret reinitialization.
if err = migrateFunc(ctx, xormEngine); err != nil {
return fmt.Errorf("migrate: %w", err)
}
if err = SyncAllTables(); err != nil {
return fmt.Errorf("sync database struct error: %w", err)
}
for _, initFunc := range registeredInitFuncs {
if err := initFunc(); err != nil {
return fmt.Errorf("initFunc failed: %w", err)
}
}
return nil
}
+83
View File
@@ -0,0 +1,83 @@
// Copyright 2019 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db_test
import (
"path/filepath"
"testing"
"gitea.dev/models/db"
issues_model "gitea.dev/models/issues"
"gitea.dev/models/unittest"
"gitea.dev/modules/setting"
_ "gitea.dev/cmd" // for TestPrimaryKeys
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDumpDatabase(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
dir := t.TempDir()
type Version struct {
ID int64 `xorm:"pk autoincr"`
Version int64
}
assert.NoError(t, db.GetEngine(t.Context()).Sync(new(Version)))
for _, dbType := range setting.SupportedDatabaseTypes {
assert.NoError(t, db.DumpDatabase(filepath.Join(dir, dbType+".sql"), setting.DatabaseType(dbType)))
}
}
func TestDeleteOrphanedObjects(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
countBefore, err := db.GetEngine(t.Context()).Count(&issues_model.PullRequest{})
assert.NoError(t, err)
_, err = db.GetEngine(t.Context()).Insert(&issues_model.PullRequest{IssueID: 1000}, &issues_model.PullRequest{IssueID: 1001}, &issues_model.PullRequest{IssueID: 1003})
assert.NoError(t, err)
orphaned, err := db.CountOrphanedObjects(t.Context(), "pull_request", "issue", "pull_request.issue_id=issue.id")
assert.NoError(t, err)
assert.EqualValues(t, 3, orphaned)
err = db.DeleteOrphanedObjects(t.Context(), "pull_request", "issue", "pull_request.issue_id=issue.id")
assert.NoError(t, err)
countAfter, err := db.GetEngine(t.Context()).Count(&issues_model.PullRequest{})
assert.NoError(t, err)
assert.Equal(t, countBefore, countAfter)
}
func TestPrimaryKeys(t *testing.T) {
// Some dbs require that all tables have primary keys, see
// https://github.com/go-gitea/gitea/issues/21086
// https://github.com/go-gitea/gitea/issues/16802
// To avoid creating tables without primary key again, this test will check them.
// Import "gitea.dev/cmd" to make sure each db.RegisterModel in init functions has been called.
beans, err := db.NamesToBean()
require.NoError(t, err)
whitelist := map[string]string{
"the_table_name_to_skip_checking": "Write a note here to explain why",
}
for _, bean := range beans {
table, err := db.GetXORMEngineForTesting().TableInfo(bean)
if err != nil {
t.Fatal(err)
}
if why, ok := whitelist[table.Name]; ok {
t.Logf("ignore %q because %q", table.Name, why)
continue
}
assert.NotEmpty(t, table.PrimaryKeys, "table %q has no primary key", table.Name)
}
}
+74
View File
@@ -0,0 +1,74 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db
import (
"fmt"
"gitea.dev/modules/util"
)
// ErrCancelled represents an error due to context cancellation
type ErrCancelled struct {
Message string
}
// IsErrCancelled checks if an error is a ErrCancelled.
func IsErrCancelled(err error) bool {
_, ok := err.(ErrCancelled)
return ok
}
func (err ErrCancelled) Error() string {
return "Cancelled: " + err.Message
}
// ErrCancelledf returns an ErrCancelled for the provided format and args
func ErrCancelledf(format string, args ...any) error {
return ErrCancelled{
fmt.Sprintf(format, args...),
}
}
// ErrSSHDisabled represents an "SSH disabled" error.
type ErrSSHDisabled struct{}
// IsErrSSHDisabled checks if an error is a ErrSSHDisabled.
func IsErrSSHDisabled(err error) bool {
_, ok := err.(ErrSSHDisabled)
return ok
}
func (err ErrSSHDisabled) Error() string {
return "SSH is disabled"
}
// ErrNotExist represents a non-exist error.
type ErrNotExist struct {
Resource string
ID int64
}
// IsErrNotExist checks if an error is an ErrNotExist
func IsErrNotExist(err error) bool {
_, ok := err.(ErrNotExist)
return ok
}
func (err ErrNotExist) Error() string {
name := "record"
if err.Resource != "" {
name = err.Resource
}
if err.ID != 0 {
return fmt.Sprintf("%s does not exist [id: %d]", name, err.ID)
}
return name + " does not exist"
}
// Unwrap unwraps this as a ErrNotExist err
func (err ErrNotExist) Unwrap() error {
return util.ErrNotExist
}
+172
View File
@@ -0,0 +1,172 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db
import (
"context"
"errors"
"fmt"
"strconv"
"gitea.dev/modules/setting"
)
// ResourceIndex represents a resource index which could be used as issue/release and others
// We can create different tables i.e. issue_index, release_index, etc.
type ResourceIndex struct {
GroupID int64 `xorm:"pk"`
MaxIndex int64 `xorm:"index"`
}
var ErrGetResourceIndexFailed = errors.New("get resource index failed")
// SyncMaxResourceIndex sync the max index with the resource
func SyncMaxResourceIndex(ctx context.Context, tableName string, groupID, maxIndex int64) (err error) {
e := GetEngine(ctx)
// try to update the max_index and acquire the write-lock for the record
res, err := e.Exec(fmt.Sprintf("UPDATE %s SET max_index=? WHERE group_id=? AND max_index<?", tableName), maxIndex, groupID, maxIndex)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
// if nothing is updated, the record might not exist or might be larger, it's safe to try to insert it again and then check whether the record exists
_, errIns := e.Exec(fmt.Sprintf("INSERT INTO %s (group_id, max_index) VALUES (?, ?)", tableName), groupID, maxIndex)
var savedIdx int64
has, err := e.SQL(fmt.Sprintf("SELECT max_index FROM %s WHERE group_id=?", tableName), groupID).Get(&savedIdx)
if err != nil {
return err
}
// if the record still doesn't exist, there must be some errors (insert error)
if !has {
if errIns == nil {
return errors.New("impossible error when SyncMaxResourceIndex, insert succeeded but no record is saved")
}
return errIns
}
}
return nil
}
func postgresGetNextResourceIndex(ctx context.Context, tableName string, groupID int64) (int64, error) {
res, err := GetEngine(ctx).Query(fmt.Sprintf("INSERT INTO %s (group_id, max_index) "+
"VALUES (?,1) ON CONFLICT (group_id) DO UPDATE SET max_index = %s.max_index+1 RETURNING max_index",
tableName, tableName), groupID)
if err != nil {
return 0, err
}
if len(res) == 0 {
return 0, ErrGetResourceIndexFailed
}
return strconv.ParseInt(string(res[0]["max_index"]), 10, 64)
}
func mysqlGetNextResourceIndex(ctx context.Context, tableName string, groupID int64) (int64, error) {
if _, err := GetEngine(ctx).Exec(fmt.Sprintf("INSERT INTO %s (group_id, max_index) "+
"VALUES (?,1) ON DUPLICATE KEY UPDATE max_index = max_index+1",
tableName), groupID); err != nil {
return 0, err
}
var idx int64
_, err := GetEngine(ctx).SQL(fmt.Sprintf("SELECT max_index FROM %s WHERE group_id = ?", tableName), groupID).Get(&idx)
if err != nil {
return 0, err
}
if idx == 0 {
return 0, errors.New("cannot get the correct index")
}
return idx, nil
}
func mssqlGetNextResourceIndex(ctx context.Context, tableName string, groupID int64) (int64, error) {
if _, err := GetEngine(ctx).Exec(fmt.Sprintf(`
MERGE INTO %s WITH (HOLDLOCK) AS target
USING (SELECT %d AS group_id) AS source
(group_id)
ON target.group_id = source.group_id
WHEN MATCHED
THEN UPDATE
SET max_index = max_index + 1
WHEN NOT MATCHED
THEN INSERT (group_id, max_index)
VALUES (%d, 1);
`, tableName, groupID, groupID)); err != nil {
return 0, err
}
var idx int64
_, err := GetEngine(ctx).SQL(fmt.Sprintf("SELECT max_index FROM %s WHERE group_id = ?", tableName), groupID).Get(&idx)
if err != nil {
return 0, err
}
if idx == 0 {
return 0, errors.New("cannot get the correct index")
}
return idx, nil
}
// GetNextResourceIndex generates a resource index, it must run in the same transaction where the resource is created
func GetNextResourceIndex(ctx context.Context, tableName string, groupID int64) (int64, error) {
switch {
case setting.Database.Type.IsPostgreSQL():
return postgresGetNextResourceIndex(ctx, tableName, groupID)
case setting.Database.Type.IsMySQL():
return mysqlGetNextResourceIndex(ctx, tableName, groupID)
case setting.Database.Type.IsMSSQL():
return mssqlGetNextResourceIndex(ctx, tableName, groupID)
}
e := GetEngine(ctx)
// try to update the max_index to next value, and acquire the write-lock for the record
res, err := e.Exec(fmt.Sprintf("UPDATE %s SET max_index=max_index+1 WHERE group_id=?", tableName), groupID)
if err != nil {
return 0, err
}
affected, err := res.RowsAffected()
if err != nil {
return 0, err
}
if affected == 0 {
// this slow path is only for the first time of creating a resource index
_, errIns := e.Exec(fmt.Sprintf("INSERT INTO %s (group_id, max_index) VALUES (?, 0)", tableName), groupID)
res, err = e.Exec(fmt.Sprintf("UPDATE %s SET max_index=max_index+1 WHERE group_id=?", tableName), groupID)
if err != nil {
return 0, err
}
affected, err = res.RowsAffected()
if err != nil {
return 0, err
}
// if the update still can not update any records, the record must not exist and there must be some errors (insert error)
if affected == 0 {
if errIns == nil {
return 0, errors.New("impossible error when GetNextResourceIndex, insert and update both succeeded but no record is updated")
}
return 0, errIns
}
}
// now, the new index is in database (protected by the transaction and write-lock)
var newIdx int64
has, err := e.SQL(fmt.Sprintf("SELECT max_index FROM %s WHERE group_id=?", tableName), groupID).Get(&newIdx)
if err != nil {
return 0, err
}
if !has {
return 0, errors.New("impossible error when GetNextResourceIndex, upsert succeeded but no record can be selected")
}
return newIdx, nil
}
// DeleteResourceIndex delete resource index
func DeleteResourceIndex(ctx context.Context, tableName string, groupID int64) error {
_, err := Exec(ctx, fmt.Sprintf("DELETE FROM %s WHERE group_id=?", tableName), groupID)
return err
}
+126
View File
@@ -0,0 +1,126 @@
// Copyright 2022 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db_test
import (
"context"
"errors"
"fmt"
"testing"
"gitea.dev/models/db"
"gitea.dev/models/unittest"
"github.com/stretchr/testify/assert"
)
type TestIndex db.ResourceIndex
func getCurrentResourceIndex(ctx context.Context, tableName string, groupID int64) (int64, error) {
e := db.GetEngine(ctx)
var idx int64
has, err := e.SQL(fmt.Sprintf("SELECT max_index FROM %s WHERE group_id=?", tableName), groupID).Get(&idx)
if err != nil {
return 0, err
}
if !has {
return 0, errors.New("no record")
}
return idx, nil
}
func TestSyncMaxResourceIndex(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
xe := unittest.GetXORMEngine()
assert.NoError(t, xe.Sync(&TestIndex{}))
err := db.SyncMaxResourceIndex(t.Context(), "test_index", 10, 51)
assert.NoError(t, err)
// sync new max index
maxIndex, err := getCurrentResourceIndex(t.Context(), "test_index", 10)
assert.NoError(t, err)
assert.EqualValues(t, 51, maxIndex)
// smaller index doesn't change
err = db.SyncMaxResourceIndex(t.Context(), "test_index", 10, 30)
assert.NoError(t, err)
maxIndex, err = getCurrentResourceIndex(t.Context(), "test_index", 10)
assert.NoError(t, err)
assert.EqualValues(t, 51, maxIndex)
// larger index changes
err = db.SyncMaxResourceIndex(t.Context(), "test_index", 10, 62)
assert.NoError(t, err)
maxIndex, err = getCurrentResourceIndex(t.Context(), "test_index", 10)
assert.NoError(t, err)
assert.EqualValues(t, 62, maxIndex)
// commit transaction
err = db.WithTx(t.Context(), func(ctx context.Context) error {
err = db.SyncMaxResourceIndex(ctx, "test_index", 10, 73)
assert.NoError(t, err)
maxIndex, err = getCurrentResourceIndex(ctx, "test_index", 10)
assert.NoError(t, err)
assert.EqualValues(t, 73, maxIndex)
return nil
})
assert.NoError(t, err)
maxIndex, err = getCurrentResourceIndex(t.Context(), "test_index", 10)
assert.NoError(t, err)
assert.EqualValues(t, 73, maxIndex)
// rollback transaction
err = db.WithTx(t.Context(), func(ctx context.Context) error {
err = db.SyncMaxResourceIndex(ctx, "test_index", 10, 84)
maxIndex, err = getCurrentResourceIndex(ctx, "test_index", 10)
assert.NoError(t, err)
assert.EqualValues(t, 84, maxIndex)
return errors.New("test rollback")
})
assert.Error(t, err)
maxIndex, err = getCurrentResourceIndex(t.Context(), "test_index", 10)
assert.NoError(t, err)
assert.EqualValues(t, 73, maxIndex) // the max index doesn't change because the transaction was rolled back
}
func TestGetNextResourceIndex(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
xe := unittest.GetXORMEngine()
assert.NoError(t, xe.Sync(&TestIndex{}))
// create a new record
maxIndex, err := db.GetNextResourceIndex(t.Context(), "test_index", 20)
assert.NoError(t, err)
assert.EqualValues(t, 1, maxIndex)
// increase the existing record
maxIndex, err = db.GetNextResourceIndex(t.Context(), "test_index", 20)
assert.NoError(t, err)
assert.EqualValues(t, 2, maxIndex)
// commit transaction
err = db.WithTx(t.Context(), func(ctx context.Context) error {
maxIndex, err = db.GetNextResourceIndex(ctx, "test_index", 20)
assert.NoError(t, err)
assert.EqualValues(t, 3, maxIndex)
return nil
})
assert.NoError(t, err)
maxIndex, err = getCurrentResourceIndex(t.Context(), "test_index", 20)
assert.NoError(t, err)
assert.EqualValues(t, 3, maxIndex)
// rollback transaction
err = db.WithTx(t.Context(), func(ctx context.Context) error {
maxIndex, err = db.GetNextResourceIndex(ctx, "test_index", 20)
assert.NoError(t, err)
assert.EqualValues(t, 4, maxIndex)
return errors.New("test rollback")
})
assert.Error(t, err)
maxIndex, err = getCurrentResourceIndex(t.Context(), "test_index", 20)
assert.NoError(t, err)
assert.EqualValues(t, 3, maxIndex) // the max index doesn't change because the transaction was rolled back
}
+59
View File
@@ -0,0 +1,59 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package install
import (
"context"
"gitea.dev/models/db"
"gitea.dev/modules/setting"
)
// CheckDatabaseConnection checks the database connection
func CheckDatabaseConnection(ctx context.Context) error {
_, err := db.GetEngine(ctx).Exec("SELECT 1")
return err
}
// GetMigrationVersion gets the database migration version
func GetMigrationVersion(ctx context.Context) (int64, error) {
var installedDbVersion int64
x := db.GetEngine(ctx)
exist, err := x.IsTableExist("version")
if err != nil {
return 0, err
}
if !exist {
return 0, nil
}
_, err = x.Table("version").Cols("version").Get(&installedDbVersion)
if err != nil {
return 0, err
}
return installedDbVersion, nil
}
// HasPostInstallationUsers checks whether there are users after installation
func HasPostInstallationUsers(ctx context.Context) (bool, error) {
x := db.GetEngine(ctx)
exist, err := x.IsTableExist("user")
if err != nil {
return false, err
}
if !exist {
return false, nil
}
// if there are 2 or more users in database, we consider there are users created after installation
threshold := 2
if !setting.IsProd {
// to debug easily, with non-prod RUN_MODE, we only check the count to 1
threshold = 1
}
res, err := x.Table("user").Cols("id").Limit(threshold).Query()
if err != nil {
return false, err
}
return len(res) >= threshold, nil
}
+43
View File
@@ -0,0 +1,43 @@
// Copyright 2022 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db
import (
"context"
"gitea.dev/modules/setting"
"xorm.io/builder"
)
// Iterate iterates all the Bean object
func Iterate[Bean any](ctx context.Context, cond builder.Cond, f func(ctx context.Context, bean *Bean) error) error {
var start int
batchSize := setting.Database.IterateBufferSize
sess := GetEngine(ctx)
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
beans := make([]*Bean, 0, batchSize)
if cond != nil {
sess = sess.Where(cond)
}
if err := sess.Limit(batchSize, start).Find(&beans); err != nil {
return err
}
if len(beans) == 0 {
return nil
}
start += len(beans)
for _, bean := range beans {
if err := f(ctx, bean); err != nil {
return err
}
}
}
}
}
+44
View File
@@ -0,0 +1,44 @@
// Copyright 2022 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db_test
import (
"context"
"testing"
"gitea.dev/models/db"
repo_model "gitea.dev/models/repo"
"gitea.dev/models/unittest"
"github.com/stretchr/testify/assert"
)
func TestIterate(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
xe := unittest.GetXORMEngine()
assert.NoError(t, xe.Sync(&repo_model.RepoUnit{}))
cnt, err := db.GetEngine(t.Context()).Count(&repo_model.RepoUnit{})
assert.NoError(t, err)
var repoUnitCnt int
err = db.Iterate(t.Context(), nil, func(ctx context.Context, repo *repo_model.RepoUnit) error {
repoUnitCnt++
return nil
})
assert.NoError(t, err)
assert.EqualValues(t, cnt, repoUnitCnt)
err = db.Iterate(t.Context(), nil, func(ctx context.Context, repoUnit *repo_model.RepoUnit) error {
has, err := db.ExistByID[repo_model.RepoUnit](ctx, repoUnit.ID)
if err != nil {
return err
}
if !has {
return db.ErrNotExist{Resource: "repo_unit", ID: repoUnit.ID}
}
return nil
})
assert.NoError(t, err)
}
+214
View File
@@ -0,0 +1,214 @@
// Copyright 2020 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db
import (
"context"
"gitea.dev/modules/setting"
"xorm.io/builder"
)
const (
// DefaultMaxInSize represents default variables number on IN () in SQL
DefaultMaxInSize = 50
defaultFindSliceSize = 10
)
// Paginator is the base for different ListOptions types
type Paginator interface {
GetSkipTake() (skip, take int)
IsListAll() bool
}
// SetSessionPagination sets pagination for a database session
func SetSessionPagination(sess Engine, p Paginator) Session {
skip, take := p.GetSkipTake()
return sess.Limit(take, skip)
}
// ListOptions options to paginate results
type ListOptions struct {
PageSize int
Page int // start from 1
ListAll bool // if true, then PageSize and Page will not be taken
}
var ListOptionsAll = ListOptions{ListAll: true}
var (
_ Paginator = &ListOptions{}
_ FindOptions = ListOptions{}
)
// GetSkipTake returns the skip and take values
func (opts *ListOptions) GetSkipTake() (skip, take int) {
opts.SetDefaultValues()
return (opts.Page - 1) * opts.PageSize, opts.PageSize
}
func (opts ListOptions) GetPage() int {
return opts.Page
}
func (opts ListOptions) GetPageSize() int {
return opts.PageSize
}
// IsListAll indicates PageSize and Page will be ignored
func (opts ListOptions) IsListAll() bool {
return opts.ListAll
}
// SetDefaultValues sets default values
func (opts *ListOptions) SetDefaultValues() {
if opts.PageSize <= 0 {
opts.PageSize = setting.API.DefaultPagingNum
}
if opts.PageSize > setting.API.MaxResponseItems {
opts.PageSize = setting.API.MaxResponseItems
}
if opts.Page <= 0 {
opts.Page = 1
}
}
func (opts ListOptions) ToConds() builder.Cond {
return builder.NewCond()
}
// AbsoluteListOptions absolute options to paginate results
type AbsoluteListOptions struct {
skip int
take int
}
var _ Paginator = &AbsoluteListOptions{}
// NewAbsoluteListOptions creates a list option with applied limits
func NewAbsoluteListOptions(skip, take int) *AbsoluteListOptions {
if skip < 0 {
skip = 0
}
if take <= 0 {
take = setting.API.DefaultPagingNum
}
if take > setting.API.MaxResponseItems {
take = setting.API.MaxResponseItems
}
return &AbsoluteListOptions{skip, take}
}
// IsListAll will always return false
func (opts *AbsoluteListOptions) IsListAll() bool {
return false
}
// GetSkipTake returns the skip and take values
func (opts *AbsoluteListOptions) GetSkipTake() (skip, take int) {
return opts.skip, opts.take
}
// FindOptions represents a find options
type FindOptions interface {
GetPage() int
GetPageSize() int
IsListAll() bool
ToConds() builder.Cond
}
type JoinFunc func(sess Engine) error
type FindOptionsJoin interface {
ToJoins() []JoinFunc
}
type FindOptionsOrder interface {
ToOrders() string
}
// Find represents a common find function which accept an options interface
func Find[T any](ctx context.Context, opts FindOptions) ([]*T, error) {
sess := GetEngine(ctx).Where(opts.ToConds())
if joinOpt, ok := opts.(FindOptionsJoin); ok {
for _, joinFunc := range joinOpt.ToJoins() {
if err := joinFunc(sess); err != nil {
return nil, err
}
}
}
if orderOpt, ok := opts.(FindOptionsOrder); ok {
if order := orderOpt.ToOrders(); order != "" {
sess.OrderBy(order)
}
}
page, pageSize := opts.GetPage(), opts.GetPageSize()
if !opts.IsListAll() && pageSize > 0 {
if page == 0 {
page = 1
}
sess.Limit(pageSize, (page-1)*pageSize)
}
findPageSize := defaultFindSliceSize
if pageSize > 0 {
findPageSize = pageSize
}
objects := make([]*T, 0, findPageSize)
if err := sess.Find(&objects); err != nil {
return nil, err
}
return objects, nil
}
// Count represents a common count function which accept an options interface
func Count[T any](ctx context.Context, opts FindOptions) (int64, error) {
sess := GetEngine(ctx).Where(opts.ToConds())
if joinOpt, ok := opts.(FindOptionsJoin); ok {
for _, joinFunc := range joinOpt.ToJoins() {
if err := joinFunc(sess); err != nil {
return 0, err
}
}
}
var object T
return sess.Count(&object)
}
// FindAndCount represents a common findandcount function which accept an options interface
func FindAndCount[T any](ctx context.Context, opts FindOptions) ([]*T, int64, error) {
sess := GetEngine(ctx).Where(opts.ToConds())
page, pageSize := opts.GetPage(), opts.GetPageSize()
if !opts.IsListAll() && pageSize > 0 && page >= 1 {
sess.Limit(pageSize, (page-1)*pageSize)
}
if joinOpt, ok := opts.(FindOptionsJoin); ok {
for _, joinFunc := range joinOpt.ToJoins() {
if err := joinFunc(sess); err != nil {
return nil, 0, err
}
}
}
if orderOpt, ok := opts.(FindOptionsOrder); ok {
if order := orderOpt.ToOrders(); order != "" {
sess.OrderBy(order)
}
}
findPageSize := defaultFindSliceSize
if pageSize > 0 {
findPageSize = pageSize
}
objects := make([]*T, 0, findPageSize)
cnt, err := sess.FindAndCount(&objects)
if err != nil {
return nil, 0, err
}
return objects, cnt, nil
}
+52
View File
@@ -0,0 +1,52 @@
// Copyright 2023 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db_test
import (
"testing"
"gitea.dev/models/db"
repo_model "gitea.dev/models/repo"
"gitea.dev/models/unittest"
"github.com/stretchr/testify/assert"
"xorm.io/builder"
)
type mockListOptions struct {
db.ListOptions
}
func (opts mockListOptions) IsListAll() bool {
return true
}
func (opts mockListOptions) ToConds() builder.Cond {
return builder.NewCond()
}
func TestFind(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
xe := unittest.GetXORMEngine()
assert.NoError(t, xe.Sync(&repo_model.RepoUnit{}))
var repoUnitCount int
_, err := db.GetEngine(t.Context()).SQL("SELECT COUNT(*) FROM repo_unit").Get(&repoUnitCount)
assert.NoError(t, err)
assert.NotEmpty(t, repoUnitCount)
opts := mockListOptions{}
repoUnits, err := db.Find[repo_model.RepoUnit](t.Context(), opts)
assert.NoError(t, err)
assert.Len(t, repoUnits, repoUnitCount)
cnt, err := db.Count[repo_model.RepoUnit](t.Context(), opts)
assert.NoError(t, err)
assert.EqualValues(t, repoUnitCount, cnt)
repoUnits, newCnt, err := db.FindAndCount[repo_model.RepoUnit](t.Context(), opts)
assert.NoError(t, err)
assert.Equal(t, cnt, newCnt)
assert.Len(t, repoUnits, repoUnitCount)
}
+107
View File
@@ -0,0 +1,107 @@
// Copyright 2017 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db
import (
"fmt"
"sync/atomic"
"gitea.dev/modules/log"
xormlog "xorm.io/xorm/log"
)
// XORMLogBridge a logger bridge from Logger to xorm
type XORMLogBridge struct {
showSQL atomic.Bool
logger log.Logger
}
// NewXORMLogger inits a log bridge for xorm
func NewXORMLogger(showSQL bool) xormlog.Logger {
l := &XORMLogBridge{logger: log.GetLogger("xorm")}
l.showSQL.Store(showSQL)
return l
}
const stackLevel = 8
// Log a message with defined skip and at logging level
func (l *XORMLogBridge) Log(skip int, level log.Level, format string, v ...any) {
l.logger.Log(skip+1, &log.Event{Level: level}, format, v...)
}
// Debug show debug log
func (l *XORMLogBridge) Debug(v ...any) {
l.Log(stackLevel, log.DEBUG, "%s", fmt.Sprint(v...))
}
// Debugf show debug log
func (l *XORMLogBridge) Debugf(format string, v ...any) {
l.Log(stackLevel, log.DEBUG, format, v...)
}
// Error show error log
func (l *XORMLogBridge) Error(v ...any) {
l.Log(stackLevel, log.ERROR, "%s", fmt.Sprint(v...))
}
// Errorf show error log
func (l *XORMLogBridge) Errorf(format string, v ...any) {
l.Log(stackLevel, log.ERROR, format, v...)
}
// Info show information level log
func (l *XORMLogBridge) Info(v ...any) {
l.Log(stackLevel, log.INFO, "%s", fmt.Sprint(v...))
}
// Infof show information level log
func (l *XORMLogBridge) Infof(format string, v ...any) {
l.Log(stackLevel, log.INFO, format, v...)
}
// Warn show warning log
func (l *XORMLogBridge) Warn(v ...any) {
l.Log(stackLevel, log.WARN, "%s", fmt.Sprint(v...))
}
// Warnf show warning log
func (l *XORMLogBridge) Warnf(format string, v ...any) {
l.Log(stackLevel, log.WARN, format, v...)
}
// Level get logger level
func (l *XORMLogBridge) Level() xormlog.LogLevel {
switch l.logger.GetLevel() {
case log.TRACE, log.DEBUG:
return xormlog.LOG_DEBUG
case log.INFO:
return xormlog.LOG_INFO
case log.WARN:
return xormlog.LOG_WARNING
case log.ERROR:
return xormlog.LOG_ERR
case log.NONE:
return xormlog.LOG_OFF
}
return xormlog.LOG_UNKNOWN
}
// SetLevel set the logger level
func (l *XORMLogBridge) SetLevel(lvl xormlog.LogLevel) {
}
// ShowSQL set if record SQL
func (l *XORMLogBridge) ShowSQL(show ...bool) {
if len(show) == 0 {
show = []bool{true}
}
l.showSQL.Store(show[0])
}
// IsShowSQL if record SQL
func (l *XORMLogBridge) IsShowSQL() bool {
return l.showSQL.Load()
}
+17
View File
@@ -0,0 +1,17 @@
// Copyright 2020 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db_test
import (
"testing"
"gitea.dev/models/unittest"
_ "gitea.dev/models"
_ "gitea.dev/models/repo"
)
func TestMain(m *testing.M) {
unittest.MainTest(m)
}
+96
View File
@@ -0,0 +1,96 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db
import (
"fmt"
"slices"
"strings"
"unicode/utf8"
"gitea.dev/modules/util"
)
// ErrNameReserved represents a "reserved name" error.
type ErrNameReserved struct {
Name string
}
// IsErrNameReserved checks if an error is a ErrNameReserved.
func IsErrNameReserved(err error) bool {
_, ok := err.(ErrNameReserved)
return ok
}
func (err ErrNameReserved) Error() string {
return fmt.Sprintf("name is reserved [name: %s]", err.Name)
}
// Unwrap unwraps this as a ErrInvalid err
func (err ErrNameReserved) Unwrap() error {
return util.ErrInvalidArgument
}
// ErrNamePatternNotAllowed represents a "pattern not allowed" error.
type ErrNamePatternNotAllowed struct {
Pattern string
}
// IsErrNamePatternNotAllowed checks if an error is an ErrNamePatternNotAllowed.
func IsErrNamePatternNotAllowed(err error) bool {
_, ok := err.(ErrNamePatternNotAllowed)
return ok
}
func (err ErrNamePatternNotAllowed) Error() string {
return fmt.Sprintf("name pattern is not allowed [pattern: %s]", err.Pattern)
}
// Unwrap unwraps this as a ErrInvalid err
func (err ErrNamePatternNotAllowed) Unwrap() error {
return util.ErrInvalidArgument
}
// ErrNameCharsNotAllowed represents a "character not allowed in name" error.
type ErrNameCharsNotAllowed struct {
Name string
}
// IsErrNameCharsNotAllowed checks if an error is an ErrNameCharsNotAllowed.
func IsErrNameCharsNotAllowed(err error) bool {
_, ok := err.(ErrNameCharsNotAllowed)
return ok
}
func (err ErrNameCharsNotAllowed) Error() string {
return fmt.Sprintf("name is invalid [%s]: must be valid alpha or numeric or dash(-_) or dot characters", err.Name)
}
// Unwrap unwraps this as a ErrInvalid err
func (err ErrNameCharsNotAllowed) Unwrap() error {
return util.ErrInvalidArgument
}
// IsUsableName checks if name is reserved or pattern of name is not allowed
// based on given reserved names and patterns.
// Names are exact match, patterns can be a prefix or suffix match with placeholder '*'.
func IsUsableName(reservedNames, reservedPatterns []string, name string) error {
name = strings.TrimSpace(strings.ToLower(name))
if utf8.RuneCountInString(name) == 0 {
return util.NewInvalidArgumentErrorf("name is empty")
}
if slices.Contains(reservedNames, name) {
return ErrNameReserved{name}
}
for _, pat := range reservedPatterns {
if pat[0] == '*' && strings.HasSuffix(name, pat[1:]) ||
(pat[len(pat)-1] == '*' && strings.HasPrefix(name, pat[:len(pat)-1])) {
return ErrNamePatternNotAllowed{pat}
}
}
return nil
}
+14
View File
@@ -0,0 +1,14 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package paginator
import (
"testing"
"gitea.dev/models/unittest"
)
func TestMain(m *testing.M) {
unittest.MainTest(m)
}
+7
View File
@@ -0,0 +1,7 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package paginator
// dummy only. in the future, the models/db/list_options.go should be moved here to decouple from db package
// otherwise the unit test will cause cycle import
+59
View File
@@ -0,0 +1,59 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package paginator
import (
"testing"
"gitea.dev/models/db"
"gitea.dev/modules/setting"
"github.com/stretchr/testify/assert"
)
func TestPaginator(t *testing.T) {
cases := []struct {
db.Paginator
Skip int
Take int
Start int
End int
}{
{
Paginator: &db.ListOptions{Page: -1, PageSize: -1},
Skip: 0,
Take: setting.API.DefaultPagingNum,
Start: 0,
End: setting.API.DefaultPagingNum,
},
{
Paginator: &db.ListOptions{Page: 2, PageSize: 10},
Skip: 10,
Take: 10,
Start: 10,
End: 20,
},
{
Paginator: db.NewAbsoluteListOptions(-1, -1),
Skip: 0,
Take: setting.API.DefaultPagingNum,
Start: 0,
End: setting.API.DefaultPagingNum,
},
{
Paginator: db.NewAbsoluteListOptions(2, 10),
Skip: 2,
Take: 10,
Start: 2,
End: 12,
},
}
for _, c := range cases {
skip, take := c.Paginator.GetSkipTake()
assert.Equal(t, c.Skip, skip)
assert.Equal(t, c.Take, take)
}
}
+31
View File
@@ -0,0 +1,31 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db
// SearchOrderBy is used to sort the result
type SearchOrderBy string
func (s SearchOrderBy) String() string {
return string(s)
}
// Strings for sorting result
const (
SearchOrderByAlphabetically SearchOrderBy = "name ASC"
SearchOrderByAlphabeticallyReverse SearchOrderBy = "name DESC"
SearchOrderByLeastUpdated SearchOrderBy = "updated_unix ASC"
SearchOrderByRecentUpdated SearchOrderBy = "updated_unix DESC"
SearchOrderByOldest SearchOrderBy = "created_unix ASC"
SearchOrderByNewest SearchOrderBy = "created_unix DESC"
SearchOrderByID SearchOrderBy = "id ASC"
SearchOrderByIDReverse SearchOrderBy = "id DESC"
SearchOrderByStars SearchOrderBy = "num_stars ASC"
SearchOrderByStarsReverse SearchOrderBy = "num_stars DESC"
SearchOrderByForks SearchOrderBy = "num_forks ASC"
SearchOrderByForksReverse SearchOrderBy = "num_forks DESC"
)
// NoConditionID means a condition to filter the records which don't match any id.
// eg: "milestone_id=-1" means "find the items without any milestone.
const NoConditionID int64 = -1
+70
View File
@@ -0,0 +1,70 @@
// Copyright 2018 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db
import (
"context"
"fmt"
"regexp"
"gitea.dev/modules/setting"
)
// CountBadSequences looks for broken sequences from recreate-table mistakes
func CountBadSequences(_ context.Context) (int64, error) {
if !setting.Database.Type.IsPostgreSQL() {
return 0, nil
}
sess := xormEngine.NewSession()
defer sess.Close()
var sequences []string
schema := xormEngine.Dialect().URI().Schema
sess.Engine().SetSchema("")
if err := sess.Table("information_schema.sequences").Cols("sequence_name").Where("sequence_name LIKE 'tmp_recreate__%_id_seq%' AND sequence_catalog = ?", setting.Database.Name).Find(&sequences); err != nil {
return 0, err
}
sess.Engine().SetSchema(schema)
return int64(len(sequences)), nil
}
// FixBadSequences fixes for broken sequences from recreate-table mistakes
func FixBadSequences(_ context.Context) error {
if !setting.Database.Type.IsPostgreSQL() {
return nil
}
sess := xormEngine.NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
}
var sequences []string
schema := sess.Engine().Dialect().URI().Schema
sess.Engine().SetSchema("")
if err := sess.Table("information_schema.sequences").Cols("sequence_name").Where("sequence_name LIKE 'tmp_recreate__%_id_seq%' AND sequence_catalog = ?", setting.Database.Name).Find(&sequences); err != nil {
return err
}
sess.Engine().SetSchema(schema)
sequenceRegexp := regexp.MustCompile(`tmp_recreate__(\w+)_id_seq.*`)
for _, sequence := range sequences {
tableName := sequenceRegexp.FindStringSubmatch(sequence)[1]
newSequenceName := tableName + "_id_seq"
if _, err := sess.Exec(fmt.Sprintf("ALTER SEQUENCE `%s` RENAME TO `%s`", sequence, newSequenceName)); err != nil {
return err
}
if _, err := sess.Exec(fmt.Sprintf("SELECT setval('%s', COALESCE((SELECT MAX(id)+1 FROM `%s`), 1), false)", newSequenceName, tableName)); err != nil {
return err
}
}
return sess.Commit()
}