初始提交: Gitea 项目代码
This commit is contained in:
@@ -0,0 +1,246 @@
|
||||
// Copyright 2021 The Gitea Authors. All rights reserved.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"html"
|
||||
"html/template"
|
||||
"net/url"
|
||||
"slices"
|
||||
"sort"
|
||||
|
||||
"gitea.dev/models/auth"
|
||||
"gitea.dev/models/db"
|
||||
"gitea.dev/modules/log"
|
||||
"gitea.dev/modules/optional"
|
||||
"gitea.dev/modules/setting"
|
||||
|
||||
"github.com/markbates/goth"
|
||||
"github.com/markbates/goth/providers/openidConnect"
|
||||
)
|
||||
|
||||
// Provider is an interface for describing a single OAuth2 provider
|
||||
type Provider interface {
|
||||
Name() string
|
||||
DisplayName() string
|
||||
IconHTML(size int) template.HTML
|
||||
CustomURLSettings() *CustomURLSettings
|
||||
SupportSSHPublicKey() bool
|
||||
}
|
||||
|
||||
// GothProviderCreator provides a function to create a goth.Provider
|
||||
type GothProviderCreator interface {
|
||||
CreateGothProvider(providerName, callbackURL string, source *Source) (goth.Provider, error)
|
||||
}
|
||||
|
||||
// GothProvider is an interface for describing a single OAuth2 provider
|
||||
type GothProvider interface {
|
||||
Provider
|
||||
GothProviderCreator
|
||||
}
|
||||
|
||||
// AuthSourceProvider provides a provider for an AuthSource. Multiple auth sources could use the same registered GothProvider
|
||||
// So each auth source should have its own DisplayName and IconHTML for display.
|
||||
// The Name is the GothProvider's name, to help to find the GothProvider to sign in.
|
||||
// The DisplayName is the auth source config's name, site admin set it on the admin page, the IconURL can also be set there.
|
||||
type AuthSourceProvider struct {
|
||||
GothProvider
|
||||
sourceName, iconURL string
|
||||
}
|
||||
|
||||
func (p *AuthSourceProvider) Name() string {
|
||||
return p.GothProvider.Name()
|
||||
}
|
||||
|
||||
func (p *AuthSourceProvider) DisplayName() string {
|
||||
return p.sourceName
|
||||
}
|
||||
|
||||
func (p *AuthSourceProvider) IconHTML(size int) template.HTML {
|
||||
if p.iconURL != "" {
|
||||
img := fmt.Sprintf(`<img class="tw-object-contain" width="%d" height="%d" src="%s" alt="%s">`,
|
||||
size,
|
||||
size,
|
||||
html.EscapeString(p.iconURL), html.EscapeString(p.DisplayName()),
|
||||
)
|
||||
return template.HTML(img)
|
||||
}
|
||||
return p.GothProvider.IconHTML(size)
|
||||
}
|
||||
|
||||
// Providers contains the map of registered OAuth2 providers in Gitea (based on goth)
|
||||
// key is used to map the OAuth2Provider with the goth provider type (also in AuthSource.OAuth2Config.Provider)
|
||||
// value is used to store display data
|
||||
var gothProviders = map[string]GothProvider{}
|
||||
|
||||
func isAzureProvider(name string) bool {
|
||||
return name == "azuread" || name == "microsoftonline" || name == "azureadv2"
|
||||
}
|
||||
|
||||
// RegisterGothProvider registers a GothProvider
|
||||
func RegisterGothProvider(provider GothProvider) {
|
||||
if _, has := gothProviders[provider.Name()]; has {
|
||||
log.Fatal("Duplicate oauth2provider type provided: %s", provider.Name())
|
||||
}
|
||||
gothProviders[provider.Name()] = provider
|
||||
}
|
||||
|
||||
// getExistingAzureADAuthSources returns a list of Azure AD provider names that are already configured
|
||||
func getExistingAzureADAuthSources(ctx context.Context) ([]string, error) {
|
||||
authSources, err := db.Find[auth.Source](ctx, auth.FindSourcesOptions{
|
||||
LoginType: auth.OAuth2,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var existingAzureProviders []string
|
||||
for _, source := range authSources {
|
||||
if oauth2Cfg, ok := source.Cfg.(*Source); ok {
|
||||
if isAzureProvider(oauth2Cfg.Provider) {
|
||||
existingAzureProviders = append(existingAzureProviders, oauth2Cfg.Provider)
|
||||
}
|
||||
}
|
||||
}
|
||||
return existingAzureProviders, nil
|
||||
}
|
||||
|
||||
// GetSupportedOAuth2Providers returns the list of supported OAuth2 providers with context for filtering
|
||||
// key is used as technical name (like in the callbackURL)
|
||||
// values to display
|
||||
// Note: Azure AD providers (azuread, microsoftonline, azureadv2) are filtered out
|
||||
// unless they already exist in the system to encourage use of OpenID Connect
|
||||
func GetSupportedOAuth2Providers(ctx context.Context) []Provider {
|
||||
providers := make([]Provider, 0, len(gothProviders))
|
||||
existingAzureSources, err := getExistingAzureADAuthSources(ctx)
|
||||
if err != nil {
|
||||
log.Error("Failed to get existing OAuth2 auth sources: %v", err)
|
||||
}
|
||||
|
||||
for _, provider := range gothProviders {
|
||||
if isAzureProvider(provider.Name()) && !slices.Contains(existingAzureSources, provider.Name()) {
|
||||
continue
|
||||
}
|
||||
providers = append(providers, provider)
|
||||
}
|
||||
sort.Slice(providers, func(i, j int) bool {
|
||||
return providers[i].Name() < providers[j].Name()
|
||||
})
|
||||
return providers
|
||||
}
|
||||
|
||||
func CreateProviderFromSource(source *auth.Source) (Provider, error) {
|
||||
oauth2Cfg, ok := source.Cfg.(*Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid OAuth2 source config: %v", oauth2Cfg)
|
||||
}
|
||||
gothProv := gothProviders[oauth2Cfg.Provider]
|
||||
return &AuthSourceProvider{GothProvider: gothProv, sourceName: source.Name, iconURL: oauth2Cfg.IconURL}, nil
|
||||
}
|
||||
|
||||
// GetOAuth2Providers returns the list of configured OAuth2 providers
|
||||
func GetOAuth2Providers(ctx context.Context, isActive optional.Option[bool]) ([]Provider, error) {
|
||||
authSources, err := db.Find[auth.Source](ctx, auth.FindSourcesOptions{
|
||||
IsActive: isActive,
|
||||
LoginType: auth.OAuth2,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
providers := make([]Provider, 0, len(authSources))
|
||||
for _, source := range authSources {
|
||||
provider, err := CreateProviderFromSource(source)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
providers = append(providers, provider)
|
||||
}
|
||||
|
||||
sort.Slice(providers, func(i, j int) bool {
|
||||
return providers[i].Name() < providers[j].Name()
|
||||
})
|
||||
|
||||
return providers, nil
|
||||
}
|
||||
|
||||
// RegisterProviderWithGothic register a OAuth2 provider in goth lib
|
||||
func RegisterProviderWithGothic(providerName string, source *Source) error {
|
||||
provider, err := createProvider(providerName, source)
|
||||
|
||||
if err == nil && provider != nil {
|
||||
gothRWMutex.Lock()
|
||||
defer gothRWMutex.Unlock()
|
||||
|
||||
goth.UseProviders(provider)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// RemoveProviderFromGothic removes the given OAuth2 provider from the goth lib
|
||||
func RemoveProviderFromGothic(providerName string) {
|
||||
gothRWMutex.Lock()
|
||||
defer gothRWMutex.Unlock()
|
||||
|
||||
delete(goth.GetProviders(), providerName)
|
||||
}
|
||||
|
||||
// ClearProviders clears all OAuth2 providers from the goth lib
|
||||
func ClearProviders() {
|
||||
gothRWMutex.Lock()
|
||||
defer gothRWMutex.Unlock()
|
||||
|
||||
goth.ClearProviders()
|
||||
}
|
||||
|
||||
// GetOIDCEndSessionEndpoint returns the OIDC end_session_endpoint for the
|
||||
// given provider name. Returns "" if the provider is not OIDC or doesn't
|
||||
// advertise an end_session_endpoint in its discovery document.
|
||||
func GetOIDCEndSessionEndpoint(providerName string) string {
|
||||
gothRWMutex.RLock()
|
||||
defer gothRWMutex.RUnlock()
|
||||
|
||||
provider, ok := goth.GetProviders()[providerName]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
oidcProvider, ok := provider.(*openidConnect.Provider)
|
||||
if !ok || oidcProvider.OpenIDConfig == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return oidcProvider.OpenIDConfig.EndSessionEndpoint
|
||||
}
|
||||
|
||||
var ErrAuthSourceNotActivated = errors.New("auth source is not activated")
|
||||
|
||||
// used to create different types of goth providers
|
||||
func createProvider(providerName string, source *Source) (goth.Provider, error) {
|
||||
callbackURL := setting.AppURL + "user/oauth2/" + url.PathEscape(providerName) + "/callback"
|
||||
|
||||
var provider goth.Provider
|
||||
var err error
|
||||
|
||||
p, ok := gothProviders[source.Provider]
|
||||
if !ok {
|
||||
return nil, ErrAuthSourceNotActivated
|
||||
}
|
||||
|
||||
provider, err = p.CreateGothProvider(providerName, callbackURL, source)
|
||||
if err != nil {
|
||||
return provider, err
|
||||
}
|
||||
|
||||
// always set the name if provider is created so we can support multiple setups of 1 provider
|
||||
if provider != nil {
|
||||
provider.SetName(providerName)
|
||||
}
|
||||
|
||||
return provider, err
|
||||
}
|
||||
Reference in New Issue
Block a user