From 1bb8fcaca2e4e2ef996e41851fd599d4e67a7650 Mon Sep 17 00:00:00 2001 From: zijiren233 Date: Sun, 6 Oct 2024 12:56:14 +0800 Subject: [PATCH] refactor: signup api --- internal/bootstrap/provider.go | 25 +++++++++++++- internal/db/user.go | 30 +++++++++++++++-- internal/model/user.go | 20 +++++++++--- internal/op/users.go | 2 +- internal/settings/var.go | 10 +++--- server/handlers/init.go | 2 ++ server/handlers/public.go | 35 +++++++++++++++----- server/handlers/user.go | 60 ++++++++++++++++++++++++++++++++-- server/oauth2/init.go | 2 ++ server/oauth2/oauth2.go | 16 +++++++++ 10 files changed, 178 insertions(+), 24 deletions(-) diff --git a/internal/bootstrap/provider.go b/internal/bootstrap/provider.go index 293e8a8..0cd2483 100644 --- a/internal/bootstrap/provider.go +++ b/internal/bootstrap/provider.go @@ -37,7 +37,7 @@ type ProviderGroupSetting struct { var Oauth2EnabledCache = refreshcache.NewRefreshCache[[]provider.OAuth2Provider](func(context.Context, ...any) ([]provider.OAuth2Provider, error) { ps := providers.EnabledProvider() r := make([]provider.OAuth2Provider, 0, ps.Len()) - providers.EnabledProvider().Range(func(p provider.OAuth2Provider, value struct{}) bool { + ps.Range(func(p provider.OAuth2Provider, value struct{}) bool { r = append(r, p) return true }) @@ -53,6 +53,29 @@ var Oauth2EnabledCache = refreshcache.NewRefreshCache[[]provider.OAuth2Provider] return r, nil }, 0) +var Oauth2SignupEnabledCache = refreshcache.NewRefreshCache[[]provider.OAuth2Provider](func(ctx context.Context, _ ...any) ([]provider.OAuth2Provider, error) { + ps := providers.EnabledProvider() + r := make([]provider.OAuth2Provider, 0, ps.Len()) + ps.Range(func(p provider.OAuth2Provider, value struct{}) bool { + group := model.SettingGroup(fmt.Sprintf("%s_%s", model.SettingGroupOauth2, p)) + groupSettings := ProviderGroupSettings[group] + if groupSettings.Enabled.Get() && !groupSettings.DisableUserSignup.Get() { + r = append(r, p) + } + return true + }) + slices.SortStableFunc(r, func(a, b provider.OAuth2Provider) int { + if a == b { + return 0 + } else if natural.Less(a, b) { + return -1 + } else { + return 1 + } + }) + return r, nil +}, 0) + func InitProvider(ctx context.Context) (err error) { logOur := log.StandardLogger().Writer() logLevle := hclog.Info diff --git a/internal/db/user.go b/internal/db/user.go index 9b6a7fa..99c1dbd 100644 --- a/internal/db/user.go +++ b/internal/db/user.go @@ -76,6 +76,18 @@ func WithRegisteredByEmail(b bool) CreateUserConfig { } } +func WithEnableAutoAddUsernameSuffix() CreateUserConfig { + return func(u *model.User) { + u.EnableAutoAddUsernameSuffix() + } +} + +func WithDisableAutoAddUsernameSuffix() CreateUserConfig { + return func(u *model.User) { + u.DisableAutoAddUsernameSuffix() + } +} + func CreateUserWithHashedPassword(username string, hashedPassword []byte, conf ...CreateUserConfig) (*model.User, error) { if username == "" { return nil, errors.New("username cannot be empty") @@ -153,7 +165,11 @@ func CreateOrLoadUserWithProvider(username, password string, p provider.OAuth2Pr var user model.User if err := db.Where("id = (?)", db.Table("user_providers").Where("provider = ? AND provider_user_id = ?", p, puid).Select("user_id")).First(&user).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return CreateUser(username, password, append(conf, WithSetProvider(p, puid), WithRegisteredByProvider(true))...) + return CreateUser(username, password, append(conf, + WithSetProvider(p, puid), + WithRegisteredByProvider(true), + WithEnableAutoAddUsernameSuffix(), + )...) } else { return nil, err } @@ -169,7 +185,11 @@ func CreateOrLoadUserWithEmail(username, password, email string, conf ...CreateU var user model.User if err := db.Where("email = ?", email).First(&user).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return CreateUser(username, password, append(conf, WithEmail(email), WithRegisteredByEmail(true))...) + return CreateUser(username, password, append(conf, + WithEmail(email), + WithRegisteredByEmail(true), + WithEnableAutoAddUsernameSuffix(), + )...) } else { return nil, err } @@ -181,7 +201,11 @@ func CreateUserWithEmail(username, password, email string, conf ...CreateUserCon if email == "" { return nil, errors.New("email cannot be empty") } - return CreateUser(username, password, append(conf, WithEmail(email), WithRegisteredByEmail(true))...) + return CreateUser(username, password, append(conf, + WithEmail(email), + WithRegisteredByEmail(true), + WithEnableAutoAddUsernameSuffix(), + )...) } func GetUserByProvider(p provider.OAuth2Provider, puid string) (*model.User, error) { diff --git a/internal/model/user.go b/internal/model/user.go index eeb22f7..d885b18 100644 --- a/internal/model/user.go +++ b/internal/model/user.go @@ -55,6 +55,16 @@ type User struct { BilibiliVendor *BilibiliVendor `gorm:"foreignKey:UserID;constraint:OnUpdate:CASCADE,OnDelete:CASCADE"` AlistVendor []*AlistVendor `gorm:"foreignKey:UserID;constraint:OnUpdate:CASCADE,OnDelete:CASCADE"` EmbyVendor []*EmbyVendor `gorm:"foreignKey:UserID;constraint:OnUpdate:CASCADE,OnDelete:CASCADE"` + + autoAddUsernameSuffix bool +} + +func (u *User) EnableAutoAddUsernameSuffix() { + u.autoAddUsernameSuffix = true +} + +func (u *User) DisableAutoAddUsernameSuffix() { + u.autoAddUsernameSuffix = false } func (u *User) CheckPassword(password string) bool { @@ -62,10 +72,12 @@ func (u *User) CheckPassword(password string) bool { } func (u *User) BeforeCreate(tx *gorm.DB) error { - var existingUser User - err := tx.Where("username = ?", u.Username).First(&existingUser).Error - if err == nil { - u.Username = fmt.Sprintf("%s#%d", u.Username, rand.Intn(9999)) + if u.autoAddUsernameSuffix { + var existingUser User + err := tx.Select("username").Where("username = ?", u.Username).First(&existingUser).Error + if err == nil { + u.Username = fmt.Sprintf("%s#%d", u.Username, rand.Intn(9999)) + } } if u.ID == "" { u.ID = utils.SortUUID() diff --git a/internal/op/users.go b/internal/op/users.go index f482425..43901f5 100644 --- a/internal/op/users.go +++ b/internal/op/users.go @@ -52,7 +52,7 @@ func LoadOrInitUserByEmail(email string) (*UserEntry, error) { return LoadOrInitUser(u) } -func LoadUserByUsername(username string) (*UserEntry, error) { +func LoadOrInitUserByUsername(username string) (*UserEntry, error) { u, err := db.GetUserByUsername(username) if err != nil { return nil, err diff --git a/internal/settings/var.go b/internal/settings/var.go index 636f3eb..23da0ff 100644 --- a/internal/settings/var.go +++ b/internal/settings/var.go @@ -42,10 +42,12 @@ func init() { } var ( - DisableUserSignup = NewBoolSetting("disable_user_signup", false, model.SettingGroupUser) - SignupNeedReview = NewBoolSetting("signup_need_review", false, model.SettingGroupUser) - UserMaxRoomCount = NewInt64Setting("user_max_room_count", 3, model.SettingGroupUser) - EnableGuest = NewBoolSetting("enable_guest", true, model.SettingGroupUser) + DisableUserSignup = NewBoolSetting("disable_user_signup", false, model.SettingGroupUser) + SignupNeedReview = NewBoolSetting("signup_need_review", false, model.SettingGroupUser) + EnablePasswordSignup = NewBoolSetting("enable_password_signup", false, model.SettingGroupUser) + PasswordSignupNeedReview = NewBoolSetting("password_signup_need_review", false, model.SettingGroupUser) + UserMaxRoomCount = NewInt64Setting("user_max_room_count", 3, model.SettingGroupUser) + EnableGuest = NewBoolSetting("enable_guest", true, model.SettingGroupUser) ) var ( diff --git a/server/handlers/init.go b/server/handlers/init.go index e688353..03fb12e 100644 --- a/server/handlers/init.go +++ b/server/handlers/init.go @@ -247,6 +247,8 @@ func initMovie(movie *gin.RouterGroup, needAuthMovie *gin.RouterGroup) { func initUser(user *gin.RouterGroup, needAuthUser *gin.RouterGroup) { user.POST("/login", LoginUser) + user.POST("/signup", UserSignupPassword) + user.GET("/signup/email/captcha", GetUserSignupEmailStep1Captcha) user.POST("/signup/email/captcha", SendUserSignupEmailCaptcha) diff --git a/server/handlers/public.go b/server/handlers/public.go index c99a79b..9b3c356 100644 --- a/server/handlers/public.go +++ b/server/handlers/public.go @@ -1,30 +1,49 @@ package handlers import ( + "net/http" "strings" "github.com/gin-gonic/gin" + log "github.com/sirupsen/logrus" + "github.com/synctv-org/synctv/internal/bootstrap" "github.com/synctv-org/synctv/internal/email" "github.com/synctv-org/synctv/internal/settings" "github.com/synctv-org/synctv/server/model" ) type publicSettings struct { - EmailEnable bool `json:"emailEnable"` - EmailDisableUserSignup bool `json:"emailDisableUserSignup"` - EmailWhitelistEnabled bool `json:"emailWhitelistEnabled"` - EmailWhitelist []string `json:"emailWhitelist,omitempty"` + PasswordDisableSignup bool `json:"passwordDisableSignup"` + + EmailEnable bool `json:"emailEnable"` + EmailDisableSignup bool `json:"emailDisableSignup"` + EmailWhitelistEnabled bool `json:"emailWhitelistEnabled"` + EmailWhitelist []string `json:"emailWhitelist,omitempty"` + + Oauth2DisableSignup bool `json:"oauth2DisableSignup"` GuestEnable bool `json:"guestEnable"` } func Settings(ctx *gin.Context) { + log := ctx.MustGet("log").(*log.Entry) + + oauth2SignupEnabled, err := bootstrap.Oauth2SignupEnabledCache.Get(ctx) + if err != nil { + log.Errorf("failed to get oauth2 signup enabled: %v", err) + ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewApiErrorResp(err)) + return + } ctx.JSON(200, model.NewApiDataResp( &publicSettings{ - EmailEnable: email.EnableEmail.Get(), - EmailDisableUserSignup: email.DisableUserSignup.Get(), - EmailWhitelistEnabled: email.EmailSignupWhiteListEnable.Get(), - EmailWhitelist: strings.Split(email.EmailSignupWhiteList.Get(), ","), + PasswordDisableSignup: settings.DisableUserSignup.Get() || !settings.EnablePasswordSignup.Get(), + + EmailEnable: email.EnableEmail.Get(), + EmailDisableSignup: settings.DisableUserSignup.Get() || email.DisableUserSignup.Get(), + EmailWhitelistEnabled: email.EmailSignupWhiteListEnable.Get(), + EmailWhitelist: strings.Split(email.EmailSignupWhiteList.Get(), ","), + + Oauth2DisableSignup: settings.DisableUserSignup.Get() || len(oauth2SignupEnabled) == 0, GuestEnable: settings.EnableGuest.Get(), }, diff --git a/server/handlers/user.go b/server/handlers/user.go index 2184841..9394536 100644 --- a/server/handlers/user.go +++ b/server/handlers/user.go @@ -46,7 +46,7 @@ func LoginUser(ctx *gin.Context) { return } - user, err := op.LoadUserByUsername(req.Username) + user, err := op.LoadOrInitUserByUsername(req.Username) if err != nil { log.Errorf("failed to load user: %v", err) if err == op.ErrUserBanned || err == op.ErrUserPending { @@ -487,9 +487,14 @@ func GetUserSignupEmailStep1Captcha(ctx *gin.Context) { func SendUserSignupEmailCaptcha(ctx *gin.Context) { log := ctx.MustGet("log").(*logrus.Entry) - if settings.DisableUserSignup.Get() || email.DisableUserSignup.Get() { + if settings.DisableUserSignup.Get() { + log.Errorf("user signup disabled") ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("user signup disabled")) return + } else if email.DisableUserSignup.Get() { + log.Errorf("email signup disabled") + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("email signup disabled")) + return } req := model.SendUserSignupEmailCaptchaReq{} @@ -545,10 +550,14 @@ func SendUserSignupEmailCaptcha(ctx *gin.Context) { func UserSignupEmail(ctx *gin.Context) { log := ctx.MustGet("log").(*logrus.Entry) - if settings.DisableUserSignup.Get() || email.DisableUserSignup.Get() { + if settings.DisableUserSignup.Get() { log.Errorf("user signup disabled") ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("user signup disabled")) return + } else if email.DisableUserSignup.Get() { + log.Errorf("email signup disabled") + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("email signup disabled")) + return } req := model.UserSignupEmailReq{} @@ -736,3 +745,48 @@ func UserDeleteRoom(ctx *gin.Context) { ctx.Status(http.StatusNoContent) } + +func UserSignupPassword(ctx *gin.Context) { + log := ctx.MustGet("log").(*logrus.Entry) + + if settings.DisableUserSignup.Get() { + log.Errorf("user signup disabled") + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("user signup disabled")) + return + } else if !settings.EnablePasswordSignup.Get() { + log.Errorf("password signup disabled") + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("password signup disabled")) + return + } + + var req model.LoginUserReq + if err := model.Decode(ctx, &req); err != nil { + log.Errorf("failed to decode request: %v", err) + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) + return + } + + var user *op.UserEntry + var err error + if settings.SignupNeedReview.Get() || settings.PasswordSignupNeedReview.Get() { + user, err = op.CreateUser(req.Username, req.Password, db.WithRole(dbModel.RolePending)) + } else { + user, err = op.CreateUser(req.Username, req.Password) + } + if err != nil { + log.Errorf("failed to create user: %v", err) + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) + return + } + + token, err := middlewares.NewAuthUserToken(user.Value()) + if err != nil { + log.Errorf("failed to generate token: %v", err) + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) + return + } + + ctx.JSON(http.StatusOK, model.NewApiDataResp(gin.H{ + "token": token, + })) +} diff --git a/server/oauth2/init.go b/server/oauth2/init.go index c6d8f17..068dec4 100644 --- a/server/oauth2/init.go +++ b/server/oauth2/init.go @@ -13,6 +13,8 @@ func Init(e *gin.Engine) { oauth2.GET("/enabled", OAuth2EnabledApi) + oauth2.GET("/enabled/signup", OAuth2SignupEnabledApi) + oauth2.GET("/login/:type", OAuth2) oauth2.POST("/login/:type", OAuth2Api) diff --git a/server/oauth2/oauth2.go b/server/oauth2/oauth2.go index 0ddd7a0..ac6a23c 100644 --- a/server/oauth2/oauth2.go +++ b/server/oauth2/oauth2.go @@ -18,7 +18,23 @@ func OAuth2EnabledApi(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewApiErrorResp(err)) return } + ctx.JSON(200, gin.H{ "enabled": data, }) } + +func OAuth2SignupEnabledApi(ctx *gin.Context) { + log := ctx.MustGet("log").(*logrus.Entry) + + oauth2SignupEnabled, err := bootstrap.Oauth2SignupEnabledCache.Get(ctx) + if err != nil { + log.Errorf("failed to get oauth2 signup enabled: %v", err) + ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewApiErrorResp(err)) + return + } + + ctx.JSON(200, gin.H{ + "signupEnabled": oauth2SignupEnabled, + }) +}