初始提交: Gitea 项目代码
This commit is contained in:
@@ -0,0 +1,190 @@
|
||||
// Copyright 2023 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"gitea.dev/modules/container"
|
||||
"gitea.dev/modules/log"
|
||||
"gitea.dev/modules/setting"
|
||||
|
||||
"xorm.io/xorm/schemas"
|
||||
)
|
||||
|
||||
type CheckCollationsResult struct {
|
||||
ExpectedCollation string
|
||||
AvailableCollation container.Set[string]
|
||||
DatabaseCollation string
|
||||
IsCollationCaseSensitive func(s string) bool
|
||||
CollationEquals func(a, b string) bool
|
||||
ExistingTableNumber int
|
||||
|
||||
InconsistentCollationColumns []string
|
||||
}
|
||||
|
||||
func findAvailableCollationsMySQL(x EngineMigration) (ret container.Set[string], err error) {
|
||||
var res []struct {
|
||||
Collation string
|
||||
}
|
||||
if err = x.SQL("SHOW COLLATION WHERE (Collation = 'utf8mb4_bin') OR (Collation LIKE '%\\_as\\_cs%')").Find(&res); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ret = make(container.Set[string], len(res))
|
||||
for _, r := range res {
|
||||
ret.Add(r.Collation)
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func findAvailableCollationsMSSQL(x EngineMigration) (ret container.Set[string], err error) {
|
||||
var res []struct {
|
||||
Name string
|
||||
}
|
||||
if err = x.SQL("SELECT * FROM sys.fn_helpcollations() WHERE name LIKE '%[_]CS[_]AS%'").Find(&res); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ret = make(container.Set[string], len(res))
|
||||
for _, r := range res {
|
||||
ret.Add(r.Name)
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func CheckCollations(x EngineMigration) (*CheckCollationsResult, error) {
|
||||
dbTables, err := x.DBMetas()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res := &CheckCollationsResult{
|
||||
ExistingTableNumber: len(dbTables),
|
||||
CollationEquals: func(a, b string) bool { return a == b },
|
||||
}
|
||||
|
||||
var candidateCollations []string
|
||||
if x.Dialect().URI().DBType == schemas.MYSQL {
|
||||
_, err = x.SQL("SELECT DEFAULT_COLLATION_NAME FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = ?", setting.Database.Name).Get(&res.DatabaseCollation)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res.IsCollationCaseSensitive = func(s string) bool {
|
||||
return s == "utf8mb4_bin" || strings.HasSuffix(s, "_as_cs")
|
||||
}
|
||||
candidateCollations = []string{"utf8mb4_0900_as_cs", "uca1400_as_cs", "utf8mb4_bin"}
|
||||
res.AvailableCollation, err = findAvailableCollationsMySQL(x)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res.CollationEquals = func(a, b string) bool {
|
||||
// MariaDB adds the "utf8mb4_" prefix, eg: "utf8mb4_uca1400_as_cs", but not the name "uca1400_as_cs" in "SHOW COLLATION"
|
||||
// At the moment, it's safe to ignore the database difference, just trim the prefix and compare. It could be fixed easily if there is any problem in the future.
|
||||
return a == b || strings.TrimPrefix(a, "utf8mb4_") == strings.TrimPrefix(b, "utf8mb4_")
|
||||
}
|
||||
} else if x.Dialect().URI().DBType == schemas.MSSQL {
|
||||
if _, err = x.SQL("SELECT DATABASEPROPERTYEX(DB_NAME(), 'Collation')").Get(&res.DatabaseCollation); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res.IsCollationCaseSensitive = func(s string) bool {
|
||||
return strings.HasSuffix(s, "_CS_AS")
|
||||
}
|
||||
candidateCollations = []string{"Latin1_General_CS_AS"}
|
||||
res.AvailableCollation, err = findAvailableCollationsMSSQL(x)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
return nil, nil //nolint:nilnil // return nil for unsupported database types
|
||||
}
|
||||
|
||||
if res.DatabaseCollation == "" {
|
||||
return nil, errors.New("unable to get collation for current database")
|
||||
}
|
||||
|
||||
res.ExpectedCollation = setting.Database.CharsetCollation
|
||||
if res.ExpectedCollation == "" {
|
||||
for _, collation := range candidateCollations {
|
||||
if res.AvailableCollation.Contains(collation) {
|
||||
res.ExpectedCollation = collation
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if res.ExpectedCollation == "" {
|
||||
return nil, errors.New("unable to find a suitable collation for current database")
|
||||
}
|
||||
|
||||
allColumnsMatchExpected := true
|
||||
allColumnsMatchDatabase := true
|
||||
for _, table := range dbTables {
|
||||
for _, col := range table.Columns() {
|
||||
if col.Collation != "" {
|
||||
allColumnsMatchExpected = allColumnsMatchExpected && res.CollationEquals(col.Collation, res.ExpectedCollation)
|
||||
allColumnsMatchDatabase = allColumnsMatchDatabase && res.CollationEquals(col.Collation, res.DatabaseCollation)
|
||||
if !res.IsCollationCaseSensitive(col.Collation) || !res.CollationEquals(col.Collation, res.DatabaseCollation) {
|
||||
res.InconsistentCollationColumns = append(res.InconsistentCollationColumns, fmt.Sprintf("%s.%s", table.Name, col.Name))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// if all columns match expected collation or all match database collation, then it could also be considered as "consistent"
|
||||
if allColumnsMatchExpected || allColumnsMatchDatabase {
|
||||
res.InconsistentCollationColumns = nil
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func CheckCollationsDefaultEngine() (*CheckCollationsResult, error) {
|
||||
return CheckCollations(xormEngine)
|
||||
}
|
||||
|
||||
func alterDatabaseCollation(x EngineMigration, collation string) error {
|
||||
if x.Dialect().URI().DBType == schemas.MYSQL {
|
||||
_, err := x.Exec("ALTER DATABASE CHARACTER SET utf8mb4 COLLATE " + collation)
|
||||
return err
|
||||
} else if x.Dialect().URI().DBType == schemas.MSSQL {
|
||||
// TODO: MSSQL has many limitations on changing database collation, it could fail in many cases.
|
||||
_, err := x.Exec("ALTER DATABASE CURRENT COLLATE " + collation)
|
||||
return err
|
||||
}
|
||||
return errors.New("unsupported database type")
|
||||
}
|
||||
|
||||
// preprocessDatabaseCollation checks database & table column collation, and alter the database collation if needed
|
||||
func preprocessDatabaseCollation(x EngineMigration) {
|
||||
r, err := CheckCollations(x)
|
||||
if err != nil {
|
||||
log.Error("Failed to check database collation: %v", err)
|
||||
}
|
||||
if r == nil {
|
||||
return // no check result means the database doesn't need to do such check/process (at the moment ....)
|
||||
}
|
||||
|
||||
// try to alter database collation to expected if the database is empty, it might fail in some cases (and it isn't necessary to succeed)
|
||||
// at the moment, there is no "altering" solution for MSSQL, site admin should manually change the database collation
|
||||
if !r.CollationEquals(r.DatabaseCollation, r.ExpectedCollation) && r.ExistingTableNumber == 0 {
|
||||
if err = alterDatabaseCollation(x, r.ExpectedCollation); err != nil {
|
||||
log.Error("Failed to change database collation to %q: %v", r.ExpectedCollation, err)
|
||||
} else {
|
||||
_, _ = x.Exec("SELECT 1") // after "altering", MSSQL's session becomes invalid, so make a simple query to "refresh" the session
|
||||
if r, err = CheckCollations(x); err != nil {
|
||||
log.Error("Failed to check database collation again after altering: %v", err) // impossible case
|
||||
return
|
||||
}
|
||||
log.Warn("Current database has been altered to use collation %q", r.DatabaseCollation)
|
||||
}
|
||||
}
|
||||
|
||||
// check column collation, and show warning/error to end users -- no need to fatal, do not block the startup
|
||||
if !r.IsCollationCaseSensitive(r.DatabaseCollation) {
|
||||
log.Warn("Current database is using a case-insensitive collation %q, although Gitea could work with it, there might be some rare cases which don't work as expected.", r.DatabaseCollation)
|
||||
}
|
||||
|
||||
if len(r.InconsistentCollationColumns) > 0 {
|
||||
log.Error("There are %d table columns using inconsistent collation, they should use %q. Please go to admin panel Self Check page", len(r.InconsistentCollationColumns), r.DatabaseCollation)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
// Copyright 2022 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"gitea.dev/modules/setting"
|
||||
"gitea.dev/modules/util"
|
||||
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
// BuildCaseInsensitiveLike returns a case-insensitive LIKE condition for the given key and value.
|
||||
// Cast the search value and the database column value to the same case for case-insensitive matching.
|
||||
// * SQLite: only cast ASCII chars because it doesn't handle complete Unicode case folding
|
||||
// * Other databases: use database's string function, assuming that they are able to handle complete Unicode case folding correctly
|
||||
func BuildCaseInsensitiveLike(key, value string) builder.Cond {
|
||||
// ToLowerASCII is about 7% faster than ToUpperASCII (according to Golang's benchmark)
|
||||
if setting.Database.Type.IsSQLite3() {
|
||||
return builder.Like{"LOWER(" + key + ")", util.ToLowerASCII(value)}
|
||||
}
|
||||
return builder.Like{"LOWER(" + key + ")", strings.ToLower(value)}
|
||||
}
|
||||
|
||||
// BuildCaseInsensitiveIn returns a condition to check if the given value is in the given values case-insensitively.
|
||||
// See BuildCaseInsensitiveLike for more details
|
||||
func BuildCaseInsensitiveIn(key string, values []string) builder.Cond {
|
||||
incaseValues := make([]string, len(values))
|
||||
caseCast := strings.ToLower
|
||||
if setting.Database.Type.IsSQLite3() {
|
||||
caseCast = util.ToLowerASCII
|
||||
}
|
||||
for i, value := range values {
|
||||
incaseValues[i] = caseCast(value)
|
||||
}
|
||||
return builder.In("LOWER("+key+")", incaseValues)
|
||||
}
|
||||
|
||||
// BuilderDialect returns the xorm.Builder dialect of the engine
|
||||
func BuilderDialect() string {
|
||||
switch {
|
||||
case setting.Database.Type.IsMySQL():
|
||||
return builder.MYSQL
|
||||
case setting.Database.Type.IsSQLite3():
|
||||
return builder.SQLITE
|
||||
case setting.Database.Type.IsPostgreSQL():
|
||||
return builder.POSTGRES
|
||||
case setting.Database.Type.IsMSSQL():
|
||||
return builder.MSSQL
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,187 @@
|
||||
// Copyright 2026 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"gitea.dev/modules/setting"
|
||||
"gitea.dev/modules/util"
|
||||
)
|
||||
|
||||
type ConnOptions struct {
|
||||
Type setting.DatabaseType
|
||||
Host string
|
||||
Database string
|
||||
User string
|
||||
Passwd string
|
||||
Schema string
|
||||
SSLMode string
|
||||
|
||||
SQLitePath string
|
||||
SQLiteBusyTimeout int
|
||||
SQLiteJournalMode string
|
||||
}
|
||||
|
||||
type SQLiteConnStrOptions struct {
|
||||
FilePath string
|
||||
// how long a concurrent query can wait for others (milliseconds),
|
||||
// if timeout is reached, the error is something like "database is locked (SQLITE_BUSY)"
|
||||
BusyTimeout int
|
||||
JournalMode string
|
||||
}
|
||||
|
||||
func GlobalConnOptions() ConnOptions {
|
||||
return ConnOptions{
|
||||
Type: setting.Database.Type,
|
||||
Host: setting.Database.Host,
|
||||
Database: setting.Database.Name,
|
||||
User: setting.Database.User,
|
||||
Passwd: setting.Database.Passwd,
|
||||
Schema: setting.Database.Schema,
|
||||
SSLMode: setting.Database.SSLMode,
|
||||
|
||||
SQLitePath: setting.Database.Path,
|
||||
SQLiteBusyTimeout: setting.Database.SQLiteBusyTimeout,
|
||||
SQLiteJournalMode: setting.Database.SQLiteJournalMode,
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
sqlDriverPostgresSchema = "postgresschema"
|
||||
sqlDriverSQLite3 = "sqlite3" // although database type also has "sqlite3", they are different, for different purposes
|
||||
)
|
||||
|
||||
var makeSQLiteConnStr = func(opts SQLiteConnStrOptions) (string, string, error) {
|
||||
return "", "", errors.New(`this Gitea binary was not built with SQLite3 support, get an official release or rebuild with correct "-tags"`)
|
||||
}
|
||||
|
||||
func registerSQLiteConnStrMaker(fn func(opts SQLiteConnStrOptions) (string, string, error)) {
|
||||
if slices.Contains(setting.SupportedDatabaseTypes, setting.DatabaseTypeSQLite3) {
|
||||
panic("another sqlite3 driver has been registered")
|
||||
}
|
||||
setting.SupportedDatabaseTypes = append(setting.SupportedDatabaseTypes, setting.DatabaseTypeSQLite3)
|
||||
makeSQLiteConnStr = fn
|
||||
}
|
||||
|
||||
func ConnStrDefaultDatabase(opts ConnOptions) (string, string, error) {
|
||||
opts.Database, opts.Schema = "", ""
|
||||
return ConnStr(opts)
|
||||
}
|
||||
|
||||
func ConnStr(opts ConnOptions) (string, string, error) {
|
||||
switch {
|
||||
case opts.Type.IsMySQL():
|
||||
// use unix socket or tcp socket
|
||||
connType := util.Iif(strings.HasPrefix(opts.Host, "/"), "unix", "tcp")
|
||||
// allow (Postgres-inspired) default value to work in MySQL
|
||||
tls := util.Iif(opts.SSLMode == "disable", "false", opts.SSLMode)
|
||||
// in case the database name is a partial connection string which contains "?" parameters
|
||||
paramSep := util.Iif(strings.Contains(opts.Database, "?"), "&", "?")
|
||||
connStr := fmt.Sprintf("%s:%s@%s(%s)/%s%sparseTime=true&tls=%s", opts.User, opts.Passwd, connType, opts.Host, opts.Database, paramSep, tls)
|
||||
return "mysql", connStr, nil
|
||||
|
||||
case opts.Type.IsPostgreSQL():
|
||||
connStr := makePgSQLConnStr(opts.Host, opts.User, opts.Passwd, opts.Database, opts.SSLMode)
|
||||
driver := util.Iif(opts.Schema == "", "postgres", sqlDriverPostgresSchema)
|
||||
registerPostgresSchemaDriver()
|
||||
return driver, connStr, nil
|
||||
|
||||
case opts.Type.IsMSSQL():
|
||||
host, port := parseMSSQLHostPort(opts.Host)
|
||||
connStr := fmt.Sprintf("server=%s; port=%s; user id=%s; password=%s;", host, port, opts.User, opts.Passwd)
|
||||
if opts.Database != "" {
|
||||
connStr += "; database=" + opts.Database
|
||||
}
|
||||
return "mssql", connStr, nil
|
||||
|
||||
case opts.Type.IsSQLite3():
|
||||
if opts.SQLitePath == "" {
|
||||
return "", "", errors.New("sqlite3 database path cannot be empty")
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(opts.SQLitePath), os.ModePerm); err != nil {
|
||||
return "", "", fmt.Errorf("failed to create directories: %w", err)
|
||||
}
|
||||
return makeSQLiteConnStr(SQLiteConnStrOptions{
|
||||
FilePath: opts.SQLitePath,
|
||||
JournalMode: opts.SQLiteJournalMode,
|
||||
BusyTimeout: opts.SQLiteBusyTimeout,
|
||||
})
|
||||
}
|
||||
return "", "", fmt.Errorf("unknown database type: %s", opts.Type)
|
||||
}
|
||||
|
||||
// parsePgSQLHostPort parses given input in various forms defined in
|
||||
// https://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING
|
||||
// and returns proper host and port number.
|
||||
func parsePgSQLHostPort(info string) (host, port string) {
|
||||
if h, p, err := net.SplitHostPort(info); err == nil {
|
||||
host, port = h, p
|
||||
} else {
|
||||
// treat the "info" as "host", if it's an IPv6 address, remove the wrapper
|
||||
host = info
|
||||
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
|
||||
host = host[1 : len(host)-1]
|
||||
}
|
||||
}
|
||||
|
||||
// set fallback values
|
||||
if host == "" {
|
||||
host = "127.0.0.1"
|
||||
}
|
||||
if port == "" {
|
||||
port = "5432"
|
||||
}
|
||||
return host, port
|
||||
}
|
||||
|
||||
func makePgSQLConnStr(dbHost, dbUser, dbPasswd, dbName, dbsslMode string) (connStr string) {
|
||||
dbName, dbParam, _ := strings.Cut(dbName, "?")
|
||||
host, port := parsePgSQLHostPort(dbHost)
|
||||
connURL := url.URL{
|
||||
Scheme: "postgres",
|
||||
User: url.UserPassword(dbUser, dbPasswd),
|
||||
Host: net.JoinHostPort(host, port),
|
||||
Path: dbName,
|
||||
OmitHost: false,
|
||||
RawQuery: dbParam,
|
||||
}
|
||||
query := connURL.Query()
|
||||
if strings.HasPrefix(host, "/") { // looks like a unix socket
|
||||
query.Add("host", host)
|
||||
connURL.Host = ":" + port
|
||||
}
|
||||
query.Set("sslmode", dbsslMode)
|
||||
connURL.RawQuery = query.Encode()
|
||||
return connURL.String()
|
||||
}
|
||||
|
||||
// parseMSSQLHostPort splits the host into host and port
|
||||
func parseMSSQLHostPort(info string) (string, string) {
|
||||
// the default port "0" might be related to MSSQL's dynamic port, maybe it should be double-confirmed in the future
|
||||
host, port := "127.0.0.1", "0"
|
||||
if strings.Contains(info, ":") {
|
||||
host = strings.Split(info, ":")[0]
|
||||
port = strings.Split(info, ":")[1]
|
||||
} else if strings.Contains(info, ",") {
|
||||
host = strings.Split(info, ",")[0]
|
||||
port = strings.TrimSpace(strings.Split(info, ",")[1])
|
||||
} else if len(info) > 0 {
|
||||
host = info
|
||||
}
|
||||
if host == "" {
|
||||
host = "127.0.0.1"
|
||||
}
|
||||
if port == "" {
|
||||
port = "0"
|
||||
}
|
||||
return host, port
|
||||
}
|
||||
@@ -0,0 +1,109 @@
|
||||
// Copyright 2019 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestParsePgSQLHostPort(t *testing.T) {
|
||||
tests := map[string]struct {
|
||||
HostPort string
|
||||
Host string
|
||||
Port string
|
||||
}{
|
||||
"host-port": {
|
||||
HostPort: "127.0.0.1:1234",
|
||||
Host: "127.0.0.1",
|
||||
Port: "1234",
|
||||
},
|
||||
"no-port": {
|
||||
HostPort: "127.0.0.1",
|
||||
Host: "127.0.0.1",
|
||||
Port: "5432",
|
||||
},
|
||||
"ipv6-port": {
|
||||
HostPort: "[::1]:1234",
|
||||
Host: "::1",
|
||||
Port: "1234",
|
||||
},
|
||||
"ipv6-no-port": {
|
||||
HostPort: "[::1]",
|
||||
Host: "::1",
|
||||
Port: "5432",
|
||||
},
|
||||
"unix-socket": {
|
||||
HostPort: "/tmp/pg.sock:1234",
|
||||
Host: "/tmp/pg.sock",
|
||||
Port: "1234",
|
||||
},
|
||||
"unix-socket-no-port": {
|
||||
HostPort: "/tmp/pg.sock",
|
||||
Host: "/tmp/pg.sock",
|
||||
Port: "5432",
|
||||
},
|
||||
}
|
||||
for k, test := range tests {
|
||||
t.Run(k, func(t *testing.T) {
|
||||
t.Log(test.HostPort)
|
||||
host, port := parsePgSQLHostPort(test.HostPort)
|
||||
assert.Equal(t, test.Host, host)
|
||||
assert.Equal(t, test.Port, port)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMakePgSQLConnStr(t *testing.T) {
|
||||
tests := []struct {
|
||||
Host string
|
||||
User string
|
||||
Passwd string
|
||||
Name string
|
||||
SSLMode string
|
||||
Output string
|
||||
}{
|
||||
{
|
||||
Host: "", // empty means default
|
||||
Output: "postgres://:@127.0.0.1:5432?sslmode=",
|
||||
},
|
||||
{
|
||||
Host: "/tmp/pg.sock",
|
||||
User: "testuser",
|
||||
Passwd: "space space !#$%^^%^```-=?=",
|
||||
Name: "gitea",
|
||||
SSLMode: "false",
|
||||
Output: "postgres://testuser:space%20space%20%21%23$%25%5E%5E%25%5E%60%60%60-=%3F=@:5432/gitea?host=%2Ftmp%2Fpg.sock&sslmode=false",
|
||||
},
|
||||
{
|
||||
Host: "/tmp/pg.sock:6432",
|
||||
User: "testuser",
|
||||
Passwd: "pass",
|
||||
Name: "gitea",
|
||||
SSLMode: "false",
|
||||
Output: "postgres://testuser:pass@:6432/gitea?host=%2Ftmp%2Fpg.sock&sslmode=false",
|
||||
},
|
||||
{
|
||||
Host: "localhost",
|
||||
User: "pgsqlusername",
|
||||
Passwd: "I love Gitea!",
|
||||
Name: "gitea",
|
||||
SSLMode: "true",
|
||||
Output: "postgres://pgsqlusername:I%20love%20Gitea%21@localhost:5432/gitea?sslmode=true",
|
||||
},
|
||||
{
|
||||
Host: "localhost:1234",
|
||||
User: "user",
|
||||
Passwd: "pass",
|
||||
Name: "gitea?param=1",
|
||||
Output: "postgres://user:pass@localhost:1234/gitea?param=1&sslmode=",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
connStr := makePgSQLConnStr(test.Host, test.User, test.Passwd, test.Name, test.SSLMode)
|
||||
assert.Equal(t, test.Output, connStr)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
// Copyright 2022 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
// CountOrphanedObjects count subjects with have no existing refobject anymore
|
||||
func CountOrphanedObjects(ctx context.Context, subject, refObject, joinCond string) (int64, error) {
|
||||
return GetEngine(ctx).
|
||||
Table("`"+subject+"`").
|
||||
Join("LEFT", "`"+refObject+"`", joinCond).
|
||||
Where(builder.IsNull{"`" + refObject + "`.id"}).
|
||||
Select("COUNT(`" + subject + "`.`id`)").
|
||||
Count()
|
||||
}
|
||||
|
||||
// DeleteOrphanedObjects delete subjects with have no existing refobject anymore
|
||||
func DeleteOrphanedObjects(ctx context.Context, subject, refObject, joinCond string) error {
|
||||
subQuery := builder.Select("`"+subject+"`.id").
|
||||
From("`"+subject+"`").
|
||||
Join("LEFT", "`"+refObject+"`", joinCond).
|
||||
Where(builder.IsNull{"`" + refObject + "`.id"})
|
||||
b := builder.Delete(builder.In("id", subQuery)).From("`" + subject + "`")
|
||||
_, err := GetEngine(ctx).Exec(b)
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,320 @@
|
||||
// Copyright 2019 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"gitea.dev/modules/setting"
|
||||
|
||||
"xorm.io/builder"
|
||||
"xorm.io/xorm"
|
||||
)
|
||||
|
||||
type contextKey struct{ key string }
|
||||
|
||||
var (
|
||||
contextKeyEngine = contextKey{"engine"}
|
||||
ContextKeyTestFixtures = contextKey{"test-fixtures"}
|
||||
)
|
||||
|
||||
func withContextEngine(ctx context.Context, e Engine) context.Context {
|
||||
return context.WithValue(ctx, contextKeyEngine, e)
|
||||
}
|
||||
|
||||
var (
|
||||
contextSafetyOnce sync.Once
|
||||
contextSafetyDeniedFuncPCs []uintptr
|
||||
)
|
||||
|
||||
func contextSafetyCheck(e Engine) {
|
||||
if setting.IsProd && !setting.IsInTesting {
|
||||
return
|
||||
}
|
||||
if e == nil {
|
||||
return
|
||||
}
|
||||
// Only do this check for non-end-users. If the problem could be fixed in the future, this code could be removed.
|
||||
contextSafetyOnce.Do(func() {
|
||||
// try to figure out the bad functions to deny
|
||||
type m struct{}
|
||||
_ = e.SQL("SELECT 1").Iterate(&m{}, func(int, any) error {
|
||||
callers := make([]uintptr, 32)
|
||||
callerNum := runtime.Callers(1, callers)
|
||||
for i := range callerNum {
|
||||
if funcName := runtime.FuncForPC(callers[i]).Name(); funcName == "xorm.io/xorm.(*Session).Iterate" {
|
||||
contextSafetyDeniedFuncPCs = append(contextSafetyDeniedFuncPCs, callers[i])
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if len(contextSafetyDeniedFuncPCs) != 1 {
|
||||
panic(errors.New("unable to determine the functions to deny"))
|
||||
}
|
||||
})
|
||||
|
||||
// it should be very fast: xxxx ns/op
|
||||
callers := make([]uintptr, 32)
|
||||
callerNum := runtime.Callers(3, callers) // skip 3: runtime.Callers, contextSafetyCheck, GetEngine
|
||||
for i := range callerNum {
|
||||
if slices.Contains(contextSafetyDeniedFuncPCs, callers[i]) {
|
||||
panic(errors.New("using session context in an iterator would cause corrupted results"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetEngine gets an existing db Engine/Statement or creates a new Session
|
||||
func GetEngine(ctx context.Context) Engine {
|
||||
if engine, ok := ctx.Value(contextKeyEngine).(Engine); ok {
|
||||
// if reusing the existing session, need to do "contextSafetyCheck" because the Iterate creates a "autoResetStatement=false" session
|
||||
contextSafetyCheck(engine)
|
||||
return engine
|
||||
}
|
||||
// no need to do "contextSafetyCheck" because it's a new Session
|
||||
return xormEngine.Context(ctx)
|
||||
}
|
||||
|
||||
func GetXORMEngineForTesting() *xorm.Engine {
|
||||
return xormEngine
|
||||
}
|
||||
|
||||
// Committer represents an interface to Commit or Close the Context
|
||||
type Committer interface {
|
||||
Commit() error
|
||||
Close() error
|
||||
}
|
||||
|
||||
// halfCommitter is a wrapper of Committer.
|
||||
// It can be closed early, but can't be committed early, it is useful for reusing a transaction.
|
||||
type halfCommitter struct {
|
||||
committer Committer
|
||||
committed bool
|
||||
}
|
||||
|
||||
func (c *halfCommitter) Commit() error {
|
||||
c.committed = true
|
||||
// should do nothing, and the parent committer will commit later
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *halfCommitter) Close() error {
|
||||
if c.committed {
|
||||
// it's "commit and close", should do nothing, and the parent committer will commit later
|
||||
return nil
|
||||
}
|
||||
|
||||
// it's "rollback and close", let the parent committer rollback right now
|
||||
return c.committer.Close()
|
||||
}
|
||||
|
||||
// TxContext represents a transaction Context,
|
||||
// it will reuse the existing transaction in the parent context or create a new one.
|
||||
// Some tips to use:
|
||||
//
|
||||
// 1 It's always recommended to use `WithTx` in new code instead of `TxContext`, since `WithTx` will handle the transaction automatically.
|
||||
// 2. To maintain the old code which uses `TxContext`:
|
||||
// a. Always call `Close()` before returning regardless of whether `Commit()` has been called.
|
||||
// b. Always call `Commit()` before returning if there are no errors, even if the code did not change any data.
|
||||
// c. Remember the `Committer` will be a halfCommitter when a transaction is being reused.
|
||||
// So calling `Commit()` will do nothing, but calling `Close()` without calling `Commit()` will rollback the transaction.
|
||||
// And all operations submitted by the caller stack will be rollbacked as well, not only the operations in the current function.
|
||||
// d. It doesn't mean rollback is forbidden, but always do it only when there is an error, and you do want to rollback.
|
||||
func TxContext(parentCtx context.Context) (context.Context, Committer, error) {
|
||||
if sess := getTransactionSession(parentCtx); sess != nil {
|
||||
return withContextEngine(parentCtx, sess), &halfCommitter{committer: sess}, nil
|
||||
}
|
||||
|
||||
sess := xormEngine.NewSession()
|
||||
if err := sess.Begin(); err != nil {
|
||||
_ = sess.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
return withContextEngine(parentCtx, sess), sess, nil
|
||||
}
|
||||
|
||||
// WithTx represents executing database operations on a transaction, if the transaction exist,
|
||||
// this function will reuse it otherwise will create a new one and close it when finished.
|
||||
func WithTx(parentCtx context.Context, f func(ctx context.Context) error) error {
|
||||
if sess := getTransactionSession(parentCtx); sess != nil {
|
||||
err := f(withContextEngine(parentCtx, sess))
|
||||
if err != nil {
|
||||
// rollback immediately, in case the caller ignores returned error and tries to commit the transaction.
|
||||
_ = sess.Close()
|
||||
}
|
||||
return err
|
||||
}
|
||||
return txWithNoCheck(parentCtx, f)
|
||||
}
|
||||
|
||||
// WithTx2 is similar to WithTx, but it has two return values: result and error.
|
||||
func WithTx2[T any](parentCtx context.Context, f func(ctx context.Context) (T, error)) (ret T, errRet error) {
|
||||
errRet = WithTx(parentCtx, func(ctx context.Context) (errInner error) {
|
||||
ret, errInner = f(ctx)
|
||||
return errInner
|
||||
})
|
||||
return ret, errRet
|
||||
}
|
||||
|
||||
func txWithNoCheck(parentCtx context.Context, f func(ctx context.Context) error) error {
|
||||
sess := xormEngine.NewSession()
|
||||
defer sess.Close()
|
||||
if err := sess.Begin(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := f(withContextEngine(parentCtx, sess)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return sess.Commit()
|
||||
}
|
||||
|
||||
// Insert inserts records into database
|
||||
func Insert(ctx context.Context, beans ...any) error {
|
||||
_, err := GetEngine(ctx).Insert(beans...)
|
||||
return err
|
||||
}
|
||||
|
||||
// Exec executes a sql with args
|
||||
func Exec(ctx context.Context, sqlAndArgs ...any) (sql.Result, error) {
|
||||
return GetEngine(ctx).Exec(sqlAndArgs...)
|
||||
}
|
||||
|
||||
func Get[T any](ctx context.Context, cond builder.Cond) (object *T, exist bool, err error) {
|
||||
if !cond.IsValid() {
|
||||
panic("cond is invalid in db.Get(ctx, cond). This should not be possible.")
|
||||
}
|
||||
|
||||
var bean T
|
||||
has, err := GetEngine(ctx).Where(cond).NoAutoCondition().Get(&bean)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
} else if !has {
|
||||
return nil, false, nil
|
||||
}
|
||||
return &bean, true, nil
|
||||
}
|
||||
|
||||
func GetByID[T any](ctx context.Context, id int64) (object *T, exist bool, err error) {
|
||||
var bean T
|
||||
has, err := GetEngine(ctx).ID(id).NoAutoCondition().Get(&bean)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
} else if !has {
|
||||
return nil, false, nil
|
||||
}
|
||||
return &bean, true, nil
|
||||
}
|
||||
|
||||
func Exist[T any](ctx context.Context, cond builder.Cond) (bool, error) {
|
||||
if !cond.IsValid() {
|
||||
panic("cond is invalid in db.Exist(ctx, cond). This should not be possible.")
|
||||
}
|
||||
|
||||
var bean T
|
||||
return GetEngine(ctx).Where(cond).NoAutoCondition().Exist(&bean)
|
||||
}
|
||||
|
||||
func ExistByID[T any](ctx context.Context, id int64) (bool, error) {
|
||||
var bean T
|
||||
return GetEngine(ctx).ID(id).NoAutoCondition().Exist(&bean)
|
||||
}
|
||||
|
||||
// DeleteByID deletes the given bean with the given ID
|
||||
func DeleteByID[T any](ctx context.Context, id int64) (int64, error) {
|
||||
var bean T
|
||||
return GetEngine(ctx).ID(id).NoAutoCondition().NoAutoTime().Delete(&bean)
|
||||
}
|
||||
|
||||
func DeleteByIDs[T any](ctx context.Context, ids ...int64) error {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var bean T
|
||||
_, err := GetEngine(ctx).In("id", ids).NoAutoCondition().NoAutoTime().Delete(&bean)
|
||||
return err
|
||||
}
|
||||
|
||||
func Delete[T any](ctx context.Context, opts FindOptions) (int64, error) {
|
||||
if opts == nil || !opts.ToConds().IsValid() {
|
||||
panic("opts are empty or invalid in db.Delete(ctx, opts). This should not be possible.")
|
||||
}
|
||||
|
||||
var bean T
|
||||
return GetEngine(ctx).Where(opts.ToConds()).NoAutoCondition().NoAutoTime().Delete(&bean)
|
||||
}
|
||||
|
||||
// DeleteByBean deletes all records according non-empty fields of the bean as conditions.
|
||||
func DeleteByBean(ctx context.Context, bean any) (int64, error) {
|
||||
return GetEngine(ctx).Delete(bean)
|
||||
}
|
||||
|
||||
// FindIDs finds the IDs for the given table name satisfying the given condition
|
||||
// By passing a different value than "id" for "idCol", you can query for foreign IDs, i.e. the repo IDs which satisfy the condition
|
||||
func FindIDs(ctx context.Context, tableName, idCol string, cond builder.Cond) ([]int64, error) {
|
||||
ids := make([]int64, 0, 10)
|
||||
if err := GetEngine(ctx).Table(tableName).
|
||||
Cols(idCol).
|
||||
Where(cond).
|
||||
Find(&ids); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
// DecrByIDs decreases the given column for entities of the "bean" type with one of the given ids by one
|
||||
// Timestamps of the entities won't be updated
|
||||
func DecrByIDs(ctx context.Context, ids []int64, decrCol string, bean any) error {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
_, err := GetEngine(ctx).Decr(decrCol).In("id", ids).NoAutoCondition().NoAutoTime().Update(bean)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteBeans deletes all given beans, beans must contain delete conditions.
|
||||
func DeleteBeans(ctx context.Context, beans ...any) (err error) {
|
||||
e := GetEngine(ctx)
|
||||
for i := range beans {
|
||||
if _, err = e.Delete(beans[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TruncateBeans deletes all given beans, beans may contain delete conditions.
|
||||
func TruncateBeans(ctx context.Context, beans ...any) (err error) {
|
||||
e := GetEngine(ctx)
|
||||
for i := range beans {
|
||||
if _, err = e.Truncate(beans[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CountByBean counts the number of database records according non-empty fields of the bean as conditions.
|
||||
func CountByBean(ctx context.Context, bean any) (int64, error) {
|
||||
return GetEngine(ctx).Count(bean)
|
||||
}
|
||||
|
||||
// InTransaction returns true if the engine is in a transaction otherwise return false
|
||||
func InTransaction(ctx context.Context) bool {
|
||||
return getTransactionSession(ctx) != nil
|
||||
}
|
||||
|
||||
func getTransactionSession(ctx context.Context) *xorm.Session {
|
||||
e, _ := ctx.Value(contextKeyEngine).(Engine)
|
||||
if sess, ok := e.(*xorm.Session); ok && sess.IsInTx() {
|
||||
return sess
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,102 @@
|
||||
// Copyright 2022 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db // it's not db_test, because this file is for testing the private type halfCommitter
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type MockCommitter struct {
|
||||
wants []string
|
||||
gots []string
|
||||
}
|
||||
|
||||
func NewMockCommitter(wants ...string) *MockCommitter {
|
||||
return &MockCommitter{
|
||||
wants: wants,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *MockCommitter) Commit() error {
|
||||
c.gots = append(c.gots, "commit")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MockCommitter) Close() error {
|
||||
c.gots = append(c.gots, "close")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *MockCommitter) Assert(t *testing.T) {
|
||||
assert.Equal(t, c.wants, c.gots, "want operations %v, but got %v", c.wants, c.gots)
|
||||
}
|
||||
|
||||
func Test_halfCommitter(t *testing.T) {
|
||||
/*
|
||||
Do something like:
|
||||
|
||||
ctx, committer, err := db.TxContext(t.Context())
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
defer committer.Close()
|
||||
|
||||
// ...
|
||||
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ...
|
||||
|
||||
return committer.Commit()
|
||||
*/
|
||||
|
||||
testWithCommitter := func(committer Committer, f func(committer Committer) error) {
|
||||
if err := f(&halfCommitter{committer: committer}); err == nil {
|
||||
committer.Commit()
|
||||
}
|
||||
committer.Close()
|
||||
}
|
||||
|
||||
t.Run("commit and close", func(t *testing.T) {
|
||||
mockCommitter := NewMockCommitter("commit", "close")
|
||||
|
||||
testWithCommitter(mockCommitter, func(committer Committer) error {
|
||||
defer committer.Close()
|
||||
return committer.Commit()
|
||||
})
|
||||
|
||||
mockCommitter.Assert(t)
|
||||
})
|
||||
|
||||
t.Run("rollback and close", func(t *testing.T) {
|
||||
mockCommitter := NewMockCommitter("close", "close")
|
||||
|
||||
testWithCommitter(mockCommitter, func(committer Committer) error {
|
||||
defer committer.Close()
|
||||
if true {
|
||||
return errors.New("error")
|
||||
}
|
||||
return committer.Commit()
|
||||
})
|
||||
|
||||
mockCommitter.Assert(t)
|
||||
})
|
||||
|
||||
t.Run("close and commit", func(t *testing.T) {
|
||||
mockCommitter := NewMockCommitter("close", "close")
|
||||
|
||||
testWithCommitter(mockCommitter, func(committer Committer) error {
|
||||
committer.Close()
|
||||
committer.Commit()
|
||||
return errors.New("error")
|
||||
})
|
||||
|
||||
mockCommitter.Assert(t)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,135 @@
|
||||
// Copyright 2022 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"gitea.dev/models/db"
|
||||
"gitea.dev/models/unittest"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestInTransaction(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
assert.False(t, db.InTransaction(t.Context()))
|
||||
assert.NoError(t, db.WithTx(t.Context(), func(ctx context.Context) error {
|
||||
assert.True(t, db.InTransaction(ctx))
|
||||
return nil
|
||||
}))
|
||||
|
||||
ctx, committer, err := db.TxContext(t.Context())
|
||||
assert.NoError(t, err)
|
||||
defer committer.Close()
|
||||
assert.True(t, db.InTransaction(ctx))
|
||||
assert.NoError(t, db.WithTx(ctx, func(ctx context.Context) error {
|
||||
assert.True(t, db.InTransaction(ctx))
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
func TestTxContext(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
|
||||
{ // create new transaction
|
||||
ctx, committer, err := db.TxContext(t.Context())
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, db.InTransaction(ctx))
|
||||
assert.NoError(t, committer.Commit())
|
||||
}
|
||||
|
||||
{ // reuse the transaction created by TxContext and commit it
|
||||
ctx, committer, err := db.TxContext(t.Context())
|
||||
engine := db.GetEngine(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, db.InTransaction(ctx))
|
||||
{
|
||||
ctx, committer, err := db.TxContext(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, db.InTransaction(ctx))
|
||||
assert.Equal(t, engine, db.GetEngine(ctx))
|
||||
assert.NoError(t, committer.Commit())
|
||||
}
|
||||
assert.NoError(t, committer.Commit())
|
||||
}
|
||||
|
||||
{ // reuse the transaction created by TxContext and close it
|
||||
ctx, committer, err := db.TxContext(t.Context())
|
||||
engine := db.GetEngine(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, db.InTransaction(ctx))
|
||||
{
|
||||
ctx, committer, err := db.TxContext(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, db.InTransaction(ctx))
|
||||
assert.Equal(t, engine, db.GetEngine(ctx))
|
||||
assert.NoError(t, committer.Close())
|
||||
}
|
||||
assert.NoError(t, committer.Close())
|
||||
}
|
||||
|
||||
{ // reuse the transaction created by WithTx
|
||||
assert.NoError(t, db.WithTx(t.Context(), func(ctx context.Context) error {
|
||||
assert.True(t, db.InTransaction(ctx))
|
||||
{
|
||||
ctx, committer, err := db.TxContext(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, db.InTransaction(ctx))
|
||||
assert.NoError(t, committer.Commit())
|
||||
}
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextSafety(t *testing.T) {
|
||||
type TestModel1 struct {
|
||||
ID int64
|
||||
}
|
||||
type TestModel2 struct {
|
||||
ID int64
|
||||
}
|
||||
assert.NoError(t, unittest.GetXORMEngine().Sync(&TestModel1{}, &TestModel2{}))
|
||||
assert.NoError(t, db.TruncateBeans(t.Context(), &TestModel1{}, &TestModel2{}))
|
||||
testCount := 10
|
||||
for i := 1; i <= testCount; i++ {
|
||||
assert.NoError(t, db.Insert(t.Context(), &TestModel1{ID: int64(i)}))
|
||||
assert.NoError(t, db.Insert(t.Context(), &TestModel2{ID: int64(-i)}))
|
||||
}
|
||||
|
||||
t.Run("Show-XORM-Bug", func(t *testing.T) {
|
||||
actualCount := 0
|
||||
// here: db.GetEngine(t.Context()) is a new *Session created from *Engine
|
||||
_ = db.WithTx(t.Context(), func(ctx context.Context) error {
|
||||
_ = db.GetEngine(ctx).Iterate(&TestModel1{}, func(i int, bean any) error {
|
||||
// here: db.GetEngine(ctx) is always the unclosed "Iterate" *Session with autoResetStatement=false,
|
||||
// and the internal states (including "cond" and others) are always there and not be reset in this callback.
|
||||
m1 := bean.(*TestModel1)
|
||||
assert.EqualValues(t, i+1, m1.ID)
|
||||
|
||||
// here: XORM bug, it fails because the SQL becomes "WHERE id=-1", "WHERE id=-1 AND id=-2", "WHERE id=-1 AND id=-2 AND id=-3" ...
|
||||
// and it conflicts with the "Iterate"'s internal states.
|
||||
// has, err := db.GetEngine(ctx).Get(&TestModel2{ID: -m1.ID})
|
||||
|
||||
actualCount++
|
||||
return nil
|
||||
})
|
||||
return nil
|
||||
})
|
||||
assert.Equal(t, testCount, actualCount)
|
||||
})
|
||||
|
||||
t.Run("DenyBadUsage", func(t *testing.T) {
|
||||
assert.PanicsWithError(t, "using session context in an iterator would cause corrupted results", func() {
|
||||
_ = db.WithTx(t.Context(), func(ctx context.Context) error {
|
||||
return db.GetEngine(ctx).Iterate(&TestModel1{}, func(i int, bean any) error {
|
||||
_ = db.GetEngine(ctx)
|
||||
return nil
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,86 @@
|
||||
// Copyright 2019 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"gitea.dev/modules/setting"
|
||||
|
||||
"xorm.io/xorm"
|
||||
"xorm.io/xorm/convert"
|
||||
"xorm.io/xorm/schemas"
|
||||
)
|
||||
|
||||
// ConvertDatabaseTable converts database and tables from utf8 to utf8mb4 if it's mysql and set ROW_FORMAT=dynamic
|
||||
func ConvertDatabaseTable() error {
|
||||
if xormEngine.Dialect().URI().DBType != schemas.MYSQL {
|
||||
return nil
|
||||
}
|
||||
|
||||
r, err := CheckCollations(xormEngine)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = xormEngine.Exec(fmt.Sprintf("ALTER DATABASE `%s` CHARACTER SET utf8mb4 COLLATE %s", setting.Database.Name, r.ExpectedCollation))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tables, err := xormEngine.DBMetas()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, table := range tables {
|
||||
if _, err := xormEngine.Exec(fmt.Sprintf("ALTER TABLE `%s` ROW_FORMAT=dynamic", table.Name)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := xormEngine.Exec(fmt.Sprintf("ALTER TABLE `%s` CONVERT TO CHARACTER SET utf8mb4 COLLATE %s", table.Name, r.ExpectedCollation)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConvertVarcharToNVarchar converts database and tables from varchar to nvarchar if it's mssql
|
||||
func ConvertVarcharToNVarchar() error {
|
||||
if xormEngine.Dialect().URI().DBType != schemas.MSSQL {
|
||||
return nil
|
||||
}
|
||||
|
||||
sess := xormEngine.NewSession()
|
||||
defer sess.Close()
|
||||
res, err := sess.QuerySliceString(`SELECT 'ALTER TABLE ' + OBJECT_NAME(SC.object_id) + ' MODIFY SC.name NVARCHAR(' + CONVERT(VARCHAR(5),SC.max_length) + ')'
|
||||
FROM SYS.columns SC
|
||||
JOIN SYS.types ST
|
||||
ON SC.system_type_id = ST.system_type_id
|
||||
AND SC.user_type_id = ST.user_type_id
|
||||
WHERE ST.name ='varchar'`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, row := range res {
|
||||
if len(row) == 1 {
|
||||
if _, err = sess.Exec(row[0]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// CellToInt converts a xorm.Cell field value to an int value
|
||||
func CellToInt[T ~int | int64](cell xorm.Cell, def T) (ret T, has bool, err error) {
|
||||
if *cell == nil {
|
||||
return def, false, nil
|
||||
}
|
||||
val, err := convert.AsInt64(*cell)
|
||||
if err != nil {
|
||||
return def, false, err
|
||||
}
|
||||
return T(val), true, err
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
// Copyright 2020 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"sync"
|
||||
|
||||
"gitea.dev/modules/setting"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"xorm.io/xorm/dialects"
|
||||
)
|
||||
|
||||
var registerOnce sync.Once
|
||||
|
||||
func registerPostgresSchemaDriver() {
|
||||
registerOnce.Do(func() {
|
||||
sql.Register(sqlDriverPostgresSchema, &postgresSchemaDriver{})
|
||||
dialects.RegisterDriver(sqlDriverPostgresSchema, dialects.QueryDriver("postgres"))
|
||||
})
|
||||
}
|
||||
|
||||
type postgresSchemaDriver struct {
|
||||
pq.Driver
|
||||
}
|
||||
|
||||
// Open opens a new connection to the database. name is a connection string.
|
||||
// This function opens the postgres connection in the default manner but immediately
|
||||
// runs set_config to set the search_path appropriately
|
||||
func (d *postgresSchemaDriver) Open(name string) (driver.Conn, error) {
|
||||
conn, err := d.Driver.Open(name)
|
||||
if err != nil {
|
||||
return conn, err
|
||||
}
|
||||
schemaValue, _ := driver.String.ConvertValue(setting.Database.Schema)
|
||||
|
||||
// golangci lint is incorrect here - there is no benefit to using driver.ExecerContext here
|
||||
// and in any case pq does not implement it
|
||||
if execer, ok := conn.(driver.Execer); ok { //nolint:staticcheck // see above
|
||||
_, err := execer.Exec(`SELECT set_config(
|
||||
'search_path',
|
||||
$1 || ',' || current_setting('search_path'),
|
||||
false)`, []driver.Value{schemaValue})
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
stmt, err := conn.Prepare(`SELECT set_config(
|
||||
'search_path',
|
||||
$1 || ',' || current_setting('search_path'),
|
||||
false)`)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
// driver.String.ConvertValue will never return err for string
|
||||
|
||||
// golangci lint is incorrect here - there is no benefit to using stmt.ExecWithContext here
|
||||
_, err = stmt.Exec([]driver.Value{schemaValue}) //nolint:staticcheck // see above
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
// Copyright 2026 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
//go:build sqlite_mattn && sqlite_unlock_notify
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
func init() {
|
||||
registerSQLiteConnStrMaker(makeSQLiteConnStrMattnCGO)
|
||||
}
|
||||
|
||||
func makeSQLiteConnStrMattnCGO(opts SQLiteConnStrOptions) (string, string, error) {
|
||||
var params []string
|
||||
params = append(params, "cache=shared")
|
||||
params = append(params, "mode=rwc")
|
||||
params = append(params, "_busy_timeout="+strconv.Itoa(opts.BusyTimeout))
|
||||
params = append(params, "_txlock=immediate")
|
||||
if opts.JournalMode != "" {
|
||||
params = append(params, "_journal_mode="+opts.JournalMode)
|
||||
}
|
||||
connStr := fmt.Sprintf("file:%s?%s", opts.FilePath, strings.Join(params, "&"))
|
||||
return sqlDriverSQLite3, connStr, nil
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
// Copyright 2026 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
//go:build !sqlite_mattn
|
||||
|
||||
// modernc driver is chosen as the default one (compared to mattn, ncruces)
|
||||
// * mattn was used as default, but it requires CGO
|
||||
// * the CI times are almost the same for these three (race detector must be disabled)
|
||||
// * modernc increases the binary size about 2MB, ncruces increases about 7MB
|
||||
// * compiling time: modernc is slightly slower than mattn, ncruces is the slowest
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// this driver contains huge amount of Golang code, so it is much slower when "-race" check is enabled.
|
||||
registerSQLiteConnStrMaker(makeSQLiteConnStrModerncCCGO)
|
||||
sql.Register(sqlDriverSQLite3, &sqlite.Driver{})
|
||||
}
|
||||
|
||||
func makeSQLiteConnStrModerncCCGO(opts SQLiteConnStrOptions) (string, string, error) {
|
||||
var params []string
|
||||
// TODO: there is a changed behavior from mattn driver:
|
||||
// * mattn driver can wait for pretty long time for concurrent accesses (not limited by the busy timeout)
|
||||
// * but other drivers will report something like "database is locked (5) (SQLITE_BUSY)" if the timeout is reached
|
||||
// Maybe we need to relax the busy timeout to a reasonable long time in the future
|
||||
params = append(params, fmt.Sprintf("_pragma=busy_timeout(%d)", opts.BusyTimeout))
|
||||
params = append(params, "_txlock=immediate")
|
||||
if opts.JournalMode != "" {
|
||||
params = append(params, fmt.Sprintf("_pragma=journal_mode(%s)", opts.JournalMode))
|
||||
}
|
||||
connStr := fmt.Sprintf("file:%s?%s", opts.FilePath, strings.Join(params, "&"))
|
||||
return sqlDriverSQLite3, connStr, nil
|
||||
}
|
||||
Executable
+184
@@ -0,0 +1,184 @@
|
||||
// Copyright 2014 The Gogs Authors. All rights reserved.
|
||||
// Copyright 2018 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql" // Needed for the MySQL driver
|
||||
_ "github.com/lib/pq" // Needed for the Postgresql driver
|
||||
_ "github.com/microsoft/go-mssqldb" // Needed for the MSSQL driver
|
||||
|
||||
"xorm.io/xorm"
|
||||
"xorm.io/xorm/core"
|
||||
"xorm.io/xorm/dialects"
|
||||
"xorm.io/xorm/names"
|
||||
"xorm.io/xorm/schemas"
|
||||
)
|
||||
|
||||
var (
|
||||
xormEngine *xorm.Engine
|
||||
registeredModels []any
|
||||
registeredInitFuncs []func() error
|
||||
)
|
||||
|
||||
// Engine represents a xorm engine or session.
|
||||
type Engine interface {
|
||||
Table(tableNameOrBean any) *xorm.Session
|
||||
Count(...any) (int64, error)
|
||||
Decr(column string, arg ...any) *xorm.Session
|
||||
Delete(...any) (int64, error)
|
||||
Truncate(...any) (int64, error)
|
||||
Exec(...any) (sql.Result, error)
|
||||
Find(any, ...any) error
|
||||
FindAndCount(any, ...any) (int64, error)
|
||||
Get(beans ...any) (bool, error)
|
||||
ID(any) *xorm.Session
|
||||
In(string, ...any) *xorm.Session
|
||||
Incr(column string, arg ...any) *xorm.Session
|
||||
Insert(...any) (int64, error)
|
||||
Iterate(any, xorm.IterFunc) error
|
||||
Join(joinOperator string, tablename, condition any, args ...any) *xorm.Session
|
||||
SQL(any, ...any) *xorm.Session
|
||||
Where(any, ...any) *xorm.Session
|
||||
Asc(colNames ...string) *xorm.Session
|
||||
Desc(colNames ...string) *xorm.Session
|
||||
Limit(limit int, start ...int) *xorm.Session
|
||||
NoAutoTime() *xorm.Session
|
||||
SumInt(bean any, columnName string) (res int64, err error)
|
||||
Sync(...any) error
|
||||
Select(string) *xorm.Session
|
||||
SetExpr(string, any) *xorm.Session
|
||||
NotIn(string, ...any) *xorm.Session
|
||||
OrderBy(any, ...any) *xorm.Session
|
||||
Exist(...any) (bool, error)
|
||||
Distinct(...string) *xorm.Session
|
||||
Query(...any) ([]map[string][]byte, error)
|
||||
Cols(...string) *xorm.Session
|
||||
Context(ctx context.Context) *xorm.Session
|
||||
Ping() error
|
||||
IsTableExist(tableNameOrBean any) (bool, error)
|
||||
}
|
||||
|
||||
// Session represents a xorm session interface, used as an abstraction over *xorm.Session.
|
||||
type Session interface {
|
||||
Engine
|
||||
And(query any, args ...any) *xorm.Session
|
||||
Begin() error
|
||||
Close() error
|
||||
Commit() error
|
||||
IsInTx() bool
|
||||
Rollback() error
|
||||
Engine() *xorm.Engine
|
||||
}
|
||||
|
||||
// EngineMigration is a xorm engine interface used for migrations.
|
||||
// It extends Engine with additional methods that are only available on the engine (not on the session)
|
||||
// and are needed by the migration packages.
|
||||
type EngineMigration interface {
|
||||
Engine
|
||||
Close() error
|
||||
DB() *core.DB
|
||||
DBMetas() ([]*schemas.Table, error)
|
||||
Dialect() dialects.Dialect
|
||||
DropTables(beans ...any) error
|
||||
NewSession() *xorm.Session
|
||||
QueryInterface(sqlOrArgs ...any) ([]map[string]any, error)
|
||||
SetMapper(mapper names.Mapper)
|
||||
SyncWithOptions(opts xorm.SyncOptions, beans ...any) (*xorm.SyncResult, error)
|
||||
TableInfo(bean any) (*schemas.Table, error)
|
||||
TableName(bean any, includeSchema ...bool) string
|
||||
}
|
||||
|
||||
var (
|
||||
_ Engine = (*xorm.Engine)(nil)
|
||||
_ Engine = (*xorm.Session)(nil)
|
||||
_ Session = (*xorm.Session)(nil)
|
||||
_ EngineMigration = (*xorm.Engine)(nil)
|
||||
)
|
||||
|
||||
// RegisterModel registers model, if initFuncs provided, it will be invoked after data model sync
|
||||
func RegisterModel(bean any, initFunc ...func() error) {
|
||||
registeredModels = append(registeredModels, bean)
|
||||
if len(registeredInitFuncs) > 0 && initFunc[0] != nil {
|
||||
registeredInitFuncs = append(registeredInitFuncs, initFunc[0])
|
||||
}
|
||||
}
|
||||
|
||||
// SyncAllTables sync the schemas of all tables, is required by unit test code
|
||||
func SyncAllTables() error {
|
||||
_, err := xormEngine.StoreEngine("InnoDB").SyncWithOptions(xorm.SyncOptions{
|
||||
WarnIfDatabaseColumnMissed: true,
|
||||
}, registeredModels...)
|
||||
return err
|
||||
}
|
||||
|
||||
// NamesToBean return a list of beans or an error
|
||||
func NamesToBean(names ...string) ([]any, error) {
|
||||
beans := []any{}
|
||||
if len(names) == 0 {
|
||||
beans = append(beans, registeredModels...)
|
||||
return beans, nil
|
||||
}
|
||||
// Need to map provided names to beans...
|
||||
beanMap := make(map[string]any)
|
||||
for _, bean := range registeredModels {
|
||||
beanMap[strings.ToLower(reflect.Indirect(reflect.ValueOf(bean)).Type().Name())] = bean
|
||||
beanMap[strings.ToLower(xormEngine.TableName(bean))] = bean
|
||||
beanMap[strings.ToLower(xormEngine.TableName(bean, true))] = bean
|
||||
}
|
||||
|
||||
gotBean := make(map[any]bool)
|
||||
for _, name := range names {
|
||||
bean, ok := beanMap[strings.ToLower(strings.TrimSpace(name))]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no table found that matches: %s", name)
|
||||
}
|
||||
if !gotBean[bean] {
|
||||
beans = append(beans, bean)
|
||||
gotBean[bean] = true
|
||||
}
|
||||
}
|
||||
return beans, nil
|
||||
}
|
||||
|
||||
// MaxBatchInsertSize returns the table's max batch insert size
|
||||
func MaxBatchInsertSize(bean any) int {
|
||||
t, err := xormEngine.TableInfo(bean)
|
||||
if err != nil {
|
||||
return 50
|
||||
}
|
||||
return 999 / len(t.ColumnsSeq())
|
||||
}
|
||||
|
||||
// IsTableNotEmpty returns true if table has at least one record
|
||||
func IsTableNotEmpty(beanOrTableName any) (bool, error) {
|
||||
return xormEngine.Table(beanOrTableName).Exist()
|
||||
}
|
||||
|
||||
// DeleteAllRecords will delete all the records of this table
|
||||
func DeleteAllRecords(tableName string) error {
|
||||
_, err := xormEngine.Exec("DELETE FROM " + tableName)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetMaxID will return max id of the table
|
||||
func GetMaxID(beanOrTableName any) (maxID int64, err error) {
|
||||
_, err = xormEngine.Select("MAX(id)").Table(beanOrTableName).Get(&maxID)
|
||||
return maxID, err
|
||||
}
|
||||
|
||||
func SetLogSQL(ctx context.Context, on bool) {
|
||||
e := GetEngine(ctx)
|
||||
if x, ok := e.(*xorm.Engine); ok {
|
||||
x.ShowSQL(on)
|
||||
} else if sess, ok := e.(*xorm.Session); ok {
|
||||
sess.Engine().ShowSQL(on)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
// Copyright 2024 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"gitea.dev/modules/setting"
|
||||
|
||||
"xorm.io/xorm/schemas"
|
||||
)
|
||||
|
||||
// DumpDatabase dumps all data from database according the special database SQL syntax to file system.
|
||||
func DumpDatabase(filePath string, dbType setting.DatabaseType) error {
|
||||
var tbs []*schemas.Table
|
||||
for _, t := range registeredModels {
|
||||
t, err := xormEngine.TableInfo(t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tbs = append(tbs, t)
|
||||
}
|
||||
|
||||
type Version struct {
|
||||
ID int64 `xorm:"pk autoincr"`
|
||||
Version int64
|
||||
}
|
||||
t, err := xormEngine.TableInfo(&Version{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tbs = append(tbs, t)
|
||||
|
||||
if dbType != "" {
|
||||
return xormEngine.DumpTablesToFile(tbs, filePath, schemas.DBType(dbType))
|
||||
}
|
||||
return xormEngine.DumpTablesToFile(tbs, filePath)
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
// Copyright 2024 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"gitea.dev/modules/gtprof"
|
||||
"gitea.dev/modules/log"
|
||||
"gitea.dev/modules/setting"
|
||||
|
||||
"xorm.io/xorm/contexts"
|
||||
)
|
||||
|
||||
type EngineHook struct {
|
||||
Threshold time.Duration
|
||||
Logger log.Logger
|
||||
}
|
||||
|
||||
var _ contexts.Hook = (*EngineHook)(nil)
|
||||
|
||||
func (*EngineHook) BeforeProcess(c *contexts.ContextHook) (context.Context, error) {
|
||||
if c.Ctx.Value(ContextKeyTestFixtures) != nil {
|
||||
return c.Ctx, nil
|
||||
}
|
||||
ctx, _ := gtprof.GetTracer().Start(c.Ctx, gtprof.TraceSpanDatabase)
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
func (h *EngineHook) AfterProcess(c *contexts.ContextHook) error {
|
||||
if c.Ctx.Value(ContextKeyTestFixtures) != nil {
|
||||
return nil
|
||||
}
|
||||
span := gtprof.GetContextSpan(c.Ctx)
|
||||
if span != nil {
|
||||
// Do not record SQL parameters here:
|
||||
// * It shouldn't expose the parameters because they contain sensitive information, end users need to report the trace details safely.
|
||||
// * Some parameters contain quite long texts, waste memory and are difficult to display.
|
||||
span.SetAttributeString(gtprof.TraceAttrDbSQL, c.SQL)
|
||||
span.End()
|
||||
} else {
|
||||
setting.PanicInDevOrTesting("span in database engine hook is nil")
|
||||
}
|
||||
if c.ExecuteTime >= h.Threshold {
|
||||
// 8 is the amount of skips passed to runtime.Caller, so that in the log the correct function
|
||||
// is being displayed (the function that ultimately wants to execute the query in the code)
|
||||
// instead of the function of the slow query hook being called.
|
||||
h.Logger.Log(8, &log.Event{Level: log.WARN}, "[Slow SQL Query] %s %v - %v", c.SQL, c.Args, c.ExecuteTime)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,127 @@
|
||||
// Copyright 2024 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"gitea.dev/modules/log"
|
||||
"gitea.dev/modules/setting"
|
||||
|
||||
"xorm.io/xorm"
|
||||
"xorm.io/xorm/names"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gonicNames := []string{"SSL", "UID"}
|
||||
for _, name := range gonicNames {
|
||||
names.LintGonicMapper[name] = true
|
||||
}
|
||||
}
|
||||
|
||||
// newXORMEngine returns a new XORM engine from the configuration
|
||||
func newXORMEngine() (*xorm.Engine, error) {
|
||||
connOpts := GlobalConnOptions()
|
||||
driver, connStr, err := ConnStr(connOpts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
engine, err := xorm.NewEngine(driver, connStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch {
|
||||
case connOpts.Type.IsMySQL():
|
||||
engine.Dialect().SetParams(map[string]string{"rowFormat": "DYNAMIC"})
|
||||
case connOpts.Type.IsMSSQL():
|
||||
engine.Dialect().SetParams(map[string]string{"DEFAULT_VARCHAR": "nvarchar"})
|
||||
}
|
||||
engine.SetSchema(connOpts.Schema)
|
||||
return engine, nil
|
||||
}
|
||||
|
||||
// InitEngine initializes the xorm.Engine and sets it as XORM's default context
|
||||
func InitEngine(ctx context.Context) error {
|
||||
xe, err := newXORMEngine()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to init database engine: %w", err)
|
||||
}
|
||||
|
||||
xe.SetMapper(names.GonicMapper{})
|
||||
// WARNING: for serv command, MUST remove the output to os.stdout,
|
||||
// so use log file to instead print to stdout.
|
||||
xe.SetLogger(NewXORMLogger(setting.Database.LogSQL))
|
||||
xe.ShowSQL(setting.Database.LogSQL)
|
||||
xe.SetMaxOpenConns(setting.Database.MaxOpenConns)
|
||||
xe.SetMaxIdleConns(setting.Database.MaxIdleConns)
|
||||
xe.SetConnMaxLifetime(setting.Database.ConnMaxLifetime)
|
||||
|
||||
if setting.Database.SlowQueryThreshold > 0 {
|
||||
xe.AddHook(&EngineHook{
|
||||
Threshold: setting.Database.SlowQueryThreshold,
|
||||
Logger: log.GetLogger("xorm"),
|
||||
})
|
||||
}
|
||||
|
||||
SetDefaultEngine(ctx, xe)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetDefaultEngine sets the default engine for db
|
||||
func SetDefaultEngine(ctx context.Context, eng *xorm.Engine) {
|
||||
xormEngine = eng
|
||||
xormEngine.SetDefaultContext(ctx)
|
||||
}
|
||||
|
||||
// UnsetDefaultEngine closes and unsets the default engine
|
||||
// We hope the SetDefaultEngine and UnsetDefaultEngine can be paired, but it's impossible now,
|
||||
// there are many calls to InitEngine -> SetDefaultEngine directly to overwrite the `xormEngine` and `xormContext` without close
|
||||
// Global database engine related functions are all racy and there is no graceful close right now.
|
||||
func UnsetDefaultEngine() {
|
||||
if xormEngine != nil {
|
||||
_ = xormEngine.Close()
|
||||
xormEngine = nil
|
||||
}
|
||||
}
|
||||
|
||||
// InitEngineWithMigration initializes a new xorm.Engine and sets it as the XORM's default context
|
||||
// This function must never call .Sync() if the provided migration function fails.
|
||||
// When called from the "doctor" command, the migration function is a version check
|
||||
// that prevents the doctor from fixing anything in the database if the migration level
|
||||
// is different from the expected value.
|
||||
func InitEngineWithMigration(ctx context.Context, migrateFunc func(context.Context, EngineMigration) error) (err error) {
|
||||
if err = InitEngine(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = xormEngine.Ping(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
preprocessDatabaseCollation(xormEngine)
|
||||
|
||||
// We have to run migrateFunc here in case the user is re-running installation on a previously created DB.
|
||||
// If we do not then table schemas will be changed and there will be conflicts when the migrations run properly.
|
||||
//
|
||||
// Installation should only be being re-run if users want to recover an old database.
|
||||
// However, we should think carefully about should we support re-install on an installed instance,
|
||||
// as there may be other problems due to secret reinitialization.
|
||||
if err = migrateFunc(ctx, xormEngine); err != nil {
|
||||
return fmt.Errorf("migrate: %w", err)
|
||||
}
|
||||
|
||||
if err = SyncAllTables(); err != nil {
|
||||
return fmt.Errorf("sync database struct error: %w", err)
|
||||
}
|
||||
|
||||
for _, initFunc := range registeredInitFuncs {
|
||||
if err := initFunc(); err != nil {
|
||||
return fmt.Errorf("initFunc failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
// Copyright 2019 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db_test
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"gitea.dev/models/db"
|
||||
issues_model "gitea.dev/models/issues"
|
||||
"gitea.dev/models/unittest"
|
||||
"gitea.dev/modules/setting"
|
||||
|
||||
_ "gitea.dev/cmd" // for TestPrimaryKeys
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDumpDatabase(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
|
||||
dir := t.TempDir()
|
||||
|
||||
type Version struct {
|
||||
ID int64 `xorm:"pk autoincr"`
|
||||
Version int64
|
||||
}
|
||||
assert.NoError(t, db.GetEngine(t.Context()).Sync(new(Version)))
|
||||
|
||||
for _, dbType := range setting.SupportedDatabaseTypes {
|
||||
assert.NoError(t, db.DumpDatabase(filepath.Join(dir, dbType+".sql"), setting.DatabaseType(dbType)))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteOrphanedObjects(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
|
||||
countBefore, err := db.GetEngine(t.Context()).Count(&issues_model.PullRequest{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = db.GetEngine(t.Context()).Insert(&issues_model.PullRequest{IssueID: 1000}, &issues_model.PullRequest{IssueID: 1001}, &issues_model.PullRequest{IssueID: 1003})
|
||||
assert.NoError(t, err)
|
||||
|
||||
orphaned, err := db.CountOrphanedObjects(t.Context(), "pull_request", "issue", "pull_request.issue_id=issue.id")
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 3, orphaned)
|
||||
|
||||
err = db.DeleteOrphanedObjects(t.Context(), "pull_request", "issue", "pull_request.issue_id=issue.id")
|
||||
assert.NoError(t, err)
|
||||
|
||||
countAfter, err := db.GetEngine(t.Context()).Count(&issues_model.PullRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, countBefore, countAfter)
|
||||
}
|
||||
|
||||
func TestPrimaryKeys(t *testing.T) {
|
||||
// Some dbs require that all tables have primary keys, see
|
||||
// https://github.com/go-gitea/gitea/issues/21086
|
||||
// https://github.com/go-gitea/gitea/issues/16802
|
||||
// To avoid creating tables without primary key again, this test will check them.
|
||||
// Import "gitea.dev/cmd" to make sure each db.RegisterModel in init functions has been called.
|
||||
|
||||
beans, err := db.NamesToBean()
|
||||
require.NoError(t, err)
|
||||
|
||||
whitelist := map[string]string{
|
||||
"the_table_name_to_skip_checking": "Write a note here to explain why",
|
||||
}
|
||||
|
||||
for _, bean := range beans {
|
||||
table, err := db.GetXORMEngineForTesting().TableInfo(bean)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if why, ok := whitelist[table.Name]; ok {
|
||||
t.Logf("ignore %q because %q", table.Name, why)
|
||||
continue
|
||||
}
|
||||
assert.NotEmpty(t, table.PrimaryKeys, "table %q has no primary key", table.Name)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
// Copyright 2021 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"gitea.dev/modules/util"
|
||||
)
|
||||
|
||||
// ErrCancelled represents an error due to context cancellation
|
||||
type ErrCancelled struct {
|
||||
Message string
|
||||
}
|
||||
|
||||
// IsErrCancelled checks if an error is a ErrCancelled.
|
||||
func IsErrCancelled(err error) bool {
|
||||
_, ok := err.(ErrCancelled)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrCancelled) Error() string {
|
||||
return "Cancelled: " + err.Message
|
||||
}
|
||||
|
||||
// ErrCancelledf returns an ErrCancelled for the provided format and args
|
||||
func ErrCancelledf(format string, args ...any) error {
|
||||
return ErrCancelled{
|
||||
fmt.Sprintf(format, args...),
|
||||
}
|
||||
}
|
||||
|
||||
// ErrSSHDisabled represents an "SSH disabled" error.
|
||||
type ErrSSHDisabled struct{}
|
||||
|
||||
// IsErrSSHDisabled checks if an error is a ErrSSHDisabled.
|
||||
func IsErrSSHDisabled(err error) bool {
|
||||
_, ok := err.(ErrSSHDisabled)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrSSHDisabled) Error() string {
|
||||
return "SSH is disabled"
|
||||
}
|
||||
|
||||
// ErrNotExist represents a non-exist error.
|
||||
type ErrNotExist struct {
|
||||
Resource string
|
||||
ID int64
|
||||
}
|
||||
|
||||
// IsErrNotExist checks if an error is an ErrNotExist
|
||||
func IsErrNotExist(err error) bool {
|
||||
_, ok := err.(ErrNotExist)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrNotExist) Error() string {
|
||||
name := "record"
|
||||
if err.Resource != "" {
|
||||
name = err.Resource
|
||||
}
|
||||
|
||||
if err.ID != 0 {
|
||||
return fmt.Sprintf("%s does not exist [id: %d]", name, err.ID)
|
||||
}
|
||||
return name + " does not exist"
|
||||
}
|
||||
|
||||
// Unwrap unwraps this as a ErrNotExist err
|
||||
func (err ErrNotExist) Unwrap() error {
|
||||
return util.ErrNotExist
|
||||
}
|
||||
@@ -0,0 +1,172 @@
|
||||
// Copyright 2021 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"gitea.dev/modules/setting"
|
||||
)
|
||||
|
||||
// ResourceIndex represents a resource index which could be used as issue/release and others
|
||||
// We can create different tables i.e. issue_index, release_index, etc.
|
||||
type ResourceIndex struct {
|
||||
GroupID int64 `xorm:"pk"`
|
||||
MaxIndex int64 `xorm:"index"`
|
||||
}
|
||||
|
||||
var ErrGetResourceIndexFailed = errors.New("get resource index failed")
|
||||
|
||||
// SyncMaxResourceIndex sync the max index with the resource
|
||||
func SyncMaxResourceIndex(ctx context.Context, tableName string, groupID, maxIndex int64) (err error) {
|
||||
e := GetEngine(ctx)
|
||||
|
||||
// try to update the max_index and acquire the write-lock for the record
|
||||
res, err := e.Exec(fmt.Sprintf("UPDATE %s SET max_index=? WHERE group_id=? AND max_index<?", tableName), maxIndex, groupID, maxIndex)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
// if nothing is updated, the record might not exist or might be larger, it's safe to try to insert it again and then check whether the record exists
|
||||
_, errIns := e.Exec(fmt.Sprintf("INSERT INTO %s (group_id, max_index) VALUES (?, ?)", tableName), groupID, maxIndex)
|
||||
var savedIdx int64
|
||||
has, err := e.SQL(fmt.Sprintf("SELECT max_index FROM %s WHERE group_id=?", tableName), groupID).Get(&savedIdx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// if the record still doesn't exist, there must be some errors (insert error)
|
||||
if !has {
|
||||
if errIns == nil {
|
||||
return errors.New("impossible error when SyncMaxResourceIndex, insert succeeded but no record is saved")
|
||||
}
|
||||
return errIns
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func postgresGetNextResourceIndex(ctx context.Context, tableName string, groupID int64) (int64, error) {
|
||||
res, err := GetEngine(ctx).Query(fmt.Sprintf("INSERT INTO %s (group_id, max_index) "+
|
||||
"VALUES (?,1) ON CONFLICT (group_id) DO UPDATE SET max_index = %s.max_index+1 RETURNING max_index",
|
||||
tableName, tableName), groupID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(res) == 0 {
|
||||
return 0, ErrGetResourceIndexFailed
|
||||
}
|
||||
return strconv.ParseInt(string(res[0]["max_index"]), 10, 64)
|
||||
}
|
||||
|
||||
func mysqlGetNextResourceIndex(ctx context.Context, tableName string, groupID int64) (int64, error) {
|
||||
if _, err := GetEngine(ctx).Exec(fmt.Sprintf("INSERT INTO %s (group_id, max_index) "+
|
||||
"VALUES (?,1) ON DUPLICATE KEY UPDATE max_index = max_index+1",
|
||||
tableName), groupID); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var idx int64
|
||||
_, err := GetEngine(ctx).SQL(fmt.Sprintf("SELECT max_index FROM %s WHERE group_id = ?", tableName), groupID).Get(&idx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if idx == 0 {
|
||||
return 0, errors.New("cannot get the correct index")
|
||||
}
|
||||
return idx, nil
|
||||
}
|
||||
|
||||
func mssqlGetNextResourceIndex(ctx context.Context, tableName string, groupID int64) (int64, error) {
|
||||
if _, err := GetEngine(ctx).Exec(fmt.Sprintf(`
|
||||
MERGE INTO %s WITH (HOLDLOCK) AS target
|
||||
USING (SELECT %d AS group_id) AS source
|
||||
(group_id)
|
||||
ON target.group_id = source.group_id
|
||||
WHEN MATCHED
|
||||
THEN UPDATE
|
||||
SET max_index = max_index + 1
|
||||
WHEN NOT MATCHED
|
||||
THEN INSERT (group_id, max_index)
|
||||
VALUES (%d, 1);
|
||||
`, tableName, groupID, groupID)); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var idx int64
|
||||
_, err := GetEngine(ctx).SQL(fmt.Sprintf("SELECT max_index FROM %s WHERE group_id = ?", tableName), groupID).Get(&idx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if idx == 0 {
|
||||
return 0, errors.New("cannot get the correct index")
|
||||
}
|
||||
return idx, nil
|
||||
}
|
||||
|
||||
// GetNextResourceIndex generates a resource index, it must run in the same transaction where the resource is created
|
||||
func GetNextResourceIndex(ctx context.Context, tableName string, groupID int64) (int64, error) {
|
||||
switch {
|
||||
case setting.Database.Type.IsPostgreSQL():
|
||||
return postgresGetNextResourceIndex(ctx, tableName, groupID)
|
||||
case setting.Database.Type.IsMySQL():
|
||||
return mysqlGetNextResourceIndex(ctx, tableName, groupID)
|
||||
case setting.Database.Type.IsMSSQL():
|
||||
return mssqlGetNextResourceIndex(ctx, tableName, groupID)
|
||||
}
|
||||
|
||||
e := GetEngine(ctx)
|
||||
|
||||
// try to update the max_index to next value, and acquire the write-lock for the record
|
||||
res, err := e.Exec(fmt.Sprintf("UPDATE %s SET max_index=max_index+1 WHERE group_id=?", tableName), groupID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if affected == 0 {
|
||||
// this slow path is only for the first time of creating a resource index
|
||||
_, errIns := e.Exec(fmt.Sprintf("INSERT INTO %s (group_id, max_index) VALUES (?, 0)", tableName), groupID)
|
||||
res, err = e.Exec(fmt.Sprintf("UPDATE %s SET max_index=max_index+1 WHERE group_id=?", tableName), groupID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
affected, err = res.RowsAffected()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
// if the update still can not update any records, the record must not exist and there must be some errors (insert error)
|
||||
if affected == 0 {
|
||||
if errIns == nil {
|
||||
return 0, errors.New("impossible error when GetNextResourceIndex, insert and update both succeeded but no record is updated")
|
||||
}
|
||||
return 0, errIns
|
||||
}
|
||||
}
|
||||
|
||||
// now, the new index is in database (protected by the transaction and write-lock)
|
||||
var newIdx int64
|
||||
has, err := e.SQL(fmt.Sprintf("SELECT max_index FROM %s WHERE group_id=?", tableName), groupID).Get(&newIdx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !has {
|
||||
return 0, errors.New("impossible error when GetNextResourceIndex, upsert succeeded but no record can be selected")
|
||||
}
|
||||
return newIdx, nil
|
||||
}
|
||||
|
||||
// DeleteResourceIndex delete resource index
|
||||
func DeleteResourceIndex(ctx context.Context, tableName string, groupID int64) error {
|
||||
_, err := Exec(ctx, fmt.Sprintf("DELETE FROM %s WHERE group_id=?", tableName), groupID)
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,126 @@
|
||||
// Copyright 2022 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"gitea.dev/models/db"
|
||||
"gitea.dev/models/unittest"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type TestIndex db.ResourceIndex
|
||||
|
||||
func getCurrentResourceIndex(ctx context.Context, tableName string, groupID int64) (int64, error) {
|
||||
e := db.GetEngine(ctx)
|
||||
var idx int64
|
||||
has, err := e.SQL(fmt.Sprintf("SELECT max_index FROM %s WHERE group_id=?", tableName), groupID).Get(&idx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !has {
|
||||
return 0, errors.New("no record")
|
||||
}
|
||||
return idx, nil
|
||||
}
|
||||
|
||||
func TestSyncMaxResourceIndex(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
xe := unittest.GetXORMEngine()
|
||||
assert.NoError(t, xe.Sync(&TestIndex{}))
|
||||
|
||||
err := db.SyncMaxResourceIndex(t.Context(), "test_index", 10, 51)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// sync new max index
|
||||
maxIndex, err := getCurrentResourceIndex(t.Context(), "test_index", 10)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 51, maxIndex)
|
||||
|
||||
// smaller index doesn't change
|
||||
err = db.SyncMaxResourceIndex(t.Context(), "test_index", 10, 30)
|
||||
assert.NoError(t, err)
|
||||
maxIndex, err = getCurrentResourceIndex(t.Context(), "test_index", 10)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 51, maxIndex)
|
||||
|
||||
// larger index changes
|
||||
err = db.SyncMaxResourceIndex(t.Context(), "test_index", 10, 62)
|
||||
assert.NoError(t, err)
|
||||
maxIndex, err = getCurrentResourceIndex(t.Context(), "test_index", 10)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 62, maxIndex)
|
||||
|
||||
// commit transaction
|
||||
err = db.WithTx(t.Context(), func(ctx context.Context) error {
|
||||
err = db.SyncMaxResourceIndex(ctx, "test_index", 10, 73)
|
||||
assert.NoError(t, err)
|
||||
maxIndex, err = getCurrentResourceIndex(ctx, "test_index", 10)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 73, maxIndex)
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
maxIndex, err = getCurrentResourceIndex(t.Context(), "test_index", 10)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 73, maxIndex)
|
||||
|
||||
// rollback transaction
|
||||
err = db.WithTx(t.Context(), func(ctx context.Context) error {
|
||||
err = db.SyncMaxResourceIndex(ctx, "test_index", 10, 84)
|
||||
maxIndex, err = getCurrentResourceIndex(ctx, "test_index", 10)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 84, maxIndex)
|
||||
return errors.New("test rollback")
|
||||
})
|
||||
assert.Error(t, err)
|
||||
maxIndex, err = getCurrentResourceIndex(t.Context(), "test_index", 10)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 73, maxIndex) // the max index doesn't change because the transaction was rolled back
|
||||
}
|
||||
|
||||
func TestGetNextResourceIndex(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
xe := unittest.GetXORMEngine()
|
||||
assert.NoError(t, xe.Sync(&TestIndex{}))
|
||||
|
||||
// create a new record
|
||||
maxIndex, err := db.GetNextResourceIndex(t.Context(), "test_index", 20)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 1, maxIndex)
|
||||
|
||||
// increase the existing record
|
||||
maxIndex, err = db.GetNextResourceIndex(t.Context(), "test_index", 20)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 2, maxIndex)
|
||||
|
||||
// commit transaction
|
||||
err = db.WithTx(t.Context(), func(ctx context.Context) error {
|
||||
maxIndex, err = db.GetNextResourceIndex(ctx, "test_index", 20)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 3, maxIndex)
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
maxIndex, err = getCurrentResourceIndex(t.Context(), "test_index", 20)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 3, maxIndex)
|
||||
|
||||
// rollback transaction
|
||||
err = db.WithTx(t.Context(), func(ctx context.Context) error {
|
||||
maxIndex, err = db.GetNextResourceIndex(ctx, "test_index", 20)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 4, maxIndex)
|
||||
return errors.New("test rollback")
|
||||
})
|
||||
assert.Error(t, err)
|
||||
maxIndex, err = getCurrentResourceIndex(t.Context(), "test_index", 20)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 3, maxIndex) // the max index doesn't change because the transaction was rolled back
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
// Copyright 2021 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package install
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gitea.dev/models/db"
|
||||
"gitea.dev/modules/setting"
|
||||
)
|
||||
|
||||
// CheckDatabaseConnection checks the database connection
|
||||
func CheckDatabaseConnection(ctx context.Context) error {
|
||||
_, err := db.GetEngine(ctx).Exec("SELECT 1")
|
||||
return err
|
||||
}
|
||||
|
||||
// GetMigrationVersion gets the database migration version
|
||||
func GetMigrationVersion(ctx context.Context) (int64, error) {
|
||||
var installedDbVersion int64
|
||||
x := db.GetEngine(ctx)
|
||||
exist, err := x.IsTableExist("version")
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !exist {
|
||||
return 0, nil
|
||||
}
|
||||
_, err = x.Table("version").Cols("version").Get(&installedDbVersion)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return installedDbVersion, nil
|
||||
}
|
||||
|
||||
// HasPostInstallationUsers checks whether there are users after installation
|
||||
func HasPostInstallationUsers(ctx context.Context) (bool, error) {
|
||||
x := db.GetEngine(ctx)
|
||||
exist, err := x.IsTableExist("user")
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if !exist {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// if there are 2 or more users in database, we consider there are users created after installation
|
||||
threshold := 2
|
||||
if !setting.IsProd {
|
||||
// to debug easily, with non-prod RUN_MODE, we only check the count to 1
|
||||
threshold = 1
|
||||
}
|
||||
res, err := x.Table("user").Cols("id").Limit(threshold).Query()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return len(res) >= threshold, nil
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
// Copyright 2022 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gitea.dev/modules/setting"
|
||||
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
// Iterate iterates all the Bean object
|
||||
func Iterate[Bean any](ctx context.Context, cond builder.Cond, f func(ctx context.Context, bean *Bean) error) error {
|
||||
var start int
|
||||
batchSize := setting.Database.IterateBufferSize
|
||||
sess := GetEngine(ctx)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
beans := make([]*Bean, 0, batchSize)
|
||||
if cond != nil {
|
||||
sess = sess.Where(cond)
|
||||
}
|
||||
if err := sess.Limit(batchSize, start).Find(&beans); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(beans) == 0 {
|
||||
return nil
|
||||
}
|
||||
start += len(beans)
|
||||
|
||||
for _, bean := range beans {
|
||||
if err := f(ctx, bean); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
// Copyright 2022 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"gitea.dev/models/db"
|
||||
repo_model "gitea.dev/models/repo"
|
||||
"gitea.dev/models/unittest"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIterate(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
xe := unittest.GetXORMEngine()
|
||||
assert.NoError(t, xe.Sync(&repo_model.RepoUnit{}))
|
||||
|
||||
cnt, err := db.GetEngine(t.Context()).Count(&repo_model.RepoUnit{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
var repoUnitCnt int
|
||||
err = db.Iterate(t.Context(), nil, func(ctx context.Context, repo *repo_model.RepoUnit) error {
|
||||
repoUnitCnt++
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, cnt, repoUnitCnt)
|
||||
|
||||
err = db.Iterate(t.Context(), nil, func(ctx context.Context, repoUnit *repo_model.RepoUnit) error {
|
||||
has, err := db.ExistByID[repo_model.RepoUnit](ctx, repoUnit.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !has {
|
||||
return db.ErrNotExist{Resource: "repo_unit", ID: repoUnit.ID}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
@@ -0,0 +1,214 @@
|
||||
// Copyright 2020 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gitea.dev/modules/setting"
|
||||
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultMaxInSize represents default variables number on IN () in SQL
|
||||
DefaultMaxInSize = 50
|
||||
defaultFindSliceSize = 10
|
||||
)
|
||||
|
||||
// Paginator is the base for different ListOptions types
|
||||
type Paginator interface {
|
||||
GetSkipTake() (skip, take int)
|
||||
IsListAll() bool
|
||||
}
|
||||
|
||||
// SetSessionPagination sets pagination for a database session
|
||||
func SetSessionPagination(sess Engine, p Paginator) Session {
|
||||
skip, take := p.GetSkipTake()
|
||||
|
||||
return sess.Limit(take, skip)
|
||||
}
|
||||
|
||||
// ListOptions options to paginate results
|
||||
type ListOptions struct {
|
||||
PageSize int
|
||||
Page int // start from 1
|
||||
ListAll bool // if true, then PageSize and Page will not be taken
|
||||
}
|
||||
|
||||
var ListOptionsAll = ListOptions{ListAll: true}
|
||||
|
||||
var (
|
||||
_ Paginator = &ListOptions{}
|
||||
_ FindOptions = ListOptions{}
|
||||
)
|
||||
|
||||
// GetSkipTake returns the skip and take values
|
||||
func (opts *ListOptions) GetSkipTake() (skip, take int) {
|
||||
opts.SetDefaultValues()
|
||||
return (opts.Page - 1) * opts.PageSize, opts.PageSize
|
||||
}
|
||||
|
||||
func (opts ListOptions) GetPage() int {
|
||||
return opts.Page
|
||||
}
|
||||
|
||||
func (opts ListOptions) GetPageSize() int {
|
||||
return opts.PageSize
|
||||
}
|
||||
|
||||
// IsListAll indicates PageSize and Page will be ignored
|
||||
func (opts ListOptions) IsListAll() bool {
|
||||
return opts.ListAll
|
||||
}
|
||||
|
||||
// SetDefaultValues sets default values
|
||||
func (opts *ListOptions) SetDefaultValues() {
|
||||
if opts.PageSize <= 0 {
|
||||
opts.PageSize = setting.API.DefaultPagingNum
|
||||
}
|
||||
if opts.PageSize > setting.API.MaxResponseItems {
|
||||
opts.PageSize = setting.API.MaxResponseItems
|
||||
}
|
||||
if opts.Page <= 0 {
|
||||
opts.Page = 1
|
||||
}
|
||||
}
|
||||
|
||||
func (opts ListOptions) ToConds() builder.Cond {
|
||||
return builder.NewCond()
|
||||
}
|
||||
|
||||
// AbsoluteListOptions absolute options to paginate results
|
||||
type AbsoluteListOptions struct {
|
||||
skip int
|
||||
take int
|
||||
}
|
||||
|
||||
var _ Paginator = &AbsoluteListOptions{}
|
||||
|
||||
// NewAbsoluteListOptions creates a list option with applied limits
|
||||
func NewAbsoluteListOptions(skip, take int) *AbsoluteListOptions {
|
||||
if skip < 0 {
|
||||
skip = 0
|
||||
}
|
||||
if take <= 0 {
|
||||
take = setting.API.DefaultPagingNum
|
||||
}
|
||||
if take > setting.API.MaxResponseItems {
|
||||
take = setting.API.MaxResponseItems
|
||||
}
|
||||
return &AbsoluteListOptions{skip, take}
|
||||
}
|
||||
|
||||
// IsListAll will always return false
|
||||
func (opts *AbsoluteListOptions) IsListAll() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// GetSkipTake returns the skip and take values
|
||||
func (opts *AbsoluteListOptions) GetSkipTake() (skip, take int) {
|
||||
return opts.skip, opts.take
|
||||
}
|
||||
|
||||
// FindOptions represents a find options
|
||||
type FindOptions interface {
|
||||
GetPage() int
|
||||
GetPageSize() int
|
||||
IsListAll() bool
|
||||
ToConds() builder.Cond
|
||||
}
|
||||
|
||||
type JoinFunc func(sess Engine) error
|
||||
|
||||
type FindOptionsJoin interface {
|
||||
ToJoins() []JoinFunc
|
||||
}
|
||||
|
||||
type FindOptionsOrder interface {
|
||||
ToOrders() string
|
||||
}
|
||||
|
||||
// Find represents a common find function which accept an options interface
|
||||
func Find[T any](ctx context.Context, opts FindOptions) ([]*T, error) {
|
||||
sess := GetEngine(ctx).Where(opts.ToConds())
|
||||
|
||||
if joinOpt, ok := opts.(FindOptionsJoin); ok {
|
||||
for _, joinFunc := range joinOpt.ToJoins() {
|
||||
if err := joinFunc(sess); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
if orderOpt, ok := opts.(FindOptionsOrder); ok {
|
||||
if order := orderOpt.ToOrders(); order != "" {
|
||||
sess.OrderBy(order)
|
||||
}
|
||||
}
|
||||
|
||||
page, pageSize := opts.GetPage(), opts.GetPageSize()
|
||||
if !opts.IsListAll() && pageSize > 0 {
|
||||
if page == 0 {
|
||||
page = 1
|
||||
}
|
||||
sess.Limit(pageSize, (page-1)*pageSize)
|
||||
}
|
||||
|
||||
findPageSize := defaultFindSliceSize
|
||||
if pageSize > 0 {
|
||||
findPageSize = pageSize
|
||||
}
|
||||
objects := make([]*T, 0, findPageSize)
|
||||
if err := sess.Find(&objects); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return objects, nil
|
||||
}
|
||||
|
||||
// Count represents a common count function which accept an options interface
|
||||
func Count[T any](ctx context.Context, opts FindOptions) (int64, error) {
|
||||
sess := GetEngine(ctx).Where(opts.ToConds())
|
||||
if joinOpt, ok := opts.(FindOptionsJoin); ok {
|
||||
for _, joinFunc := range joinOpt.ToJoins() {
|
||||
if err := joinFunc(sess); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var object T
|
||||
return sess.Count(&object)
|
||||
}
|
||||
|
||||
// FindAndCount represents a common findandcount function which accept an options interface
|
||||
func FindAndCount[T any](ctx context.Context, opts FindOptions) ([]*T, int64, error) {
|
||||
sess := GetEngine(ctx).Where(opts.ToConds())
|
||||
page, pageSize := opts.GetPage(), opts.GetPageSize()
|
||||
if !opts.IsListAll() && pageSize > 0 && page >= 1 {
|
||||
sess.Limit(pageSize, (page-1)*pageSize)
|
||||
}
|
||||
if joinOpt, ok := opts.(FindOptionsJoin); ok {
|
||||
for _, joinFunc := range joinOpt.ToJoins() {
|
||||
if err := joinFunc(sess); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
}
|
||||
}
|
||||
if orderOpt, ok := opts.(FindOptionsOrder); ok {
|
||||
if order := orderOpt.ToOrders(); order != "" {
|
||||
sess.OrderBy(order)
|
||||
}
|
||||
}
|
||||
|
||||
findPageSize := defaultFindSliceSize
|
||||
if pageSize > 0 {
|
||||
findPageSize = pageSize
|
||||
}
|
||||
objects := make([]*T, 0, findPageSize)
|
||||
cnt, err := sess.FindAndCount(&objects)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return objects, cnt, nil
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
// Copyright 2023 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"gitea.dev/models/db"
|
||||
repo_model "gitea.dev/models/repo"
|
||||
"gitea.dev/models/unittest"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
type mockListOptions struct {
|
||||
db.ListOptions
|
||||
}
|
||||
|
||||
func (opts mockListOptions) IsListAll() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (opts mockListOptions) ToConds() builder.Cond {
|
||||
return builder.NewCond()
|
||||
}
|
||||
|
||||
func TestFind(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
xe := unittest.GetXORMEngine()
|
||||
assert.NoError(t, xe.Sync(&repo_model.RepoUnit{}))
|
||||
|
||||
var repoUnitCount int
|
||||
_, err := db.GetEngine(t.Context()).SQL("SELECT COUNT(*) FROM repo_unit").Get(&repoUnitCount)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, repoUnitCount)
|
||||
|
||||
opts := mockListOptions{}
|
||||
repoUnits, err := db.Find[repo_model.RepoUnit](t.Context(), opts)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, repoUnits, repoUnitCount)
|
||||
|
||||
cnt, err := db.Count[repo_model.RepoUnit](t.Context(), opts)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, repoUnitCount, cnt)
|
||||
|
||||
repoUnits, newCnt, err := db.FindAndCount[repo_model.RepoUnit](t.Context(), opts)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, cnt, newCnt)
|
||||
assert.Len(t, repoUnits, repoUnitCount)
|
||||
}
|
||||
@@ -0,0 +1,107 @@
|
||||
// Copyright 2017 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
|
||||
"gitea.dev/modules/log"
|
||||
|
||||
xormlog "xorm.io/xorm/log"
|
||||
)
|
||||
|
||||
// XORMLogBridge a logger bridge from Logger to xorm
|
||||
type XORMLogBridge struct {
|
||||
showSQL atomic.Bool
|
||||
logger log.Logger
|
||||
}
|
||||
|
||||
// NewXORMLogger inits a log bridge for xorm
|
||||
func NewXORMLogger(showSQL bool) xormlog.Logger {
|
||||
l := &XORMLogBridge{logger: log.GetLogger("xorm")}
|
||||
l.showSQL.Store(showSQL)
|
||||
return l
|
||||
}
|
||||
|
||||
const stackLevel = 8
|
||||
|
||||
// Log a message with defined skip and at logging level
|
||||
func (l *XORMLogBridge) Log(skip int, level log.Level, format string, v ...any) {
|
||||
l.logger.Log(skip+1, &log.Event{Level: level}, format, v...)
|
||||
}
|
||||
|
||||
// Debug show debug log
|
||||
func (l *XORMLogBridge) Debug(v ...any) {
|
||||
l.Log(stackLevel, log.DEBUG, "%s", fmt.Sprint(v...))
|
||||
}
|
||||
|
||||
// Debugf show debug log
|
||||
func (l *XORMLogBridge) Debugf(format string, v ...any) {
|
||||
l.Log(stackLevel, log.DEBUG, format, v...)
|
||||
}
|
||||
|
||||
// Error show error log
|
||||
func (l *XORMLogBridge) Error(v ...any) {
|
||||
l.Log(stackLevel, log.ERROR, "%s", fmt.Sprint(v...))
|
||||
}
|
||||
|
||||
// Errorf show error log
|
||||
func (l *XORMLogBridge) Errorf(format string, v ...any) {
|
||||
l.Log(stackLevel, log.ERROR, format, v...)
|
||||
}
|
||||
|
||||
// Info show information level log
|
||||
func (l *XORMLogBridge) Info(v ...any) {
|
||||
l.Log(stackLevel, log.INFO, "%s", fmt.Sprint(v...))
|
||||
}
|
||||
|
||||
// Infof show information level log
|
||||
func (l *XORMLogBridge) Infof(format string, v ...any) {
|
||||
l.Log(stackLevel, log.INFO, format, v...)
|
||||
}
|
||||
|
||||
// Warn show warning log
|
||||
func (l *XORMLogBridge) Warn(v ...any) {
|
||||
l.Log(stackLevel, log.WARN, "%s", fmt.Sprint(v...))
|
||||
}
|
||||
|
||||
// Warnf show warning log
|
||||
func (l *XORMLogBridge) Warnf(format string, v ...any) {
|
||||
l.Log(stackLevel, log.WARN, format, v...)
|
||||
}
|
||||
|
||||
// Level get logger level
|
||||
func (l *XORMLogBridge) Level() xormlog.LogLevel {
|
||||
switch l.logger.GetLevel() {
|
||||
case log.TRACE, log.DEBUG:
|
||||
return xormlog.LOG_DEBUG
|
||||
case log.INFO:
|
||||
return xormlog.LOG_INFO
|
||||
case log.WARN:
|
||||
return xormlog.LOG_WARNING
|
||||
case log.ERROR:
|
||||
return xormlog.LOG_ERR
|
||||
case log.NONE:
|
||||
return xormlog.LOG_OFF
|
||||
}
|
||||
return xormlog.LOG_UNKNOWN
|
||||
}
|
||||
|
||||
// SetLevel set the logger level
|
||||
func (l *XORMLogBridge) SetLevel(lvl xormlog.LogLevel) {
|
||||
}
|
||||
|
||||
// ShowSQL set if record SQL
|
||||
func (l *XORMLogBridge) ShowSQL(show ...bool) {
|
||||
if len(show) == 0 {
|
||||
show = []bool{true}
|
||||
}
|
||||
l.showSQL.Store(show[0])
|
||||
}
|
||||
|
||||
// IsShowSQL if record SQL
|
||||
func (l *XORMLogBridge) IsShowSQL() bool {
|
||||
return l.showSQL.Load()
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
// Copyright 2020 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"gitea.dev/models/unittest"
|
||||
|
||||
_ "gitea.dev/models"
|
||||
_ "gitea.dev/models/repo"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
unittest.MainTest(m)
|
||||
}
|
||||
@@ -0,0 +1,96 @@
|
||||
// Copyright 2021 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"gitea.dev/modules/util"
|
||||
)
|
||||
|
||||
// ErrNameReserved represents a "reserved name" error.
|
||||
type ErrNameReserved struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
// IsErrNameReserved checks if an error is a ErrNameReserved.
|
||||
func IsErrNameReserved(err error) bool {
|
||||
_, ok := err.(ErrNameReserved)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrNameReserved) Error() string {
|
||||
return fmt.Sprintf("name is reserved [name: %s]", err.Name)
|
||||
}
|
||||
|
||||
// Unwrap unwraps this as a ErrInvalid err
|
||||
func (err ErrNameReserved) Unwrap() error {
|
||||
return util.ErrInvalidArgument
|
||||
}
|
||||
|
||||
// ErrNamePatternNotAllowed represents a "pattern not allowed" error.
|
||||
type ErrNamePatternNotAllowed struct {
|
||||
Pattern string
|
||||
}
|
||||
|
||||
// IsErrNamePatternNotAllowed checks if an error is an ErrNamePatternNotAllowed.
|
||||
func IsErrNamePatternNotAllowed(err error) bool {
|
||||
_, ok := err.(ErrNamePatternNotAllowed)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrNamePatternNotAllowed) Error() string {
|
||||
return fmt.Sprintf("name pattern is not allowed [pattern: %s]", err.Pattern)
|
||||
}
|
||||
|
||||
// Unwrap unwraps this as a ErrInvalid err
|
||||
func (err ErrNamePatternNotAllowed) Unwrap() error {
|
||||
return util.ErrInvalidArgument
|
||||
}
|
||||
|
||||
// ErrNameCharsNotAllowed represents a "character not allowed in name" error.
|
||||
type ErrNameCharsNotAllowed struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
// IsErrNameCharsNotAllowed checks if an error is an ErrNameCharsNotAllowed.
|
||||
func IsErrNameCharsNotAllowed(err error) bool {
|
||||
_, ok := err.(ErrNameCharsNotAllowed)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrNameCharsNotAllowed) Error() string {
|
||||
return fmt.Sprintf("name is invalid [%s]: must be valid alpha or numeric or dash(-_) or dot characters", err.Name)
|
||||
}
|
||||
|
||||
// Unwrap unwraps this as a ErrInvalid err
|
||||
func (err ErrNameCharsNotAllowed) Unwrap() error {
|
||||
return util.ErrInvalidArgument
|
||||
}
|
||||
|
||||
// IsUsableName checks if name is reserved or pattern of name is not allowed
|
||||
// based on given reserved names and patterns.
|
||||
// Names are exact match, patterns can be a prefix or suffix match with placeholder '*'.
|
||||
func IsUsableName(reservedNames, reservedPatterns []string, name string) error {
|
||||
name = strings.TrimSpace(strings.ToLower(name))
|
||||
if utf8.RuneCountInString(name) == 0 {
|
||||
return util.NewInvalidArgumentErrorf("name is empty")
|
||||
}
|
||||
|
||||
if slices.Contains(reservedNames, name) {
|
||||
return ErrNameReserved{name}
|
||||
}
|
||||
|
||||
for _, pat := range reservedPatterns {
|
||||
if pat[0] == '*' && strings.HasSuffix(name, pat[1:]) ||
|
||||
(pat[len(pat)-1] == '*' && strings.HasPrefix(name, pat[:len(pat)-1])) {
|
||||
return ErrNamePatternNotAllowed{pat}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
// Copyright 2021 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package paginator
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"gitea.dev/models/unittest"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
unittest.MainTest(m)
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
// Copyright 2021 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package paginator
|
||||
|
||||
// dummy only. in the future, the models/db/list_options.go should be moved here to decouple from db package
|
||||
// otherwise the unit test will cause cycle import
|
||||
@@ -0,0 +1,59 @@
|
||||
// Copyright 2021 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package paginator
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"gitea.dev/models/db"
|
||||
"gitea.dev/modules/setting"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestPaginator(t *testing.T) {
|
||||
cases := []struct {
|
||||
db.Paginator
|
||||
Skip int
|
||||
Take int
|
||||
Start int
|
||||
End int
|
||||
}{
|
||||
{
|
||||
Paginator: &db.ListOptions{Page: -1, PageSize: -1},
|
||||
Skip: 0,
|
||||
Take: setting.API.DefaultPagingNum,
|
||||
Start: 0,
|
||||
End: setting.API.DefaultPagingNum,
|
||||
},
|
||||
{
|
||||
Paginator: &db.ListOptions{Page: 2, PageSize: 10},
|
||||
Skip: 10,
|
||||
Take: 10,
|
||||
Start: 10,
|
||||
End: 20,
|
||||
},
|
||||
{
|
||||
Paginator: db.NewAbsoluteListOptions(-1, -1),
|
||||
Skip: 0,
|
||||
Take: setting.API.DefaultPagingNum,
|
||||
Start: 0,
|
||||
End: setting.API.DefaultPagingNum,
|
||||
},
|
||||
{
|
||||
Paginator: db.NewAbsoluteListOptions(2, 10),
|
||||
Skip: 2,
|
||||
Take: 10,
|
||||
Start: 2,
|
||||
End: 12,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
skip, take := c.Paginator.GetSkipTake()
|
||||
|
||||
assert.Equal(t, c.Skip, skip)
|
||||
assert.Equal(t, c.Take, take)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
// Copyright 2021 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db
|
||||
|
||||
// SearchOrderBy is used to sort the result
|
||||
type SearchOrderBy string
|
||||
|
||||
func (s SearchOrderBy) String() string {
|
||||
return string(s)
|
||||
}
|
||||
|
||||
// Strings for sorting result
|
||||
const (
|
||||
SearchOrderByAlphabetically SearchOrderBy = "name ASC"
|
||||
SearchOrderByAlphabeticallyReverse SearchOrderBy = "name DESC"
|
||||
SearchOrderByLeastUpdated SearchOrderBy = "updated_unix ASC"
|
||||
SearchOrderByRecentUpdated SearchOrderBy = "updated_unix DESC"
|
||||
SearchOrderByOldest SearchOrderBy = "created_unix ASC"
|
||||
SearchOrderByNewest SearchOrderBy = "created_unix DESC"
|
||||
SearchOrderByID SearchOrderBy = "id ASC"
|
||||
SearchOrderByIDReverse SearchOrderBy = "id DESC"
|
||||
SearchOrderByStars SearchOrderBy = "num_stars ASC"
|
||||
SearchOrderByStarsReverse SearchOrderBy = "num_stars DESC"
|
||||
SearchOrderByForks SearchOrderBy = "num_forks ASC"
|
||||
SearchOrderByForksReverse SearchOrderBy = "num_forks DESC"
|
||||
)
|
||||
|
||||
// NoConditionID means a condition to filter the records which don't match any id.
|
||||
// eg: "milestone_id=-1" means "find the items without any milestone.
|
||||
const NoConditionID int64 = -1
|
||||
@@ -0,0 +1,70 @@
|
||||
// Copyright 2018 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
|
||||
"gitea.dev/modules/setting"
|
||||
)
|
||||
|
||||
// CountBadSequences looks for broken sequences from recreate-table mistakes
|
||||
func CountBadSequences(_ context.Context) (int64, error) {
|
||||
if !setting.Database.Type.IsPostgreSQL() {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
sess := xormEngine.NewSession()
|
||||
defer sess.Close()
|
||||
|
||||
var sequences []string
|
||||
schema := xormEngine.Dialect().URI().Schema
|
||||
|
||||
sess.Engine().SetSchema("")
|
||||
if err := sess.Table("information_schema.sequences").Cols("sequence_name").Where("sequence_name LIKE 'tmp_recreate__%_id_seq%' AND sequence_catalog = ?", setting.Database.Name).Find(&sequences); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
sess.Engine().SetSchema(schema)
|
||||
|
||||
return int64(len(sequences)), nil
|
||||
}
|
||||
|
||||
// FixBadSequences fixes for broken sequences from recreate-table mistakes
|
||||
func FixBadSequences(_ context.Context) error {
|
||||
if !setting.Database.Type.IsPostgreSQL() {
|
||||
return nil
|
||||
}
|
||||
|
||||
sess := xormEngine.NewSession()
|
||||
defer sess.Close()
|
||||
if err := sess.Begin(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var sequences []string
|
||||
schema := sess.Engine().Dialect().URI().Schema
|
||||
|
||||
sess.Engine().SetSchema("")
|
||||
if err := sess.Table("information_schema.sequences").Cols("sequence_name").Where("sequence_name LIKE 'tmp_recreate__%_id_seq%' AND sequence_catalog = ?", setting.Database.Name).Find(&sequences); err != nil {
|
||||
return err
|
||||
}
|
||||
sess.Engine().SetSchema(schema)
|
||||
|
||||
sequenceRegexp := regexp.MustCompile(`tmp_recreate__(\w+)_id_seq.*`)
|
||||
|
||||
for _, sequence := range sequences {
|
||||
tableName := sequenceRegexp.FindStringSubmatch(sequence)[1]
|
||||
newSequenceName := tableName + "_id_seq"
|
||||
if _, err := sess.Exec(fmt.Sprintf("ALTER SEQUENCE `%s` RENAME TO `%s`", sequence, newSequenceName)); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := sess.Exec(fmt.Sprintf("SELECT setval('%s', COALESCE((SELECT MAX(id)+1 FROM `%s`), 1), false)", newSequenceName, tableName)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return sess.Commit()
|
||||
}
|
||||
Reference in New Issue
Block a user