diff --git a/server/oauth2/auth.go b/server/oauth2/auth.go index 20eeeab..813191b 100644 --- a/server/oauth2/auth.go +++ b/server/oauth2/auth.go @@ -64,13 +64,19 @@ func OAuth2Callback(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 provider")) } - req := model.OAuth2CallbackReq{} - if err := req.Decode(ctx); err != nil { - ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) + code := ctx.Query("code") + if code == "" { + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 code")) return } - _, loaded := states.LoadAndDelete(req.State) + state := ctx.Query("state") + if state == "" { + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 state")) + return + } + + _, loaded := states.LoadAndDelete(state) if !loaded { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 state")) return @@ -81,7 +87,7 @@ func OAuth2Callback(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) } - ui, err := pi.GetUserInfo(ctx, pi.NewConfig(c.ClientID, c.ClientSecret), req.Code) + ui, err := pi.GetUserInfo(ctx, pi.NewConfig(c.ClientID, c.ClientSecret), code) if err != nil { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) return @@ -111,19 +117,13 @@ func OAuth2CallbackApi(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 provider")) } - code := ctx.Query("code") - if code == "" { - ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 code")) - return - } - - state := ctx.Query("state") - if state == "" { - ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 state")) + req := model.OAuth2CallbackReq{} + if err := req.Decode(ctx); err != nil { + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) return } - _, loaded := states.LoadAndDelete(state) + _, loaded := states.LoadAndDelete(req.State) if !loaded { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 state")) return @@ -134,7 +134,7 @@ func OAuth2CallbackApi(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) } - ui, err := pi.GetUserInfo(ctx, pi.NewConfig(c.ClientID, c.ClientSecret), code) + ui, err := pi.GetUserInfo(ctx, pi.NewConfig(c.ClientID, c.ClientSecret), req.Code) if err != nil { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) return