diff --git a/server/model/auth.go b/server/model/auth.go new file mode 100644 index 0000000..b4e2c9e --- /dev/null +++ b/server/model/auth.go @@ -0,0 +1,32 @@ +package model + +import ( + "errors" + + "github.com/gin-gonic/gin" + json "github.com/json-iterator/go" +) + +type OAuth2CallbackReq struct { + Code string `json:"code"` + State string `json:"state"` +} + +var ( + ErrInvalidOAuth2Code = errors.New("invalid oauth2 code") + ErrInvalidOAuth2State = errors.New("invalid oauth2 state") +) + +func (o *OAuth2CallbackReq) Validate() error { + if o.Code == "" { + return ErrInvalidOAuth2Code + } + if o.State == "" { + return ErrInvalidOAuth2State + } + return nil +} + +func (o *OAuth2CallbackReq) Decode(ctx *gin.Context) error { + return json.NewDecoder(ctx.Request.Body).Decode(o) +} diff --git a/server/oauth2/auth.go b/server/oauth2/auth.go index 896c1b8..20eeeab 100644 --- a/server/oauth2/auth.go +++ b/server/oauth2/auth.go @@ -2,6 +2,7 @@ package auth import ( "net/http" + "time" "github.com/gin-gonic/gin" "github.com/synctv-org/synctv/internal/conf" @@ -9,6 +10,7 @@ import ( "github.com/synctv-org/synctv/internal/provider" "github.com/synctv-org/synctv/server/middlewares" "github.com/synctv-org/synctv/server/model" + "github.com/synctv-org/synctv/utils" "golang.org/x/oauth2" ) @@ -26,7 +28,31 @@ func OAuth2(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) } - Render(ctx, pi.NewConfig(c.ClientID, c.ClientSecret), oauth2.AccessTypeOnline) + state := utils.RandString(16) + states.Store(state, struct{}{}, time.Minute*5) + + RenderRedirect(ctx, pi.NewConfig(c.ClientID, c.ClientSecret).AuthCodeURL(state, oauth2.AccessTypeOnline)) +} + +func OAuth2Api(ctx *gin.Context) { + t := ctx.Param("type") + p := provider.OAuth2Provider(t) + c, ok := conf.Conf.OAuth2[p] + if !ok { + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 provider")) + } + + pi, err := p.GetProvider() + if err != nil { + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) + } + + state := utils.RandString(16) + states.Store(state, struct{}{}, time.Minute*5) + + ctx.JSON(http.StatusOK, model.NewApiDataResp(gin.H{ + "url": pi.NewConfig(c.ClientID, c.ClientSecret).AuthCodeURL(state, oauth2.AccessTypeOnline), + })) } // /oauth2/callback/:type @@ -38,6 +64,53 @@ func OAuth2Callback(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 provider")) } + req := model.OAuth2CallbackReq{} + if err := req.Decode(ctx); err != nil { + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) + return + } + + _, loaded := states.LoadAndDelete(req.State) + if !loaded { + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 state")) + return + } + + pi, err := p.GetProvider() + if err != nil { + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) + } + + ui, err := pi.GetUserInfo(ctx, pi.NewConfig(c.ClientID, c.ClientSecret), req.Code) + if err != nil { + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) + return + } + + user, err := op.CreateOrLoadUser(ui.Username, p, ui.ProviderUserID) + if err != nil { + ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewApiErrorResp(err)) + return + } + + token, err := middlewares.NewAuthUserToken(user) + if err != nil { + ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewApiErrorResp(err)) + return + } + + RenderToken(ctx, "/web/", token) +} + +// /oauth2/callback/:type +func OAuth2CallbackApi(ctx *gin.Context) { + t := ctx.Param("type") + p := provider.OAuth2Provider(t) + c, ok := conf.Conf.OAuth2[p] + if !ok { + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 provider")) + } + code := ctx.Query("code") if code == "" { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 code")) diff --git a/server/oauth2/init.go b/server/oauth2/init.go index 883640c..3f199fc 100644 --- a/server/oauth2/init.go +++ b/server/oauth2/init.go @@ -6,8 +6,14 @@ func Init(e *gin.Engine) { { auth := e.Group("/oauth2") + auth.GET("/enabled", OAuth2EnabledApi) + auth.GET("/login/:type", OAuth2) + auth.POST("/login/:type", OAuth2Api) + auth.GET("/callback/:type", OAuth2Callback) + + auth.POST("/callback/:type", OAuth2CallbackApi) } } diff --git a/server/oauth2/oauth2.go b/server/oauth2/oauth2.go new file mode 100644 index 0000000..00ac190 --- /dev/null +++ b/server/oauth2/oauth2.go @@ -0,0 +1,13 @@ +package auth + +import ( + "github.com/gin-gonic/gin" + "github.com/synctv-org/synctv/internal/conf" + "golang.org/x/exp/maps" +) + +func OAuth2EnabledApi(ctx *gin.Context) { + ctx.JSON(200, gin.H{ + "enabled": maps.Keys(conf.Conf.OAuth2), + }) +} diff --git a/server/oauth2/render.go b/server/oauth2/render.go index 27c3d1e..0e91fef 100644 --- a/server/oauth2/render.go +++ b/server/oauth2/render.go @@ -6,26 +6,30 @@ import ( "time" "github.com/gin-gonic/gin" - "github.com/synctv-org/synctv/utils" synccache "github.com/synctv-org/synctv/utils/syncCache" - "golang.org/x/oauth2" ) -//go:embed templates/redirect.html +//go:embed templates/*.html var temp embed.FS var ( redirectTemplate *template.Template + tokenTemplate *template.Template states *synccache.SyncCache[string, struct{}] ) -func Render(ctx *gin.Context, c *oauth2.Config, option ...oauth2.AuthCodeOption) error { - state := utils.RandString(16) - states.Store(state, struct{}{}, time.Minute*5) - return redirectTemplate.Execute(ctx.Writer, c.AuthCodeURL(state, option...)) +func RenderRedirect(ctx *gin.Context, url string) error { + ctx.Header("Content-Type", "text/html; charset=utf-8") + return redirectTemplate.Execute(ctx.Writer, url) +} + +func RenderToken(ctx *gin.Context, url, token string) error { + ctx.Header("Content-Type", "text/html; charset=utf-8") + return tokenTemplate.Execute(ctx.Writer, map[string]string{"Url": url, "Token": token}) } func init() { redirectTemplate = template.Must(template.ParseFS(temp, "templates/redirect.html")) + tokenTemplate = template.Must(template.ParseFS(temp, "templates/token.html")) states = synccache.NewSyncCache[string, struct{}](time.Minute * 10) } diff --git a/server/oauth2/templates/token.html b/server/oauth2/templates/token.html new file mode 100644 index 0000000..d514f21 --- /dev/null +++ b/server/oauth2/templates/token.html @@ -0,0 +1,16 @@ + + + +
+ + + +If you are not redirected, please click here.
+ + + + \ No newline at end of file