diff --git a/cmd/server.go b/cmd/server.go index 8bf946c..e5a5f91 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -28,6 +28,7 @@ var ServerCmd = &cobra.Command{ bootstrap.InitLog, bootstrap.InitGinMode, bootstrap.InitDatabase, + bootstrap.InitProvider, bootstrap.InitOp, bootstrap.InitRtmp, bootstrap.InitRoom, diff --git a/go.mod b/go.mod index ea7f254..c8ed759 100644 --- a/go.mod +++ b/go.mod @@ -28,6 +28,7 @@ require ( github.com/zijiren233/stream v0.5.1 github.com/zijiren233/yaml-comment v0.2.0 golang.org/x/crypto v0.14.0 + golang.org/x/exp v0.0.0-20231006140011-7918f672742d golang.org/x/oauth2 v0.13.0 google.golang.org/protobuf v1.31.0 gopkg.in/yaml.v3 v3.0.1 @@ -77,7 +78,6 @@ require ( github.com/ugorji/go/codec v1.2.11 // indirect go.uber.org/mock v0.3.0 // indirect golang.org/x/arch v0.5.0 // indirect - golang.org/x/exp v0.0.0-20231006140011-7918f672742d // indirect golang.org/x/mod v0.13.0 // indirect golang.org/x/net v0.17.0 // indirect golang.org/x/sys v0.13.0 // indirect diff --git a/internal/bootstrap/provider.go b/internal/bootstrap/provider.go new file mode 100644 index 0000000..4152e3a --- /dev/null +++ b/internal/bootstrap/provider.go @@ -0,0 +1,18 @@ +package bootstrap + +import ( + "context" + + "github.com/synctv-org/synctv/internal/conf" + "github.com/synctv-org/synctv/internal/provider" +) + +func InitProvider(ctx context.Context) error { + for op, v := range conf.Conf.OAuth2 { + err := provider.InitProvider(op, v.ClientID, v.ClientSecret) + if err != nil { + return err + } + } + return nil +} diff --git a/internal/conf/oauth2.go b/internal/conf/oauth2.go index 86107e2..b5aa485 100644 --- a/internal/conf/oauth2.go +++ b/internal/conf/oauth2.go @@ -14,7 +14,7 @@ type OAuth2ProviderConfig struct { func DefaultOAuth2Config() OAuth2Config { return OAuth2Config{ - provider.GithubProvider{}.Provider(): { + (&provider.GithubProvider{}).Provider(): { ClientID: "github_client_id", ClientSecret: "github_client_secret", }, diff --git a/internal/provider/github.go b/internal/provider/github.go index 594890b..9132c3d 100644 --- a/internal/provider/github.go +++ b/internal/provider/github.go @@ -10,22 +10,29 @@ import ( "golang.org/x/oauth2/github" ) -type GithubProvider struct{} +type GithubProvider struct { + ClientID, ClientSecret string +} + +func (p *GithubProvider) Init(ClientID, ClientSecret string) { + p.ClientID = ClientID + p.ClientSecret = ClientSecret +} -func (p GithubProvider) Provider() OAuth2Provider { +func (p *GithubProvider) Provider() OAuth2Provider { return "github" } -func (p GithubProvider) NewConfig(ClientID, ClientSecret string) *oauth2.Config { +func (p *GithubProvider) NewConfig() *oauth2.Config { return &oauth2.Config{ - ClientID: ClientID, - ClientSecret: ClientSecret, + ClientID: p.ClientID, + ClientSecret: p.ClientSecret, Scopes: []string{"user"}, Endpoint: github.Endpoint, } } -func (p GithubProvider) GetUserInfo(ctx context.Context, config *oauth2.Config, code string) (*UserInfo, error) { +func (p *GithubProvider) GetUserInfo(ctx context.Context, config *oauth2.Config, code string) (*UserInfo, error) { oauth2Token, err := config.Exchange(ctx, code) if err != nil { return nil, err @@ -51,10 +58,6 @@ func (p GithubProvider) GetUserInfo(ctx context.Context, config *oauth2.Config, }, nil } -func init() { - RegisterProvider(GithubProvider{}) -} - type githubUserInfo struct { Login string `json:"login"` ID uint `json:"id"` @@ -102,3 +105,7 @@ type Plan struct { Collaborators int `json:"collaborators"` PrivateRepos int `json:"private_repos"` } + +func init() { + registerProvider(new(GithubProvider)) +} diff --git a/internal/provider/gitlab.go b/internal/provider/gitlab.go index 3e2aff6..2815951 100644 --- a/internal/provider/gitlab.go +++ b/internal/provider/gitlab.go @@ -8,22 +8,29 @@ import ( "golang.org/x/oauth2/gitlab" ) -type GitlabProvider struct{} +type GitlabProvider struct { + ClientID, ClientSecret string +} + +func (g *GitlabProvider) Init(ClientID, ClientSecret string) { + g.ClientID = ClientID + g.ClientSecret = ClientSecret +} -func (g GitlabProvider) Provider() OAuth2Provider { +func (g *GitlabProvider) Provider() OAuth2Provider { return "gitlab" } -func (g GitlabProvider) NewConfig(ClientID, ClientSecret string) *oauth2.Config { +func (g *GitlabProvider) NewConfig() *oauth2.Config { return &oauth2.Config{ - ClientID: ClientID, - ClientSecret: ClientSecret, + ClientID: g.ClientID, + ClientSecret: g.ClientSecret, Scopes: []string{"read_user"}, Endpoint: gitlab.Endpoint, } } -func (g GitlabProvider) GetUserInfo(ctx context.Context, config *oauth2.Config, code string) (*UserInfo, error) { +func (g *GitlabProvider) GetUserInfo(ctx context.Context, config *oauth2.Config, code string) (*UserInfo, error) { oauth2Token, err := config.Exchange(ctx, code) if err != nil { return nil, err @@ -42,5 +49,5 @@ func (g GitlabProvider) GetUserInfo(ctx context.Context, config *oauth2.Config, } func init() { - RegisterProvider(GitlabProvider{}) + registerProvider(new(GitlabProvider)) } diff --git a/internal/provider/google.go b/internal/provider/google.go index 7c2db84..993510c 100644 --- a/internal/provider/google.go +++ b/internal/provider/google.go @@ -8,22 +8,29 @@ import ( "golang.org/x/oauth2/google" ) -type GoogleProvider struct{} +type GoogleProvider struct { + ClientID, ClientSecret string +} + +func (g *GoogleProvider) Init(ClientID, ClientSecret string) { + g.ClientID = ClientID + g.ClientSecret = ClientSecret +} -func (g GoogleProvider) Provider() OAuth2Provider { +func (g *GoogleProvider) Provider() OAuth2Provider { return "google" } -func (g GoogleProvider) NewConfig(ClientID, ClientSecret string) *oauth2.Config { +func (g *GoogleProvider) NewConfig() *oauth2.Config { return &oauth2.Config{ - ClientID: ClientID, - ClientSecret: ClientSecret, + ClientID: g.ClientID, + ClientSecret: g.ClientSecret, Scopes: []string{"profile"}, Endpoint: google.Endpoint, } } -func (g GoogleProvider) GetUserInfo(ctx context.Context, config *oauth2.Config, code string) (*UserInfo, error) { +func (g *GoogleProvider) GetUserInfo(ctx context.Context, config *oauth2.Config, code string) (*UserInfo, error) { oauth2Token, err := config.Exchange(ctx, code) if err != nil { return nil, err @@ -42,5 +49,5 @@ func (g GoogleProvider) GetUserInfo(ctx context.Context, config *oauth2.Config, } func init() { - RegisterProvider(GoogleProvider{}) + registerProvider(new(GoogleProvider)) } diff --git a/internal/provider/provider.go b/internal/provider/provider.go index f549c08..adfb923 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -3,7 +3,6 @@ package provider import ( "context" "fmt" - "sync" "golang.org/x/oauth2" ) @@ -11,8 +10,8 @@ import ( type OAuth2Provider string var ( - providers = make(map[OAuth2Provider]ProviderInterface) - lock sync.Mutex + enabledProviders map[OAuth2Provider]ProviderInterface + allowedProviders = make(map[OAuth2Provider]ProviderInterface) ) type UserInfo struct { @@ -21,27 +20,47 @@ type UserInfo struct { } type ProviderInterface interface { + Init(ClientID, ClientSecret string) Provider() OAuth2Provider - NewConfig(ClientID, ClientSecret string) *oauth2.Config + NewConfig() *oauth2.Config GetUserInfo(ctx context.Context, config *oauth2.Config, code string) (*UserInfo, error) } -func RegisterProvider(provider ProviderInterface) { - lock.Lock() - defer lock.Unlock() - providers[provider.Provider()] = provider +func InitProvider(p OAuth2Provider, ClientID, ClientSecret string) error { + pi, ok := allowedProviders[p] + if !ok { + return FormatErrNotImplemented(p) + } + pi.Init(ClientID, ClientSecret) + if enabledProviders == nil { + enabledProviders = make(map[OAuth2Provider]ProviderInterface) + } + enabledProviders[pi.Provider()] = pi + return nil +} + +func registerProvider(ps ...ProviderInterface) { + for _, p := range ps { + allowedProviders[p.Provider()] = p + } } -func (p OAuth2Provider) GetProvider() (ProviderInterface, error) { - lock.Lock() - defer lock.Unlock() - pi, ok := providers[p] +func GetProvider(p OAuth2Provider) (ProviderInterface, error) { + pi, ok := enabledProviders[p] if !ok { return nil, FormatErrNotImplemented(p) } return pi, nil } +func AllowedProvider() map[OAuth2Provider]ProviderInterface { + return allowedProviders +} + +func EnabledProvider() map[OAuth2Provider]ProviderInterface { + return enabledProviders +} + type FormatErrNotImplemented string func (f FormatErrNotImplemented) Error() string { diff --git a/server/oauth2/auth.go b/server/oauth2/auth.go index 813191b..154efc8 100644 --- a/server/oauth2/auth.go +++ b/server/oauth2/auth.go @@ -5,7 +5,6 @@ import ( "time" "github.com/gin-gonic/gin" - "github.com/synctv-org/synctv/internal/conf" "github.com/synctv-org/synctv/internal/op" "github.com/synctv-org/synctv/internal/provider" "github.com/synctv-org/synctv/server/middlewares" @@ -16,33 +15,24 @@ import ( // /oauth2/login/:type func OAuth2(ctx *gin.Context) { - t := ctx.Param("type") - p := provider.OAuth2Provider(t) - c, ok := conf.Conf.OAuth2[p] - if !ok { - ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 provider")) - } + p := provider.OAuth2Provider(ctx.Param("type")) - pi, err := p.GetProvider() + pi, err := provider.GetProvider(p) if err != nil { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) + return } state := utils.RandString(16) states.Store(state, struct{}{}, time.Minute*5) - RenderRedirect(ctx, pi.NewConfig(c.ClientID, c.ClientSecret).AuthCodeURL(state, oauth2.AccessTypeOnline)) + RenderRedirect(ctx, pi.NewConfig().AuthCodeURL(state, oauth2.AccessTypeOnline)) } func OAuth2Api(ctx *gin.Context) { - t := ctx.Param("type") - p := provider.OAuth2Provider(t) - c, ok := conf.Conf.OAuth2[p] - if !ok { - ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 provider")) - } + p := provider.OAuth2Provider(ctx.Param("type")) - pi, err := p.GetProvider() + pi, err := provider.GetProvider(p) if err != nil { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) } @@ -51,18 +41,13 @@ func OAuth2Api(ctx *gin.Context) { states.Store(state, struct{}{}, time.Minute*5) ctx.JSON(http.StatusOK, model.NewApiDataResp(gin.H{ - "url": pi.NewConfig(c.ClientID, c.ClientSecret).AuthCodeURL(state, oauth2.AccessTypeOnline), + "url": pi.NewConfig().AuthCodeURL(state, oauth2.AccessTypeOnline), })) } // /oauth2/callback/:type func OAuth2Callback(ctx *gin.Context) { - t := ctx.Param("type") - p := provider.OAuth2Provider(t) - c, ok := conf.Conf.OAuth2[p] - if !ok { - ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 provider")) - } + p := provider.OAuth2Provider(ctx.Param("type")) code := ctx.Query("code") if code == "" { @@ -82,12 +67,12 @@ func OAuth2Callback(ctx *gin.Context) { return } - pi, err := p.GetProvider() + pi, err := provider.GetProvider(p) if err != nil { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) } - ui, err := pi.GetUserInfo(ctx, pi.NewConfig(c.ClientID, c.ClientSecret), code) + ui, err := pi.GetUserInfo(ctx, pi.NewConfig(), code) if err != nil { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) return @@ -110,12 +95,7 @@ func OAuth2Callback(ctx *gin.Context) { // /oauth2/callback/:type func OAuth2CallbackApi(ctx *gin.Context) { - t := ctx.Param("type") - p := provider.OAuth2Provider(t) - c, ok := conf.Conf.OAuth2[p] - if !ok { - ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 provider")) - } + p := provider.OAuth2Provider(ctx.Param("type")) req := model.OAuth2CallbackReq{} if err := req.Decode(ctx); err != nil { @@ -129,12 +109,12 @@ func OAuth2CallbackApi(ctx *gin.Context) { return } - pi, err := p.GetProvider() + pi, err := provider.GetProvider(p) if err != nil { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) } - ui, err := pi.GetUserInfo(ctx, pi.NewConfig(c.ClientID, c.ClientSecret), req.Code) + ui, err := pi.GetUserInfo(ctx, pi.NewConfig(), req.Code) if err != nil { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) return diff --git a/server/oauth2/oauth2.go b/server/oauth2/oauth2.go index 00ac190..702c375 100644 --- a/server/oauth2/oauth2.go +++ b/server/oauth2/oauth2.go @@ -2,12 +2,12 @@ package auth import ( "github.com/gin-gonic/gin" - "github.com/synctv-org/synctv/internal/conf" + "github.com/synctv-org/synctv/internal/provider" "golang.org/x/exp/maps" ) func OAuth2EnabledApi(ctx *gin.Context) { ctx.JSON(200, gin.H{ - "enabled": maps.Keys(conf.Conf.OAuth2), + "enabled": maps.Keys(provider.EnabledProvider()), }) }