初始提交: Gitea 项目代码
This commit is contained in:
@@ -0,0 +1,230 @@
|
||||
// Copyright 2014 The Gogs Authors. All rights reserved.
|
||||
// Copyright 2019 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"gitea.dev/models/db"
|
||||
"gitea.dev/modules/setting"
|
||||
"gitea.dev/modules/timeutil"
|
||||
"gitea.dev/modules/util"
|
||||
|
||||
lru "github.com/hashicorp/golang-lru/v2"
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
// ErrAccessTokenNotExist represents a "AccessTokenNotExist" kind of error.
|
||||
type ErrAccessTokenNotExist struct {
|
||||
Token string
|
||||
}
|
||||
|
||||
// IsErrAccessTokenNotExist checks if an error is a ErrAccessTokenNotExist.
|
||||
func IsErrAccessTokenNotExist(err error) bool {
|
||||
_, ok := err.(ErrAccessTokenNotExist)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrAccessTokenNotExist) Error() string {
|
||||
return fmt.Sprintf("access token does not exist [sha: %s]", err.Token)
|
||||
}
|
||||
|
||||
func (err ErrAccessTokenNotExist) Unwrap() error {
|
||||
return util.ErrNotExist
|
||||
}
|
||||
|
||||
// ErrAccessTokenEmpty represents a "AccessTokenEmpty" kind of error.
|
||||
type ErrAccessTokenEmpty struct{}
|
||||
|
||||
// IsErrAccessTokenEmpty checks if an error is a ErrAccessTokenEmpty.
|
||||
func IsErrAccessTokenEmpty(err error) bool {
|
||||
_, ok := err.(ErrAccessTokenEmpty)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrAccessTokenEmpty) Error() string {
|
||||
return "access token is empty"
|
||||
}
|
||||
|
||||
func (err ErrAccessTokenEmpty) Unwrap() error {
|
||||
return util.ErrInvalidArgument
|
||||
}
|
||||
|
||||
var successfulAccessTokenCache *lru.Cache[string, any]
|
||||
|
||||
// AccessToken represents a personal access token.
|
||||
type AccessToken struct {
|
||||
ID int64 `xorm:"pk autoincr"`
|
||||
UID int64 `xorm:"INDEX"`
|
||||
Name string
|
||||
Token string `xorm:"-"`
|
||||
TokenHash string `xorm:"UNIQUE"` // sha256 of token
|
||||
TokenSalt string
|
||||
TokenLastEight string `xorm:"INDEX token_last_eight"`
|
||||
Scope AccessTokenScope
|
||||
|
||||
CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"`
|
||||
UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
|
||||
HasRecentActivity bool `xorm:"-"`
|
||||
HasUsed bool `xorm:"-"`
|
||||
}
|
||||
|
||||
// AfterLoad is invoked from XORM after setting the values of all fields of this object.
|
||||
func (t *AccessToken) AfterLoad() {
|
||||
t.HasUsed = t.UpdatedUnix > t.CreatedUnix
|
||||
t.HasRecentActivity = t.UpdatedUnix.AddDuration(7*24*time.Hour) > timeutil.TimeStampNow()
|
||||
}
|
||||
|
||||
func init() {
|
||||
db.RegisterModel(new(AccessToken), func() error {
|
||||
if setting.SuccessfulTokensCacheSize > 0 {
|
||||
var err error
|
||||
successfulAccessTokenCache, err = lru.New[string, any](setting.SuccessfulTokensCacheSize)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to allocate AccessToken cache: %w", err)
|
||||
}
|
||||
} else {
|
||||
successfulAccessTokenCache = nil
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// NewAccessToken creates new access token.
|
||||
func NewAccessToken(ctx context.Context, t *AccessToken) error {
|
||||
salt := util.CryptoRandomString(10)
|
||||
token := util.CryptoRandomBytes(20)
|
||||
t.TokenSalt = salt
|
||||
t.Token = hex.EncodeToString(token)
|
||||
t.TokenHash = HashToken(t.Token, t.TokenSalt)
|
||||
t.TokenLastEight = t.Token[len(t.Token)-8:]
|
||||
_, err := db.GetEngine(ctx).Insert(t)
|
||||
return err
|
||||
}
|
||||
|
||||
// DisplayPublicOnly whether to display this as a public-only token.
|
||||
func (t *AccessToken) DisplayPublicOnly() bool {
|
||||
publicOnly, err := t.Scope.PublicOnly()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return publicOnly
|
||||
}
|
||||
|
||||
func getAccessTokenIDFromCache(token string) int64 {
|
||||
if successfulAccessTokenCache == nil {
|
||||
return 0
|
||||
}
|
||||
tInterface, ok := successfulAccessTokenCache.Get(token)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
t, ok := tInterface.(int64)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
// GetAccessTokenBySHA returns access token by given token value
|
||||
func GetAccessTokenBySHA(ctx context.Context, token string) (*AccessToken, error) {
|
||||
if token == "" {
|
||||
return nil, ErrAccessTokenEmpty{}
|
||||
}
|
||||
// A token is defined as being SHA1 sum these are 40 hexadecimal bytes long
|
||||
if len(token) != 40 {
|
||||
return nil, ErrAccessTokenNotExist{token}
|
||||
}
|
||||
for _, x := range []byte(token) {
|
||||
if x < '0' || (x > '9' && x < 'a') || x > 'f' {
|
||||
return nil, ErrAccessTokenNotExist{token}
|
||||
}
|
||||
}
|
||||
|
||||
lastEight := token[len(token)-8:]
|
||||
|
||||
if id := getAccessTokenIDFromCache(token); id > 0 {
|
||||
accessToken := &AccessToken{
|
||||
TokenLastEight: lastEight,
|
||||
}
|
||||
// Re-get the token from the db in case it has been deleted in the intervening period
|
||||
has, err := db.GetEngine(ctx).ID(id).Get(accessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if has {
|
||||
return accessToken, nil
|
||||
}
|
||||
successfulAccessTokenCache.Remove(token)
|
||||
}
|
||||
|
||||
var tokens []AccessToken
|
||||
err := db.GetEngine(ctx).Table(&AccessToken{}).Where("token_last_eight = ?", lastEight).Find(&tokens)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if len(tokens) == 0 {
|
||||
return nil, ErrAccessTokenNotExist{token}
|
||||
}
|
||||
|
||||
for _, t := range tokens {
|
||||
tempHash := HashToken(token, t.TokenSalt)
|
||||
if subtle.ConstantTimeCompare([]byte(t.TokenHash), []byte(tempHash)) == 1 {
|
||||
if successfulAccessTokenCache != nil {
|
||||
successfulAccessTokenCache.Add(token, t.ID)
|
||||
}
|
||||
return &t, nil
|
||||
}
|
||||
}
|
||||
return nil, ErrAccessTokenNotExist{token}
|
||||
}
|
||||
|
||||
// AccessTokenByNameExists checks if a token name has been used already by a user.
|
||||
func AccessTokenByNameExists(ctx context.Context, token *AccessToken) (bool, error) {
|
||||
return db.GetEngine(ctx).Table("access_token").Where("name = ?", token.Name).And("uid = ?", token.UID).Exist()
|
||||
}
|
||||
|
||||
// ListAccessTokensOptions contain filter options
|
||||
type ListAccessTokensOptions struct {
|
||||
db.ListOptions
|
||||
Name string
|
||||
UserID int64
|
||||
}
|
||||
|
||||
func (opts ListAccessTokensOptions) ToConds() builder.Cond {
|
||||
cond := builder.NewCond()
|
||||
// user id is required, otherwise it will return all result which maybe a possible bug
|
||||
cond = cond.And(builder.Eq{"uid": opts.UserID})
|
||||
if len(opts.Name) > 0 {
|
||||
cond = cond.And(builder.Eq{"name": opts.Name})
|
||||
}
|
||||
return cond
|
||||
}
|
||||
|
||||
func (opts ListAccessTokensOptions) ToOrders() string {
|
||||
return "created_unix DESC"
|
||||
}
|
||||
|
||||
// UpdateAccessToken updates information of access token.
|
||||
func UpdateAccessToken(ctx context.Context, t *AccessToken) error {
|
||||
_, err := db.GetEngine(ctx).ID(t.ID).AllCols().Update(t)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteAccessTokenByID deletes access token by given ID.
|
||||
func DeleteAccessTokenByID(ctx context.Context, id, userID int64) error {
|
||||
cnt, err := db.GetEngine(ctx).ID(id).Delete(&AccessToken{
|
||||
UID: userID,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
} else if cnt != 1 {
|
||||
return ErrAccessTokenNotExist{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,377 @@
|
||||
// Copyright 2022 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"gitea.dev/models/perm"
|
||||
)
|
||||
|
||||
// AccessTokenScopeCategory represents the scope category for an access token
|
||||
type AccessTokenScopeCategory int
|
||||
|
||||
const (
|
||||
AccessTokenScopeCategoryActivityPub AccessTokenScopeCategory = iota
|
||||
AccessTokenScopeCategoryAdmin
|
||||
AccessTokenScopeCategoryMisc // WARN: this is now just a placeholder, don't remove it which will change the following values
|
||||
AccessTokenScopeCategoryNotification
|
||||
AccessTokenScopeCategoryOrganization
|
||||
AccessTokenScopeCategoryPackage
|
||||
AccessTokenScopeCategoryIssue
|
||||
AccessTokenScopeCategoryRepository
|
||||
AccessTokenScopeCategoryUser
|
||||
)
|
||||
|
||||
// AllAccessTokenScopeCategories contains all access token scope categories
|
||||
var AllAccessTokenScopeCategories = []AccessTokenScopeCategory{
|
||||
AccessTokenScopeCategoryActivityPub,
|
||||
AccessTokenScopeCategoryAdmin,
|
||||
AccessTokenScopeCategoryMisc,
|
||||
AccessTokenScopeCategoryNotification,
|
||||
AccessTokenScopeCategoryOrganization,
|
||||
AccessTokenScopeCategoryPackage,
|
||||
AccessTokenScopeCategoryIssue,
|
||||
AccessTokenScopeCategoryRepository,
|
||||
AccessTokenScopeCategoryUser,
|
||||
}
|
||||
|
||||
// AccessTokenScopeLevel represents the access levels without a given scope category
|
||||
type AccessTokenScopeLevel int
|
||||
|
||||
const (
|
||||
NoAccess AccessTokenScopeLevel = iota
|
||||
Read
|
||||
Write
|
||||
)
|
||||
|
||||
// AccessTokenScope represents the scope for an access token.
|
||||
type AccessTokenScope string
|
||||
|
||||
// for all categories, write implies read
|
||||
const (
|
||||
AccessTokenScopeAll AccessTokenScope = "all"
|
||||
AccessTokenScopePublicOnly AccessTokenScope = "public-only" // limited to public orgs/repos
|
||||
|
||||
AccessTokenScopeReadActivityPub AccessTokenScope = "read:activitypub"
|
||||
AccessTokenScopeWriteActivityPub AccessTokenScope = "write:activitypub"
|
||||
|
||||
AccessTokenScopeReadAdmin AccessTokenScope = "read:admin"
|
||||
AccessTokenScopeWriteAdmin AccessTokenScope = "write:admin"
|
||||
|
||||
AccessTokenScopeReadMisc AccessTokenScope = "read:misc"
|
||||
AccessTokenScopeWriteMisc AccessTokenScope = "write:misc"
|
||||
|
||||
AccessTokenScopeReadNotification AccessTokenScope = "read:notification"
|
||||
AccessTokenScopeWriteNotification AccessTokenScope = "write:notification"
|
||||
|
||||
AccessTokenScopeReadOrganization AccessTokenScope = "read:organization"
|
||||
AccessTokenScopeWriteOrganization AccessTokenScope = "write:organization"
|
||||
|
||||
AccessTokenScopeReadPackage AccessTokenScope = "read:package"
|
||||
AccessTokenScopeWritePackage AccessTokenScope = "write:package"
|
||||
|
||||
AccessTokenScopeReadIssue AccessTokenScope = "read:issue"
|
||||
AccessTokenScopeWriteIssue AccessTokenScope = "write:issue"
|
||||
|
||||
AccessTokenScopeReadRepository AccessTokenScope = "read:repository"
|
||||
AccessTokenScopeWriteRepository AccessTokenScope = "write:repository"
|
||||
|
||||
AccessTokenScopeReadUser AccessTokenScope = "read:user"
|
||||
AccessTokenScopeWriteUser AccessTokenScope = "write:user"
|
||||
)
|
||||
|
||||
// accessTokenScopeBitmap represents a bitmap of access token scopes.
|
||||
type accessTokenScopeBitmap uint64
|
||||
|
||||
// Bitmap of each scope, including the child scopes.
|
||||
const (
|
||||
// AccessTokenScopeAllBits is the bitmap of all access token scopes
|
||||
accessTokenScopeAllBits accessTokenScopeBitmap = accessTokenScopeWriteActivityPubBits |
|
||||
accessTokenScopeWriteAdminBits | accessTokenScopeWriteMiscBits | accessTokenScopeWriteNotificationBits |
|
||||
accessTokenScopeWriteOrganizationBits | accessTokenScopeWritePackageBits | accessTokenScopeWriteIssueBits |
|
||||
accessTokenScopeWriteRepositoryBits | accessTokenScopeWriteUserBits
|
||||
|
||||
accessTokenScopePublicOnlyBits accessTokenScopeBitmap = 1 << iota
|
||||
|
||||
accessTokenScopeReadActivityPubBits accessTokenScopeBitmap = 1 << iota
|
||||
accessTokenScopeWriteActivityPubBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadActivityPubBits
|
||||
|
||||
accessTokenScopeReadAdminBits accessTokenScopeBitmap = 1 << iota
|
||||
accessTokenScopeWriteAdminBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadAdminBits
|
||||
|
||||
accessTokenScopeReadMiscBits accessTokenScopeBitmap = 1 << iota
|
||||
accessTokenScopeWriteMiscBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadMiscBits
|
||||
|
||||
accessTokenScopeReadNotificationBits accessTokenScopeBitmap = 1 << iota
|
||||
accessTokenScopeWriteNotificationBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadNotificationBits
|
||||
|
||||
accessTokenScopeReadOrganizationBits accessTokenScopeBitmap = 1 << iota
|
||||
accessTokenScopeWriteOrganizationBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadOrganizationBits
|
||||
|
||||
accessTokenScopeReadPackageBits accessTokenScopeBitmap = 1 << iota
|
||||
accessTokenScopeWritePackageBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadPackageBits
|
||||
|
||||
accessTokenScopeReadIssueBits accessTokenScopeBitmap = 1 << iota
|
||||
accessTokenScopeWriteIssueBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadIssueBits
|
||||
|
||||
accessTokenScopeReadRepositoryBits accessTokenScopeBitmap = 1 << iota
|
||||
accessTokenScopeWriteRepositoryBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadRepositoryBits
|
||||
|
||||
accessTokenScopeReadUserBits accessTokenScopeBitmap = 1 << iota
|
||||
accessTokenScopeWriteUserBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadUserBits
|
||||
|
||||
// The current implementation only supports up to 64 token scopes.
|
||||
// If we need to support > 64 scopes,
|
||||
// refactoring the whole implementation in this file (and only this file) is needed.
|
||||
)
|
||||
|
||||
// allAccessTokenScopes contains all access token scopes.
|
||||
// The order is important: parent scope must precede child scopes.
|
||||
var allAccessTokenScopes = []AccessTokenScope{
|
||||
AccessTokenScopePublicOnly,
|
||||
AccessTokenScopeWriteActivityPub, AccessTokenScopeReadActivityPub,
|
||||
AccessTokenScopeWriteAdmin, AccessTokenScopeReadAdmin,
|
||||
AccessTokenScopeWriteMisc, AccessTokenScopeReadMisc,
|
||||
AccessTokenScopeWriteNotification, AccessTokenScopeReadNotification,
|
||||
AccessTokenScopeWriteOrganization, AccessTokenScopeReadOrganization,
|
||||
AccessTokenScopeWritePackage, AccessTokenScopeReadPackage,
|
||||
AccessTokenScopeWriteIssue, AccessTokenScopeReadIssue,
|
||||
AccessTokenScopeWriteRepository, AccessTokenScopeReadRepository,
|
||||
AccessTokenScopeWriteUser, AccessTokenScopeReadUser,
|
||||
}
|
||||
|
||||
// allAccessTokenScopeBits contains all access token scopes.
|
||||
var allAccessTokenScopeBits = map[AccessTokenScope]accessTokenScopeBitmap{
|
||||
AccessTokenScopeAll: accessTokenScopeAllBits,
|
||||
AccessTokenScopePublicOnly: accessTokenScopePublicOnlyBits,
|
||||
AccessTokenScopeReadActivityPub: accessTokenScopeReadActivityPubBits,
|
||||
AccessTokenScopeWriteActivityPub: accessTokenScopeWriteActivityPubBits,
|
||||
AccessTokenScopeReadAdmin: accessTokenScopeReadAdminBits,
|
||||
AccessTokenScopeWriteAdmin: accessTokenScopeWriteAdminBits,
|
||||
AccessTokenScopeReadMisc: accessTokenScopeReadMiscBits,
|
||||
AccessTokenScopeWriteMisc: accessTokenScopeWriteMiscBits,
|
||||
AccessTokenScopeReadNotification: accessTokenScopeReadNotificationBits,
|
||||
AccessTokenScopeWriteNotification: accessTokenScopeWriteNotificationBits,
|
||||
AccessTokenScopeReadOrganization: accessTokenScopeReadOrganizationBits,
|
||||
AccessTokenScopeWriteOrganization: accessTokenScopeWriteOrganizationBits,
|
||||
AccessTokenScopeReadPackage: accessTokenScopeReadPackageBits,
|
||||
AccessTokenScopeWritePackage: accessTokenScopeWritePackageBits,
|
||||
AccessTokenScopeReadIssue: accessTokenScopeReadIssueBits,
|
||||
AccessTokenScopeWriteIssue: accessTokenScopeWriteIssueBits,
|
||||
AccessTokenScopeReadRepository: accessTokenScopeReadRepositoryBits,
|
||||
AccessTokenScopeWriteRepository: accessTokenScopeWriteRepositoryBits,
|
||||
AccessTokenScopeReadUser: accessTokenScopeReadUserBits,
|
||||
AccessTokenScopeWriteUser: accessTokenScopeWriteUserBits,
|
||||
}
|
||||
|
||||
// readAccessTokenScopes maps a scope category to the read permission scope
|
||||
var accessTokenScopes = map[AccessTokenScopeLevel]map[AccessTokenScopeCategory]AccessTokenScope{
|
||||
Read: {
|
||||
AccessTokenScopeCategoryActivityPub: AccessTokenScopeReadActivityPub,
|
||||
AccessTokenScopeCategoryAdmin: AccessTokenScopeReadAdmin,
|
||||
AccessTokenScopeCategoryMisc: AccessTokenScopeReadMisc,
|
||||
AccessTokenScopeCategoryNotification: AccessTokenScopeReadNotification,
|
||||
AccessTokenScopeCategoryOrganization: AccessTokenScopeReadOrganization,
|
||||
AccessTokenScopeCategoryPackage: AccessTokenScopeReadPackage,
|
||||
AccessTokenScopeCategoryIssue: AccessTokenScopeReadIssue,
|
||||
AccessTokenScopeCategoryRepository: AccessTokenScopeReadRepository,
|
||||
AccessTokenScopeCategoryUser: AccessTokenScopeReadUser,
|
||||
},
|
||||
Write: {
|
||||
AccessTokenScopeCategoryActivityPub: AccessTokenScopeWriteActivityPub,
|
||||
AccessTokenScopeCategoryAdmin: AccessTokenScopeWriteAdmin,
|
||||
AccessTokenScopeCategoryMisc: AccessTokenScopeWriteMisc,
|
||||
AccessTokenScopeCategoryNotification: AccessTokenScopeWriteNotification,
|
||||
AccessTokenScopeCategoryOrganization: AccessTokenScopeWriteOrganization,
|
||||
AccessTokenScopeCategoryPackage: AccessTokenScopeWritePackage,
|
||||
AccessTokenScopeCategoryIssue: AccessTokenScopeWriteIssue,
|
||||
AccessTokenScopeCategoryRepository: AccessTokenScopeWriteRepository,
|
||||
AccessTokenScopeCategoryUser: AccessTokenScopeWriteUser,
|
||||
},
|
||||
}
|
||||
|
||||
func GetAccessTokenCategories() (res []string) {
|
||||
for _, cat := range accessTokenScopes[Read] {
|
||||
res = append(res, strings.TrimPrefix(string(cat), "read:"))
|
||||
}
|
||||
slices.Sort(res)
|
||||
return res
|
||||
}
|
||||
|
||||
// GetRequiredScopes gets the specific scopes for a given level and categories
|
||||
func GetRequiredScopes(level AccessTokenScopeLevel, scopeCategories ...AccessTokenScopeCategory) []AccessTokenScope {
|
||||
scopes := make([]AccessTokenScope, 0, len(scopeCategories))
|
||||
for _, cat := range scopeCategories {
|
||||
scopes = append(scopes, accessTokenScopes[level][cat])
|
||||
}
|
||||
return scopes
|
||||
}
|
||||
|
||||
// ContainsCategory checks if a list of categories contains a specific category
|
||||
func ContainsCategory(categories []AccessTokenScopeCategory, category AccessTokenScopeCategory) bool {
|
||||
return slices.Contains(categories, category)
|
||||
}
|
||||
|
||||
// GetScopeLevelFromAccessMode converts permission access mode to scope level
|
||||
func GetScopeLevelFromAccessMode(mode perm.AccessMode) AccessTokenScopeLevel {
|
||||
switch mode {
|
||||
case perm.AccessModeNone:
|
||||
return NoAccess
|
||||
case perm.AccessModeRead:
|
||||
return Read
|
||||
case perm.AccessModeWrite:
|
||||
return Write
|
||||
case perm.AccessModeAdmin:
|
||||
return Write
|
||||
case perm.AccessModeOwner:
|
||||
return Write
|
||||
default:
|
||||
return NoAccess
|
||||
}
|
||||
}
|
||||
|
||||
// parse the scope string into a bitmap, thus removing possible duplicates.
|
||||
func (s AccessTokenScope) parse() (accessTokenScopeBitmap, error) {
|
||||
var bitmap accessTokenScopeBitmap
|
||||
|
||||
// The following is the more performant equivalent of 'for _, v := range strings.Split(remainingScope, ",")' as this is hot code
|
||||
remainingScopes := string(s)
|
||||
for len(remainingScopes) > 0 {
|
||||
i := strings.IndexByte(remainingScopes, ',')
|
||||
var v string
|
||||
if i < 0 {
|
||||
v = remainingScopes
|
||||
remainingScopes = ""
|
||||
} else if i+1 >= len(remainingScopes) {
|
||||
v = remainingScopes[:i]
|
||||
remainingScopes = ""
|
||||
} else {
|
||||
v = remainingScopes[:i]
|
||||
remainingScopes = remainingScopes[i+1:]
|
||||
}
|
||||
singleScope := AccessTokenScope(v)
|
||||
if singleScope == "" {
|
||||
continue
|
||||
}
|
||||
if singleScope == AccessTokenScopeAll {
|
||||
bitmap |= accessTokenScopeAllBits
|
||||
continue
|
||||
}
|
||||
|
||||
bits, ok := allAccessTokenScopeBits[singleScope]
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("invalid access token scope: %s", singleScope)
|
||||
}
|
||||
bitmap |= bits
|
||||
}
|
||||
|
||||
return bitmap, nil
|
||||
}
|
||||
|
||||
// StringSlice returns the AccessTokenScope as a []string
|
||||
func (s AccessTokenScope) StringSlice() []string {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
return strings.Split(string(s), ",")
|
||||
}
|
||||
|
||||
// Normalize returns a normalized scope string without any duplicates.
|
||||
func (s AccessTokenScope) Normalize() (AccessTokenScope, error) {
|
||||
bitmap, err := s.parse()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return bitmap.toScope(), nil
|
||||
}
|
||||
|
||||
func (s AccessTokenScope) HasPermissionScope() bool {
|
||||
return s != "" && s != AccessTokenScopePublicOnly
|
||||
}
|
||||
|
||||
// PublicOnly checks if this token scope is limited to public resources
|
||||
func (s AccessTokenScope) PublicOnly() (bool, error) {
|
||||
bitmap, err := s.parse()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return bitmap.hasScope(AccessTokenScopePublicOnly)
|
||||
}
|
||||
|
||||
// HasScope returns true if the string has the given scope
|
||||
func (s AccessTokenScope) HasScope(scopes ...AccessTokenScope) (bool, error) {
|
||||
bitmap, err := s.parse()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, s := range scopes {
|
||||
if has, err := bitmap.hasScope(s); !has || err != nil {
|
||||
return has, err
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// HasAnyScope returns true if any of the scopes is contained in the string
|
||||
func (s AccessTokenScope) HasAnyScope(scopes ...AccessTokenScope) (bool, error) {
|
||||
bitmap, err := s.parse()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, s := range scopes {
|
||||
if has, err := bitmap.hasScope(s); has || err != nil {
|
||||
return has, err
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// hasScope returns true if the string has the given scope
|
||||
func (bitmap accessTokenScopeBitmap) hasScope(scope AccessTokenScope) (bool, error) {
|
||||
expectedBits, ok := allAccessTokenScopeBits[scope]
|
||||
if !ok {
|
||||
return false, fmt.Errorf("invalid access token scope: %s", scope)
|
||||
}
|
||||
|
||||
return bitmap&expectedBits == expectedBits, nil
|
||||
}
|
||||
|
||||
// toScope returns a normalized scope string without any duplicates.
|
||||
func (bitmap accessTokenScopeBitmap) toScope() AccessTokenScope {
|
||||
var scopes []string
|
||||
|
||||
// iterate over all scopes, and reconstruct the bitmap
|
||||
// if the reconstructed bitmap doesn't change, then the scope is already included
|
||||
var reconstruct accessTokenScopeBitmap
|
||||
|
||||
for _, singleScope := range allAccessTokenScopes {
|
||||
// no need for error checking here, since we know the scope is valid
|
||||
if ok, _ := bitmap.hasScope(singleScope); ok {
|
||||
current := reconstruct | allAccessTokenScopeBits[singleScope]
|
||||
if current == reconstruct {
|
||||
continue
|
||||
}
|
||||
|
||||
reconstruct = current
|
||||
scopes = append(scopes, string(singleScope))
|
||||
}
|
||||
}
|
||||
|
||||
scope := AccessTokenScope(strings.Join(scopes, ","))
|
||||
scope = AccessTokenScope(strings.ReplaceAll(
|
||||
string(scope),
|
||||
"write:activitypub,write:admin,write:misc,write:notification,write:organization,write:package,write:issue,write:repository,write:user",
|
||||
"all",
|
||||
))
|
||||
return scope
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
// Copyright 2022 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type scopeTestNormalize struct {
|
||||
in AccessTokenScope
|
||||
out AccessTokenScope
|
||||
err error
|
||||
}
|
||||
|
||||
func TestAccessTokenScope_Normalize(t *testing.T) {
|
||||
assert.Equal(t, []string{"activitypub", "admin", "issue", "misc", "notification", "organization", "package", "repository", "user"}, GetAccessTokenCategories())
|
||||
tests := []scopeTestNormalize{
|
||||
{"", "", nil},
|
||||
{"write:misc,write:notification,read:package,write:notification,public-only", "public-only,write:misc,write:notification,read:package", nil},
|
||||
{"all", "all", nil},
|
||||
{"write:activitypub,write:admin,write:misc,write:notification,write:organization,write:package,write:issue,write:repository,write:user", "all", nil},
|
||||
{"write:activitypub,write:admin,write:misc,write:notification,write:organization,write:package,write:issue,write:repository,write:user,public-only", "public-only,all", nil},
|
||||
}
|
||||
|
||||
for _, scope := range GetAccessTokenCategories() {
|
||||
tests = append(tests,
|
||||
scopeTestNormalize{AccessTokenScope("read:" + scope), AccessTokenScope("read:" + scope), nil},
|
||||
scopeTestNormalize{AccessTokenScope("write:" + scope), AccessTokenScope("write:" + scope), nil},
|
||||
scopeTestNormalize{AccessTokenScope(fmt.Sprintf("write:%[1]s,read:%[1]s", scope)), AccessTokenScope("write:" + scope), nil},
|
||||
scopeTestNormalize{AccessTokenScope(fmt.Sprintf("read:%[1]s,write:%[1]s", scope)), AccessTokenScope("write:" + scope), nil},
|
||||
scopeTestNormalize{AccessTokenScope(fmt.Sprintf("read:%[1]s,write:%[1]s,write:%[1]s", scope)), AccessTokenScope("write:" + scope), nil},
|
||||
)
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(string(test.in), func(t *testing.T) {
|
||||
scope, err := test.in.Normalize()
|
||||
assert.Equal(t, test.out, scope)
|
||||
assert.Equal(t, test.err, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type scopeTestHasScope struct {
|
||||
in AccessTokenScope
|
||||
scope AccessTokenScope
|
||||
out bool
|
||||
err error
|
||||
}
|
||||
|
||||
func TestAccessTokenScope_HasScope(t *testing.T) {
|
||||
tests := []scopeTestHasScope{
|
||||
{"read:admin", "write:package", false, nil},
|
||||
{"all", "write:package", true, nil},
|
||||
{"write:package", "all", false, nil},
|
||||
{"public-only", "read:issue", false, nil},
|
||||
}
|
||||
|
||||
for _, scope := range GetAccessTokenCategories() {
|
||||
tests = append(tests,
|
||||
scopeTestHasScope{
|
||||
AccessTokenScope("read:" + scope),
|
||||
AccessTokenScope("read:" + scope), true, nil,
|
||||
},
|
||||
scopeTestHasScope{
|
||||
AccessTokenScope("write:" + scope),
|
||||
AccessTokenScope("write:" + scope), true, nil,
|
||||
},
|
||||
scopeTestHasScope{
|
||||
AccessTokenScope("write:" + scope),
|
||||
AccessTokenScope("read:" + scope), true, nil,
|
||||
},
|
||||
scopeTestHasScope{
|
||||
AccessTokenScope("read:" + scope),
|
||||
AccessTokenScope("write:" + scope), false, nil,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(string(test.in), func(t *testing.T) {
|
||||
hasScope, err := test.in.HasScope(test.scope)
|
||||
assert.Equal(t, test.out, hasScope)
|
||||
assert.Equal(t, test.err, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,132 @@
|
||||
// Copyright 2016 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
auth_model "gitea.dev/models/auth"
|
||||
"gitea.dev/models/db"
|
||||
"gitea.dev/models/unittest"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewAccessToken(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
token := &auth_model.AccessToken{
|
||||
UID: 3,
|
||||
Name: "Token C",
|
||||
}
|
||||
assert.NoError(t, auth_model.NewAccessToken(t.Context(), token))
|
||||
unittest.AssertExistsAndLoadBean(t, token)
|
||||
|
||||
invalidToken := &auth_model.AccessToken{
|
||||
ID: token.ID, // duplicate
|
||||
UID: 2,
|
||||
Name: "Token F",
|
||||
}
|
||||
assert.Error(t, auth_model.NewAccessToken(t.Context(), invalidToken))
|
||||
}
|
||||
|
||||
func TestAccessTokenByNameExists(t *testing.T) {
|
||||
name := "Token Gitea"
|
||||
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
token := &auth_model.AccessToken{
|
||||
UID: 3,
|
||||
Name: name,
|
||||
}
|
||||
|
||||
// Check to make sure it doesn't exists already
|
||||
exist, err := auth_model.AccessTokenByNameExists(t.Context(), token)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, exist)
|
||||
|
||||
// Save it to the database
|
||||
assert.NoError(t, auth_model.NewAccessToken(t.Context(), token))
|
||||
unittest.AssertExistsAndLoadBean(t, token)
|
||||
|
||||
// This token must be found by name in the DB now
|
||||
exist, err = auth_model.AccessTokenByNameExists(t.Context(), token)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, exist)
|
||||
|
||||
user4Token := &auth_model.AccessToken{
|
||||
UID: 4,
|
||||
Name: name,
|
||||
}
|
||||
|
||||
// Name matches but different user ID, this shouldn't exists in the
|
||||
// database
|
||||
exist, err = auth_model.AccessTokenByNameExists(t.Context(), user4Token)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, exist)
|
||||
}
|
||||
|
||||
func TestGetAccessTokenBySHA(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
token, err := auth_model.GetAccessTokenBySHA(t.Context(), "d2c6c1ba3890b309189a8e618c72a162e4efbf36")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(1), token.UID)
|
||||
assert.Equal(t, "Token A", token.Name)
|
||||
assert.Equal(t, "2b3668e11cb82d3af8c6e4524fc7841297668f5008d1626f0ad3417e9fa39af84c268248b78c481daa7e5dc437784003494f", token.TokenHash)
|
||||
assert.Equal(t, "e4efbf36", token.TokenLastEight)
|
||||
|
||||
_, err = auth_model.GetAccessTokenBySHA(t.Context(), "notahash")
|
||||
assert.Error(t, err)
|
||||
assert.True(t, auth_model.IsErrAccessTokenNotExist(err))
|
||||
|
||||
_, err = auth_model.GetAccessTokenBySHA(t.Context(), "")
|
||||
assert.Error(t, err)
|
||||
assert.True(t, auth_model.IsErrAccessTokenEmpty(err))
|
||||
}
|
||||
|
||||
func TestListAccessTokens(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
tokens, err := db.Find[auth_model.AccessToken](t.Context(), auth_model.ListAccessTokensOptions{UserID: 1})
|
||||
assert.NoError(t, err)
|
||||
if assert.Len(t, tokens, 2) {
|
||||
assert.Equal(t, int64(1), tokens[0].UID)
|
||||
assert.Equal(t, int64(1), tokens[1].UID)
|
||||
assert.Contains(t, []string{tokens[0].Name, tokens[1].Name}, "Token A")
|
||||
assert.Contains(t, []string{tokens[0].Name, tokens[1].Name}, "Token B")
|
||||
}
|
||||
|
||||
tokens, err = db.Find[auth_model.AccessToken](t.Context(), auth_model.ListAccessTokensOptions{UserID: 2})
|
||||
assert.NoError(t, err)
|
||||
if assert.Len(t, tokens, 1) {
|
||||
assert.Equal(t, int64(2), tokens[0].UID)
|
||||
assert.Equal(t, "Token A", tokens[0].Name)
|
||||
}
|
||||
|
||||
tokens, err = db.Find[auth_model.AccessToken](t.Context(), auth_model.ListAccessTokensOptions{UserID: 100})
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, tokens)
|
||||
}
|
||||
|
||||
func TestUpdateAccessToken(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
token, err := auth_model.GetAccessTokenBySHA(t.Context(), "4c6f36e6cf498e2a448662f915d932c09c5a146c")
|
||||
assert.NoError(t, err)
|
||||
token.Name = "Token Z"
|
||||
|
||||
assert.NoError(t, auth_model.UpdateAccessToken(t.Context(), token))
|
||||
unittest.AssertExistsAndLoadBean(t, token)
|
||||
}
|
||||
|
||||
func TestDeleteAccessTokenByID(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
|
||||
token, err := auth_model.GetAccessTokenBySHA(t.Context(), "4c6f36e6cf498e2a448662f915d932c09c5a146c")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(1), token.UID)
|
||||
|
||||
assert.NoError(t, auth_model.DeleteAccessTokenByID(t.Context(), token.ID, 1))
|
||||
unittest.AssertNotExistsBean(t, token)
|
||||
|
||||
err = auth_model.DeleteAccessTokenByID(t.Context(), 100, 100)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, auth_model.IsErrAccessTokenNotExist(err))
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
// Copyright 2023 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gitea.dev/models/db"
|
||||
"gitea.dev/modules/timeutil"
|
||||
"gitea.dev/modules/util"
|
||||
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
var ErrAuthTokenNotExist = util.NewNotExistErrorf("auth token does not exist")
|
||||
|
||||
type AuthToken struct { //nolint:revive // export stutter
|
||||
ID string `xorm:"pk"`
|
||||
TokenHash string
|
||||
UserID int64 `xorm:"INDEX"`
|
||||
ExpiresUnix timeutil.TimeStamp `xorm:"INDEX"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
db.RegisterModel(new(AuthToken))
|
||||
}
|
||||
|
||||
func InsertAuthToken(ctx context.Context, t *AuthToken) error {
|
||||
_, err := db.GetEngine(ctx).Insert(t)
|
||||
return err
|
||||
}
|
||||
|
||||
func GetAuthTokenByID(ctx context.Context, id string) (*AuthToken, error) {
|
||||
at := &AuthToken{}
|
||||
|
||||
has, err := db.GetEngine(ctx).ID(id).Get(at)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !has {
|
||||
return nil, ErrAuthTokenNotExist
|
||||
}
|
||||
return at, nil
|
||||
}
|
||||
|
||||
func UpdateAuthTokenByID(ctx context.Context, t *AuthToken) error {
|
||||
_, err := db.GetEngine(ctx).ID(t.ID).Cols("token_hash", "expires_unix").Update(t)
|
||||
return err
|
||||
}
|
||||
|
||||
func DeleteAuthTokenByID(ctx context.Context, id string) error {
|
||||
_, err := db.GetEngine(ctx).ID(id).Delete(&AuthToken{})
|
||||
return err
|
||||
}
|
||||
|
||||
func DeleteAuthTokensByUserID(ctx context.Context, uid int64) error {
|
||||
_, err := db.GetEngine(ctx).Where(builder.Eq{"user_id": uid}).Delete(&AuthToken{})
|
||||
return err
|
||||
}
|
||||
|
||||
func DeleteExpiredAuthTokens(ctx context.Context) error {
|
||||
_, err := db.GetEngine(ctx).Where(builder.Lt{"expires_unix": timeutil.TimeStampNow()}).Delete(&AuthToken{})
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
// Copyright 2020 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"gitea.dev/models/unittest"
|
||||
|
||||
_ "gitea.dev/models"
|
||||
_ "gitea.dev/models/actions"
|
||||
_ "gitea.dev/models/activities"
|
||||
_ "gitea.dev/models/auth"
|
||||
_ "gitea.dev/models/perm/access"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
unittest.MainTest(m)
|
||||
}
|
||||
@@ -0,0 +1,680 @@
|
||||
// Copyright 2019 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"encoding/base32"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gitea.dev/models/db"
|
||||
"gitea.dev/modules/container"
|
||||
"gitea.dev/modules/setting"
|
||||
"gitea.dev/modules/timeutil"
|
||||
"gitea.dev/modules/util"
|
||||
|
||||
uuid "github.com/google/uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"golang.org/x/oauth2"
|
||||
"xorm.io/builder"
|
||||
"xorm.io/xorm"
|
||||
)
|
||||
|
||||
// Authorization codes should expire within 10 minutes per https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2
|
||||
const oauth2AuthorizationCodeValidity = 10 * time.Minute
|
||||
|
||||
var (
|
||||
ErrOAuth2AuthorizationCodeInvalidated = errors.New("oauth2 authorization code already invalidated")
|
||||
ErrOAuth2GrantStaleCounter = errors.New("oauth2 grant state changed during token refresh")
|
||||
)
|
||||
|
||||
// OAuth2Application represents an OAuth2 client (RFC 6749)
|
||||
type OAuth2Application struct {
|
||||
ID int64 `xorm:"pk autoincr"`
|
||||
UID int64 `xorm:"INDEX"`
|
||||
Name string
|
||||
ClientID string `xorm:"unique"`
|
||||
ClientSecret string
|
||||
// OAuth defines both Confidential and Public client types
|
||||
// https://datatracker.ietf.org/doc/html/rfc6749#section-2.1
|
||||
// "Authorization servers MUST record the client type in the client registration details"
|
||||
// https://datatracker.ietf.org/doc/html/rfc8252#section-8.4
|
||||
ConfidentialClient bool `xorm:"NOT NULL DEFAULT TRUE"`
|
||||
SkipSecondaryAuthorization bool `xorm:"NOT NULL DEFAULT FALSE"`
|
||||
RedirectURIs []string `xorm:"redirect_uris JSON TEXT"`
|
||||
CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"`
|
||||
UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
db.RegisterModel(new(OAuth2Application))
|
||||
db.RegisterModel(new(OAuth2AuthorizationCode))
|
||||
db.RegisterModel(new(OAuth2Grant))
|
||||
}
|
||||
|
||||
type BuiltinOAuth2Application struct {
|
||||
ConfigName string
|
||||
DisplayName string
|
||||
RedirectURIs []string
|
||||
}
|
||||
|
||||
func BuiltinApplications() map[string]*BuiltinOAuth2Application {
|
||||
m := make(map[string]*BuiltinOAuth2Application)
|
||||
m["a4792ccc-144e-407e-86c9-5e7d8d9c3269"] = &BuiltinOAuth2Application{
|
||||
ConfigName: "git-credential-oauth",
|
||||
DisplayName: "git-credential-oauth",
|
||||
RedirectURIs: []string{"http://127.0.0.1", "https://127.0.0.1"},
|
||||
}
|
||||
m["e90ee53c-94e2-48ac-9358-a874fb9e0662"] = &BuiltinOAuth2Application{
|
||||
ConfigName: "git-credential-manager",
|
||||
DisplayName: "Git Credential Manager",
|
||||
RedirectURIs: []string{"http://127.0.0.1", "https://127.0.0.1"},
|
||||
}
|
||||
m["d57cb8c4-630c-4168-8324-ec79935e18d4"] = &BuiltinOAuth2Application{
|
||||
ConfigName: "tea",
|
||||
DisplayName: "tea",
|
||||
RedirectURIs: []string{"http://127.0.0.1", "https://127.0.0.1"},
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func Init(ctx context.Context) error {
|
||||
builtinApps := BuiltinApplications()
|
||||
var builtinAllClientIDs []string
|
||||
for clientID := range builtinApps {
|
||||
builtinAllClientIDs = append(builtinAllClientIDs, clientID)
|
||||
}
|
||||
|
||||
var registeredApps []*OAuth2Application
|
||||
if err := db.GetEngine(ctx).In("client_id", builtinAllClientIDs).Find(®isteredApps); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
clientIDsToAdd := container.Set[string]{}
|
||||
for _, configName := range setting.OAuth2.DefaultApplications {
|
||||
found := false
|
||||
for clientID, builtinApp := range builtinApps {
|
||||
if builtinApp.ConfigName == configName {
|
||||
clientIDsToAdd.Add(clientID) // add all user-configured apps to the "add" list
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return fmt.Errorf("unknown oauth2 application: %q", configName)
|
||||
}
|
||||
}
|
||||
clientIDsToDelete := container.Set[string]{}
|
||||
for _, app := range registeredApps {
|
||||
if !clientIDsToAdd.Contains(app.ClientID) {
|
||||
clientIDsToDelete.Add(app.ClientID) // if a registered app is not in the "add" list, it should be deleted
|
||||
}
|
||||
}
|
||||
for _, app := range registeredApps {
|
||||
clientIDsToAdd.Remove(app.ClientID) // no need to re-add existing (registered) apps, so remove them from the set
|
||||
}
|
||||
|
||||
for _, app := range registeredApps {
|
||||
if clientIDsToDelete.Contains(app.ClientID) {
|
||||
if err := deleteOAuth2Application(ctx, app.ID, 0); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
for clientID := range clientIDsToAdd {
|
||||
builtinApp := builtinApps[clientID]
|
||||
if err := db.Insert(ctx, &OAuth2Application{
|
||||
Name: builtinApp.DisplayName,
|
||||
ClientID: clientID,
|
||||
RedirectURIs: builtinApp.RedirectURIs,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TableName sets the table name to `oauth2_application`
|
||||
func (app *OAuth2Application) TableName() string {
|
||||
return "oauth2_application"
|
||||
}
|
||||
|
||||
// ContainsRedirectURI checks if redirectURI is allowed for app
|
||||
func (app *OAuth2Application) ContainsRedirectURI(redirectURI string) bool {
|
||||
// OAuth2 requires the redirect URI to be an exact match, no dynamic parts are allowed.
|
||||
// https://stackoverflow.com/questions/55524480/should-dynamic-query-parameters-be-present-in-the-redirection-uri-for-an-oauth2
|
||||
// https://www.rfc-editor.org/rfc/rfc6819#section-5.2.3.3
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
|
||||
// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-security-topics-12#section-3.1
|
||||
redirectCandidates := []string{redirectURI}
|
||||
if !app.ConfidentialClient {
|
||||
loopbackRedirect, ok := normalizePublicClientRedirectURI(redirectURI)
|
||||
if ok {
|
||||
redirectCandidates = append(redirectCandidates, loopbackRedirect)
|
||||
}
|
||||
}
|
||||
|
||||
for _, candidate := range redirectCandidates {
|
||||
normalizedCandidate := normalizeRedirectURIForComparison(candidate)
|
||||
for _, registeredURI := range app.RedirectURIs {
|
||||
if normalizeRedirectURIForComparison(registeredURI) == normalizedCandidate {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func normalizeRedirectURIForComparison(redirectURI string) string {
|
||||
return strings.TrimSuffix(util.ToLowerASCII(redirectURI), "/")
|
||||
}
|
||||
|
||||
func normalizePublicClientRedirectURI(redirectURI string) (string, bool) {
|
||||
parsedURI, err := url.Parse(redirectURI)
|
||||
if err != nil || parsedURI.Scheme != "http" || parsedURI.Port() == "" {
|
||||
return "", false
|
||||
}
|
||||
if ip := net.ParseIP(parsedURI.Hostname()); ip == nil || !ip.IsLoopback() {
|
||||
return "", false
|
||||
}
|
||||
parsedURI.Host = parsedURI.Hostname()
|
||||
return parsedURI.String(), true
|
||||
}
|
||||
|
||||
// Base32 characters, but lowercased.
|
||||
const lowerBase32Chars = "abcdefghijklmnopqrstuvwxyz234567"
|
||||
|
||||
// base32 encoder that uses lowered characters without padding.
|
||||
var base32Lower = base32.NewEncoding(lowerBase32Chars).WithPadding(base32.NoPadding)
|
||||
|
||||
// GenerateClientSecret will generate the client secret and returns the plaintext and saves the hash at the database
|
||||
func (app *OAuth2Application) GenerateClientSecret(ctx context.Context) (string, error) {
|
||||
rBytes := util.CryptoRandomBytes(32)
|
||||
// Add a prefix to the base32, this is in order to make it easier
|
||||
// for code scanners to grab sensitive tokens.
|
||||
clientSecret := "gto_" + base32Lower.EncodeToString(rBytes)
|
||||
|
||||
hashedSecret, err := bcrypt.GenerateFromPassword([]byte(clientSecret), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
app.ClientSecret = string(hashedSecret)
|
||||
if _, err := db.GetEngine(ctx).ID(app.ID).Cols("client_secret").Update(app); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return clientSecret, nil
|
||||
}
|
||||
|
||||
// ValidateClientSecret validates the given secret by the hash saved in database
|
||||
func (app *OAuth2Application) ValidateClientSecret(secret []byte) bool {
|
||||
return bcrypt.CompareHashAndPassword([]byte(app.ClientSecret), secret) == nil
|
||||
}
|
||||
|
||||
// GetGrantByUserID returns a OAuth2Grant by its user and application ID
|
||||
func (app *OAuth2Application) GetGrantByUserID(ctx context.Context, userID int64) (grant *OAuth2Grant, err error) {
|
||||
grant = new(OAuth2Grant)
|
||||
if has, err := db.GetEngine(ctx).Where("user_id = ? AND application_id = ?", userID, app.ID).Get(grant); err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, nil //nolint:nilnil // return nil to indicate that the object does not exist
|
||||
}
|
||||
return grant, nil
|
||||
}
|
||||
|
||||
// CreateGrant generates a grant for a user
|
||||
func (app *OAuth2Application) CreateGrant(ctx context.Context, userID int64, scope string) (*OAuth2Grant, error) {
|
||||
grant := &OAuth2Grant{
|
||||
ApplicationID: app.ID,
|
||||
UserID: userID,
|
||||
Scope: scope,
|
||||
}
|
||||
err := db.Insert(ctx, grant)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return grant, nil
|
||||
}
|
||||
|
||||
// GetOAuth2ApplicationByClientID returns the oauth2 application with the given client_id. Returns an error if not found.
|
||||
func GetOAuth2ApplicationByClientID(ctx context.Context, clientID string) (app *OAuth2Application, err error) {
|
||||
app = new(OAuth2Application)
|
||||
has, err := db.GetEngine(ctx).Where("client_id = ?", clientID).Get(app)
|
||||
if !has {
|
||||
return nil, ErrOAuthClientIDInvalid{ClientID: clientID}
|
||||
}
|
||||
return app, err
|
||||
}
|
||||
|
||||
// GetOAuth2ApplicationByID returns the oauth2 application with the given id. Returns an error if not found.
|
||||
func GetOAuth2ApplicationByID(ctx context.Context, id int64) (app *OAuth2Application, err error) {
|
||||
app = new(OAuth2Application)
|
||||
has, err := db.GetEngine(ctx).ID(id).Get(app)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !has {
|
||||
return nil, ErrOAuthApplicationNotFound{ID: id}
|
||||
}
|
||||
return app, nil
|
||||
}
|
||||
|
||||
// CreateOAuth2ApplicationOptions holds options to create an oauth2 application
|
||||
type CreateOAuth2ApplicationOptions struct {
|
||||
Name string
|
||||
UserID int64
|
||||
ConfidentialClient bool
|
||||
SkipSecondaryAuthorization bool
|
||||
RedirectURIs []string
|
||||
}
|
||||
|
||||
// CreateOAuth2Application inserts a new oauth2 application
|
||||
func CreateOAuth2Application(ctx context.Context, opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) {
|
||||
clientID := uuid.New().String()
|
||||
app := &OAuth2Application{
|
||||
UID: opts.UserID,
|
||||
Name: opts.Name,
|
||||
ClientID: clientID,
|
||||
RedirectURIs: opts.RedirectURIs,
|
||||
ConfidentialClient: opts.ConfidentialClient,
|
||||
SkipSecondaryAuthorization: opts.SkipSecondaryAuthorization,
|
||||
}
|
||||
if err := db.Insert(ctx, app); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return app, nil
|
||||
}
|
||||
|
||||
// UpdateOAuth2ApplicationOptions holds options to update an oauth2 application
|
||||
type UpdateOAuth2ApplicationOptions struct {
|
||||
ID int64
|
||||
Name string
|
||||
UserID int64
|
||||
ConfidentialClient bool
|
||||
SkipSecondaryAuthorization bool
|
||||
RedirectURIs []string
|
||||
}
|
||||
|
||||
// UpdateOAuth2Application updates an oauth2 application
|
||||
func UpdateOAuth2Application(ctx context.Context, opts UpdateOAuth2ApplicationOptions) (*OAuth2Application, error) {
|
||||
return db.WithTx2(ctx, func(ctx context.Context) (*OAuth2Application, error) {
|
||||
app, err := GetOAuth2ApplicationByID(ctx, opts.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if app.UID != opts.UserID {
|
||||
return nil, errors.New("UID mismatch")
|
||||
}
|
||||
builtinApps := BuiltinApplications()
|
||||
if _, builtin := builtinApps[app.ClientID]; builtin {
|
||||
return nil, fmt.Errorf("failed to edit OAuth2 application: application is locked: %s", app.ClientID)
|
||||
}
|
||||
|
||||
app.Name = opts.Name
|
||||
app.RedirectURIs = opts.RedirectURIs
|
||||
app.ConfidentialClient = opts.ConfidentialClient
|
||||
app.SkipSecondaryAuthorization = opts.SkipSecondaryAuthorization
|
||||
|
||||
if err = updateOAuth2Application(ctx, app); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
app.ClientSecret = ""
|
||||
|
||||
return app, nil
|
||||
})
|
||||
}
|
||||
|
||||
func updateOAuth2Application(ctx context.Context, app *OAuth2Application) error {
|
||||
if _, err := db.GetEngine(ctx).ID(app.ID).UseBool("confidential_client", "skip_secondary_authorization").Update(app); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func deleteOAuth2Application(ctx context.Context, id, userid int64) error {
|
||||
sess := db.GetEngine(ctx)
|
||||
// the userid could be 0 if the app is instance-wide
|
||||
if deleted, err := sess.Where(builder.Eq{"id": id, "uid": userid}).Delete(&OAuth2Application{}); err != nil {
|
||||
return err
|
||||
} else if deleted == 0 {
|
||||
return ErrOAuthApplicationNotFound{ID: id}
|
||||
}
|
||||
codes := make([]*OAuth2AuthorizationCode, 0)
|
||||
// delete correlating auth codes
|
||||
if err := sess.Join("INNER", "oauth2_grant",
|
||||
"oauth2_authorization_code.grant_id = oauth2_grant.id AND oauth2_grant.application_id = ?", id).Find(&codes); err != nil {
|
||||
return err
|
||||
}
|
||||
codeIDs := make([]int64, 0, len(codes))
|
||||
for _, grant := range codes {
|
||||
codeIDs = append(codeIDs, grant.ID)
|
||||
}
|
||||
|
||||
if _, err := sess.In("id", codeIDs).Delete(new(OAuth2AuthorizationCode)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := sess.Where("application_id = ?", id).Delete(new(OAuth2Grant)); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteOAuth2Application deletes the application with the given id and the grants and auth codes related to it. It checks if the userid was the creator of the app.
|
||||
func DeleteOAuth2Application(ctx context.Context, id, userid int64) error {
|
||||
return db.WithTx(ctx, func(ctx context.Context) error {
|
||||
app, err := GetOAuth2ApplicationByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
builtinApps := BuiltinApplications()
|
||||
if _, builtin := builtinApps[app.ClientID]; builtin {
|
||||
return fmt.Errorf("failed to delete OAuth2 application: application is locked: %s", app.ClientID)
|
||||
}
|
||||
return deleteOAuth2Application(ctx, id, userid)
|
||||
})
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////
|
||||
|
||||
// OAuth2AuthorizationCode is a code to obtain an access token in combination with the client secret once. It has a limited lifetime.
|
||||
type OAuth2AuthorizationCode struct {
|
||||
ID int64 `xorm:"pk autoincr"`
|
||||
Grant *OAuth2Grant `xorm:"-"`
|
||||
GrantID int64
|
||||
Code string `xorm:"INDEX unique"`
|
||||
CodeChallenge string
|
||||
CodeChallengeMethod string
|
||||
RedirectURI string
|
||||
ValidUntil timeutil.TimeStamp `xorm:"index"`
|
||||
}
|
||||
|
||||
// TableName sets the table name to `oauth2_authorization_code`
|
||||
func (code *OAuth2AuthorizationCode) TableName() string {
|
||||
return "oauth2_authorization_code"
|
||||
}
|
||||
|
||||
// IsExpired reports whether the authorization code is expired.
|
||||
func (code *OAuth2AuthorizationCode) IsExpired() bool {
|
||||
if code.ValidUntil.IsZero() {
|
||||
return true
|
||||
}
|
||||
return code.ValidUntil <= timeutil.TimeStampNow()
|
||||
}
|
||||
|
||||
// GenerateRedirectURI generates a redirect URI for a successful authorization request. State will be used if not empty.
|
||||
func (code *OAuth2AuthorizationCode) GenerateRedirectURI(state string) (*url.URL, error) {
|
||||
redirect, err := url.Parse(code.RedirectURI)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
q := redirect.Query()
|
||||
if state != "" {
|
||||
q.Set("state", state)
|
||||
}
|
||||
q.Set("code", code.Code)
|
||||
redirect.RawQuery = q.Encode()
|
||||
return redirect, err
|
||||
}
|
||||
|
||||
// Invalidate deletes the auth code from the database to invalidate this code
|
||||
func (code *OAuth2AuthorizationCode) Invalidate(ctx context.Context) error {
|
||||
affected, err := db.GetEngine(ctx).ID(code.ID).NoAutoCondition().Delete(code)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
return ErrOAuth2AuthorizationCodeInvalidated
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (code *OAuth2AuthorizationCode) requiresCodeVerifier() bool {
|
||||
return code.CodeChallengeMethod != "" || code.CodeChallenge != ""
|
||||
}
|
||||
|
||||
func deriveCodeChallenge(method, verifier string) (string, bool) {
|
||||
switch method {
|
||||
case "S256":
|
||||
return oauth2.S256ChallengeFromVerifier(verifier), true
|
||||
case "plain":
|
||||
return verifier, true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateCodeChallenge validates the given verifier against the saved code challenge. This is part of the PKCE implementation.
|
||||
func (code *OAuth2AuthorizationCode) ValidateCodeChallenge(verifier string) bool {
|
||||
if !code.requiresCodeVerifier() {
|
||||
return true
|
||||
}
|
||||
if verifier == "" || code.CodeChallengeMethod == "" {
|
||||
return false
|
||||
}
|
||||
expectedChallenge, ok := deriveCodeChallenge(code.CodeChallengeMethod, verifier)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return subtle.ConstantTimeCompare([]byte(expectedChallenge), []byte(code.CodeChallenge)) == 1
|
||||
}
|
||||
|
||||
// GetOAuth2AuthorizationByCode returns an authorization by its code
|
||||
func GetOAuth2AuthorizationByCode(ctx context.Context, code string) (auth *OAuth2AuthorizationCode, err error) {
|
||||
auth = new(OAuth2AuthorizationCode)
|
||||
if has, err := db.GetEngine(ctx).Where("code = ?", code).Get(auth); err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, nil //nolint:nilnil // return nil to indicate that the object does not exist
|
||||
}
|
||||
auth.Grant = new(OAuth2Grant)
|
||||
if has, err := db.GetEngine(ctx).ID(auth.GrantID).Get(auth.Grant); err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, nil //nolint:nilnil // return nil to indicate that the object does not exist
|
||||
}
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////
|
||||
|
||||
// OAuth2Grant represents the permission of a user for a specific application to access resources
|
||||
type OAuth2Grant struct {
|
||||
ID int64 `xorm:"pk autoincr"`
|
||||
UserID int64 `xorm:"INDEX unique(user_application)"`
|
||||
Application *OAuth2Application `xorm:"-"`
|
||||
ApplicationID int64 `xorm:"INDEX unique(user_application)"`
|
||||
Counter int64 `xorm:"NOT NULL DEFAULT 1"`
|
||||
Scope string `xorm:"TEXT"`
|
||||
Nonce string `xorm:"TEXT"`
|
||||
CreatedUnix timeutil.TimeStamp `xorm:"created"`
|
||||
UpdatedUnix timeutil.TimeStamp `xorm:"updated"`
|
||||
}
|
||||
|
||||
// TableName sets the table name to `oauth2_grant`
|
||||
func (grant *OAuth2Grant) TableName() string {
|
||||
return "oauth2_grant"
|
||||
}
|
||||
|
||||
// GenerateNewAuthorizationCode generates a new authorization code for a grant and saves it to the database
|
||||
func (grant *OAuth2Grant) GenerateNewAuthorizationCode(ctx context.Context, redirectURI, codeChallenge, codeChallengeMethod string) (code *OAuth2AuthorizationCode, err error) {
|
||||
rBytes := util.CryptoRandomBytes(32)
|
||||
// Add a prefix to the base32, this is in order to make it easier
|
||||
// for code scanners to grab sensitive tokens.
|
||||
codeSecret := "gta_" + base32Lower.EncodeToString(rBytes)
|
||||
|
||||
validUntil := time.Now().Add(oauth2AuthorizationCodeValidity)
|
||||
code = &OAuth2AuthorizationCode{
|
||||
Grant: grant,
|
||||
GrantID: grant.ID,
|
||||
RedirectURI: redirectURI,
|
||||
Code: codeSecret,
|
||||
CodeChallenge: codeChallenge,
|
||||
CodeChallengeMethod: codeChallengeMethod,
|
||||
ValidUntil: timeutil.TimeStamp(validUntil.Unix()),
|
||||
}
|
||||
if err := db.Insert(ctx, code); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return code, nil
|
||||
}
|
||||
|
||||
// IncreaseCounter increases the counter and updates the grant
|
||||
func (grant *OAuth2Grant) IncreaseCounter(ctx context.Context) error {
|
||||
affected, err := db.GetEngine(ctx).
|
||||
Where("id = ?", grant.ID).
|
||||
And("counter = ?", grant.Counter).
|
||||
Incr("counter").
|
||||
Update(new(OAuth2Grant))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
return ErrOAuth2GrantStaleCounter
|
||||
}
|
||||
grant.Counter++
|
||||
return nil
|
||||
}
|
||||
|
||||
// ScopeContains returns true if the grant scope contains the specified scope
|
||||
func (grant *OAuth2Grant) ScopeContains(scope string) bool {
|
||||
return slices.Contains(strings.Split(grant.Scope, " "), scope)
|
||||
}
|
||||
|
||||
// SetNonce updates the current nonce value of a grant
|
||||
func (grant *OAuth2Grant) SetNonce(ctx context.Context, nonce string) error {
|
||||
grant.Nonce = nonce
|
||||
_, err := db.GetEngine(ctx).ID(grant.ID).Cols("nonce").Update(grant)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetOAuth2GrantByID returns the grant with the given ID
|
||||
func GetOAuth2GrantByID(ctx context.Context, id int64) (grant *OAuth2Grant, err error) {
|
||||
grant = new(OAuth2Grant)
|
||||
if has, err := db.GetEngine(ctx).ID(id).Get(grant); err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, nil //nolint:nilnil // return nil to indicate that the object does not exist
|
||||
}
|
||||
return grant, err
|
||||
}
|
||||
|
||||
// GetOAuth2GrantsByUserID lists all grants of a certain user
|
||||
func GetOAuth2GrantsByUserID(ctx context.Context, uid int64) ([]*OAuth2Grant, error) {
|
||||
type joinedOAuth2Grant struct {
|
||||
Grant *OAuth2Grant `xorm:"extends"`
|
||||
Application *OAuth2Application `xorm:"extends"`
|
||||
}
|
||||
var results *xorm.Rows
|
||||
var err error
|
||||
if results, err = db.GetEngine(ctx).
|
||||
Table("oauth2_grant").
|
||||
Where("user_id = ?", uid).
|
||||
Join("INNER", "oauth2_application", "application_id = oauth2_application.id").
|
||||
Rows(new(joinedOAuth2Grant)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer results.Close()
|
||||
grants := make([]*OAuth2Grant, 0)
|
||||
for results.Next() {
|
||||
joinedGrant := new(joinedOAuth2Grant)
|
||||
if err := results.Scan(joinedGrant); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
joinedGrant.Grant.Application = joinedGrant.Application
|
||||
grants = append(grants, joinedGrant.Grant)
|
||||
}
|
||||
return grants, nil
|
||||
}
|
||||
|
||||
// RevokeOAuth2Grant deletes the grant with grantID and userID
|
||||
func RevokeOAuth2Grant(ctx context.Context, grantID, userID int64) error {
|
||||
_, err := db.GetEngine(ctx).Where(builder.Eq{"id": grantID, "user_id": userID}).Delete(&OAuth2Grant{})
|
||||
return err
|
||||
}
|
||||
|
||||
// ErrOAuthClientIDInvalid will be thrown if client id cannot be found
|
||||
type ErrOAuthClientIDInvalid struct {
|
||||
ClientID string
|
||||
}
|
||||
|
||||
// IsErrOauthClientIDInvalid checks if an error is a ErrOAuthClientIDInvalid.
|
||||
func IsErrOauthClientIDInvalid(err error) bool {
|
||||
_, ok := err.(ErrOAuthClientIDInvalid)
|
||||
return ok
|
||||
}
|
||||
|
||||
// Error returns the error message
|
||||
func (err ErrOAuthClientIDInvalid) Error() string {
|
||||
return fmt.Sprintf("Client ID invalid [Client ID: %s]", err.ClientID)
|
||||
}
|
||||
|
||||
// Unwrap unwraps this as a ErrNotExist err
|
||||
func (err ErrOAuthClientIDInvalid) Unwrap() error {
|
||||
return util.ErrNotExist
|
||||
}
|
||||
|
||||
// ErrOAuthApplicationNotFound will be thrown if id cannot be found
|
||||
type ErrOAuthApplicationNotFound struct {
|
||||
ID int64
|
||||
}
|
||||
|
||||
// IsErrOAuthApplicationNotFound checks if an error is a ErrReviewNotExist.
|
||||
func IsErrOAuthApplicationNotFound(err error) bool {
|
||||
_, ok := err.(ErrOAuthApplicationNotFound)
|
||||
return ok
|
||||
}
|
||||
|
||||
// Error returns the error message
|
||||
func (err ErrOAuthApplicationNotFound) Error() string {
|
||||
return fmt.Sprintf("OAuth application not found [ID: %d]", err.ID)
|
||||
}
|
||||
|
||||
// Unwrap unwraps this as a ErrNotExist err
|
||||
func (err ErrOAuthApplicationNotFound) Unwrap() error {
|
||||
return util.ErrNotExist
|
||||
}
|
||||
|
||||
// GetActiveOAuth2SourceByAuthName returns a OAuth2 AuthSource based on the given name
|
||||
func GetActiveOAuth2SourceByAuthName(ctx context.Context, name string) (*Source, error) {
|
||||
authSource := new(Source)
|
||||
has, err := db.GetEngine(ctx).Where("name = ? and type = ? and is_active = ?", name, OAuth2, true).Get(authSource)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !has {
|
||||
return nil, util.NewNotExistErrorf("oauth2 source not found, name: %q", name)
|
||||
}
|
||||
|
||||
return authSource, nil
|
||||
}
|
||||
|
||||
func DeleteOAuth2RelictsByUserID(ctx context.Context, userID int64) error {
|
||||
deleteCond := builder.Select("id").From("oauth2_grant").Where(builder.Eq{"oauth2_grant.user_id": userID})
|
||||
|
||||
if _, err := db.GetEngine(ctx).In("grant_id", deleteCond).
|
||||
Delete(&OAuth2AuthorizationCode{}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := db.DeleteBeans(ctx,
|
||||
&OAuth2Application{UID: userID},
|
||||
&OAuth2Grant{UserID: userID},
|
||||
); err != nil {
|
||||
return fmt.Errorf("DeleteBeans: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
// Copyright 2023 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"gitea.dev/models/db"
|
||||
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
type FindOAuth2ApplicationsOptions struct {
|
||||
db.ListOptions
|
||||
// OwnerID is the user id or org id of the owner of the application
|
||||
OwnerID int64
|
||||
// find global applications, if true, then OwnerID will be igonred
|
||||
IsGlobal bool
|
||||
}
|
||||
|
||||
func (opts FindOAuth2ApplicationsOptions) ToConds() builder.Cond {
|
||||
conds := builder.NewCond()
|
||||
if opts.IsGlobal {
|
||||
conds = conds.And(builder.Eq{"uid": 0})
|
||||
} else if opts.OwnerID != 0 {
|
||||
conds = conds.And(builder.Eq{"uid": opts.OwnerID})
|
||||
}
|
||||
return conds
|
||||
}
|
||||
|
||||
func (opts FindOAuth2ApplicationsOptions) ToOrders() string {
|
||||
return "id DESC"
|
||||
}
|
||||
@@ -0,0 +1,343 @@
|
||||
// Copyright 2019 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
auth_model "gitea.dev/models/auth"
|
||||
"gitea.dev/models/unittest"
|
||||
"gitea.dev/modules/timeutil"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func TestOAuth2AuthorizationCode(t *testing.T) {
|
||||
require.NoError(t, unittest.PrepareTestDatabase())
|
||||
|
||||
t.Run("GenerateSetsValidUntil", func(t *testing.T) {
|
||||
grant := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Grant{ID: 1})
|
||||
expectedValidUntil := timeutil.TimeStamp(time.Now().Unix() + 600)
|
||||
code, err := grant.GenerateNewAuthorizationCode(t.Context(), "http://127.0.0.1/", "", "")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedValidUntil, code.ValidUntil)
|
||||
assert.False(t, code.IsExpired())
|
||||
assert.Equal(t, int64(1), code.ID)
|
||||
|
||||
code2, err := auth_model.GetOAuth2AuthorizationByCode(t.Context(), code.Code)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, code.Code, code2.Code)
|
||||
|
||||
assert.NoError(t, code.Invalidate(t.Context()))
|
||||
|
||||
code, err = auth_model.GetOAuth2AuthorizationByCode(t.Context(), "does not exist")
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, code)
|
||||
})
|
||||
|
||||
t.Run("Expired", func(t *testing.T) {
|
||||
defer timeutil.MockSet(time.Unix(2, 0).UTC())()
|
||||
|
||||
code := &auth_model.OAuth2AuthorizationCode{ValidUntil: timeutil.TimeStamp(1)}
|
||||
assert.True(t, code.IsExpired())
|
||||
})
|
||||
|
||||
t.Run("Invalidate", func(t *testing.T) {
|
||||
grant := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Grant{ID: 1})
|
||||
code, err := grant.GenerateNewAuthorizationCode(t.Context(), "http://127.0.0.1/", "", "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, code)
|
||||
require.NoError(t, code.Invalidate(t.Context()))
|
||||
unittest.AssertNotExistsBean(t, &auth_model.OAuth2AuthorizationCode{Code: code.Code})
|
||||
assert.ErrorIs(t, code.Invalidate(t.Context()), auth_model.ErrOAuth2AuthorizationCodeInvalidated)
|
||||
})
|
||||
}
|
||||
|
||||
func TestOAuth2Application_GenerateClientSecret(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
app := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{ID: 1})
|
||||
secret, err := app.GenerateClientSecret(t.Context())
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, secret)
|
||||
unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{ID: 1, ClientSecret: app.ClientSecret})
|
||||
}
|
||||
|
||||
func BenchmarkOAuth2Application_GenerateClientSecret(b *testing.B) {
|
||||
assert.NoError(b, unittest.PrepareTestDatabase())
|
||||
app := unittest.AssertExistsAndLoadBean(b, &auth_model.OAuth2Application{ID: 1})
|
||||
for b.Loop() {
|
||||
_, _ = app.GenerateClientSecret(b.Context())
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuth2Application_ContainsRedirectURI(t *testing.T) {
|
||||
app := &auth_model.OAuth2Application{
|
||||
RedirectURIs: []string{"a", "b", "c"},
|
||||
}
|
||||
assert.True(t, app.ContainsRedirectURI("a"))
|
||||
assert.True(t, app.ContainsRedirectURI("b"))
|
||||
assert.True(t, app.ContainsRedirectURI("c"))
|
||||
assert.False(t, app.ContainsRedirectURI("d"))
|
||||
}
|
||||
|
||||
func TestOAuth2Application_ContainsRedirectURI_WithPort(t *testing.T) {
|
||||
app := &auth_model.OAuth2Application{
|
||||
RedirectURIs: []string{"http://127.0.0.1/", "http://::1/", "http://192.168.0.1/", "http://intranet/", "https://127.0.0.1/"},
|
||||
ConfidentialClient: false,
|
||||
}
|
||||
|
||||
// http loopback uris should ignore port
|
||||
// https://datatracker.ietf.org/doc/html/rfc8252#section-7.3
|
||||
assert.True(t, app.ContainsRedirectURI("http://127.0.0.1:3456/"))
|
||||
assert.True(t, app.ContainsRedirectURI("http://127.0.0.1/"))
|
||||
assert.True(t, app.ContainsRedirectURI("http://[::1]:3456/"))
|
||||
|
||||
// not http
|
||||
assert.False(t, app.ContainsRedirectURI("https://127.0.0.1:3456/"))
|
||||
// not loopback
|
||||
assert.False(t, app.ContainsRedirectURI("http://192.168.0.1:9954/"))
|
||||
assert.False(t, app.ContainsRedirectURI("http://intranet:3456/"))
|
||||
// unparseable
|
||||
assert.False(t, app.ContainsRedirectURI(":"))
|
||||
}
|
||||
|
||||
func TestOAuth2Application_ContainsRedirect_Slash(t *testing.T) {
|
||||
app := &auth_model.OAuth2Application{RedirectURIs: []string{"http://127.0.0.1"}}
|
||||
assert.True(t, app.ContainsRedirectURI("http://127.0.0.1"))
|
||||
assert.True(t, app.ContainsRedirectURI("http://127.0.0.1/"))
|
||||
assert.False(t, app.ContainsRedirectURI("http://127.0.0.1/other"))
|
||||
|
||||
app = &auth_model.OAuth2Application{RedirectURIs: []string{"http://127.0.0.1/"}}
|
||||
assert.True(t, app.ContainsRedirectURI("http://127.0.0.1"))
|
||||
assert.True(t, app.ContainsRedirectURI("http://127.0.0.1/"))
|
||||
assert.False(t, app.ContainsRedirectURI("http://127.0.0.1/other"))
|
||||
}
|
||||
|
||||
func TestOAuth2Application_ContainsRedirectURI_ASCIIOnlyNormalization(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
registered []string
|
||||
redirectURI string
|
||||
allowed bool
|
||||
}{
|
||||
{
|
||||
name: "exact-match",
|
||||
registered: []string{"https://signin.example.test/callback"},
|
||||
redirectURI: "https://signin.example.test/callback",
|
||||
allowed: true,
|
||||
},
|
||||
{
|
||||
name: "ascii-case-insensitive",
|
||||
registered: []string{"https://signin.example.test/callback"},
|
||||
redirectURI: "https://signIN.example.test/callback",
|
||||
allowed: true,
|
||||
},
|
||||
{
|
||||
name: "non-ascii-not-folded",
|
||||
registered: []string{"https://signin.example.test/callback"},
|
||||
redirectURI: "https://signİn.example.test/callback",
|
||||
allowed: false,
|
||||
},
|
||||
{
|
||||
name: "loopback-strips-port",
|
||||
registered: []string{"http://127.0.0.1/callback"},
|
||||
redirectURI: "http://127.0.0.1:12345/callback",
|
||||
allowed: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
app := &auth_model.OAuth2Application{RedirectURIs: tc.registered}
|
||||
assert.Equal(t, tc.allowed, app.ContainsRedirectURI(tc.redirectURI))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuth2Application_ValidateClientSecret(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
app := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{ID: 1})
|
||||
secret, err := app.GenerateClientSecret(t.Context())
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, app.ValidateClientSecret([]byte(secret)))
|
||||
assert.False(t, app.ValidateClientSecret([]byte("fewijfowejgfiowjeoifew")))
|
||||
}
|
||||
|
||||
func TestGetOAuth2ApplicationByClientID(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
app, err := auth_model.GetOAuth2ApplicationByClientID(t.Context(), "da7da3ba-9a13-4167-856f-3899de0b0138")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "da7da3ba-9a13-4167-856f-3899de0b0138", app.ClientID)
|
||||
|
||||
app, err = auth_model.GetOAuth2ApplicationByClientID(t.Context(), "invalid client id")
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, app)
|
||||
}
|
||||
|
||||
func TestCreateOAuth2Application(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
app, err := auth_model.CreateOAuth2Application(t.Context(), auth_model.CreateOAuth2ApplicationOptions{Name: "newapp", UserID: 1})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "newapp", app.Name)
|
||||
assert.Len(t, app.ClientID, 36)
|
||||
unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{Name: "newapp"})
|
||||
}
|
||||
|
||||
func TestOAuth2Application_TableName(t *testing.T) {
|
||||
assert.Equal(t, "oauth2_application", new(auth_model.OAuth2Application).TableName())
|
||||
}
|
||||
|
||||
func TestOAuth2Application_GetGrantByUserID(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
app := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{ID: 1})
|
||||
grant, err := app.GetGrantByUserID(t.Context(), 1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(1), grant.UserID)
|
||||
|
||||
grant, err = app.GetGrantByUserID(t.Context(), 34923458)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, grant)
|
||||
}
|
||||
|
||||
func TestOAuth2Application_CreateGrant(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
app := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{ID: 1})
|
||||
grant, err := app.CreateGrant(t.Context(), 2, "")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, grant)
|
||||
assert.Equal(t, int64(2), grant.UserID)
|
||||
assert.Equal(t, int64(1), grant.ApplicationID)
|
||||
assert.Empty(t, grant.Scope)
|
||||
}
|
||||
|
||||
//////////////////// Grant
|
||||
|
||||
func TestGetOAuth2GrantByID(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
grant, err := auth_model.GetOAuth2GrantByID(t.Context(), 1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(1), grant.ID)
|
||||
|
||||
grant, err = auth_model.GetOAuth2GrantByID(t.Context(), 34923458)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, grant)
|
||||
}
|
||||
|
||||
func TestOAuth2Grant_IncreaseCounter(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
grant := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Grant{ID: 1, Counter: 1})
|
||||
assert.NoError(t, grant.IncreaseCounter(t.Context()))
|
||||
assert.Equal(t, int64(2), grant.Counter)
|
||||
unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Grant{ID: 1, Counter: 2})
|
||||
}
|
||||
|
||||
func TestOAuth2Grant_IncreaseCounterRejectsStaleCounter(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
grant := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Grant{ID: 1, Counter: 1})
|
||||
stale := *grant
|
||||
|
||||
assert.NoError(t, grant.IncreaseCounter(t.Context()))
|
||||
err := stale.IncreaseCounter(t.Context())
|
||||
assert.ErrorIs(t, err, auth_model.ErrOAuth2GrantStaleCounter)
|
||||
}
|
||||
|
||||
func TestOAuth2Grant_ScopeContains(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
grant := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Grant{ID: 1, Scope: "openid profile"})
|
||||
assert.True(t, grant.ScopeContains("openid"))
|
||||
assert.True(t, grant.ScopeContains("profile"))
|
||||
assert.False(t, grant.ScopeContains("profil"))
|
||||
assert.False(t, grant.ScopeContains("profile2"))
|
||||
}
|
||||
|
||||
func TestOAuth2Grant_GenerateNewAuthorizationCode(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
grant := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Grant{ID: 1})
|
||||
code, err := grant.GenerateNewAuthorizationCode(t.Context(), "https://example2.com/callback", "CjvyTLSdR47G5zYenDA-eDWW4lRrO8yvjcWwbD_deOg", "S256")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, code)
|
||||
assert.Greater(t, len(code.Code), 32) // secret length > 32
|
||||
}
|
||||
|
||||
func TestOAuth2Grant_TableName(t *testing.T) {
|
||||
assert.Equal(t, "oauth2_grant", new(auth_model.OAuth2Grant).TableName())
|
||||
}
|
||||
|
||||
func TestGetOAuth2GrantsByUserID(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
result, err := auth_model.GetOAuth2GrantsByUserID(t.Context(), 1)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, result, 1)
|
||||
assert.Equal(t, int64(1), result[0].ID)
|
||||
assert.Equal(t, result[0].ApplicationID, result[0].Application.ID)
|
||||
|
||||
result, err = auth_model.GetOAuth2GrantsByUserID(t.Context(), 34134)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, result)
|
||||
}
|
||||
|
||||
func TestRevokeOAuth2Grant(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
assert.NoError(t, auth_model.RevokeOAuth2Grant(t.Context(), 1, 1))
|
||||
unittest.AssertNotExistsBean(t, &auth_model.OAuth2Grant{ID: 1, UserID: 1})
|
||||
}
|
||||
|
||||
//////////////////// Authorization Code
|
||||
|
||||
func TestOAuth2AuthorizationCode_ValidateCodeChallenge(t *testing.T) {
|
||||
s256Verifier := "s256-verifier"
|
||||
s256Challenge := oauth2.S256ChallengeFromVerifier(s256Verifier)
|
||||
missingVerifierChallenge := oauth2.S256ChallengeFromVerifier("verifier-not-supplied")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
method string
|
||||
challenge string
|
||||
verifier string
|
||||
valid bool
|
||||
}{
|
||||
{"plain-success", "plain", "plain-secret", "plain-secret", true},
|
||||
{"plain-failure", "plain", "plain-secret", "ierwgjoergjio", false},
|
||||
{"s256-success", "S256", s256Challenge, s256Verifier, true},
|
||||
{"s256-failure", "S256", s256Challenge, "wiogjerogorewngoenrgoiuenorg", false},
|
||||
{"unsupported-method", "monkey", "foiwgjioriogeiogjerger", "foiwgjioriogeiogjerger", false},
|
||||
{"no-pkce-configured", "", "", "", true},
|
||||
{"s256-missing-verifier", "S256", missingVerifierChallenge, "", false},
|
||||
{"plain-missing-verifier", "plain", "plain-secret", "", false},
|
||||
{"missing-method-with-challenge", "", "foierjiogerogerg", "", false},
|
||||
{"missing-method-rejects-even-matching-verifier", "", "foierjiogerogerg", "foierjiogerogerg", false},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
code := &auth_model.OAuth2AuthorizationCode{
|
||||
CodeChallengeMethod: tc.method,
|
||||
CodeChallenge: tc.challenge,
|
||||
}
|
||||
assert.Equal(t, tc.valid, code.ValidateCodeChallenge(tc.verifier))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuth2AuthorizationCode_GenerateRedirectURI(t *testing.T) {
|
||||
code := &auth_model.OAuth2AuthorizationCode{
|
||||
RedirectURI: "https://example.com/callback",
|
||||
Code: "thecode",
|
||||
}
|
||||
|
||||
redirect, err := code.GenerateRedirectURI("thestate")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "https://example.com/callback?code=thecode&state=thestate", redirect.String())
|
||||
|
||||
redirect, err = code.GenerateRedirectURI("")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "https://example.com/callback?code=thecode", redirect.String())
|
||||
}
|
||||
|
||||
func TestOAuth2AuthorizationCode_TableName(t *testing.T) {
|
||||
assert.Equal(t, "oauth2_authorization_code", new(auth_model.OAuth2AuthorizationCode).TableName())
|
||||
}
|
||||
@@ -0,0 +1,112 @@
|
||||
// Copyright 2020 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"gitea.dev/models/db"
|
||||
"gitea.dev/modules/timeutil"
|
||||
|
||||
"xorm.io/builder"
|
||||
)
|
||||
|
||||
// Session represents a session compatible for go-chi session
|
||||
type Session struct {
|
||||
Key string `xorm:"pk CHAR(16)"` // has to be Key to match with go-chi/session
|
||||
Data []byte `xorm:"BLOB"` // on MySQL this has a maximum size of 64Kb - this may need to be increased
|
||||
Expiry timeutil.TimeStamp // has to be Expiry to match with go-chi/session
|
||||
}
|
||||
|
||||
func init() {
|
||||
db.RegisterModel(new(Session))
|
||||
}
|
||||
|
||||
// UpdateSession updates the session with provided id
|
||||
func UpdateSession(ctx context.Context, key string, data []byte) error {
|
||||
_, err := db.GetEngine(ctx).ID(key).Update(&Session{
|
||||
Data: data,
|
||||
Expiry: timeutil.TimeStampNow(),
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// ReadSession reads the data for the provided session
|
||||
func ReadSession(ctx context.Context, key string) (*Session, error) {
|
||||
return db.WithTx2(ctx, func(ctx context.Context) (*Session, error) {
|
||||
session, exist, err := db.Get[Session](ctx, builder.Eq{"`key`": key})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if !exist {
|
||||
session = &Session{
|
||||
Key: key,
|
||||
Expiry: timeutil.TimeStampNow(),
|
||||
}
|
||||
if err := db.Insert(ctx, session); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return session, nil
|
||||
})
|
||||
}
|
||||
|
||||
// ExistSession checks if a session exists
|
||||
func ExistSession(ctx context.Context, key string) (bool, error) {
|
||||
return db.Exist[Session](ctx, builder.Eq{"`key`": key})
|
||||
}
|
||||
|
||||
// DestroySession destroys a session
|
||||
func DestroySession(ctx context.Context, key string) error {
|
||||
_, err := db.GetEngine(ctx).Delete(&Session{
|
||||
Key: key,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// RegenerateSession regenerates a session from the old id
|
||||
func RegenerateSession(ctx context.Context, oldKey, newKey string) (*Session, error) {
|
||||
return db.WithTx2(ctx, func(ctx context.Context) (*Session, error) {
|
||||
if has, err := db.Exist[Session](ctx, builder.Eq{"`key`": newKey}); err != nil {
|
||||
return nil, err
|
||||
} else if has {
|
||||
return nil, fmt.Errorf("session Key: %s already exists", newKey)
|
||||
}
|
||||
|
||||
if has, err := db.Exist[Session](ctx, builder.Eq{"`key`": oldKey}); err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
if err := db.Insert(ctx, &Session{
|
||||
Key: oldKey,
|
||||
Expiry: timeutil.TimeStampNow(),
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := db.Exec(ctx, "UPDATE `session` SET `key` = ? WHERE `key`=?", newKey, oldKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s, _, err := db.Get[Session](ctx, builder.Eq{"`key`": newKey})
|
||||
if err != nil {
|
||||
// is not exist, it should be impossible
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, nil
|
||||
})
|
||||
}
|
||||
|
||||
// CountSessions returns the number of sessions
|
||||
func CountSessions(ctx context.Context) (int64, error) {
|
||||
return db.GetEngine(ctx).Count(&Session{})
|
||||
}
|
||||
|
||||
// CleanupSessions cleans up expired sessions
|
||||
func CleanupSessions(ctx context.Context, maxLifetime int64) error {
|
||||
_, err := db.GetEngine(ctx).Where("expiry <= ?", timeutil.TimeStampNow().Add(-maxLifetime)).Delete(&Session{})
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,402 @@
|
||||
// Copyright 2014 The Gogs Authors. All rights reserved.
|
||||
// Copyright 2019 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"gitea.dev/models/db"
|
||||
"gitea.dev/modules/log"
|
||||
"gitea.dev/modules/optional"
|
||||
"gitea.dev/modules/setting"
|
||||
"gitea.dev/modules/timeutil"
|
||||
"gitea.dev/modules/util"
|
||||
|
||||
"xorm.io/builder"
|
||||
"xorm.io/xorm"
|
||||
"xorm.io/xorm/convert"
|
||||
)
|
||||
|
||||
// Type represents an login type.
|
||||
type Type int
|
||||
|
||||
// Note: new type must append to the end of list to maintain compatibility.
|
||||
const (
|
||||
NoType Type = iota
|
||||
Plain // 1
|
||||
LDAP // 2
|
||||
SMTP // 3
|
||||
PAM // 4
|
||||
DLDAP // 5
|
||||
OAuth2 // 6
|
||||
SSPI // 7
|
||||
)
|
||||
|
||||
// String returns the string name of the LoginType
|
||||
func (typ Type) String() string {
|
||||
return Names[typ]
|
||||
}
|
||||
|
||||
// Int returns the int value of the LoginType
|
||||
func (typ Type) Int() int {
|
||||
return int(typ)
|
||||
}
|
||||
|
||||
// Names contains the name of LoginType values.
|
||||
var Names = map[Type]string{
|
||||
LDAP: "LDAP (via BindDN)",
|
||||
DLDAP: "LDAP (simple auth)", // Via direct bind
|
||||
SMTP: "SMTP",
|
||||
PAM: "PAM",
|
||||
OAuth2: "OAuth2",
|
||||
SSPI: "SPNEGO with SSPI",
|
||||
}
|
||||
|
||||
// Config represents login config as far as the db is concerned
|
||||
type Config interface {
|
||||
convert.Conversion
|
||||
SetAuthSource(*Source)
|
||||
}
|
||||
|
||||
type ConfigBase struct {
|
||||
AuthSource *Source
|
||||
}
|
||||
|
||||
func (p *ConfigBase) SetAuthSource(s *Source) {
|
||||
p.AuthSource = s
|
||||
}
|
||||
|
||||
// SkipVerifiable configurations provide a IsSkipVerify to check if SkipVerify is set
|
||||
type SkipVerifiable interface {
|
||||
IsSkipVerify() bool
|
||||
}
|
||||
|
||||
// HasTLSer configurations provide a HasTLS to check if TLS can be enabled
|
||||
type HasTLSer interface {
|
||||
HasTLS() bool
|
||||
}
|
||||
|
||||
// UseTLSer configurations provide a HasTLS to check if TLS is enabled
|
||||
type UseTLSer interface {
|
||||
UseTLS() bool
|
||||
}
|
||||
|
||||
// SSHKeyProvider configurations provide ProvidesSSHKeys to check if they provide SSHKeys
|
||||
type SSHKeyProvider interface {
|
||||
ProvidesSSHKeys() bool
|
||||
}
|
||||
|
||||
// RegisterableSource configurations provide RegisterSource which needs to be run on creation
|
||||
type RegisterableSource interface {
|
||||
RegisterSource() error
|
||||
UnregisterSource() error
|
||||
}
|
||||
|
||||
var registeredConfigs = map[Type]func() Config{}
|
||||
|
||||
// RegisterTypeConfig register a config for a provided type
|
||||
func RegisterTypeConfig(typ Type, exemplar Config) {
|
||||
if reflect.TypeOf(exemplar).Kind() == reflect.Pointer {
|
||||
// Pointer:
|
||||
registeredConfigs[typ] = func() Config {
|
||||
return reflect.New(reflect.ValueOf(exemplar).Elem().Type()).Interface().(Config)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Not a Pointer
|
||||
registeredConfigs[typ] = func() Config {
|
||||
return reflect.New(reflect.TypeOf(exemplar)).Elem().Interface().(Config)
|
||||
}
|
||||
}
|
||||
|
||||
// Source represents an external way for authorizing users.
|
||||
type Source struct {
|
||||
ID int64 `xorm:"pk autoincr"`
|
||||
Type Type
|
||||
Name string `xorm:"UNIQUE"` // it can be the OIDC's provider name, see services/auth/source/oauth2/source_register.go: RegisterSource
|
||||
IsActive bool `xorm:"INDEX NOT NULL DEFAULT false"`
|
||||
IsSyncEnabled bool `xorm:"INDEX NOT NULL DEFAULT false"`
|
||||
TwoFactorPolicy string `xorm:"two_factor_policy NOT NULL DEFAULT ''"`
|
||||
Cfg Config `xorm:"TEXT"`
|
||||
|
||||
CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"`
|
||||
UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
|
||||
}
|
||||
|
||||
// TableName xorm will read the table name from this method
|
||||
func (Source) TableName() string {
|
||||
return "login_source"
|
||||
}
|
||||
|
||||
func init() {
|
||||
db.RegisterModel(new(Source))
|
||||
}
|
||||
|
||||
// BeforeSet is invoked from XORM before setting the value of a field of this object.
|
||||
func (source *Source) BeforeSet(colName string, val xorm.Cell) {
|
||||
if colName == "type" {
|
||||
typ, _, err := db.CellToInt(val, NoType)
|
||||
if err != nil {
|
||||
setting.PanicInDevOrTesting("Unable to convert login source (id=%d) type: %v", source.ID, err)
|
||||
}
|
||||
constructor, ok := registeredConfigs[typ]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
source.Cfg = constructor()
|
||||
source.Cfg.SetAuthSource(source)
|
||||
}
|
||||
}
|
||||
|
||||
// TypeName return name of this login source type.
|
||||
func (source *Source) TypeName() string {
|
||||
return Names[source.Type]
|
||||
}
|
||||
|
||||
// IsLDAP returns true of this source is of the LDAP type.
|
||||
func (source *Source) IsLDAP() bool {
|
||||
return source.Type == LDAP
|
||||
}
|
||||
|
||||
// IsDLDAP returns true of this source is of the DLDAP type.
|
||||
func (source *Source) IsDLDAP() bool {
|
||||
return source.Type == DLDAP
|
||||
}
|
||||
|
||||
// IsSMTP returns true of this source is of the SMTP type.
|
||||
func (source *Source) IsSMTP() bool {
|
||||
return source.Type == SMTP
|
||||
}
|
||||
|
||||
// IsPAM returns true of this source is of the PAM type.
|
||||
func (source *Source) IsPAM() bool {
|
||||
return source.Type == PAM
|
||||
}
|
||||
|
||||
// IsOAuth2 returns true of this source is of the OAuth2 type.
|
||||
func (source *Source) IsOAuth2() bool {
|
||||
return source.Type == OAuth2
|
||||
}
|
||||
|
||||
// IsSSPI returns true of this source is of the SSPI type.
|
||||
func (source *Source) IsSSPI() bool {
|
||||
return source.Type == SSPI
|
||||
}
|
||||
|
||||
// HasTLS returns true of this source supports TLS.
|
||||
func (source *Source) HasTLS() bool {
|
||||
hasTLSer, ok := source.Cfg.(HasTLSer)
|
||||
return ok && hasTLSer.HasTLS()
|
||||
}
|
||||
|
||||
// UseTLS returns true of this source is configured to use TLS.
|
||||
func (source *Source) UseTLS() bool {
|
||||
useTLSer, ok := source.Cfg.(UseTLSer)
|
||||
return ok && useTLSer.UseTLS()
|
||||
}
|
||||
|
||||
// SkipVerify returns true if this source is configured to skip SSL
|
||||
// verification.
|
||||
func (source *Source) SkipVerify() bool {
|
||||
skipVerifiable, ok := source.Cfg.(SkipVerifiable)
|
||||
return ok && skipVerifiable.IsSkipVerify()
|
||||
}
|
||||
|
||||
func (source *Source) TwoFactorShouldSkip() bool {
|
||||
return source.TwoFactorPolicy == "skip"
|
||||
}
|
||||
|
||||
// CreateSource inserts a AuthSource in the DB if not already
|
||||
// existing with the given name.
|
||||
func CreateSource(ctx context.Context, source *Source) error {
|
||||
has, err := db.GetEngine(ctx).Where("name=?", source.Name).Exist(new(Source))
|
||||
if err != nil {
|
||||
return err
|
||||
} else if has {
|
||||
return ErrSourceAlreadyExist{source.Name}
|
||||
}
|
||||
// Synchronization is only available with LDAP for now
|
||||
if !source.IsLDAP() && !source.IsOAuth2() {
|
||||
source.IsSyncEnabled = false
|
||||
}
|
||||
|
||||
_, err = db.GetEngine(ctx).Insert(source)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !source.IsActive {
|
||||
return nil
|
||||
}
|
||||
|
||||
source.Cfg.SetAuthSource(source)
|
||||
|
||||
registerableSource, ok := source.Cfg.(RegisterableSource)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = registerableSource.RegisterSource()
|
||||
if err != nil {
|
||||
// remove the AuthSource in case of errors while registering configuration
|
||||
if _, err := db.GetEngine(ctx).ID(source.ID).Delete(new(Source)); err != nil {
|
||||
log.Error("CreateSource: Error while wrapOpenIDConnectInitializeError: %v", err)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
type FindSourcesOptions struct {
|
||||
db.ListOptions
|
||||
IsActive optional.Option[bool]
|
||||
LoginType Type
|
||||
}
|
||||
|
||||
func (opts FindSourcesOptions) ToConds() builder.Cond {
|
||||
conds := builder.NewCond()
|
||||
if opts.IsActive.Has() {
|
||||
conds = conds.And(builder.Eq{"is_active": opts.IsActive.Value()})
|
||||
}
|
||||
if opts.LoginType != NoType {
|
||||
conds = conds.And(builder.Eq{"`type`": opts.LoginType})
|
||||
}
|
||||
return conds
|
||||
}
|
||||
|
||||
// IsSSPIEnabled returns true if there is at least one activated login
|
||||
// source of type LoginSSPI
|
||||
func IsSSPIEnabled(ctx context.Context) bool {
|
||||
exist, err := db.Exist[Source](ctx, FindSourcesOptions{
|
||||
IsActive: optional.Some(true),
|
||||
LoginType: SSPI,
|
||||
}.ToConds())
|
||||
if err != nil {
|
||||
log.Error("IsSSPIEnabled: failed to query active SSPI sources: %v", err)
|
||||
return false
|
||||
}
|
||||
return exist
|
||||
}
|
||||
|
||||
// GetSourceByID returns login source by given ID.
|
||||
func GetSourceByID(ctx context.Context, id int64) (*Source, error) {
|
||||
source := new(Source)
|
||||
if id == 0 {
|
||||
source.Cfg = registeredConfigs[NoType]()
|
||||
// Set this source to active
|
||||
// FIXME: allow disabling of db based password authentication in future
|
||||
source.IsActive = true
|
||||
return source, nil
|
||||
}
|
||||
|
||||
has, err := db.GetEngine(ctx).ID(id).Get(source)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, ErrSourceNotExist{id}
|
||||
}
|
||||
return source, nil
|
||||
}
|
||||
|
||||
// UpdateSource updates a Source record in DB.
|
||||
func UpdateSource(ctx context.Context, source *Source) error {
|
||||
var originalSource *Source
|
||||
if source.IsOAuth2() {
|
||||
// keep track of the original values so we can restore in case of errors while registering OAuth2 providers
|
||||
var err error
|
||||
if originalSource, err = GetSourceByID(ctx, source.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
has, err := db.GetEngine(ctx).Where("name=? AND id!=?", source.Name, source.ID).Exist(new(Source))
|
||||
if err != nil {
|
||||
return err
|
||||
} else if has {
|
||||
return ErrSourceAlreadyExist{source.Name}
|
||||
}
|
||||
|
||||
_, err = db.GetEngine(ctx).ID(source.ID).AllCols().Update(source)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !source.IsActive {
|
||||
return nil
|
||||
}
|
||||
|
||||
source.Cfg.SetAuthSource(source)
|
||||
|
||||
registerableSource, ok := source.Cfg.(RegisterableSource)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = registerableSource.RegisterSource()
|
||||
if err != nil {
|
||||
// restore original values since we cannot update the provider itself
|
||||
if _, err := db.GetEngine(ctx).ID(source.ID).AllCols().Update(originalSource); err != nil {
|
||||
log.Error("UpdateSource: Error while wrapOpenIDConnectInitializeError: %v", err)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// ErrSourceNotExist represents a "SourceNotExist" kind of error.
|
||||
type ErrSourceNotExist struct {
|
||||
ID int64
|
||||
}
|
||||
|
||||
// IsErrSourceNotExist checks if an error is a ErrSourceNotExist.
|
||||
func IsErrSourceNotExist(err error) bool {
|
||||
_, ok := err.(ErrSourceNotExist)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrSourceNotExist) Error() string {
|
||||
return fmt.Sprintf("login source does not exist [id: %d]", err.ID)
|
||||
}
|
||||
|
||||
// Unwrap unwraps this as a ErrNotExist err
|
||||
func (err ErrSourceNotExist) Unwrap() error {
|
||||
return util.ErrNotExist
|
||||
}
|
||||
|
||||
// ErrSourceAlreadyExist represents a "SourceAlreadyExist" kind of error.
|
||||
type ErrSourceAlreadyExist struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
// IsErrSourceAlreadyExist checks if an error is a ErrSourceAlreadyExist.
|
||||
func IsErrSourceAlreadyExist(err error) bool {
|
||||
_, ok := err.(ErrSourceAlreadyExist)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrSourceAlreadyExist) Error() string {
|
||||
return fmt.Sprintf("login source already exists [name: %s]", err.Name)
|
||||
}
|
||||
|
||||
// Unwrap unwraps this as a ErrExist err
|
||||
func (err ErrSourceAlreadyExist) Unwrap() error {
|
||||
return util.ErrAlreadyExist
|
||||
}
|
||||
|
||||
// ErrSourceInUse represents a "SourceInUse" kind of error.
|
||||
type ErrSourceInUse struct {
|
||||
ID int64
|
||||
}
|
||||
|
||||
// IsErrSourceInUse checks if an error is a ErrSourceInUse.
|
||||
func IsErrSourceInUse(err error) bool {
|
||||
_, ok := err.(ErrSourceInUse)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrSourceInUse) Error() string {
|
||||
return fmt.Sprintf("login source is still used by some users [id: %d]", err.ID)
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
// Copyright 2019 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
auth_model "gitea.dev/models/auth"
|
||||
"gitea.dev/models/unittest"
|
||||
"gitea.dev/modules/json"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"xorm.io/xorm/schemas"
|
||||
)
|
||||
|
||||
type TestSource struct {
|
||||
auth_model.ConfigBase `json:"-"`
|
||||
|
||||
TestField string
|
||||
}
|
||||
|
||||
// FromDB fills up a LDAPConfig from serialized format.
|
||||
func (source *TestSource) FromDB(bs []byte) error {
|
||||
return json.Unmarshal(bs, &source)
|
||||
}
|
||||
|
||||
// ToDB exports a LDAPConfig to a serialized format.
|
||||
func (source *TestSource) ToDB() ([]byte, error) {
|
||||
return json.Marshal(source)
|
||||
}
|
||||
|
||||
func TestDumpAuthSource(t *testing.T) {
|
||||
require.NoError(t, unittest.PrepareTestDatabase())
|
||||
|
||||
authSourceSchema, err := unittest.GetXORMEngine().TableInfo(new(auth_model.Source))
|
||||
require.NoError(t, err)
|
||||
|
||||
auth_model.RegisterTypeConfig(auth_model.OAuth2, new(TestSource))
|
||||
source := &auth_model.Source{
|
||||
Type: auth_model.OAuth2,
|
||||
Name: "TestSource",
|
||||
Cfg: &TestSource{TestField: "TestValue"},
|
||||
}
|
||||
require.NoError(t, auth_model.CreateSource(t.Context(), source))
|
||||
|
||||
// intentionally test the "dump" to make sure the dumped JSON is correct: https://github.com/go-gitea/gitea/pull/16847
|
||||
sb := &strings.Builder{}
|
||||
require.NoError(t, unittest.GetXORMEngine().DumpTables([]*schemas.Table{authSourceSchema}, sb))
|
||||
// the dumped SQL is something like:
|
||||
// INSERT INTO `login_source` (`id`, `type`, `name`, `is_active`, `is_sync_enabled`, `two_factor_policy`, `cfg`, `created_unix`, `updated_unix`) VALUES (1,6,'TestSource',0,0,'','{"TestField":"TestValue"}',1774179784,1774179784);
|
||||
assert.Contains(t, sb.String(), `'{"TestField":"TestValue"}'`)
|
||||
}
|
||||
@@ -0,0 +1,173 @@
|
||||
// Copyright 2017 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/base32"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
|
||||
"gitea.dev/models/db"
|
||||
"gitea.dev/modules/secret"
|
||||
"gitea.dev/modules/setting"
|
||||
"gitea.dev/modules/timeutil"
|
||||
"gitea.dev/modules/util"
|
||||
|
||||
"github.com/pquerna/otp/totp"
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
)
|
||||
|
||||
//
|
||||
// Two-factor authentication
|
||||
//
|
||||
|
||||
// ErrTwoFactorNotEnrolled indicates that a user is not enrolled in two-factor authentication.
|
||||
type ErrTwoFactorNotEnrolled struct {
|
||||
UID int64
|
||||
}
|
||||
|
||||
// IsErrTwoFactorNotEnrolled checks if an error is a ErrTwoFactorNotEnrolled.
|
||||
func IsErrTwoFactorNotEnrolled(err error) bool {
|
||||
_, ok := err.(ErrTwoFactorNotEnrolled)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (err ErrTwoFactorNotEnrolled) Error() string {
|
||||
return fmt.Sprintf("user not enrolled in 2FA [uid: %d]", err.UID)
|
||||
}
|
||||
|
||||
// Unwrap unwraps this as a ErrNotExist err
|
||||
func (err ErrTwoFactorNotEnrolled) Unwrap() error {
|
||||
return util.ErrNotExist
|
||||
}
|
||||
|
||||
// TwoFactor represents a two-factor authentication token.
|
||||
type TwoFactor struct {
|
||||
ID int64 `xorm:"pk autoincr"`
|
||||
UID int64 `xorm:"UNIQUE"`
|
||||
Secret string
|
||||
ScratchSalt string
|
||||
ScratchHash string
|
||||
LastUsedPasscode string `xorm:"VARCHAR(10)"`
|
||||
CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"`
|
||||
UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
db.RegisterModel(new(TwoFactor))
|
||||
}
|
||||
|
||||
// GenerateScratchToken recreates the scratch token the user is using.
|
||||
func (t *TwoFactor) GenerateScratchToken() (string, error) {
|
||||
tokenBytes := util.CryptoRandomBytes(6)
|
||||
// these chars are specially chosen, avoid ambiguous chars like `0`, `O`, `1`, `I`.
|
||||
const base32Chars = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789"
|
||||
token := base32.NewEncoding(base32Chars).WithPadding(base32.NoPadding).EncodeToString(tokenBytes)
|
||||
t.ScratchSalt = util.CryptoRandomString(10)
|
||||
t.ScratchHash = HashToken(token, t.ScratchSalt)
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// HashToken return the hashable salt
|
||||
func HashToken(token, salt string) string {
|
||||
tempHash := pbkdf2.Key([]byte(token), []byte(salt), 10000, 50, sha256.New)
|
||||
return hex.EncodeToString(tempHash)
|
||||
}
|
||||
|
||||
// VerifyScratchToken verifies if the specified scratch token is valid.
|
||||
func (t *TwoFactor) VerifyScratchToken(token string) bool {
|
||||
if len(token) == 0 {
|
||||
return false
|
||||
}
|
||||
tempHash := HashToken(token, t.ScratchSalt)
|
||||
return subtle.ConstantTimeCompare([]byte(t.ScratchHash), []byte(tempHash)) == 1
|
||||
}
|
||||
|
||||
func (t *TwoFactor) getEncryptionKey() []byte {
|
||||
k := md5.Sum([]byte(setting.SecretKey))
|
||||
return k[:]
|
||||
}
|
||||
|
||||
// SetSecret sets the 2FA secret.
|
||||
func (t *TwoFactor) SetSecret(secretString string) error {
|
||||
secretBytes, err := secret.AesEncrypt(t.getEncryptionKey(), []byte(secretString))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.Secret = base64.StdEncoding.EncodeToString(secretBytes)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateTOTP validates the provided passcode.
|
||||
func (t *TwoFactor) ValidateTOTP(passcode string) (bool, error) {
|
||||
decodedStoredSecret, err := base64.StdEncoding.DecodeString(t.Secret)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("ValidateTOTP invalid base64: %w", err)
|
||||
}
|
||||
secretBytes, err := secret.AesDecrypt(t.getEncryptionKey(), decodedStoredSecret)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("ValidateTOTP unable to decrypt (maybe SECRET_KEY is wrong): %w", err)
|
||||
}
|
||||
secretStr := string(secretBytes)
|
||||
return totp.Validate(passcode, secretStr), nil
|
||||
}
|
||||
|
||||
// NewTwoFactor creates a new two-factor authentication token.
|
||||
func NewTwoFactor(ctx context.Context, t *TwoFactor) error {
|
||||
_, err := db.GetEngine(ctx).Insert(t)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateTwoFactor updates a two-factor authentication token.
|
||||
func UpdateTwoFactor(ctx context.Context, t *TwoFactor) error {
|
||||
_, err := db.GetEngine(ctx).ID(t.ID).AllCols().Update(t)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetTwoFactorByUID returns the two-factor authentication token associated with
|
||||
// the user, if any.
|
||||
func GetTwoFactorByUID(ctx context.Context, uid int64) (*TwoFactor, error) {
|
||||
twofa := &TwoFactor{}
|
||||
has, err := db.GetEngine(ctx).Where("uid=?", uid).Get(twofa)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if !has {
|
||||
return nil, ErrTwoFactorNotEnrolled{uid}
|
||||
}
|
||||
return twofa, nil
|
||||
}
|
||||
|
||||
// HasTwoFactorByUID returns the two-factor authentication token associated with
|
||||
// the user, if any.
|
||||
func HasTwoFactorByUID(ctx context.Context, uid int64) (bool, error) {
|
||||
return db.GetEngine(ctx).Where("uid=?", uid).Exist(&TwoFactor{})
|
||||
}
|
||||
|
||||
// DeleteTwoFactorByID deletes two-factor authentication token by given ID.
|
||||
func DeleteTwoFactorByID(ctx context.Context, id, userID int64) error {
|
||||
cnt, err := db.GetEngine(ctx).ID(id).Delete(&TwoFactor{
|
||||
UID: userID,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
} else if cnt != 1 {
|
||||
return ErrTwoFactorNotEnrolled{userID}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func HasTwoFactorOrWebAuthn(ctx context.Context, id int64) (bool, error) {
|
||||
has, err := HasTwoFactorByUID(ctx, id)
|
||||
if err != nil {
|
||||
return false, err
|
||||
} else if has {
|
||||
return true, nil
|
||||
}
|
||||
return HasWebAuthnRegistrationsByUID(ctx, id)
|
||||
}
|
||||
@@ -0,0 +1,202 @@
|
||||
// Copyright 2020 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"gitea.dev/models/db"
|
||||
"gitea.dev/modules/timeutil"
|
||||
"gitea.dev/modules/util"
|
||||
|
||||
"github.com/go-webauthn/webauthn/protocol"
|
||||
"github.com/go-webauthn/webauthn/webauthn"
|
||||
)
|
||||
|
||||
// ErrWebAuthnCredentialNotExist represents a "ErrWebAuthnCRedentialNotExist" kind of error.
|
||||
type ErrWebAuthnCredentialNotExist struct {
|
||||
ID int64
|
||||
CredentialID []byte
|
||||
}
|
||||
|
||||
func (err ErrWebAuthnCredentialNotExist) Error() string {
|
||||
if len(err.CredentialID) == 0 {
|
||||
return fmt.Sprintf("WebAuthn credential does not exist [id: %d]", err.ID)
|
||||
}
|
||||
return fmt.Sprintf("WebAuthn credential does not exist [credential_id: %x]", err.CredentialID)
|
||||
}
|
||||
|
||||
// Unwrap unwraps this as a ErrNotExist err
|
||||
func (err ErrWebAuthnCredentialNotExist) Unwrap() error {
|
||||
return util.ErrNotExist
|
||||
}
|
||||
|
||||
// IsErrWebAuthnCredentialNotExist checks if an error is a ErrWebAuthnCredentialNotExist.
|
||||
func IsErrWebAuthnCredentialNotExist(err error) bool {
|
||||
_, ok := err.(ErrWebAuthnCredentialNotExist)
|
||||
return ok
|
||||
}
|
||||
|
||||
// WebAuthnCredential represents the WebAuthn credential data for a public-key
|
||||
// credential conformant to WebAuthn Level 1
|
||||
type WebAuthnCredential struct {
|
||||
ID int64 `xorm:"pk autoincr"`
|
||||
Name string
|
||||
LowerName string `xorm:"unique(s)"`
|
||||
UserID int64 `xorm:"INDEX unique(s)"`
|
||||
CredentialID []byte `xorm:"INDEX VARBINARY(1024)"`
|
||||
PublicKey []byte
|
||||
AttestationType string
|
||||
AAGUID []byte
|
||||
SignCount uint32 `xorm:"BIGINT"`
|
||||
CloneWarning bool
|
||||
CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"`
|
||||
UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
db.RegisterModel(new(WebAuthnCredential))
|
||||
}
|
||||
|
||||
// TableName returns a better table name for WebAuthnCredential
|
||||
func (cred WebAuthnCredential) TableName() string {
|
||||
return "webauthn_credential"
|
||||
}
|
||||
|
||||
// UpdateSignCount will update the database value of SignCount
|
||||
func (cred *WebAuthnCredential) UpdateSignCount(ctx context.Context) error {
|
||||
_, err := db.GetEngine(ctx).ID(cred.ID).Cols("sign_count").Update(cred)
|
||||
return err
|
||||
}
|
||||
|
||||
// BeforeInsert will be invoked by XORM before updating a record
|
||||
func (cred *WebAuthnCredential) BeforeInsert() {
|
||||
cred.LowerName = strings.ToLower(cred.Name)
|
||||
}
|
||||
|
||||
// BeforeUpdate will be invoked by XORM before updating a record
|
||||
func (cred *WebAuthnCredential) BeforeUpdate() {
|
||||
cred.LowerName = strings.ToLower(cred.Name)
|
||||
}
|
||||
|
||||
// AfterLoad is invoked from XORM after setting the values of all fields of this object.
|
||||
func (cred *WebAuthnCredential) AfterLoad() {
|
||||
cred.LowerName = strings.ToLower(cred.Name)
|
||||
}
|
||||
|
||||
// WebAuthnCredentialList is a list of *WebAuthnCredential
|
||||
type WebAuthnCredentialList []*WebAuthnCredential
|
||||
|
||||
// newCredentialFlagsFromAuthenticatorFlags is copied from https://github.com/go-webauthn/webauthn/pull/337
|
||||
// to convert protocol.AuthenticatorFlags to webauthn.CredentialFlags
|
||||
func newCredentialFlagsFromAuthenticatorFlags(flags protocol.AuthenticatorFlags) webauthn.CredentialFlags {
|
||||
return webauthn.CredentialFlags{
|
||||
UserPresent: flags.HasUserPresent(),
|
||||
UserVerified: flags.HasUserVerified(),
|
||||
BackupEligible: flags.HasBackupEligible(),
|
||||
BackupState: flags.HasBackupState(),
|
||||
}
|
||||
}
|
||||
|
||||
// ToCredentials will convert all WebAuthnCredentials to webauthn.Credentials
|
||||
func (list WebAuthnCredentialList) ToCredentials(defaultAuthFlags ...protocol.AuthenticatorFlags) []webauthn.Credential {
|
||||
// TODO: at the moment, Gitea doesn't store or check the flags
|
||||
// so we need to use the default flags from the authenticator to make the login validation pass
|
||||
// In the future, we should:
|
||||
// 1. store the flags when registering the credential
|
||||
// 2. provide the stored flags when converting the credentials (for login)
|
||||
// 3. for old users, still use this fallback to the default flags
|
||||
defAuthFlags := util.OptionalArg(defaultAuthFlags)
|
||||
creds := make([]webauthn.Credential, 0, len(list))
|
||||
for _, cred := range list {
|
||||
creds = append(creds, webauthn.Credential{
|
||||
ID: cred.CredentialID,
|
||||
PublicKey: cred.PublicKey,
|
||||
AttestationType: cred.AttestationType,
|
||||
Flags: newCredentialFlagsFromAuthenticatorFlags(defAuthFlags),
|
||||
Authenticator: webauthn.Authenticator{
|
||||
AAGUID: cred.AAGUID,
|
||||
SignCount: cred.SignCount,
|
||||
CloneWarning: cred.CloneWarning,
|
||||
},
|
||||
})
|
||||
}
|
||||
return creds
|
||||
}
|
||||
|
||||
// GetWebAuthnCredentialsByUID returns all WebAuthn credentials of the given user
|
||||
func GetWebAuthnCredentialsByUID(ctx context.Context, uid int64) (WebAuthnCredentialList, error) {
|
||||
creds := make(WebAuthnCredentialList, 0)
|
||||
return creds, db.GetEngine(ctx).Where("user_id = ?", uid).Find(&creds)
|
||||
}
|
||||
|
||||
// ExistsWebAuthnCredentialsForUID returns if the given user has credentials
|
||||
func ExistsWebAuthnCredentialsForUID(ctx context.Context, uid int64) (bool, error) {
|
||||
return db.GetEngine(ctx).Where("user_id = ?", uid).Exist(&WebAuthnCredential{})
|
||||
}
|
||||
|
||||
// GetWebAuthnCredentialByName returns WebAuthn credential by id
|
||||
func GetWebAuthnCredentialByName(ctx context.Context, uid int64, name string) (*WebAuthnCredential, error) {
|
||||
cred := new(WebAuthnCredential)
|
||||
if found, err := db.GetEngine(ctx).Where("user_id = ? AND lower_name = ?", uid, strings.ToLower(name)).Get(cred); err != nil {
|
||||
return nil, err
|
||||
} else if !found {
|
||||
return nil, ErrWebAuthnCredentialNotExist{}
|
||||
}
|
||||
return cred, nil
|
||||
}
|
||||
|
||||
// GetWebAuthnCredentialByID returns WebAuthn credential by id
|
||||
func GetWebAuthnCredentialByID(ctx context.Context, id int64) (*WebAuthnCredential, error) {
|
||||
cred := new(WebAuthnCredential)
|
||||
if found, err := db.GetEngine(ctx).ID(id).Get(cred); err != nil {
|
||||
return nil, err
|
||||
} else if !found {
|
||||
return nil, ErrWebAuthnCredentialNotExist{ID: id}
|
||||
}
|
||||
return cred, nil
|
||||
}
|
||||
|
||||
// HasWebAuthnRegistrationsByUID returns whether a given user has WebAuthn registrations
|
||||
func HasWebAuthnRegistrationsByUID(ctx context.Context, uid int64) (bool, error) {
|
||||
return db.GetEngine(ctx).Where("user_id = ?", uid).Exist(&WebAuthnCredential{})
|
||||
}
|
||||
|
||||
// GetWebAuthnCredentialByCredID returns WebAuthn credential by credential ID
|
||||
func GetWebAuthnCredentialByCredID(ctx context.Context, userID int64, credID []byte) (*WebAuthnCredential, error) {
|
||||
cred := new(WebAuthnCredential)
|
||||
if found, err := db.GetEngine(ctx).Where("user_id = ? AND credential_id = ?", userID, credID).Get(cred); err != nil {
|
||||
return nil, err
|
||||
} else if !found {
|
||||
return nil, ErrWebAuthnCredentialNotExist{CredentialID: credID}
|
||||
}
|
||||
return cred, nil
|
||||
}
|
||||
|
||||
// CreateCredential will create a new WebAuthnCredential from the given Credential
|
||||
func CreateCredential(ctx context.Context, userID int64, name string, cred *webauthn.Credential) (*WebAuthnCredential, error) {
|
||||
c := &WebAuthnCredential{
|
||||
UserID: userID,
|
||||
Name: name,
|
||||
CredentialID: cred.ID,
|
||||
PublicKey: cred.PublicKey,
|
||||
AttestationType: cred.AttestationType,
|
||||
AAGUID: cred.Authenticator.AAGUID,
|
||||
SignCount: cred.Authenticator.SignCount,
|
||||
CloneWarning: false,
|
||||
}
|
||||
|
||||
if err := db.Insert(ctx, c); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// DeleteCredential will delete WebAuthnCredential
|
||||
func DeleteCredential(ctx context.Context, id, userID int64) (bool, error) {
|
||||
had, err := db.GetEngine(ctx).ID(id).Where("user_id = ?", userID).Delete(&WebAuthnCredential{})
|
||||
return had > 0, err
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
// Copyright 2020 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
auth_model "gitea.dev/models/auth"
|
||||
"gitea.dev/models/unittest"
|
||||
|
||||
"github.com/go-webauthn/webauthn/webauthn"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGetWebAuthnCredentialByID(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
|
||||
res, err := auth_model.GetWebAuthnCredentialByID(t.Context(), 1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "WebAuthn credential", res.Name)
|
||||
|
||||
_, err = auth_model.GetWebAuthnCredentialByID(t.Context(), 342432)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, auth_model.IsErrWebAuthnCredentialNotExist(err))
|
||||
}
|
||||
|
||||
func TestGetWebAuthnCredentialsByUID(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
|
||||
res, err := auth_model.GetWebAuthnCredentialsByUID(t.Context(), 32)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, res, 1)
|
||||
assert.Equal(t, "WebAuthn credential", res[0].Name)
|
||||
}
|
||||
|
||||
func TestWebAuthnCredential_TableName(t *testing.T) {
|
||||
assert.Equal(t, "webauthn_credential", auth_model.WebAuthnCredential{}.TableName())
|
||||
}
|
||||
|
||||
func TestWebAuthnCredential_UpdateSignCount(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
cred := unittest.AssertExistsAndLoadBean(t, &auth_model.WebAuthnCredential{ID: 1})
|
||||
cred.SignCount = 1
|
||||
assert.NoError(t, cred.UpdateSignCount(t.Context()))
|
||||
unittest.AssertExistsAndLoadBean(t, &auth_model.WebAuthnCredential{ID: 1, SignCount: 1})
|
||||
}
|
||||
|
||||
func TestWebAuthnCredential_UpdateLargeCounter(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
cred := unittest.AssertExistsAndLoadBean(t, &auth_model.WebAuthnCredential{ID: 1})
|
||||
cred.SignCount = 0xffffffff
|
||||
assert.NoError(t, cred.UpdateSignCount(t.Context()))
|
||||
unittest.AssertExistsAndLoadBean(t, &auth_model.WebAuthnCredential{ID: 1, SignCount: 0xffffffff})
|
||||
}
|
||||
|
||||
func TestCreateCredential(t *testing.T) {
|
||||
assert.NoError(t, unittest.PrepareTestDatabase())
|
||||
|
||||
res, err := auth_model.CreateCredential(t.Context(), 1, "WebAuthn Created Credential", &webauthn.Credential{ID: []byte("Test")})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "WebAuthn Created Credential", res.Name)
|
||||
assert.Equal(t, []byte("Test"), res.CredentialID)
|
||||
|
||||
unittest.AssertExistsAndLoadBean(t, &auth_model.WebAuthnCredential{Name: "WebAuthn Created Credential", UserID: 1})
|
||||
}
|
||||
Reference in New Issue
Block a user