初始提交: Gitea 项目代码

This commit is contained in:
root
2026-05-30 22:47:36 +08:00
commit f288f76350
6116 changed files with 776822 additions and 0 deletions
+230
View File
@@ -0,0 +1,230 @@
// Copyright 2014 The Gogs Authors. All rights reserved.
// Copyright 2019 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package auth
import (
"context"
"crypto/subtle"
"encoding/hex"
"fmt"
"time"
"gitea.dev/models/db"
"gitea.dev/modules/setting"
"gitea.dev/modules/timeutil"
"gitea.dev/modules/util"
lru "github.com/hashicorp/golang-lru/v2"
"xorm.io/builder"
)
// ErrAccessTokenNotExist represents a "AccessTokenNotExist" kind of error.
type ErrAccessTokenNotExist struct {
Token string
}
// IsErrAccessTokenNotExist checks if an error is a ErrAccessTokenNotExist.
func IsErrAccessTokenNotExist(err error) bool {
_, ok := err.(ErrAccessTokenNotExist)
return ok
}
func (err ErrAccessTokenNotExist) Error() string {
return fmt.Sprintf("access token does not exist [sha: %s]", err.Token)
}
func (err ErrAccessTokenNotExist) Unwrap() error {
return util.ErrNotExist
}
// ErrAccessTokenEmpty represents a "AccessTokenEmpty" kind of error.
type ErrAccessTokenEmpty struct{}
// IsErrAccessTokenEmpty checks if an error is a ErrAccessTokenEmpty.
func IsErrAccessTokenEmpty(err error) bool {
_, ok := err.(ErrAccessTokenEmpty)
return ok
}
func (err ErrAccessTokenEmpty) Error() string {
return "access token is empty"
}
func (err ErrAccessTokenEmpty) Unwrap() error {
return util.ErrInvalidArgument
}
var successfulAccessTokenCache *lru.Cache[string, any]
// AccessToken represents a personal access token.
type AccessToken struct {
ID int64 `xorm:"pk autoincr"`
UID int64 `xorm:"INDEX"`
Name string
Token string `xorm:"-"`
TokenHash string `xorm:"UNIQUE"` // sha256 of token
TokenSalt string
TokenLastEight string `xorm:"INDEX token_last_eight"`
Scope AccessTokenScope
CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"`
UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
HasRecentActivity bool `xorm:"-"`
HasUsed bool `xorm:"-"`
}
// AfterLoad is invoked from XORM after setting the values of all fields of this object.
func (t *AccessToken) AfterLoad() {
t.HasUsed = t.UpdatedUnix > t.CreatedUnix
t.HasRecentActivity = t.UpdatedUnix.AddDuration(7*24*time.Hour) > timeutil.TimeStampNow()
}
func init() {
db.RegisterModel(new(AccessToken), func() error {
if setting.SuccessfulTokensCacheSize > 0 {
var err error
successfulAccessTokenCache, err = lru.New[string, any](setting.SuccessfulTokensCacheSize)
if err != nil {
return fmt.Errorf("unable to allocate AccessToken cache: %w", err)
}
} else {
successfulAccessTokenCache = nil
}
return nil
})
}
// NewAccessToken creates new access token.
func NewAccessToken(ctx context.Context, t *AccessToken) error {
salt := util.CryptoRandomString(10)
token := util.CryptoRandomBytes(20)
t.TokenSalt = salt
t.Token = hex.EncodeToString(token)
t.TokenHash = HashToken(t.Token, t.TokenSalt)
t.TokenLastEight = t.Token[len(t.Token)-8:]
_, err := db.GetEngine(ctx).Insert(t)
return err
}
// DisplayPublicOnly whether to display this as a public-only token.
func (t *AccessToken) DisplayPublicOnly() bool {
publicOnly, err := t.Scope.PublicOnly()
if err != nil {
return false
}
return publicOnly
}
func getAccessTokenIDFromCache(token string) int64 {
if successfulAccessTokenCache == nil {
return 0
}
tInterface, ok := successfulAccessTokenCache.Get(token)
if !ok {
return 0
}
t, ok := tInterface.(int64)
if !ok {
return 0
}
return t
}
// GetAccessTokenBySHA returns access token by given token value
func GetAccessTokenBySHA(ctx context.Context, token string) (*AccessToken, error) {
if token == "" {
return nil, ErrAccessTokenEmpty{}
}
// A token is defined as being SHA1 sum these are 40 hexadecimal bytes long
if len(token) != 40 {
return nil, ErrAccessTokenNotExist{token}
}
for _, x := range []byte(token) {
if x < '0' || (x > '9' && x < 'a') || x > 'f' {
return nil, ErrAccessTokenNotExist{token}
}
}
lastEight := token[len(token)-8:]
if id := getAccessTokenIDFromCache(token); id > 0 {
accessToken := &AccessToken{
TokenLastEight: lastEight,
}
// Re-get the token from the db in case it has been deleted in the intervening period
has, err := db.GetEngine(ctx).ID(id).Get(accessToken)
if err != nil {
return nil, err
}
if has {
return accessToken, nil
}
successfulAccessTokenCache.Remove(token)
}
var tokens []AccessToken
err := db.GetEngine(ctx).Table(&AccessToken{}).Where("token_last_eight = ?", lastEight).Find(&tokens)
if err != nil {
return nil, err
} else if len(tokens) == 0 {
return nil, ErrAccessTokenNotExist{token}
}
for _, t := range tokens {
tempHash := HashToken(token, t.TokenSalt)
if subtle.ConstantTimeCompare([]byte(t.TokenHash), []byte(tempHash)) == 1 {
if successfulAccessTokenCache != nil {
successfulAccessTokenCache.Add(token, t.ID)
}
return &t, nil
}
}
return nil, ErrAccessTokenNotExist{token}
}
// AccessTokenByNameExists checks if a token name has been used already by a user.
func AccessTokenByNameExists(ctx context.Context, token *AccessToken) (bool, error) {
return db.GetEngine(ctx).Table("access_token").Where("name = ?", token.Name).And("uid = ?", token.UID).Exist()
}
// ListAccessTokensOptions contain filter options
type ListAccessTokensOptions struct {
db.ListOptions
Name string
UserID int64
}
func (opts ListAccessTokensOptions) ToConds() builder.Cond {
cond := builder.NewCond()
// user id is required, otherwise it will return all result which maybe a possible bug
cond = cond.And(builder.Eq{"uid": opts.UserID})
if len(opts.Name) > 0 {
cond = cond.And(builder.Eq{"name": opts.Name})
}
return cond
}
func (opts ListAccessTokensOptions) ToOrders() string {
return "created_unix DESC"
}
// UpdateAccessToken updates information of access token.
func UpdateAccessToken(ctx context.Context, t *AccessToken) error {
_, err := db.GetEngine(ctx).ID(t.ID).AllCols().Update(t)
return err
}
// DeleteAccessTokenByID deletes access token by given ID.
func DeleteAccessTokenByID(ctx context.Context, id, userID int64) error {
cnt, err := db.GetEngine(ctx).ID(id).Delete(&AccessToken{
UID: userID,
})
if err != nil {
return err
} else if cnt != 1 {
return ErrAccessTokenNotExist{}
}
return nil
}
+377
View File
@@ -0,0 +1,377 @@
// Copyright 2022 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package auth
import (
"fmt"
"slices"
"strings"
"gitea.dev/models/perm"
)
// AccessTokenScopeCategory represents the scope category for an access token
type AccessTokenScopeCategory int
const (
AccessTokenScopeCategoryActivityPub AccessTokenScopeCategory = iota
AccessTokenScopeCategoryAdmin
AccessTokenScopeCategoryMisc // WARN: this is now just a placeholder, don't remove it which will change the following values
AccessTokenScopeCategoryNotification
AccessTokenScopeCategoryOrganization
AccessTokenScopeCategoryPackage
AccessTokenScopeCategoryIssue
AccessTokenScopeCategoryRepository
AccessTokenScopeCategoryUser
)
// AllAccessTokenScopeCategories contains all access token scope categories
var AllAccessTokenScopeCategories = []AccessTokenScopeCategory{
AccessTokenScopeCategoryActivityPub,
AccessTokenScopeCategoryAdmin,
AccessTokenScopeCategoryMisc,
AccessTokenScopeCategoryNotification,
AccessTokenScopeCategoryOrganization,
AccessTokenScopeCategoryPackage,
AccessTokenScopeCategoryIssue,
AccessTokenScopeCategoryRepository,
AccessTokenScopeCategoryUser,
}
// AccessTokenScopeLevel represents the access levels without a given scope category
type AccessTokenScopeLevel int
const (
NoAccess AccessTokenScopeLevel = iota
Read
Write
)
// AccessTokenScope represents the scope for an access token.
type AccessTokenScope string
// for all categories, write implies read
const (
AccessTokenScopeAll AccessTokenScope = "all"
AccessTokenScopePublicOnly AccessTokenScope = "public-only" // limited to public orgs/repos
AccessTokenScopeReadActivityPub AccessTokenScope = "read:activitypub"
AccessTokenScopeWriteActivityPub AccessTokenScope = "write:activitypub"
AccessTokenScopeReadAdmin AccessTokenScope = "read:admin"
AccessTokenScopeWriteAdmin AccessTokenScope = "write:admin"
AccessTokenScopeReadMisc AccessTokenScope = "read:misc"
AccessTokenScopeWriteMisc AccessTokenScope = "write:misc"
AccessTokenScopeReadNotification AccessTokenScope = "read:notification"
AccessTokenScopeWriteNotification AccessTokenScope = "write:notification"
AccessTokenScopeReadOrganization AccessTokenScope = "read:organization"
AccessTokenScopeWriteOrganization AccessTokenScope = "write:organization"
AccessTokenScopeReadPackage AccessTokenScope = "read:package"
AccessTokenScopeWritePackage AccessTokenScope = "write:package"
AccessTokenScopeReadIssue AccessTokenScope = "read:issue"
AccessTokenScopeWriteIssue AccessTokenScope = "write:issue"
AccessTokenScopeReadRepository AccessTokenScope = "read:repository"
AccessTokenScopeWriteRepository AccessTokenScope = "write:repository"
AccessTokenScopeReadUser AccessTokenScope = "read:user"
AccessTokenScopeWriteUser AccessTokenScope = "write:user"
)
// accessTokenScopeBitmap represents a bitmap of access token scopes.
type accessTokenScopeBitmap uint64
// Bitmap of each scope, including the child scopes.
const (
// AccessTokenScopeAllBits is the bitmap of all access token scopes
accessTokenScopeAllBits accessTokenScopeBitmap = accessTokenScopeWriteActivityPubBits |
accessTokenScopeWriteAdminBits | accessTokenScopeWriteMiscBits | accessTokenScopeWriteNotificationBits |
accessTokenScopeWriteOrganizationBits | accessTokenScopeWritePackageBits | accessTokenScopeWriteIssueBits |
accessTokenScopeWriteRepositoryBits | accessTokenScopeWriteUserBits
accessTokenScopePublicOnlyBits accessTokenScopeBitmap = 1 << iota
accessTokenScopeReadActivityPubBits accessTokenScopeBitmap = 1 << iota
accessTokenScopeWriteActivityPubBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadActivityPubBits
accessTokenScopeReadAdminBits accessTokenScopeBitmap = 1 << iota
accessTokenScopeWriteAdminBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadAdminBits
accessTokenScopeReadMiscBits accessTokenScopeBitmap = 1 << iota
accessTokenScopeWriteMiscBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadMiscBits
accessTokenScopeReadNotificationBits accessTokenScopeBitmap = 1 << iota
accessTokenScopeWriteNotificationBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadNotificationBits
accessTokenScopeReadOrganizationBits accessTokenScopeBitmap = 1 << iota
accessTokenScopeWriteOrganizationBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadOrganizationBits
accessTokenScopeReadPackageBits accessTokenScopeBitmap = 1 << iota
accessTokenScopeWritePackageBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadPackageBits
accessTokenScopeReadIssueBits accessTokenScopeBitmap = 1 << iota
accessTokenScopeWriteIssueBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadIssueBits
accessTokenScopeReadRepositoryBits accessTokenScopeBitmap = 1 << iota
accessTokenScopeWriteRepositoryBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadRepositoryBits
accessTokenScopeReadUserBits accessTokenScopeBitmap = 1 << iota
accessTokenScopeWriteUserBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadUserBits
// The current implementation only supports up to 64 token scopes.
// If we need to support > 64 scopes,
// refactoring the whole implementation in this file (and only this file) is needed.
)
// allAccessTokenScopes contains all access token scopes.
// The order is important: parent scope must precede child scopes.
var allAccessTokenScopes = []AccessTokenScope{
AccessTokenScopePublicOnly,
AccessTokenScopeWriteActivityPub, AccessTokenScopeReadActivityPub,
AccessTokenScopeWriteAdmin, AccessTokenScopeReadAdmin,
AccessTokenScopeWriteMisc, AccessTokenScopeReadMisc,
AccessTokenScopeWriteNotification, AccessTokenScopeReadNotification,
AccessTokenScopeWriteOrganization, AccessTokenScopeReadOrganization,
AccessTokenScopeWritePackage, AccessTokenScopeReadPackage,
AccessTokenScopeWriteIssue, AccessTokenScopeReadIssue,
AccessTokenScopeWriteRepository, AccessTokenScopeReadRepository,
AccessTokenScopeWriteUser, AccessTokenScopeReadUser,
}
// allAccessTokenScopeBits contains all access token scopes.
var allAccessTokenScopeBits = map[AccessTokenScope]accessTokenScopeBitmap{
AccessTokenScopeAll: accessTokenScopeAllBits,
AccessTokenScopePublicOnly: accessTokenScopePublicOnlyBits,
AccessTokenScopeReadActivityPub: accessTokenScopeReadActivityPubBits,
AccessTokenScopeWriteActivityPub: accessTokenScopeWriteActivityPubBits,
AccessTokenScopeReadAdmin: accessTokenScopeReadAdminBits,
AccessTokenScopeWriteAdmin: accessTokenScopeWriteAdminBits,
AccessTokenScopeReadMisc: accessTokenScopeReadMiscBits,
AccessTokenScopeWriteMisc: accessTokenScopeWriteMiscBits,
AccessTokenScopeReadNotification: accessTokenScopeReadNotificationBits,
AccessTokenScopeWriteNotification: accessTokenScopeWriteNotificationBits,
AccessTokenScopeReadOrganization: accessTokenScopeReadOrganizationBits,
AccessTokenScopeWriteOrganization: accessTokenScopeWriteOrganizationBits,
AccessTokenScopeReadPackage: accessTokenScopeReadPackageBits,
AccessTokenScopeWritePackage: accessTokenScopeWritePackageBits,
AccessTokenScopeReadIssue: accessTokenScopeReadIssueBits,
AccessTokenScopeWriteIssue: accessTokenScopeWriteIssueBits,
AccessTokenScopeReadRepository: accessTokenScopeReadRepositoryBits,
AccessTokenScopeWriteRepository: accessTokenScopeWriteRepositoryBits,
AccessTokenScopeReadUser: accessTokenScopeReadUserBits,
AccessTokenScopeWriteUser: accessTokenScopeWriteUserBits,
}
// readAccessTokenScopes maps a scope category to the read permission scope
var accessTokenScopes = map[AccessTokenScopeLevel]map[AccessTokenScopeCategory]AccessTokenScope{
Read: {
AccessTokenScopeCategoryActivityPub: AccessTokenScopeReadActivityPub,
AccessTokenScopeCategoryAdmin: AccessTokenScopeReadAdmin,
AccessTokenScopeCategoryMisc: AccessTokenScopeReadMisc,
AccessTokenScopeCategoryNotification: AccessTokenScopeReadNotification,
AccessTokenScopeCategoryOrganization: AccessTokenScopeReadOrganization,
AccessTokenScopeCategoryPackage: AccessTokenScopeReadPackage,
AccessTokenScopeCategoryIssue: AccessTokenScopeReadIssue,
AccessTokenScopeCategoryRepository: AccessTokenScopeReadRepository,
AccessTokenScopeCategoryUser: AccessTokenScopeReadUser,
},
Write: {
AccessTokenScopeCategoryActivityPub: AccessTokenScopeWriteActivityPub,
AccessTokenScopeCategoryAdmin: AccessTokenScopeWriteAdmin,
AccessTokenScopeCategoryMisc: AccessTokenScopeWriteMisc,
AccessTokenScopeCategoryNotification: AccessTokenScopeWriteNotification,
AccessTokenScopeCategoryOrganization: AccessTokenScopeWriteOrganization,
AccessTokenScopeCategoryPackage: AccessTokenScopeWritePackage,
AccessTokenScopeCategoryIssue: AccessTokenScopeWriteIssue,
AccessTokenScopeCategoryRepository: AccessTokenScopeWriteRepository,
AccessTokenScopeCategoryUser: AccessTokenScopeWriteUser,
},
}
func GetAccessTokenCategories() (res []string) {
for _, cat := range accessTokenScopes[Read] {
res = append(res, strings.TrimPrefix(string(cat), "read:"))
}
slices.Sort(res)
return res
}
// GetRequiredScopes gets the specific scopes for a given level and categories
func GetRequiredScopes(level AccessTokenScopeLevel, scopeCategories ...AccessTokenScopeCategory) []AccessTokenScope {
scopes := make([]AccessTokenScope, 0, len(scopeCategories))
for _, cat := range scopeCategories {
scopes = append(scopes, accessTokenScopes[level][cat])
}
return scopes
}
// ContainsCategory checks if a list of categories contains a specific category
func ContainsCategory(categories []AccessTokenScopeCategory, category AccessTokenScopeCategory) bool {
return slices.Contains(categories, category)
}
// GetScopeLevelFromAccessMode converts permission access mode to scope level
func GetScopeLevelFromAccessMode(mode perm.AccessMode) AccessTokenScopeLevel {
switch mode {
case perm.AccessModeNone:
return NoAccess
case perm.AccessModeRead:
return Read
case perm.AccessModeWrite:
return Write
case perm.AccessModeAdmin:
return Write
case perm.AccessModeOwner:
return Write
default:
return NoAccess
}
}
// parse the scope string into a bitmap, thus removing possible duplicates.
func (s AccessTokenScope) parse() (accessTokenScopeBitmap, error) {
var bitmap accessTokenScopeBitmap
// The following is the more performant equivalent of 'for _, v := range strings.Split(remainingScope, ",")' as this is hot code
remainingScopes := string(s)
for len(remainingScopes) > 0 {
i := strings.IndexByte(remainingScopes, ',')
var v string
if i < 0 {
v = remainingScopes
remainingScopes = ""
} else if i+1 >= len(remainingScopes) {
v = remainingScopes[:i]
remainingScopes = ""
} else {
v = remainingScopes[:i]
remainingScopes = remainingScopes[i+1:]
}
singleScope := AccessTokenScope(v)
if singleScope == "" {
continue
}
if singleScope == AccessTokenScopeAll {
bitmap |= accessTokenScopeAllBits
continue
}
bits, ok := allAccessTokenScopeBits[singleScope]
if !ok {
return 0, fmt.Errorf("invalid access token scope: %s", singleScope)
}
bitmap |= bits
}
return bitmap, nil
}
// StringSlice returns the AccessTokenScope as a []string
func (s AccessTokenScope) StringSlice() []string {
if s == "" {
return nil
}
return strings.Split(string(s), ",")
}
// Normalize returns a normalized scope string without any duplicates.
func (s AccessTokenScope) Normalize() (AccessTokenScope, error) {
bitmap, err := s.parse()
if err != nil {
return "", err
}
return bitmap.toScope(), nil
}
func (s AccessTokenScope) HasPermissionScope() bool {
return s != "" && s != AccessTokenScopePublicOnly
}
// PublicOnly checks if this token scope is limited to public resources
func (s AccessTokenScope) PublicOnly() (bool, error) {
bitmap, err := s.parse()
if err != nil {
return false, err
}
return bitmap.hasScope(AccessTokenScopePublicOnly)
}
// HasScope returns true if the string has the given scope
func (s AccessTokenScope) HasScope(scopes ...AccessTokenScope) (bool, error) {
bitmap, err := s.parse()
if err != nil {
return false, err
}
for _, s := range scopes {
if has, err := bitmap.hasScope(s); !has || err != nil {
return has, err
}
}
return true, nil
}
// HasAnyScope returns true if any of the scopes is contained in the string
func (s AccessTokenScope) HasAnyScope(scopes ...AccessTokenScope) (bool, error) {
bitmap, err := s.parse()
if err != nil {
return false, err
}
for _, s := range scopes {
if has, err := bitmap.hasScope(s); has || err != nil {
return has, err
}
}
return false, nil
}
// hasScope returns true if the string has the given scope
func (bitmap accessTokenScopeBitmap) hasScope(scope AccessTokenScope) (bool, error) {
expectedBits, ok := allAccessTokenScopeBits[scope]
if !ok {
return false, fmt.Errorf("invalid access token scope: %s", scope)
}
return bitmap&expectedBits == expectedBits, nil
}
// toScope returns a normalized scope string without any duplicates.
func (bitmap accessTokenScopeBitmap) toScope() AccessTokenScope {
var scopes []string
// iterate over all scopes, and reconstruct the bitmap
// if the reconstructed bitmap doesn't change, then the scope is already included
var reconstruct accessTokenScopeBitmap
for _, singleScope := range allAccessTokenScopes {
// no need for error checking here, since we know the scope is valid
if ok, _ := bitmap.hasScope(singleScope); ok {
current := reconstruct | allAccessTokenScopeBits[singleScope]
if current == reconstruct {
continue
}
reconstruct = current
scopes = append(scopes, string(singleScope))
}
}
scope := AccessTokenScope(strings.Join(scopes, ","))
scope = AccessTokenScope(strings.ReplaceAll(
string(scope),
"write:activitypub,write:admin,write:misc,write:notification,write:organization,write:package,write:issue,write:repository,write:user",
"all",
))
return scope
}
+91
View File
@@ -0,0 +1,91 @@
// Copyright 2022 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package auth
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
type scopeTestNormalize struct {
in AccessTokenScope
out AccessTokenScope
err error
}
func TestAccessTokenScope_Normalize(t *testing.T) {
assert.Equal(t, []string{"activitypub", "admin", "issue", "misc", "notification", "organization", "package", "repository", "user"}, GetAccessTokenCategories())
tests := []scopeTestNormalize{
{"", "", nil},
{"write:misc,write:notification,read:package,write:notification,public-only", "public-only,write:misc,write:notification,read:package", nil},
{"all", "all", nil},
{"write:activitypub,write:admin,write:misc,write:notification,write:organization,write:package,write:issue,write:repository,write:user", "all", nil},
{"write:activitypub,write:admin,write:misc,write:notification,write:organization,write:package,write:issue,write:repository,write:user,public-only", "public-only,all", nil},
}
for _, scope := range GetAccessTokenCategories() {
tests = append(tests,
scopeTestNormalize{AccessTokenScope("read:" + scope), AccessTokenScope("read:" + scope), nil},
scopeTestNormalize{AccessTokenScope("write:" + scope), AccessTokenScope("write:" + scope), nil},
scopeTestNormalize{AccessTokenScope(fmt.Sprintf("write:%[1]s,read:%[1]s", scope)), AccessTokenScope("write:" + scope), nil},
scopeTestNormalize{AccessTokenScope(fmt.Sprintf("read:%[1]s,write:%[1]s", scope)), AccessTokenScope("write:" + scope), nil},
scopeTestNormalize{AccessTokenScope(fmt.Sprintf("read:%[1]s,write:%[1]s,write:%[1]s", scope)), AccessTokenScope("write:" + scope), nil},
)
}
for _, test := range tests {
t.Run(string(test.in), func(t *testing.T) {
scope, err := test.in.Normalize()
assert.Equal(t, test.out, scope)
assert.Equal(t, test.err, err)
})
}
}
type scopeTestHasScope struct {
in AccessTokenScope
scope AccessTokenScope
out bool
err error
}
func TestAccessTokenScope_HasScope(t *testing.T) {
tests := []scopeTestHasScope{
{"read:admin", "write:package", false, nil},
{"all", "write:package", true, nil},
{"write:package", "all", false, nil},
{"public-only", "read:issue", false, nil},
}
for _, scope := range GetAccessTokenCategories() {
tests = append(tests,
scopeTestHasScope{
AccessTokenScope("read:" + scope),
AccessTokenScope("read:" + scope), true, nil,
},
scopeTestHasScope{
AccessTokenScope("write:" + scope),
AccessTokenScope("write:" + scope), true, nil,
},
scopeTestHasScope{
AccessTokenScope("write:" + scope),
AccessTokenScope("read:" + scope), true, nil,
},
scopeTestHasScope{
AccessTokenScope("read:" + scope),
AccessTokenScope("write:" + scope), false, nil,
},
)
}
for _, test := range tests {
t.Run(string(test.in), func(t *testing.T) {
hasScope, err := test.in.HasScope(test.scope)
assert.Equal(t, test.out, hasScope)
assert.Equal(t, test.err, err)
})
}
}
+132
View File
@@ -0,0 +1,132 @@
// Copyright 2016 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package auth_test
import (
"testing"
auth_model "gitea.dev/models/auth"
"gitea.dev/models/db"
"gitea.dev/models/unittest"
"github.com/stretchr/testify/assert"
)
func TestNewAccessToken(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
token := &auth_model.AccessToken{
UID: 3,
Name: "Token C",
}
assert.NoError(t, auth_model.NewAccessToken(t.Context(), token))
unittest.AssertExistsAndLoadBean(t, token)
invalidToken := &auth_model.AccessToken{
ID: token.ID, // duplicate
UID: 2,
Name: "Token F",
}
assert.Error(t, auth_model.NewAccessToken(t.Context(), invalidToken))
}
func TestAccessTokenByNameExists(t *testing.T) {
name := "Token Gitea"
assert.NoError(t, unittest.PrepareTestDatabase())
token := &auth_model.AccessToken{
UID: 3,
Name: name,
}
// Check to make sure it doesn't exists already
exist, err := auth_model.AccessTokenByNameExists(t.Context(), token)
assert.NoError(t, err)
assert.False(t, exist)
// Save it to the database
assert.NoError(t, auth_model.NewAccessToken(t.Context(), token))
unittest.AssertExistsAndLoadBean(t, token)
// This token must be found by name in the DB now
exist, err = auth_model.AccessTokenByNameExists(t.Context(), token)
assert.NoError(t, err)
assert.True(t, exist)
user4Token := &auth_model.AccessToken{
UID: 4,
Name: name,
}
// Name matches but different user ID, this shouldn't exists in the
// database
exist, err = auth_model.AccessTokenByNameExists(t.Context(), user4Token)
assert.NoError(t, err)
assert.False(t, exist)
}
func TestGetAccessTokenBySHA(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
token, err := auth_model.GetAccessTokenBySHA(t.Context(), "d2c6c1ba3890b309189a8e618c72a162e4efbf36")
assert.NoError(t, err)
assert.Equal(t, int64(1), token.UID)
assert.Equal(t, "Token A", token.Name)
assert.Equal(t, "2b3668e11cb82d3af8c6e4524fc7841297668f5008d1626f0ad3417e9fa39af84c268248b78c481daa7e5dc437784003494f", token.TokenHash)
assert.Equal(t, "e4efbf36", token.TokenLastEight)
_, err = auth_model.GetAccessTokenBySHA(t.Context(), "notahash")
assert.Error(t, err)
assert.True(t, auth_model.IsErrAccessTokenNotExist(err))
_, err = auth_model.GetAccessTokenBySHA(t.Context(), "")
assert.Error(t, err)
assert.True(t, auth_model.IsErrAccessTokenEmpty(err))
}
func TestListAccessTokens(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
tokens, err := db.Find[auth_model.AccessToken](t.Context(), auth_model.ListAccessTokensOptions{UserID: 1})
assert.NoError(t, err)
if assert.Len(t, tokens, 2) {
assert.Equal(t, int64(1), tokens[0].UID)
assert.Equal(t, int64(1), tokens[1].UID)
assert.Contains(t, []string{tokens[0].Name, tokens[1].Name}, "Token A")
assert.Contains(t, []string{tokens[0].Name, tokens[1].Name}, "Token B")
}
tokens, err = db.Find[auth_model.AccessToken](t.Context(), auth_model.ListAccessTokensOptions{UserID: 2})
assert.NoError(t, err)
if assert.Len(t, tokens, 1) {
assert.Equal(t, int64(2), tokens[0].UID)
assert.Equal(t, "Token A", tokens[0].Name)
}
tokens, err = db.Find[auth_model.AccessToken](t.Context(), auth_model.ListAccessTokensOptions{UserID: 100})
assert.NoError(t, err)
assert.Empty(t, tokens)
}
func TestUpdateAccessToken(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
token, err := auth_model.GetAccessTokenBySHA(t.Context(), "4c6f36e6cf498e2a448662f915d932c09c5a146c")
assert.NoError(t, err)
token.Name = "Token Z"
assert.NoError(t, auth_model.UpdateAccessToken(t.Context(), token))
unittest.AssertExistsAndLoadBean(t, token)
}
func TestDeleteAccessTokenByID(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
token, err := auth_model.GetAccessTokenBySHA(t.Context(), "4c6f36e6cf498e2a448662f915d932c09c5a146c")
assert.NoError(t, err)
assert.Equal(t, int64(1), token.UID)
assert.NoError(t, auth_model.DeleteAccessTokenByID(t.Context(), token.ID, 1))
unittest.AssertNotExistsBean(t, token)
err = auth_model.DeleteAccessTokenByID(t.Context(), 100, 100)
assert.Error(t, err)
assert.True(t, auth_model.IsErrAccessTokenNotExist(err))
}
+65
View File
@@ -0,0 +1,65 @@
// Copyright 2023 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package auth
import (
"context"
"gitea.dev/models/db"
"gitea.dev/modules/timeutil"
"gitea.dev/modules/util"
"xorm.io/builder"
)
var ErrAuthTokenNotExist = util.NewNotExistErrorf("auth token does not exist")
type AuthToken struct { //nolint:revive // export stutter
ID string `xorm:"pk"`
TokenHash string
UserID int64 `xorm:"INDEX"`
ExpiresUnix timeutil.TimeStamp `xorm:"INDEX"`
}
func init() {
db.RegisterModel(new(AuthToken))
}
func InsertAuthToken(ctx context.Context, t *AuthToken) error {
_, err := db.GetEngine(ctx).Insert(t)
return err
}
func GetAuthTokenByID(ctx context.Context, id string) (*AuthToken, error) {
at := &AuthToken{}
has, err := db.GetEngine(ctx).ID(id).Get(at)
if err != nil {
return nil, err
}
if !has {
return nil, ErrAuthTokenNotExist
}
return at, nil
}
func UpdateAuthTokenByID(ctx context.Context, t *AuthToken) error {
_, err := db.GetEngine(ctx).ID(t.ID).Cols("token_hash", "expires_unix").Update(t)
return err
}
func DeleteAuthTokenByID(ctx context.Context, id string) error {
_, err := db.GetEngine(ctx).ID(id).Delete(&AuthToken{})
return err
}
func DeleteAuthTokensByUserID(ctx context.Context, uid int64) error {
_, err := db.GetEngine(ctx).Where(builder.Eq{"user_id": uid}).Delete(&AuthToken{})
return err
}
func DeleteExpiredAuthTokens(ctx context.Context) error {
_, err := db.GetEngine(ctx).Where(builder.Lt{"expires_unix": timeutil.TimeStampNow()}).Delete(&AuthToken{})
return err
}
+20
View File
@@ -0,0 +1,20 @@
// Copyright 2020 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package auth_test
import (
"testing"
"gitea.dev/models/unittest"
_ "gitea.dev/models"
_ "gitea.dev/models/actions"
_ "gitea.dev/models/activities"
_ "gitea.dev/models/auth"
_ "gitea.dev/models/perm/access"
)
func TestMain(m *testing.M) {
unittest.MainTest(m)
}
+680
View File
@@ -0,0 +1,680 @@
// Copyright 2019 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package auth
import (
"context"
"crypto/subtle"
"encoding/base32"
"errors"
"fmt"
"net"
"net/url"
"slices"
"strings"
"time"
"gitea.dev/models/db"
"gitea.dev/modules/container"
"gitea.dev/modules/setting"
"gitea.dev/modules/timeutil"
"gitea.dev/modules/util"
uuid "github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
"golang.org/x/oauth2"
"xorm.io/builder"
"xorm.io/xorm"
)
// Authorization codes should expire within 10 minutes per https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2
const oauth2AuthorizationCodeValidity = 10 * time.Minute
var (
ErrOAuth2AuthorizationCodeInvalidated = errors.New("oauth2 authorization code already invalidated")
ErrOAuth2GrantStaleCounter = errors.New("oauth2 grant state changed during token refresh")
)
// OAuth2Application represents an OAuth2 client (RFC 6749)
type OAuth2Application struct {
ID int64 `xorm:"pk autoincr"`
UID int64 `xorm:"INDEX"`
Name string
ClientID string `xorm:"unique"`
ClientSecret string
// OAuth defines both Confidential and Public client types
// https://datatracker.ietf.org/doc/html/rfc6749#section-2.1
// "Authorization servers MUST record the client type in the client registration details"
// https://datatracker.ietf.org/doc/html/rfc8252#section-8.4
ConfidentialClient bool `xorm:"NOT NULL DEFAULT TRUE"`
SkipSecondaryAuthorization bool `xorm:"NOT NULL DEFAULT FALSE"`
RedirectURIs []string `xorm:"redirect_uris JSON TEXT"`
CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"`
UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
}
func init() {
db.RegisterModel(new(OAuth2Application))
db.RegisterModel(new(OAuth2AuthorizationCode))
db.RegisterModel(new(OAuth2Grant))
}
type BuiltinOAuth2Application struct {
ConfigName string
DisplayName string
RedirectURIs []string
}
func BuiltinApplications() map[string]*BuiltinOAuth2Application {
m := make(map[string]*BuiltinOAuth2Application)
m["a4792ccc-144e-407e-86c9-5e7d8d9c3269"] = &BuiltinOAuth2Application{
ConfigName: "git-credential-oauth",
DisplayName: "git-credential-oauth",
RedirectURIs: []string{"http://127.0.0.1", "https://127.0.0.1"},
}
m["e90ee53c-94e2-48ac-9358-a874fb9e0662"] = &BuiltinOAuth2Application{
ConfigName: "git-credential-manager",
DisplayName: "Git Credential Manager",
RedirectURIs: []string{"http://127.0.0.1", "https://127.0.0.1"},
}
m["d57cb8c4-630c-4168-8324-ec79935e18d4"] = &BuiltinOAuth2Application{
ConfigName: "tea",
DisplayName: "tea",
RedirectURIs: []string{"http://127.0.0.1", "https://127.0.0.1"},
}
return m
}
func Init(ctx context.Context) error {
builtinApps := BuiltinApplications()
var builtinAllClientIDs []string
for clientID := range builtinApps {
builtinAllClientIDs = append(builtinAllClientIDs, clientID)
}
var registeredApps []*OAuth2Application
if err := db.GetEngine(ctx).In("client_id", builtinAllClientIDs).Find(&registeredApps); err != nil {
return err
}
clientIDsToAdd := container.Set[string]{}
for _, configName := range setting.OAuth2.DefaultApplications {
found := false
for clientID, builtinApp := range builtinApps {
if builtinApp.ConfigName == configName {
clientIDsToAdd.Add(clientID) // add all user-configured apps to the "add" list
found = true
}
}
if !found {
return fmt.Errorf("unknown oauth2 application: %q", configName)
}
}
clientIDsToDelete := container.Set[string]{}
for _, app := range registeredApps {
if !clientIDsToAdd.Contains(app.ClientID) {
clientIDsToDelete.Add(app.ClientID) // if a registered app is not in the "add" list, it should be deleted
}
}
for _, app := range registeredApps {
clientIDsToAdd.Remove(app.ClientID) // no need to re-add existing (registered) apps, so remove them from the set
}
for _, app := range registeredApps {
if clientIDsToDelete.Contains(app.ClientID) {
if err := deleteOAuth2Application(ctx, app.ID, 0); err != nil {
return err
}
}
}
for clientID := range clientIDsToAdd {
builtinApp := builtinApps[clientID]
if err := db.Insert(ctx, &OAuth2Application{
Name: builtinApp.DisplayName,
ClientID: clientID,
RedirectURIs: builtinApp.RedirectURIs,
}); err != nil {
return err
}
}
return nil
}
// TableName sets the table name to `oauth2_application`
func (app *OAuth2Application) TableName() string {
return "oauth2_application"
}
// ContainsRedirectURI checks if redirectURI is allowed for app
func (app *OAuth2Application) ContainsRedirectURI(redirectURI string) bool {
// OAuth2 requires the redirect URI to be an exact match, no dynamic parts are allowed.
// https://stackoverflow.com/questions/55524480/should-dynamic-query-parameters-be-present-in-the-redirection-uri-for-an-oauth2
// https://www.rfc-editor.org/rfc/rfc6819#section-5.2.3.3
// https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-security-topics-12#section-3.1
redirectCandidates := []string{redirectURI}
if !app.ConfidentialClient {
loopbackRedirect, ok := normalizePublicClientRedirectURI(redirectURI)
if ok {
redirectCandidates = append(redirectCandidates, loopbackRedirect)
}
}
for _, candidate := range redirectCandidates {
normalizedCandidate := normalizeRedirectURIForComparison(candidate)
for _, registeredURI := range app.RedirectURIs {
if normalizeRedirectURIForComparison(registeredURI) == normalizedCandidate {
return true
}
}
}
return false
}
func normalizeRedirectURIForComparison(redirectURI string) string {
return strings.TrimSuffix(util.ToLowerASCII(redirectURI), "/")
}
func normalizePublicClientRedirectURI(redirectURI string) (string, bool) {
parsedURI, err := url.Parse(redirectURI)
if err != nil || parsedURI.Scheme != "http" || parsedURI.Port() == "" {
return "", false
}
if ip := net.ParseIP(parsedURI.Hostname()); ip == nil || !ip.IsLoopback() {
return "", false
}
parsedURI.Host = parsedURI.Hostname()
return parsedURI.String(), true
}
// Base32 characters, but lowercased.
const lowerBase32Chars = "abcdefghijklmnopqrstuvwxyz234567"
// base32 encoder that uses lowered characters without padding.
var base32Lower = base32.NewEncoding(lowerBase32Chars).WithPadding(base32.NoPadding)
// GenerateClientSecret will generate the client secret and returns the plaintext and saves the hash at the database
func (app *OAuth2Application) GenerateClientSecret(ctx context.Context) (string, error) {
rBytes := util.CryptoRandomBytes(32)
// Add a prefix to the base32, this is in order to make it easier
// for code scanners to grab sensitive tokens.
clientSecret := "gto_" + base32Lower.EncodeToString(rBytes)
hashedSecret, err := bcrypt.GenerateFromPassword([]byte(clientSecret), bcrypt.DefaultCost)
if err != nil {
return "", err
}
app.ClientSecret = string(hashedSecret)
if _, err := db.GetEngine(ctx).ID(app.ID).Cols("client_secret").Update(app); err != nil {
return "", err
}
return clientSecret, nil
}
// ValidateClientSecret validates the given secret by the hash saved in database
func (app *OAuth2Application) ValidateClientSecret(secret []byte) bool {
return bcrypt.CompareHashAndPassword([]byte(app.ClientSecret), secret) == nil
}
// GetGrantByUserID returns a OAuth2Grant by its user and application ID
func (app *OAuth2Application) GetGrantByUserID(ctx context.Context, userID int64) (grant *OAuth2Grant, err error) {
grant = new(OAuth2Grant)
if has, err := db.GetEngine(ctx).Where("user_id = ? AND application_id = ?", userID, app.ID).Get(grant); err != nil {
return nil, err
} else if !has {
return nil, nil //nolint:nilnil // return nil to indicate that the object does not exist
}
return grant, nil
}
// CreateGrant generates a grant for a user
func (app *OAuth2Application) CreateGrant(ctx context.Context, userID int64, scope string) (*OAuth2Grant, error) {
grant := &OAuth2Grant{
ApplicationID: app.ID,
UserID: userID,
Scope: scope,
}
err := db.Insert(ctx, grant)
if err != nil {
return nil, err
}
return grant, nil
}
// GetOAuth2ApplicationByClientID returns the oauth2 application with the given client_id. Returns an error if not found.
func GetOAuth2ApplicationByClientID(ctx context.Context, clientID string) (app *OAuth2Application, err error) {
app = new(OAuth2Application)
has, err := db.GetEngine(ctx).Where("client_id = ?", clientID).Get(app)
if !has {
return nil, ErrOAuthClientIDInvalid{ClientID: clientID}
}
return app, err
}
// GetOAuth2ApplicationByID returns the oauth2 application with the given id. Returns an error if not found.
func GetOAuth2ApplicationByID(ctx context.Context, id int64) (app *OAuth2Application, err error) {
app = new(OAuth2Application)
has, err := db.GetEngine(ctx).ID(id).Get(app)
if err != nil {
return nil, err
}
if !has {
return nil, ErrOAuthApplicationNotFound{ID: id}
}
return app, nil
}
// CreateOAuth2ApplicationOptions holds options to create an oauth2 application
type CreateOAuth2ApplicationOptions struct {
Name string
UserID int64
ConfidentialClient bool
SkipSecondaryAuthorization bool
RedirectURIs []string
}
// CreateOAuth2Application inserts a new oauth2 application
func CreateOAuth2Application(ctx context.Context, opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) {
clientID := uuid.New().String()
app := &OAuth2Application{
UID: opts.UserID,
Name: opts.Name,
ClientID: clientID,
RedirectURIs: opts.RedirectURIs,
ConfidentialClient: opts.ConfidentialClient,
SkipSecondaryAuthorization: opts.SkipSecondaryAuthorization,
}
if err := db.Insert(ctx, app); err != nil {
return nil, err
}
return app, nil
}
// UpdateOAuth2ApplicationOptions holds options to update an oauth2 application
type UpdateOAuth2ApplicationOptions struct {
ID int64
Name string
UserID int64
ConfidentialClient bool
SkipSecondaryAuthorization bool
RedirectURIs []string
}
// UpdateOAuth2Application updates an oauth2 application
func UpdateOAuth2Application(ctx context.Context, opts UpdateOAuth2ApplicationOptions) (*OAuth2Application, error) {
return db.WithTx2(ctx, func(ctx context.Context) (*OAuth2Application, error) {
app, err := GetOAuth2ApplicationByID(ctx, opts.ID)
if err != nil {
return nil, err
}
if app.UID != opts.UserID {
return nil, errors.New("UID mismatch")
}
builtinApps := BuiltinApplications()
if _, builtin := builtinApps[app.ClientID]; builtin {
return nil, fmt.Errorf("failed to edit OAuth2 application: application is locked: %s", app.ClientID)
}
app.Name = opts.Name
app.RedirectURIs = opts.RedirectURIs
app.ConfidentialClient = opts.ConfidentialClient
app.SkipSecondaryAuthorization = opts.SkipSecondaryAuthorization
if err = updateOAuth2Application(ctx, app); err != nil {
return nil, err
}
app.ClientSecret = ""
return app, nil
})
}
func updateOAuth2Application(ctx context.Context, app *OAuth2Application) error {
if _, err := db.GetEngine(ctx).ID(app.ID).UseBool("confidential_client", "skip_secondary_authorization").Update(app); err != nil {
return err
}
return nil
}
func deleteOAuth2Application(ctx context.Context, id, userid int64) error {
sess := db.GetEngine(ctx)
// the userid could be 0 if the app is instance-wide
if deleted, err := sess.Where(builder.Eq{"id": id, "uid": userid}).Delete(&OAuth2Application{}); err != nil {
return err
} else if deleted == 0 {
return ErrOAuthApplicationNotFound{ID: id}
}
codes := make([]*OAuth2AuthorizationCode, 0)
// delete correlating auth codes
if err := sess.Join("INNER", "oauth2_grant",
"oauth2_authorization_code.grant_id = oauth2_grant.id AND oauth2_grant.application_id = ?", id).Find(&codes); err != nil {
return err
}
codeIDs := make([]int64, 0, len(codes))
for _, grant := range codes {
codeIDs = append(codeIDs, grant.ID)
}
if _, err := sess.In("id", codeIDs).Delete(new(OAuth2AuthorizationCode)); err != nil {
return err
}
if _, err := sess.Where("application_id = ?", id).Delete(new(OAuth2Grant)); err != nil {
return err
}
return nil
}
// DeleteOAuth2Application deletes the application with the given id and the grants and auth codes related to it. It checks if the userid was the creator of the app.
func DeleteOAuth2Application(ctx context.Context, id, userid int64) error {
return db.WithTx(ctx, func(ctx context.Context) error {
app, err := GetOAuth2ApplicationByID(ctx, id)
if err != nil {
return err
}
builtinApps := BuiltinApplications()
if _, builtin := builtinApps[app.ClientID]; builtin {
return fmt.Errorf("failed to delete OAuth2 application: application is locked: %s", app.ClientID)
}
return deleteOAuth2Application(ctx, id, userid)
})
}
//////////////////////////////////////////////////////
// OAuth2AuthorizationCode is a code to obtain an access token in combination with the client secret once. It has a limited lifetime.
type OAuth2AuthorizationCode struct {
ID int64 `xorm:"pk autoincr"`
Grant *OAuth2Grant `xorm:"-"`
GrantID int64
Code string `xorm:"INDEX unique"`
CodeChallenge string
CodeChallengeMethod string
RedirectURI string
ValidUntil timeutil.TimeStamp `xorm:"index"`
}
// TableName sets the table name to `oauth2_authorization_code`
func (code *OAuth2AuthorizationCode) TableName() string {
return "oauth2_authorization_code"
}
// IsExpired reports whether the authorization code is expired.
func (code *OAuth2AuthorizationCode) IsExpired() bool {
if code.ValidUntil.IsZero() {
return true
}
return code.ValidUntil <= timeutil.TimeStampNow()
}
// GenerateRedirectURI generates a redirect URI for a successful authorization request. State will be used if not empty.
func (code *OAuth2AuthorizationCode) GenerateRedirectURI(state string) (*url.URL, error) {
redirect, err := url.Parse(code.RedirectURI)
if err != nil {
return nil, err
}
q := redirect.Query()
if state != "" {
q.Set("state", state)
}
q.Set("code", code.Code)
redirect.RawQuery = q.Encode()
return redirect, err
}
// Invalidate deletes the auth code from the database to invalidate this code
func (code *OAuth2AuthorizationCode) Invalidate(ctx context.Context) error {
affected, err := db.GetEngine(ctx).ID(code.ID).NoAutoCondition().Delete(code)
if err != nil {
return err
}
if affected == 0 {
return ErrOAuth2AuthorizationCodeInvalidated
}
return nil
}
func (code *OAuth2AuthorizationCode) requiresCodeVerifier() bool {
return code.CodeChallengeMethod != "" || code.CodeChallenge != ""
}
func deriveCodeChallenge(method, verifier string) (string, bool) {
switch method {
case "S256":
return oauth2.S256ChallengeFromVerifier(verifier), true
case "plain":
return verifier, true
default:
return "", false
}
}
// ValidateCodeChallenge validates the given verifier against the saved code challenge. This is part of the PKCE implementation.
func (code *OAuth2AuthorizationCode) ValidateCodeChallenge(verifier string) bool {
if !code.requiresCodeVerifier() {
return true
}
if verifier == "" || code.CodeChallengeMethod == "" {
return false
}
expectedChallenge, ok := deriveCodeChallenge(code.CodeChallengeMethod, verifier)
if !ok {
return false
}
return subtle.ConstantTimeCompare([]byte(expectedChallenge), []byte(code.CodeChallenge)) == 1
}
// GetOAuth2AuthorizationByCode returns an authorization by its code
func GetOAuth2AuthorizationByCode(ctx context.Context, code string) (auth *OAuth2AuthorizationCode, err error) {
auth = new(OAuth2AuthorizationCode)
if has, err := db.GetEngine(ctx).Where("code = ?", code).Get(auth); err != nil {
return nil, err
} else if !has {
return nil, nil //nolint:nilnil // return nil to indicate that the object does not exist
}
auth.Grant = new(OAuth2Grant)
if has, err := db.GetEngine(ctx).ID(auth.GrantID).Get(auth.Grant); err != nil {
return nil, err
} else if !has {
return nil, nil //nolint:nilnil // return nil to indicate that the object does not exist
}
return auth, nil
}
//////////////////////////////////////////////////////
// OAuth2Grant represents the permission of a user for a specific application to access resources
type OAuth2Grant struct {
ID int64 `xorm:"pk autoincr"`
UserID int64 `xorm:"INDEX unique(user_application)"`
Application *OAuth2Application `xorm:"-"`
ApplicationID int64 `xorm:"INDEX unique(user_application)"`
Counter int64 `xorm:"NOT NULL DEFAULT 1"`
Scope string `xorm:"TEXT"`
Nonce string `xorm:"TEXT"`
CreatedUnix timeutil.TimeStamp `xorm:"created"`
UpdatedUnix timeutil.TimeStamp `xorm:"updated"`
}
// TableName sets the table name to `oauth2_grant`
func (grant *OAuth2Grant) TableName() string {
return "oauth2_grant"
}
// GenerateNewAuthorizationCode generates a new authorization code for a grant and saves it to the database
func (grant *OAuth2Grant) GenerateNewAuthorizationCode(ctx context.Context, redirectURI, codeChallenge, codeChallengeMethod string) (code *OAuth2AuthorizationCode, err error) {
rBytes := util.CryptoRandomBytes(32)
// Add a prefix to the base32, this is in order to make it easier
// for code scanners to grab sensitive tokens.
codeSecret := "gta_" + base32Lower.EncodeToString(rBytes)
validUntil := time.Now().Add(oauth2AuthorizationCodeValidity)
code = &OAuth2AuthorizationCode{
Grant: grant,
GrantID: grant.ID,
RedirectURI: redirectURI,
Code: codeSecret,
CodeChallenge: codeChallenge,
CodeChallengeMethod: codeChallengeMethod,
ValidUntil: timeutil.TimeStamp(validUntil.Unix()),
}
if err := db.Insert(ctx, code); err != nil {
return nil, err
}
return code, nil
}
// IncreaseCounter increases the counter and updates the grant
func (grant *OAuth2Grant) IncreaseCounter(ctx context.Context) error {
affected, err := db.GetEngine(ctx).
Where("id = ?", grant.ID).
And("counter = ?", grant.Counter).
Incr("counter").
Update(new(OAuth2Grant))
if err != nil {
return err
}
if affected == 0 {
return ErrOAuth2GrantStaleCounter
}
grant.Counter++
return nil
}
// ScopeContains returns true if the grant scope contains the specified scope
func (grant *OAuth2Grant) ScopeContains(scope string) bool {
return slices.Contains(strings.Split(grant.Scope, " "), scope)
}
// SetNonce updates the current nonce value of a grant
func (grant *OAuth2Grant) SetNonce(ctx context.Context, nonce string) error {
grant.Nonce = nonce
_, err := db.GetEngine(ctx).ID(grant.ID).Cols("nonce").Update(grant)
if err != nil {
return err
}
return nil
}
// GetOAuth2GrantByID returns the grant with the given ID
func GetOAuth2GrantByID(ctx context.Context, id int64) (grant *OAuth2Grant, err error) {
grant = new(OAuth2Grant)
if has, err := db.GetEngine(ctx).ID(id).Get(grant); err != nil {
return nil, err
} else if !has {
return nil, nil //nolint:nilnil // return nil to indicate that the object does not exist
}
return grant, err
}
// GetOAuth2GrantsByUserID lists all grants of a certain user
func GetOAuth2GrantsByUserID(ctx context.Context, uid int64) ([]*OAuth2Grant, error) {
type joinedOAuth2Grant struct {
Grant *OAuth2Grant `xorm:"extends"`
Application *OAuth2Application `xorm:"extends"`
}
var results *xorm.Rows
var err error
if results, err = db.GetEngine(ctx).
Table("oauth2_grant").
Where("user_id = ?", uid).
Join("INNER", "oauth2_application", "application_id = oauth2_application.id").
Rows(new(joinedOAuth2Grant)); err != nil {
return nil, err
}
defer results.Close()
grants := make([]*OAuth2Grant, 0)
for results.Next() {
joinedGrant := new(joinedOAuth2Grant)
if err := results.Scan(joinedGrant); err != nil {
return nil, err
}
joinedGrant.Grant.Application = joinedGrant.Application
grants = append(grants, joinedGrant.Grant)
}
return grants, nil
}
// RevokeOAuth2Grant deletes the grant with grantID and userID
func RevokeOAuth2Grant(ctx context.Context, grantID, userID int64) error {
_, err := db.GetEngine(ctx).Where(builder.Eq{"id": grantID, "user_id": userID}).Delete(&OAuth2Grant{})
return err
}
// ErrOAuthClientIDInvalid will be thrown if client id cannot be found
type ErrOAuthClientIDInvalid struct {
ClientID string
}
// IsErrOauthClientIDInvalid checks if an error is a ErrOAuthClientIDInvalid.
func IsErrOauthClientIDInvalid(err error) bool {
_, ok := err.(ErrOAuthClientIDInvalid)
return ok
}
// Error returns the error message
func (err ErrOAuthClientIDInvalid) Error() string {
return fmt.Sprintf("Client ID invalid [Client ID: %s]", err.ClientID)
}
// Unwrap unwraps this as a ErrNotExist err
func (err ErrOAuthClientIDInvalid) Unwrap() error {
return util.ErrNotExist
}
// ErrOAuthApplicationNotFound will be thrown if id cannot be found
type ErrOAuthApplicationNotFound struct {
ID int64
}
// IsErrOAuthApplicationNotFound checks if an error is a ErrReviewNotExist.
func IsErrOAuthApplicationNotFound(err error) bool {
_, ok := err.(ErrOAuthApplicationNotFound)
return ok
}
// Error returns the error message
func (err ErrOAuthApplicationNotFound) Error() string {
return fmt.Sprintf("OAuth application not found [ID: %d]", err.ID)
}
// Unwrap unwraps this as a ErrNotExist err
func (err ErrOAuthApplicationNotFound) Unwrap() error {
return util.ErrNotExist
}
// GetActiveOAuth2SourceByAuthName returns a OAuth2 AuthSource based on the given name
func GetActiveOAuth2SourceByAuthName(ctx context.Context, name string) (*Source, error) {
authSource := new(Source)
has, err := db.GetEngine(ctx).Where("name = ? and type = ? and is_active = ?", name, OAuth2, true).Get(authSource)
if err != nil {
return nil, err
}
if !has {
return nil, util.NewNotExistErrorf("oauth2 source not found, name: %q", name)
}
return authSource, nil
}
func DeleteOAuth2RelictsByUserID(ctx context.Context, userID int64) error {
deleteCond := builder.Select("id").From("oauth2_grant").Where(builder.Eq{"oauth2_grant.user_id": userID})
if _, err := db.GetEngine(ctx).In("grant_id", deleteCond).
Delete(&OAuth2AuthorizationCode{}); err != nil {
return err
}
if err := db.DeleteBeans(ctx,
&OAuth2Application{UID: userID},
&OAuth2Grant{UserID: userID},
); err != nil {
return fmt.Errorf("DeleteBeans: %w", err)
}
return nil
}
+32
View File
@@ -0,0 +1,32 @@
// Copyright 2023 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package auth
import (
"gitea.dev/models/db"
"xorm.io/builder"
)
type FindOAuth2ApplicationsOptions struct {
db.ListOptions
// OwnerID is the user id or org id of the owner of the application
OwnerID int64
// find global applications, if true, then OwnerID will be igonred
IsGlobal bool
}
func (opts FindOAuth2ApplicationsOptions) ToConds() builder.Cond {
conds := builder.NewCond()
if opts.IsGlobal {
conds = conds.And(builder.Eq{"uid": 0})
} else if opts.OwnerID != 0 {
conds = conds.And(builder.Eq{"uid": opts.OwnerID})
}
return conds
}
func (opts FindOAuth2ApplicationsOptions) ToOrders() string {
return "id DESC"
}
+343
View File
@@ -0,0 +1,343 @@
// Copyright 2019 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package auth_test
import (
"testing"
"time"
auth_model "gitea.dev/models/auth"
"gitea.dev/models/unittest"
"gitea.dev/modules/timeutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
)
func TestOAuth2AuthorizationCode(t *testing.T) {
require.NoError(t, unittest.PrepareTestDatabase())
t.Run("GenerateSetsValidUntil", func(t *testing.T) {
grant := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Grant{ID: 1})
expectedValidUntil := timeutil.TimeStamp(time.Now().Unix() + 600)
code, err := grant.GenerateNewAuthorizationCode(t.Context(), "http://127.0.0.1/", "", "")
require.NoError(t, err)
assert.Equal(t, expectedValidUntil, code.ValidUntil)
assert.False(t, code.IsExpired())
assert.Equal(t, int64(1), code.ID)
code2, err := auth_model.GetOAuth2AuthorizationByCode(t.Context(), code.Code)
require.NoError(t, err)
assert.Equal(t, code.Code, code2.Code)
assert.NoError(t, code.Invalidate(t.Context()))
code, err = auth_model.GetOAuth2AuthorizationByCode(t.Context(), "does not exist")
require.NoError(t, err)
require.Nil(t, code)
})
t.Run("Expired", func(t *testing.T) {
defer timeutil.MockSet(time.Unix(2, 0).UTC())()
code := &auth_model.OAuth2AuthorizationCode{ValidUntil: timeutil.TimeStamp(1)}
assert.True(t, code.IsExpired())
})
t.Run("Invalidate", func(t *testing.T) {
grant := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Grant{ID: 1})
code, err := grant.GenerateNewAuthorizationCode(t.Context(), "http://127.0.0.1/", "", "")
require.NoError(t, err)
require.NotNil(t, code)
require.NoError(t, code.Invalidate(t.Context()))
unittest.AssertNotExistsBean(t, &auth_model.OAuth2AuthorizationCode{Code: code.Code})
assert.ErrorIs(t, code.Invalidate(t.Context()), auth_model.ErrOAuth2AuthorizationCodeInvalidated)
})
}
func TestOAuth2Application_GenerateClientSecret(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
app := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{ID: 1})
secret, err := app.GenerateClientSecret(t.Context())
assert.NoError(t, err)
assert.NotEmpty(t, secret)
unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{ID: 1, ClientSecret: app.ClientSecret})
}
func BenchmarkOAuth2Application_GenerateClientSecret(b *testing.B) {
assert.NoError(b, unittest.PrepareTestDatabase())
app := unittest.AssertExistsAndLoadBean(b, &auth_model.OAuth2Application{ID: 1})
for b.Loop() {
_, _ = app.GenerateClientSecret(b.Context())
}
}
func TestOAuth2Application_ContainsRedirectURI(t *testing.T) {
app := &auth_model.OAuth2Application{
RedirectURIs: []string{"a", "b", "c"},
}
assert.True(t, app.ContainsRedirectURI("a"))
assert.True(t, app.ContainsRedirectURI("b"))
assert.True(t, app.ContainsRedirectURI("c"))
assert.False(t, app.ContainsRedirectURI("d"))
}
func TestOAuth2Application_ContainsRedirectURI_WithPort(t *testing.T) {
app := &auth_model.OAuth2Application{
RedirectURIs: []string{"http://127.0.0.1/", "http://::1/", "http://192.168.0.1/", "http://intranet/", "https://127.0.0.1/"},
ConfidentialClient: false,
}
// http loopback uris should ignore port
// https://datatracker.ietf.org/doc/html/rfc8252#section-7.3
assert.True(t, app.ContainsRedirectURI("http://127.0.0.1:3456/"))
assert.True(t, app.ContainsRedirectURI("http://127.0.0.1/"))
assert.True(t, app.ContainsRedirectURI("http://[::1]:3456/"))
// not http
assert.False(t, app.ContainsRedirectURI("https://127.0.0.1:3456/"))
// not loopback
assert.False(t, app.ContainsRedirectURI("http://192.168.0.1:9954/"))
assert.False(t, app.ContainsRedirectURI("http://intranet:3456/"))
// unparseable
assert.False(t, app.ContainsRedirectURI(":"))
}
func TestOAuth2Application_ContainsRedirect_Slash(t *testing.T) {
app := &auth_model.OAuth2Application{RedirectURIs: []string{"http://127.0.0.1"}}
assert.True(t, app.ContainsRedirectURI("http://127.0.0.1"))
assert.True(t, app.ContainsRedirectURI("http://127.0.0.1/"))
assert.False(t, app.ContainsRedirectURI("http://127.0.0.1/other"))
app = &auth_model.OAuth2Application{RedirectURIs: []string{"http://127.0.0.1/"}}
assert.True(t, app.ContainsRedirectURI("http://127.0.0.1"))
assert.True(t, app.ContainsRedirectURI("http://127.0.0.1/"))
assert.False(t, app.ContainsRedirectURI("http://127.0.0.1/other"))
}
func TestOAuth2Application_ContainsRedirectURI_ASCIIOnlyNormalization(t *testing.T) {
testCases := []struct {
name string
registered []string
redirectURI string
allowed bool
}{
{
name: "exact-match",
registered: []string{"https://signin.example.test/callback"},
redirectURI: "https://signin.example.test/callback",
allowed: true,
},
{
name: "ascii-case-insensitive",
registered: []string{"https://signin.example.test/callback"},
redirectURI: "https://signIN.example.test/callback",
allowed: true,
},
{
name: "non-ascii-not-folded",
registered: []string{"https://signin.example.test/callback"},
redirectURI: "https://signİn.example.test/callback",
allowed: false,
},
{
name: "loopback-strips-port",
registered: []string{"http://127.0.0.1/callback"},
redirectURI: "http://127.0.0.1:12345/callback",
allowed: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
app := &auth_model.OAuth2Application{RedirectURIs: tc.registered}
assert.Equal(t, tc.allowed, app.ContainsRedirectURI(tc.redirectURI))
})
}
}
func TestOAuth2Application_ValidateClientSecret(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
app := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{ID: 1})
secret, err := app.GenerateClientSecret(t.Context())
assert.NoError(t, err)
assert.True(t, app.ValidateClientSecret([]byte(secret)))
assert.False(t, app.ValidateClientSecret([]byte("fewijfowejgfiowjeoifew")))
}
func TestGetOAuth2ApplicationByClientID(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
app, err := auth_model.GetOAuth2ApplicationByClientID(t.Context(), "da7da3ba-9a13-4167-856f-3899de0b0138")
assert.NoError(t, err)
assert.Equal(t, "da7da3ba-9a13-4167-856f-3899de0b0138", app.ClientID)
app, err = auth_model.GetOAuth2ApplicationByClientID(t.Context(), "invalid client id")
assert.Error(t, err)
assert.Nil(t, app)
}
func TestCreateOAuth2Application(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
app, err := auth_model.CreateOAuth2Application(t.Context(), auth_model.CreateOAuth2ApplicationOptions{Name: "newapp", UserID: 1})
assert.NoError(t, err)
assert.Equal(t, "newapp", app.Name)
assert.Len(t, app.ClientID, 36)
unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{Name: "newapp"})
}
func TestOAuth2Application_TableName(t *testing.T) {
assert.Equal(t, "oauth2_application", new(auth_model.OAuth2Application).TableName())
}
func TestOAuth2Application_GetGrantByUserID(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
app := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{ID: 1})
grant, err := app.GetGrantByUserID(t.Context(), 1)
assert.NoError(t, err)
assert.Equal(t, int64(1), grant.UserID)
grant, err = app.GetGrantByUserID(t.Context(), 34923458)
assert.NoError(t, err)
assert.Nil(t, grant)
}
func TestOAuth2Application_CreateGrant(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
app := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{ID: 1})
grant, err := app.CreateGrant(t.Context(), 2, "")
assert.NoError(t, err)
assert.NotNil(t, grant)
assert.Equal(t, int64(2), grant.UserID)
assert.Equal(t, int64(1), grant.ApplicationID)
assert.Empty(t, grant.Scope)
}
//////////////////// Grant
func TestGetOAuth2GrantByID(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
grant, err := auth_model.GetOAuth2GrantByID(t.Context(), 1)
assert.NoError(t, err)
assert.Equal(t, int64(1), grant.ID)
grant, err = auth_model.GetOAuth2GrantByID(t.Context(), 34923458)
assert.NoError(t, err)
assert.Nil(t, grant)
}
func TestOAuth2Grant_IncreaseCounter(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
grant := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Grant{ID: 1, Counter: 1})
assert.NoError(t, grant.IncreaseCounter(t.Context()))
assert.Equal(t, int64(2), grant.Counter)
unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Grant{ID: 1, Counter: 2})
}
func TestOAuth2Grant_IncreaseCounterRejectsStaleCounter(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
grant := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Grant{ID: 1, Counter: 1})
stale := *grant
assert.NoError(t, grant.IncreaseCounter(t.Context()))
err := stale.IncreaseCounter(t.Context())
assert.ErrorIs(t, err, auth_model.ErrOAuth2GrantStaleCounter)
}
func TestOAuth2Grant_ScopeContains(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
grant := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Grant{ID: 1, Scope: "openid profile"})
assert.True(t, grant.ScopeContains("openid"))
assert.True(t, grant.ScopeContains("profile"))
assert.False(t, grant.ScopeContains("profil"))
assert.False(t, grant.ScopeContains("profile2"))
}
func TestOAuth2Grant_GenerateNewAuthorizationCode(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
grant := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Grant{ID: 1})
code, err := grant.GenerateNewAuthorizationCode(t.Context(), "https://example2.com/callback", "CjvyTLSdR47G5zYenDA-eDWW4lRrO8yvjcWwbD_deOg", "S256")
assert.NoError(t, err)
assert.NotNil(t, code)
assert.Greater(t, len(code.Code), 32) // secret length > 32
}
func TestOAuth2Grant_TableName(t *testing.T) {
assert.Equal(t, "oauth2_grant", new(auth_model.OAuth2Grant).TableName())
}
func TestGetOAuth2GrantsByUserID(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
result, err := auth_model.GetOAuth2GrantsByUserID(t.Context(), 1)
assert.NoError(t, err)
assert.Len(t, result, 1)
assert.Equal(t, int64(1), result[0].ID)
assert.Equal(t, result[0].ApplicationID, result[0].Application.ID)
result, err = auth_model.GetOAuth2GrantsByUserID(t.Context(), 34134)
assert.NoError(t, err)
assert.Empty(t, result)
}
func TestRevokeOAuth2Grant(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
assert.NoError(t, auth_model.RevokeOAuth2Grant(t.Context(), 1, 1))
unittest.AssertNotExistsBean(t, &auth_model.OAuth2Grant{ID: 1, UserID: 1})
}
//////////////////// Authorization Code
func TestOAuth2AuthorizationCode_ValidateCodeChallenge(t *testing.T) {
s256Verifier := "s256-verifier"
s256Challenge := oauth2.S256ChallengeFromVerifier(s256Verifier)
missingVerifierChallenge := oauth2.S256ChallengeFromVerifier("verifier-not-supplied")
testCases := []struct {
name string
method string
challenge string
verifier string
valid bool
}{
{"plain-success", "plain", "plain-secret", "plain-secret", true},
{"plain-failure", "plain", "plain-secret", "ierwgjoergjio", false},
{"s256-success", "S256", s256Challenge, s256Verifier, true},
{"s256-failure", "S256", s256Challenge, "wiogjerogorewngoenrgoiuenorg", false},
{"unsupported-method", "monkey", "foiwgjioriogeiogjerger", "foiwgjioriogeiogjerger", false},
{"no-pkce-configured", "", "", "", true},
{"s256-missing-verifier", "S256", missingVerifierChallenge, "", false},
{"plain-missing-verifier", "plain", "plain-secret", "", false},
{"missing-method-with-challenge", "", "foierjiogerogerg", "", false},
{"missing-method-rejects-even-matching-verifier", "", "foierjiogerogerg", "foierjiogerogerg", false},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
code := &auth_model.OAuth2AuthorizationCode{
CodeChallengeMethod: tc.method,
CodeChallenge: tc.challenge,
}
assert.Equal(t, tc.valid, code.ValidateCodeChallenge(tc.verifier))
})
}
}
func TestOAuth2AuthorizationCode_GenerateRedirectURI(t *testing.T) {
code := &auth_model.OAuth2AuthorizationCode{
RedirectURI: "https://example.com/callback",
Code: "thecode",
}
redirect, err := code.GenerateRedirectURI("thestate")
assert.NoError(t, err)
assert.Equal(t, "https://example.com/callback?code=thecode&state=thestate", redirect.String())
redirect, err = code.GenerateRedirectURI("")
assert.NoError(t, err)
assert.Equal(t, "https://example.com/callback?code=thecode", redirect.String())
}
func TestOAuth2AuthorizationCode_TableName(t *testing.T) {
assert.Equal(t, "oauth2_authorization_code", new(auth_model.OAuth2AuthorizationCode).TableName())
}
+112
View File
@@ -0,0 +1,112 @@
// Copyright 2020 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package auth
import (
"context"
"fmt"
"gitea.dev/models/db"
"gitea.dev/modules/timeutil"
"xorm.io/builder"
)
// Session represents a session compatible for go-chi session
type Session struct {
Key string `xorm:"pk CHAR(16)"` // has to be Key to match with go-chi/session
Data []byte `xorm:"BLOB"` // on MySQL this has a maximum size of 64Kb - this may need to be increased
Expiry timeutil.TimeStamp // has to be Expiry to match with go-chi/session
}
func init() {
db.RegisterModel(new(Session))
}
// UpdateSession updates the session with provided id
func UpdateSession(ctx context.Context, key string, data []byte) error {
_, err := db.GetEngine(ctx).ID(key).Update(&Session{
Data: data,
Expiry: timeutil.TimeStampNow(),
})
return err
}
// ReadSession reads the data for the provided session
func ReadSession(ctx context.Context, key string) (*Session, error) {
return db.WithTx2(ctx, func(ctx context.Context) (*Session, error) {
session, exist, err := db.Get[Session](ctx, builder.Eq{"`key`": key})
if err != nil {
return nil, err
} else if !exist {
session = &Session{
Key: key,
Expiry: timeutil.TimeStampNow(),
}
if err := db.Insert(ctx, session); err != nil {
return nil, err
}
}
return session, nil
})
}
// ExistSession checks if a session exists
func ExistSession(ctx context.Context, key string) (bool, error) {
return db.Exist[Session](ctx, builder.Eq{"`key`": key})
}
// DestroySession destroys a session
func DestroySession(ctx context.Context, key string) error {
_, err := db.GetEngine(ctx).Delete(&Session{
Key: key,
})
return err
}
// RegenerateSession regenerates a session from the old id
func RegenerateSession(ctx context.Context, oldKey, newKey string) (*Session, error) {
return db.WithTx2(ctx, func(ctx context.Context) (*Session, error) {
if has, err := db.Exist[Session](ctx, builder.Eq{"`key`": newKey}); err != nil {
return nil, err
} else if has {
return nil, fmt.Errorf("session Key: %s already exists", newKey)
}
if has, err := db.Exist[Session](ctx, builder.Eq{"`key`": oldKey}); err != nil {
return nil, err
} else if !has {
if err := db.Insert(ctx, &Session{
Key: oldKey,
Expiry: timeutil.TimeStampNow(),
}); err != nil {
return nil, err
}
}
if _, err := db.Exec(ctx, "UPDATE `session` SET `key` = ? WHERE `key`=?", newKey, oldKey); err != nil {
return nil, err
}
s, _, err := db.Get[Session](ctx, builder.Eq{"`key`": newKey})
if err != nil {
// is not exist, it should be impossible
return nil, err
}
return s, nil
})
}
// CountSessions returns the number of sessions
func CountSessions(ctx context.Context) (int64, error) {
return db.GetEngine(ctx).Count(&Session{})
}
// CleanupSessions cleans up expired sessions
func CleanupSessions(ctx context.Context, maxLifetime int64) error {
_, err := db.GetEngine(ctx).Where("expiry <= ?", timeutil.TimeStampNow().Add(-maxLifetime)).Delete(&Session{})
return err
}
+402
View File
@@ -0,0 +1,402 @@
// Copyright 2014 The Gogs Authors. All rights reserved.
// Copyright 2019 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package auth
import (
"context"
"fmt"
"reflect"
"gitea.dev/models/db"
"gitea.dev/modules/log"
"gitea.dev/modules/optional"
"gitea.dev/modules/setting"
"gitea.dev/modules/timeutil"
"gitea.dev/modules/util"
"xorm.io/builder"
"xorm.io/xorm"
"xorm.io/xorm/convert"
)
// Type represents an login type.
type Type int
// Note: new type must append to the end of list to maintain compatibility.
const (
NoType Type = iota
Plain // 1
LDAP // 2
SMTP // 3
PAM // 4
DLDAP // 5
OAuth2 // 6
SSPI // 7
)
// String returns the string name of the LoginType
func (typ Type) String() string {
return Names[typ]
}
// Int returns the int value of the LoginType
func (typ Type) Int() int {
return int(typ)
}
// Names contains the name of LoginType values.
var Names = map[Type]string{
LDAP: "LDAP (via BindDN)",
DLDAP: "LDAP (simple auth)", // Via direct bind
SMTP: "SMTP",
PAM: "PAM",
OAuth2: "OAuth2",
SSPI: "SPNEGO with SSPI",
}
// Config represents login config as far as the db is concerned
type Config interface {
convert.Conversion
SetAuthSource(*Source)
}
type ConfigBase struct {
AuthSource *Source
}
func (p *ConfigBase) SetAuthSource(s *Source) {
p.AuthSource = s
}
// SkipVerifiable configurations provide a IsSkipVerify to check if SkipVerify is set
type SkipVerifiable interface {
IsSkipVerify() bool
}
// HasTLSer configurations provide a HasTLS to check if TLS can be enabled
type HasTLSer interface {
HasTLS() bool
}
// UseTLSer configurations provide a HasTLS to check if TLS is enabled
type UseTLSer interface {
UseTLS() bool
}
// SSHKeyProvider configurations provide ProvidesSSHKeys to check if they provide SSHKeys
type SSHKeyProvider interface {
ProvidesSSHKeys() bool
}
// RegisterableSource configurations provide RegisterSource which needs to be run on creation
type RegisterableSource interface {
RegisterSource() error
UnregisterSource() error
}
var registeredConfigs = map[Type]func() Config{}
// RegisterTypeConfig register a config for a provided type
func RegisterTypeConfig(typ Type, exemplar Config) {
if reflect.TypeOf(exemplar).Kind() == reflect.Pointer {
// Pointer:
registeredConfigs[typ] = func() Config {
return reflect.New(reflect.ValueOf(exemplar).Elem().Type()).Interface().(Config)
}
return
}
// Not a Pointer
registeredConfigs[typ] = func() Config {
return reflect.New(reflect.TypeOf(exemplar)).Elem().Interface().(Config)
}
}
// Source represents an external way for authorizing users.
type Source struct {
ID int64 `xorm:"pk autoincr"`
Type Type
Name string `xorm:"UNIQUE"` // it can be the OIDC's provider name, see services/auth/source/oauth2/source_register.go: RegisterSource
IsActive bool `xorm:"INDEX NOT NULL DEFAULT false"`
IsSyncEnabled bool `xorm:"INDEX NOT NULL DEFAULT false"`
TwoFactorPolicy string `xorm:"two_factor_policy NOT NULL DEFAULT ''"`
Cfg Config `xorm:"TEXT"`
CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"`
UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
}
// TableName xorm will read the table name from this method
func (Source) TableName() string {
return "login_source"
}
func init() {
db.RegisterModel(new(Source))
}
// BeforeSet is invoked from XORM before setting the value of a field of this object.
func (source *Source) BeforeSet(colName string, val xorm.Cell) {
if colName == "type" {
typ, _, err := db.CellToInt(val, NoType)
if err != nil {
setting.PanicInDevOrTesting("Unable to convert login source (id=%d) type: %v", source.ID, err)
}
constructor, ok := registeredConfigs[typ]
if !ok {
return
}
source.Cfg = constructor()
source.Cfg.SetAuthSource(source)
}
}
// TypeName return name of this login source type.
func (source *Source) TypeName() string {
return Names[source.Type]
}
// IsLDAP returns true of this source is of the LDAP type.
func (source *Source) IsLDAP() bool {
return source.Type == LDAP
}
// IsDLDAP returns true of this source is of the DLDAP type.
func (source *Source) IsDLDAP() bool {
return source.Type == DLDAP
}
// IsSMTP returns true of this source is of the SMTP type.
func (source *Source) IsSMTP() bool {
return source.Type == SMTP
}
// IsPAM returns true of this source is of the PAM type.
func (source *Source) IsPAM() bool {
return source.Type == PAM
}
// IsOAuth2 returns true of this source is of the OAuth2 type.
func (source *Source) IsOAuth2() bool {
return source.Type == OAuth2
}
// IsSSPI returns true of this source is of the SSPI type.
func (source *Source) IsSSPI() bool {
return source.Type == SSPI
}
// HasTLS returns true of this source supports TLS.
func (source *Source) HasTLS() bool {
hasTLSer, ok := source.Cfg.(HasTLSer)
return ok && hasTLSer.HasTLS()
}
// UseTLS returns true of this source is configured to use TLS.
func (source *Source) UseTLS() bool {
useTLSer, ok := source.Cfg.(UseTLSer)
return ok && useTLSer.UseTLS()
}
// SkipVerify returns true if this source is configured to skip SSL
// verification.
func (source *Source) SkipVerify() bool {
skipVerifiable, ok := source.Cfg.(SkipVerifiable)
return ok && skipVerifiable.IsSkipVerify()
}
func (source *Source) TwoFactorShouldSkip() bool {
return source.TwoFactorPolicy == "skip"
}
// CreateSource inserts a AuthSource in the DB if not already
// existing with the given name.
func CreateSource(ctx context.Context, source *Source) error {
has, err := db.GetEngine(ctx).Where("name=?", source.Name).Exist(new(Source))
if err != nil {
return err
} else if has {
return ErrSourceAlreadyExist{source.Name}
}
// Synchronization is only available with LDAP for now
if !source.IsLDAP() && !source.IsOAuth2() {
source.IsSyncEnabled = false
}
_, err = db.GetEngine(ctx).Insert(source)
if err != nil {
return err
}
if !source.IsActive {
return nil
}
source.Cfg.SetAuthSource(source)
registerableSource, ok := source.Cfg.(RegisterableSource)
if !ok {
return nil
}
err = registerableSource.RegisterSource()
if err != nil {
// remove the AuthSource in case of errors while registering configuration
if _, err := db.GetEngine(ctx).ID(source.ID).Delete(new(Source)); err != nil {
log.Error("CreateSource: Error while wrapOpenIDConnectInitializeError: %v", err)
}
}
return err
}
type FindSourcesOptions struct {
db.ListOptions
IsActive optional.Option[bool]
LoginType Type
}
func (opts FindSourcesOptions) ToConds() builder.Cond {
conds := builder.NewCond()
if opts.IsActive.Has() {
conds = conds.And(builder.Eq{"is_active": opts.IsActive.Value()})
}
if opts.LoginType != NoType {
conds = conds.And(builder.Eq{"`type`": opts.LoginType})
}
return conds
}
// IsSSPIEnabled returns true if there is at least one activated login
// source of type LoginSSPI
func IsSSPIEnabled(ctx context.Context) bool {
exist, err := db.Exist[Source](ctx, FindSourcesOptions{
IsActive: optional.Some(true),
LoginType: SSPI,
}.ToConds())
if err != nil {
log.Error("IsSSPIEnabled: failed to query active SSPI sources: %v", err)
return false
}
return exist
}
// GetSourceByID returns login source by given ID.
func GetSourceByID(ctx context.Context, id int64) (*Source, error) {
source := new(Source)
if id == 0 {
source.Cfg = registeredConfigs[NoType]()
// Set this source to active
// FIXME: allow disabling of db based password authentication in future
source.IsActive = true
return source, nil
}
has, err := db.GetEngine(ctx).ID(id).Get(source)
if err != nil {
return nil, err
} else if !has {
return nil, ErrSourceNotExist{id}
}
return source, nil
}
// UpdateSource updates a Source record in DB.
func UpdateSource(ctx context.Context, source *Source) error {
var originalSource *Source
if source.IsOAuth2() {
// keep track of the original values so we can restore in case of errors while registering OAuth2 providers
var err error
if originalSource, err = GetSourceByID(ctx, source.ID); err != nil {
return err
}
}
has, err := db.GetEngine(ctx).Where("name=? AND id!=?", source.Name, source.ID).Exist(new(Source))
if err != nil {
return err
} else if has {
return ErrSourceAlreadyExist{source.Name}
}
_, err = db.GetEngine(ctx).ID(source.ID).AllCols().Update(source)
if err != nil {
return err
}
if !source.IsActive {
return nil
}
source.Cfg.SetAuthSource(source)
registerableSource, ok := source.Cfg.(RegisterableSource)
if !ok {
return nil
}
err = registerableSource.RegisterSource()
if err != nil {
// restore original values since we cannot update the provider itself
if _, err := db.GetEngine(ctx).ID(source.ID).AllCols().Update(originalSource); err != nil {
log.Error("UpdateSource: Error while wrapOpenIDConnectInitializeError: %v", err)
}
}
return err
}
// ErrSourceNotExist represents a "SourceNotExist" kind of error.
type ErrSourceNotExist struct {
ID int64
}
// IsErrSourceNotExist checks if an error is a ErrSourceNotExist.
func IsErrSourceNotExist(err error) bool {
_, ok := err.(ErrSourceNotExist)
return ok
}
func (err ErrSourceNotExist) Error() string {
return fmt.Sprintf("login source does not exist [id: %d]", err.ID)
}
// Unwrap unwraps this as a ErrNotExist err
func (err ErrSourceNotExist) Unwrap() error {
return util.ErrNotExist
}
// ErrSourceAlreadyExist represents a "SourceAlreadyExist" kind of error.
type ErrSourceAlreadyExist struct {
Name string
}
// IsErrSourceAlreadyExist checks if an error is a ErrSourceAlreadyExist.
func IsErrSourceAlreadyExist(err error) bool {
_, ok := err.(ErrSourceAlreadyExist)
return ok
}
func (err ErrSourceAlreadyExist) Error() string {
return fmt.Sprintf("login source already exists [name: %s]", err.Name)
}
// Unwrap unwraps this as a ErrExist err
func (err ErrSourceAlreadyExist) Unwrap() error {
return util.ErrAlreadyExist
}
// ErrSourceInUse represents a "SourceInUse" kind of error.
type ErrSourceInUse struct {
ID int64
}
// IsErrSourceInUse checks if an error is a ErrSourceInUse.
func IsErrSourceInUse(err error) bool {
_, ok := err.(ErrSourceInUse)
return ok
}
func (err ErrSourceInUse) Error() string {
return fmt.Sprintf("login source is still used by some users [id: %d]", err.ID)
}
+55
View File
@@ -0,0 +1,55 @@
// Copyright 2019 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package auth_test
import (
"strings"
"testing"
auth_model "gitea.dev/models/auth"
"gitea.dev/models/unittest"
"gitea.dev/modules/json"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"xorm.io/xorm/schemas"
)
type TestSource struct {
auth_model.ConfigBase `json:"-"`
TestField string
}
// FromDB fills up a LDAPConfig from serialized format.
func (source *TestSource) FromDB(bs []byte) error {
return json.Unmarshal(bs, &source)
}
// ToDB exports a LDAPConfig to a serialized format.
func (source *TestSource) ToDB() ([]byte, error) {
return json.Marshal(source)
}
func TestDumpAuthSource(t *testing.T) {
require.NoError(t, unittest.PrepareTestDatabase())
authSourceSchema, err := unittest.GetXORMEngine().TableInfo(new(auth_model.Source))
require.NoError(t, err)
auth_model.RegisterTypeConfig(auth_model.OAuth2, new(TestSource))
source := &auth_model.Source{
Type: auth_model.OAuth2,
Name: "TestSource",
Cfg: &TestSource{TestField: "TestValue"},
}
require.NoError(t, auth_model.CreateSource(t.Context(), source))
// intentionally test the "dump" to make sure the dumped JSON is correct: https://github.com/go-gitea/gitea/pull/16847
sb := &strings.Builder{}
require.NoError(t, unittest.GetXORMEngine().DumpTables([]*schemas.Table{authSourceSchema}, sb))
// the dumped SQL is something like:
// INSERT INTO `login_source` (`id`, `type`, `name`, `is_active`, `is_sync_enabled`, `two_factor_policy`, `cfg`, `created_unix`, `updated_unix`) VALUES (1,6,'TestSource',0,0,'','{"TestField":"TestValue"}',1774179784,1774179784);
assert.Contains(t, sb.String(), `'{"TestField":"TestValue"}'`)
}
+173
View File
@@ -0,0 +1,173 @@
// Copyright 2017 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package auth
import (
"context"
"crypto/md5"
"crypto/sha256"
"crypto/subtle"
"encoding/base32"
"encoding/base64"
"encoding/hex"
"fmt"
"gitea.dev/models/db"
"gitea.dev/modules/secret"
"gitea.dev/modules/setting"
"gitea.dev/modules/timeutil"
"gitea.dev/modules/util"
"github.com/pquerna/otp/totp"
"golang.org/x/crypto/pbkdf2"
)
//
// Two-factor authentication
//
// ErrTwoFactorNotEnrolled indicates that a user is not enrolled in two-factor authentication.
type ErrTwoFactorNotEnrolled struct {
UID int64
}
// IsErrTwoFactorNotEnrolled checks if an error is a ErrTwoFactorNotEnrolled.
func IsErrTwoFactorNotEnrolled(err error) bool {
_, ok := err.(ErrTwoFactorNotEnrolled)
return ok
}
func (err ErrTwoFactorNotEnrolled) Error() string {
return fmt.Sprintf("user not enrolled in 2FA [uid: %d]", err.UID)
}
// Unwrap unwraps this as a ErrNotExist err
func (err ErrTwoFactorNotEnrolled) Unwrap() error {
return util.ErrNotExist
}
// TwoFactor represents a two-factor authentication token.
type TwoFactor struct {
ID int64 `xorm:"pk autoincr"`
UID int64 `xorm:"UNIQUE"`
Secret string
ScratchSalt string
ScratchHash string
LastUsedPasscode string `xorm:"VARCHAR(10)"`
CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"`
UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
}
func init() {
db.RegisterModel(new(TwoFactor))
}
// GenerateScratchToken recreates the scratch token the user is using.
func (t *TwoFactor) GenerateScratchToken() (string, error) {
tokenBytes := util.CryptoRandomBytes(6)
// these chars are specially chosen, avoid ambiguous chars like `0`, `O`, `1`, `I`.
const base32Chars = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789"
token := base32.NewEncoding(base32Chars).WithPadding(base32.NoPadding).EncodeToString(tokenBytes)
t.ScratchSalt = util.CryptoRandomString(10)
t.ScratchHash = HashToken(token, t.ScratchSalt)
return token, nil
}
// HashToken return the hashable salt
func HashToken(token, salt string) string {
tempHash := pbkdf2.Key([]byte(token), []byte(salt), 10000, 50, sha256.New)
return hex.EncodeToString(tempHash)
}
// VerifyScratchToken verifies if the specified scratch token is valid.
func (t *TwoFactor) VerifyScratchToken(token string) bool {
if len(token) == 0 {
return false
}
tempHash := HashToken(token, t.ScratchSalt)
return subtle.ConstantTimeCompare([]byte(t.ScratchHash), []byte(tempHash)) == 1
}
func (t *TwoFactor) getEncryptionKey() []byte {
k := md5.Sum([]byte(setting.SecretKey))
return k[:]
}
// SetSecret sets the 2FA secret.
func (t *TwoFactor) SetSecret(secretString string) error {
secretBytes, err := secret.AesEncrypt(t.getEncryptionKey(), []byte(secretString))
if err != nil {
return err
}
t.Secret = base64.StdEncoding.EncodeToString(secretBytes)
return nil
}
// ValidateTOTP validates the provided passcode.
func (t *TwoFactor) ValidateTOTP(passcode string) (bool, error) {
decodedStoredSecret, err := base64.StdEncoding.DecodeString(t.Secret)
if err != nil {
return false, fmt.Errorf("ValidateTOTP invalid base64: %w", err)
}
secretBytes, err := secret.AesDecrypt(t.getEncryptionKey(), decodedStoredSecret)
if err != nil {
return false, fmt.Errorf("ValidateTOTP unable to decrypt (maybe SECRET_KEY is wrong): %w", err)
}
secretStr := string(secretBytes)
return totp.Validate(passcode, secretStr), nil
}
// NewTwoFactor creates a new two-factor authentication token.
func NewTwoFactor(ctx context.Context, t *TwoFactor) error {
_, err := db.GetEngine(ctx).Insert(t)
return err
}
// UpdateTwoFactor updates a two-factor authentication token.
func UpdateTwoFactor(ctx context.Context, t *TwoFactor) error {
_, err := db.GetEngine(ctx).ID(t.ID).AllCols().Update(t)
return err
}
// GetTwoFactorByUID returns the two-factor authentication token associated with
// the user, if any.
func GetTwoFactorByUID(ctx context.Context, uid int64) (*TwoFactor, error) {
twofa := &TwoFactor{}
has, err := db.GetEngine(ctx).Where("uid=?", uid).Get(twofa)
if err != nil {
return nil, err
} else if !has {
return nil, ErrTwoFactorNotEnrolled{uid}
}
return twofa, nil
}
// HasTwoFactorByUID returns the two-factor authentication token associated with
// the user, if any.
func HasTwoFactorByUID(ctx context.Context, uid int64) (bool, error) {
return db.GetEngine(ctx).Where("uid=?", uid).Exist(&TwoFactor{})
}
// DeleteTwoFactorByID deletes two-factor authentication token by given ID.
func DeleteTwoFactorByID(ctx context.Context, id, userID int64) error {
cnt, err := db.GetEngine(ctx).ID(id).Delete(&TwoFactor{
UID: userID,
})
if err != nil {
return err
} else if cnt != 1 {
return ErrTwoFactorNotEnrolled{userID}
}
return nil
}
func HasTwoFactorOrWebAuthn(ctx context.Context, id int64) (bool, error) {
has, err := HasTwoFactorByUID(ctx, id)
if err != nil {
return false, err
} else if has {
return true, nil
}
return HasWebAuthnRegistrationsByUID(ctx, id)
}
+202
View File
@@ -0,0 +1,202 @@
// Copyright 2020 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package auth
import (
"context"
"fmt"
"strings"
"gitea.dev/models/db"
"gitea.dev/modules/timeutil"
"gitea.dev/modules/util"
"github.com/go-webauthn/webauthn/protocol"
"github.com/go-webauthn/webauthn/webauthn"
)
// ErrWebAuthnCredentialNotExist represents a "ErrWebAuthnCRedentialNotExist" kind of error.
type ErrWebAuthnCredentialNotExist struct {
ID int64
CredentialID []byte
}
func (err ErrWebAuthnCredentialNotExist) Error() string {
if len(err.CredentialID) == 0 {
return fmt.Sprintf("WebAuthn credential does not exist [id: %d]", err.ID)
}
return fmt.Sprintf("WebAuthn credential does not exist [credential_id: %x]", err.CredentialID)
}
// Unwrap unwraps this as a ErrNotExist err
func (err ErrWebAuthnCredentialNotExist) Unwrap() error {
return util.ErrNotExist
}
// IsErrWebAuthnCredentialNotExist checks if an error is a ErrWebAuthnCredentialNotExist.
func IsErrWebAuthnCredentialNotExist(err error) bool {
_, ok := err.(ErrWebAuthnCredentialNotExist)
return ok
}
// WebAuthnCredential represents the WebAuthn credential data for a public-key
// credential conformant to WebAuthn Level 1
type WebAuthnCredential struct {
ID int64 `xorm:"pk autoincr"`
Name string
LowerName string `xorm:"unique(s)"`
UserID int64 `xorm:"INDEX unique(s)"`
CredentialID []byte `xorm:"INDEX VARBINARY(1024)"`
PublicKey []byte
AttestationType string
AAGUID []byte
SignCount uint32 `xorm:"BIGINT"`
CloneWarning bool
CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"`
UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
}
func init() {
db.RegisterModel(new(WebAuthnCredential))
}
// TableName returns a better table name for WebAuthnCredential
func (cred WebAuthnCredential) TableName() string {
return "webauthn_credential"
}
// UpdateSignCount will update the database value of SignCount
func (cred *WebAuthnCredential) UpdateSignCount(ctx context.Context) error {
_, err := db.GetEngine(ctx).ID(cred.ID).Cols("sign_count").Update(cred)
return err
}
// BeforeInsert will be invoked by XORM before updating a record
func (cred *WebAuthnCredential) BeforeInsert() {
cred.LowerName = strings.ToLower(cred.Name)
}
// BeforeUpdate will be invoked by XORM before updating a record
func (cred *WebAuthnCredential) BeforeUpdate() {
cred.LowerName = strings.ToLower(cred.Name)
}
// AfterLoad is invoked from XORM after setting the values of all fields of this object.
func (cred *WebAuthnCredential) AfterLoad() {
cred.LowerName = strings.ToLower(cred.Name)
}
// WebAuthnCredentialList is a list of *WebAuthnCredential
type WebAuthnCredentialList []*WebAuthnCredential
// newCredentialFlagsFromAuthenticatorFlags is copied from https://github.com/go-webauthn/webauthn/pull/337
// to convert protocol.AuthenticatorFlags to webauthn.CredentialFlags
func newCredentialFlagsFromAuthenticatorFlags(flags protocol.AuthenticatorFlags) webauthn.CredentialFlags {
return webauthn.CredentialFlags{
UserPresent: flags.HasUserPresent(),
UserVerified: flags.HasUserVerified(),
BackupEligible: flags.HasBackupEligible(),
BackupState: flags.HasBackupState(),
}
}
// ToCredentials will convert all WebAuthnCredentials to webauthn.Credentials
func (list WebAuthnCredentialList) ToCredentials(defaultAuthFlags ...protocol.AuthenticatorFlags) []webauthn.Credential {
// TODO: at the moment, Gitea doesn't store or check the flags
// so we need to use the default flags from the authenticator to make the login validation pass
// In the future, we should:
// 1. store the flags when registering the credential
// 2. provide the stored flags when converting the credentials (for login)
// 3. for old users, still use this fallback to the default flags
defAuthFlags := util.OptionalArg(defaultAuthFlags)
creds := make([]webauthn.Credential, 0, len(list))
for _, cred := range list {
creds = append(creds, webauthn.Credential{
ID: cred.CredentialID,
PublicKey: cred.PublicKey,
AttestationType: cred.AttestationType,
Flags: newCredentialFlagsFromAuthenticatorFlags(defAuthFlags),
Authenticator: webauthn.Authenticator{
AAGUID: cred.AAGUID,
SignCount: cred.SignCount,
CloneWarning: cred.CloneWarning,
},
})
}
return creds
}
// GetWebAuthnCredentialsByUID returns all WebAuthn credentials of the given user
func GetWebAuthnCredentialsByUID(ctx context.Context, uid int64) (WebAuthnCredentialList, error) {
creds := make(WebAuthnCredentialList, 0)
return creds, db.GetEngine(ctx).Where("user_id = ?", uid).Find(&creds)
}
// ExistsWebAuthnCredentialsForUID returns if the given user has credentials
func ExistsWebAuthnCredentialsForUID(ctx context.Context, uid int64) (bool, error) {
return db.GetEngine(ctx).Where("user_id = ?", uid).Exist(&WebAuthnCredential{})
}
// GetWebAuthnCredentialByName returns WebAuthn credential by id
func GetWebAuthnCredentialByName(ctx context.Context, uid int64, name string) (*WebAuthnCredential, error) {
cred := new(WebAuthnCredential)
if found, err := db.GetEngine(ctx).Where("user_id = ? AND lower_name = ?", uid, strings.ToLower(name)).Get(cred); err != nil {
return nil, err
} else if !found {
return nil, ErrWebAuthnCredentialNotExist{}
}
return cred, nil
}
// GetWebAuthnCredentialByID returns WebAuthn credential by id
func GetWebAuthnCredentialByID(ctx context.Context, id int64) (*WebAuthnCredential, error) {
cred := new(WebAuthnCredential)
if found, err := db.GetEngine(ctx).ID(id).Get(cred); err != nil {
return nil, err
} else if !found {
return nil, ErrWebAuthnCredentialNotExist{ID: id}
}
return cred, nil
}
// HasWebAuthnRegistrationsByUID returns whether a given user has WebAuthn registrations
func HasWebAuthnRegistrationsByUID(ctx context.Context, uid int64) (bool, error) {
return db.GetEngine(ctx).Where("user_id = ?", uid).Exist(&WebAuthnCredential{})
}
// GetWebAuthnCredentialByCredID returns WebAuthn credential by credential ID
func GetWebAuthnCredentialByCredID(ctx context.Context, userID int64, credID []byte) (*WebAuthnCredential, error) {
cred := new(WebAuthnCredential)
if found, err := db.GetEngine(ctx).Where("user_id = ? AND credential_id = ?", userID, credID).Get(cred); err != nil {
return nil, err
} else if !found {
return nil, ErrWebAuthnCredentialNotExist{CredentialID: credID}
}
return cred, nil
}
// CreateCredential will create a new WebAuthnCredential from the given Credential
func CreateCredential(ctx context.Context, userID int64, name string, cred *webauthn.Credential) (*WebAuthnCredential, error) {
c := &WebAuthnCredential{
UserID: userID,
Name: name,
CredentialID: cred.ID,
PublicKey: cred.PublicKey,
AttestationType: cred.AttestationType,
AAGUID: cred.Authenticator.AAGUID,
SignCount: cred.Authenticator.SignCount,
CloneWarning: false,
}
if err := db.Insert(ctx, c); err != nil {
return nil, err
}
return c, nil
}
// DeleteCredential will delete WebAuthnCredential
func DeleteCredential(ctx context.Context, id, userID int64) (bool, error) {
had, err := db.GetEngine(ctx).ID(id).Where("user_id = ?", userID).Delete(&WebAuthnCredential{})
return had > 0, err
}
+66
View File
@@ -0,0 +1,66 @@
// Copyright 2020 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package auth_test
import (
"testing"
auth_model "gitea.dev/models/auth"
"gitea.dev/models/unittest"
"github.com/go-webauthn/webauthn/webauthn"
"github.com/stretchr/testify/assert"
)
func TestGetWebAuthnCredentialByID(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
res, err := auth_model.GetWebAuthnCredentialByID(t.Context(), 1)
assert.NoError(t, err)
assert.Equal(t, "WebAuthn credential", res.Name)
_, err = auth_model.GetWebAuthnCredentialByID(t.Context(), 342432)
assert.Error(t, err)
assert.True(t, auth_model.IsErrWebAuthnCredentialNotExist(err))
}
func TestGetWebAuthnCredentialsByUID(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
res, err := auth_model.GetWebAuthnCredentialsByUID(t.Context(), 32)
assert.NoError(t, err)
assert.Len(t, res, 1)
assert.Equal(t, "WebAuthn credential", res[0].Name)
}
func TestWebAuthnCredential_TableName(t *testing.T) {
assert.Equal(t, "webauthn_credential", auth_model.WebAuthnCredential{}.TableName())
}
func TestWebAuthnCredential_UpdateSignCount(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
cred := unittest.AssertExistsAndLoadBean(t, &auth_model.WebAuthnCredential{ID: 1})
cred.SignCount = 1
assert.NoError(t, cred.UpdateSignCount(t.Context()))
unittest.AssertExistsAndLoadBean(t, &auth_model.WebAuthnCredential{ID: 1, SignCount: 1})
}
func TestWebAuthnCredential_UpdateLargeCounter(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
cred := unittest.AssertExistsAndLoadBean(t, &auth_model.WebAuthnCredential{ID: 1})
cred.SignCount = 0xffffffff
assert.NoError(t, cred.UpdateSignCount(t.Context()))
unittest.AssertExistsAndLoadBean(t, &auth_model.WebAuthnCredential{ID: 1, SignCount: 0xffffffff})
}
func TestCreateCredential(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
res, err := auth_model.CreateCredential(t.Context(), 1, "WebAuthn Created Credential", &webauthn.Credential{ID: []byte("Test")})
assert.NoError(t, err)
assert.Equal(t, "WebAuthn Created Credential", res.Name)
assert.Equal(t, []byte("Test"), res.CredentialID)
unittest.AssertExistsAndLoadBean(t, &auth_model.WebAuthnCredential{Name: "WebAuthn Created Credential", UserID: 1})
}