Feat: boot init provider

pull/21/head
zijiren233 2 years ago
parent 4e8c9e44b4
commit 7e1c6ab5d8

@ -28,6 +28,7 @@ var ServerCmd = &cobra.Command{
bootstrap.InitLog,
bootstrap.InitGinMode,
bootstrap.InitDatabase,
bootstrap.InitProvider,
bootstrap.InitOp,
bootstrap.InitRtmp,
bootstrap.InitRoom,

@ -28,6 +28,7 @@ require (
github.com/zijiren233/stream v0.5.1
github.com/zijiren233/yaml-comment v0.2.0
golang.org/x/crypto v0.14.0
golang.org/x/exp v0.0.0-20231006140011-7918f672742d
golang.org/x/oauth2 v0.13.0
google.golang.org/protobuf v1.31.0
gopkg.in/yaml.v3 v3.0.1
@ -77,7 +78,6 @@ require (
github.com/ugorji/go/codec v1.2.11 // indirect
go.uber.org/mock v0.3.0 // indirect
golang.org/x/arch v0.5.0 // indirect
golang.org/x/exp v0.0.0-20231006140011-7918f672742d // indirect
golang.org/x/mod v0.13.0 // indirect
golang.org/x/net v0.17.0 // indirect
golang.org/x/sys v0.13.0 // indirect

@ -0,0 +1,18 @@
package bootstrap
import (
"context"
"github.com/synctv-org/synctv/internal/conf"
"github.com/synctv-org/synctv/internal/provider"
)
func InitProvider(ctx context.Context) error {
for op, v := range conf.Conf.OAuth2 {
err := provider.InitProvider(op, v.ClientID, v.ClientSecret)
if err != nil {
return err
}
}
return nil
}

@ -14,7 +14,7 @@ type OAuth2ProviderConfig struct {
func DefaultOAuth2Config() OAuth2Config {
return OAuth2Config{
provider.GithubProvider{}.Provider(): {
(&provider.GithubProvider{}).Provider(): {
ClientID: "github_client_id",
ClientSecret: "github_client_secret",
},

@ -10,22 +10,29 @@ import (
"golang.org/x/oauth2/github"
)
type GithubProvider struct{}
type GithubProvider struct {
ClientID, ClientSecret string
}
func (p *GithubProvider) Init(ClientID, ClientSecret string) {
p.ClientID = ClientID
p.ClientSecret = ClientSecret
}
func (p GithubProvider) Provider() OAuth2Provider {
func (p *GithubProvider) Provider() OAuth2Provider {
return "github"
}
func (p GithubProvider) NewConfig(ClientID, ClientSecret string) *oauth2.Config {
func (p *GithubProvider) NewConfig() *oauth2.Config {
return &oauth2.Config{
ClientID: ClientID,
ClientSecret: ClientSecret,
ClientID: p.ClientID,
ClientSecret: p.ClientSecret,
Scopes: []string{"user"},
Endpoint: github.Endpoint,
}
}
func (p GithubProvider) GetUserInfo(ctx context.Context, config *oauth2.Config, code string) (*UserInfo, error) {
func (p *GithubProvider) GetUserInfo(ctx context.Context, config *oauth2.Config, code string) (*UserInfo, error) {
oauth2Token, err := config.Exchange(ctx, code)
if err != nil {
return nil, err
@ -51,10 +58,6 @@ func (p GithubProvider) GetUserInfo(ctx context.Context, config *oauth2.Config,
}, nil
}
func init() {
RegisterProvider(GithubProvider{})
}
type githubUserInfo struct {
Login string `json:"login"`
ID uint `json:"id"`
@ -102,3 +105,7 @@ type Plan struct {
Collaborators int `json:"collaborators"`
PrivateRepos int `json:"private_repos"`
}
func init() {
registerProvider(new(GithubProvider))
}

@ -8,22 +8,29 @@ import (
"golang.org/x/oauth2/gitlab"
)
type GitlabProvider struct{}
type GitlabProvider struct {
ClientID, ClientSecret string
}
func (g *GitlabProvider) Init(ClientID, ClientSecret string) {
g.ClientID = ClientID
g.ClientSecret = ClientSecret
}
func (g GitlabProvider) Provider() OAuth2Provider {
func (g *GitlabProvider) Provider() OAuth2Provider {
return "gitlab"
}
func (g GitlabProvider) NewConfig(ClientID, ClientSecret string) *oauth2.Config {
func (g *GitlabProvider) NewConfig() *oauth2.Config {
return &oauth2.Config{
ClientID: ClientID,
ClientSecret: ClientSecret,
ClientID: g.ClientID,
ClientSecret: g.ClientSecret,
Scopes: []string{"read_user"},
Endpoint: gitlab.Endpoint,
}
}
func (g GitlabProvider) GetUserInfo(ctx context.Context, config *oauth2.Config, code string) (*UserInfo, error) {
func (g *GitlabProvider) GetUserInfo(ctx context.Context, config *oauth2.Config, code string) (*UserInfo, error) {
oauth2Token, err := config.Exchange(ctx, code)
if err != nil {
return nil, err
@ -42,5 +49,5 @@ func (g GitlabProvider) GetUserInfo(ctx context.Context, config *oauth2.Config,
}
func init() {
RegisterProvider(GitlabProvider{})
registerProvider(new(GitlabProvider))
}

@ -8,22 +8,29 @@ import (
"golang.org/x/oauth2/google"
)
type GoogleProvider struct{}
type GoogleProvider struct {
ClientID, ClientSecret string
}
func (g *GoogleProvider) Init(ClientID, ClientSecret string) {
g.ClientID = ClientID
g.ClientSecret = ClientSecret
}
func (g GoogleProvider) Provider() OAuth2Provider {
func (g *GoogleProvider) Provider() OAuth2Provider {
return "google"
}
func (g GoogleProvider) NewConfig(ClientID, ClientSecret string) *oauth2.Config {
func (g *GoogleProvider) NewConfig() *oauth2.Config {
return &oauth2.Config{
ClientID: ClientID,
ClientSecret: ClientSecret,
ClientID: g.ClientID,
ClientSecret: g.ClientSecret,
Scopes: []string{"profile"},
Endpoint: google.Endpoint,
}
}
func (g GoogleProvider) GetUserInfo(ctx context.Context, config *oauth2.Config, code string) (*UserInfo, error) {
func (g *GoogleProvider) GetUserInfo(ctx context.Context, config *oauth2.Config, code string) (*UserInfo, error) {
oauth2Token, err := config.Exchange(ctx, code)
if err != nil {
return nil, err
@ -42,5 +49,5 @@ func (g GoogleProvider) GetUserInfo(ctx context.Context, config *oauth2.Config,
}
func init() {
RegisterProvider(GoogleProvider{})
registerProvider(new(GoogleProvider))
}

@ -3,7 +3,6 @@ package provider
import (
"context"
"fmt"
"sync"
"golang.org/x/oauth2"
)
@ -11,8 +10,8 @@ import (
type OAuth2Provider string
var (
providers = make(map[OAuth2Provider]ProviderInterface)
lock sync.Mutex
enabledProviders map[OAuth2Provider]ProviderInterface
allowedProviders = make(map[OAuth2Provider]ProviderInterface)
)
type UserInfo struct {
@ -21,27 +20,47 @@ type UserInfo struct {
}
type ProviderInterface interface {
Init(ClientID, ClientSecret string)
Provider() OAuth2Provider
NewConfig(ClientID, ClientSecret string) *oauth2.Config
NewConfig() *oauth2.Config
GetUserInfo(ctx context.Context, config *oauth2.Config, code string) (*UserInfo, error)
}
func RegisterProvider(provider ProviderInterface) {
lock.Lock()
defer lock.Unlock()
providers[provider.Provider()] = provider
func InitProvider(p OAuth2Provider, ClientID, ClientSecret string) error {
pi, ok := allowedProviders[p]
if !ok {
return FormatErrNotImplemented(p)
}
pi.Init(ClientID, ClientSecret)
if enabledProviders == nil {
enabledProviders = make(map[OAuth2Provider]ProviderInterface)
}
enabledProviders[pi.Provider()] = pi
return nil
}
func registerProvider(ps ...ProviderInterface) {
for _, p := range ps {
allowedProviders[p.Provider()] = p
}
}
func (p OAuth2Provider) GetProvider() (ProviderInterface, error) {
lock.Lock()
defer lock.Unlock()
pi, ok := providers[p]
func GetProvider(p OAuth2Provider) (ProviderInterface, error) {
pi, ok := enabledProviders[p]
if !ok {
return nil, FormatErrNotImplemented(p)
}
return pi, nil
}
func AllowedProvider() map[OAuth2Provider]ProviderInterface {
return allowedProviders
}
func EnabledProvider() map[OAuth2Provider]ProviderInterface {
return enabledProviders
}
type FormatErrNotImplemented string
func (f FormatErrNotImplemented) Error() string {

@ -5,7 +5,6 @@ import (
"time"
"github.com/gin-gonic/gin"
"github.com/synctv-org/synctv/internal/conf"
"github.com/synctv-org/synctv/internal/op"
"github.com/synctv-org/synctv/internal/provider"
"github.com/synctv-org/synctv/server/middlewares"
@ -16,33 +15,24 @@ import (
// /oauth2/login/:type
func OAuth2(ctx *gin.Context) {
t := ctx.Param("type")
p := provider.OAuth2Provider(t)
c, ok := conf.Conf.OAuth2[p]
if !ok {
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 provider"))
}
p := provider.OAuth2Provider(ctx.Param("type"))
pi, err := p.GetProvider()
pi, err := provider.GetProvider(p)
if err != nil {
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err))
return
}
state := utils.RandString(16)
states.Store(state, struct{}{}, time.Minute*5)
RenderRedirect(ctx, pi.NewConfig(c.ClientID, c.ClientSecret).AuthCodeURL(state, oauth2.AccessTypeOnline))
RenderRedirect(ctx, pi.NewConfig().AuthCodeURL(state, oauth2.AccessTypeOnline))
}
func OAuth2Api(ctx *gin.Context) {
t := ctx.Param("type")
p := provider.OAuth2Provider(t)
c, ok := conf.Conf.OAuth2[p]
if !ok {
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 provider"))
}
p := provider.OAuth2Provider(ctx.Param("type"))
pi, err := p.GetProvider()
pi, err := provider.GetProvider(p)
if err != nil {
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err))
}
@ -51,18 +41,13 @@ func OAuth2Api(ctx *gin.Context) {
states.Store(state, struct{}{}, time.Minute*5)
ctx.JSON(http.StatusOK, model.NewApiDataResp(gin.H{
"url": pi.NewConfig(c.ClientID, c.ClientSecret).AuthCodeURL(state, oauth2.AccessTypeOnline),
"url": pi.NewConfig().AuthCodeURL(state, oauth2.AccessTypeOnline),
}))
}
// /oauth2/callback/:type
func OAuth2Callback(ctx *gin.Context) {
t := ctx.Param("type")
p := provider.OAuth2Provider(t)
c, ok := conf.Conf.OAuth2[p]
if !ok {
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 provider"))
}
p := provider.OAuth2Provider(ctx.Param("type"))
code := ctx.Query("code")
if code == "" {
@ -82,12 +67,12 @@ func OAuth2Callback(ctx *gin.Context) {
return
}
pi, err := p.GetProvider()
pi, err := provider.GetProvider(p)
if err != nil {
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err))
}
ui, err := pi.GetUserInfo(ctx, pi.NewConfig(c.ClientID, c.ClientSecret), code)
ui, err := pi.GetUserInfo(ctx, pi.NewConfig(), code)
if err != nil {
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err))
return
@ -110,12 +95,7 @@ func OAuth2Callback(ctx *gin.Context) {
// /oauth2/callback/:type
func OAuth2CallbackApi(ctx *gin.Context) {
t := ctx.Param("type")
p := provider.OAuth2Provider(t)
c, ok := conf.Conf.OAuth2[p]
if !ok {
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 provider"))
}
p := provider.OAuth2Provider(ctx.Param("type"))
req := model.OAuth2CallbackReq{}
if err := req.Decode(ctx); err != nil {
@ -129,12 +109,12 @@ func OAuth2CallbackApi(ctx *gin.Context) {
return
}
pi, err := p.GetProvider()
pi, err := provider.GetProvider(p)
if err != nil {
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err))
}
ui, err := pi.GetUserInfo(ctx, pi.NewConfig(c.ClientID, c.ClientSecret), req.Code)
ui, err := pi.GetUserInfo(ctx, pi.NewConfig(), req.Code)
if err != nil {
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err))
return

@ -2,12 +2,12 @@ package auth
import (
"github.com/gin-gonic/gin"
"github.com/synctv-org/synctv/internal/conf"
"github.com/synctv-org/synctv/internal/provider"
"golang.org/x/exp/maps"
)
func OAuth2EnabledApi(ctx *gin.Context) {
ctx.JSON(200, gin.H{
"enabled": maps.Keys(conf.Conf.OAuth2),
"enabled": maps.Keys(provider.EnabledProvider()),
})
}

Loading…
Cancel
Save