Feat: disable user signup setting

pull/21/head
zijiren233 1 year ago
parent c16fc972b1
commit 66ba37ac59

@ -27,7 +27,7 @@ var FixCmd = &cobra.Command{
_, err := s.Interface() _, err := s.Interface()
if err != nil { if err != nil {
fmt.Printf("setting %s, interface error: %v\n", k, err) fmt.Printf("setting %s, interface error: %v\n", k, err)
err = s.SetString(s.DefaultString()) err = s.SetRaw(s.DefaultRaw())
if err != nil { if err != nil {
errorCount++ errorCount++
fmt.Printf("setting %s fix error: %v\n", k, err) fmt.Printf("setting %s fix error: %v\n", k, err)

@ -29,14 +29,14 @@ var SetCmd = &cobra.Command{
if !ok { if !ok {
return errors.New("setting not found") return errors.New("setting not found")
} }
current := s.String() current := s.Raw()
err := s.SetString(args[1]) err := s.SetRaw(args[1])
if err != nil { if err != nil {
s.SetString(current) s.SetRaw(current)
fmt.Printf("set setting %s error: %v\n", args[0], err) fmt.Printf("set setting %s error: %v\n", args[0], err)
} }
if v, err := s.Interface(); err != nil { if v, err := s.Interface(); err != nil {
s.SetString(current) s.SetRaw(current)
fmt.Printf("set setting %s error: %v\n", args[0], err) fmt.Printf("set setting %s error: %v\n", args[0], err)
} else { } else {
fmt.Printf("set setting success:\n%s: %v\n", args[0], v) fmt.Printf("set setting success:\n%s: %v\n", args[0], v)

@ -61,13 +61,16 @@ func CreateOrLoadUser(username string, p provider.OAuth2Provider, puid uint, con
return &user, nil return &user, nil
} }
func GetUserByProvider(p provider.OAuth2Provider, puid uint) (*model.User, error) { func GetProviderUserID(p provider.OAuth2Provider, puid uint) (uint, error) {
u := &model.User{} var userProvider model.UserProvider
err := db.Preload("Providers", "provider = ? AND provider_user_id = ?", p, puid).First(u).Error if err := db.Where("provider = ? AND provider_user_id = ?", p, puid).First(&userProvider).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return u, errors.New("user not found") return 0, errors.New("user not found")
} else {
return 0, err
}
} }
return u, err return userProvider.UserID, nil
} }
func AddUserToRoom(userID uint, roomID uint, role model.RoomRole, permission model.Permission) error { func AddUserToRoom(userID uint, roomID uint, role model.RoomRole, permission model.Permission) error {

@ -76,6 +76,15 @@ func CreateOrLoadUser(username string, p provider.OAuth2Provider, pid uint, conf
return u2, userCache.SetWithExpire(u.ID, u2, time.Hour) return u2, userCache.SetWithExpire(u.ID, u2, time.Hour)
} }
func GetUserByProvider(p provider.OAuth2Provider, pid uint) (*User, error) {
uid, err := db.GetProviderUserID(p, pid)
if err != nil {
return nil, err
}
return GetUserById(uid)
}
func DeleteUserByID(userID uint) error { func DeleteUserByID(userID uint) error {
err := db.DeleteUserByID(userID) err := db.DeleteUserByID(userID)
if err != nil { if err != nil {

@ -49,7 +49,7 @@ func (b *Bool) Default() bool {
return b.defaultValue return b.defaultValue
} }
func (b *Bool) DefaultString() string { func (b *Bool) DefaultRaw() string {
if b.defaultValue { if b.defaultValue {
return "1" return "1"
} else { } else {
@ -61,7 +61,7 @@ func (b *Bool) DefaultInterface() any {
return b.Default() return b.Default()
} }
func (b *Bool) SetString(value string) error { func (b *Bool) SetRaw(value string) error {
if b.value == value { if b.value == value {
return nil return nil
} }
@ -71,9 +71,9 @@ func (b *Bool) SetString(value string) error {
func (b *Bool) Set(value bool) error { func (b *Bool) Set(value bool) error {
if value { if value {
return b.SetString("1") return b.SetRaw("1")
} else { } else {
return b.SetString("0") return b.SetRaw("0")
} }
} }
@ -81,7 +81,7 @@ func (b *Bool) Get() (bool, error) {
return b.value == "1", nil return b.value == "1", nil
} }
func (b *Bool) String() string { func (b *Bool) Raw() string {
return b.value return b.value
} }

@ -18,9 +18,9 @@ type Setting interface {
Type() model.SettingType Type() model.SettingType
Group() model.SettingGroup Group() model.SettingGroup
Init(string) Init(string)
String() string Raw() string
SetString(string) error SetRaw(string) error
DefaultString() string DefaultRaw() string
DefaultInterface() any DefaultInterface() any
Interface() (any, error) Interface() (any, error)
} }
@ -115,7 +115,7 @@ func initSettings(i ...Setting) error {
for _, b := range i { for _, b := range i {
s := &model.Setting{ s := &model.Setting{
Name: b.Name(), Name: b.Name(),
Value: b.String(), Value: b.Raw(),
Type: b.Type(), Type: b.Type(),
Group: b.Group(), Group: b.Group(),
} }

@ -5,3 +5,7 @@ import "github.com/synctv-org/synctv/internal/model"
var ( var (
DisableCreateRoom = newBoolSetting("disable_create_room", false, model.SettingGroupRoom) DisableCreateRoom = newBoolSetting("disable_create_room", false, model.SettingGroupRoom)
) )
var (
DisableUserSignup = newBoolSetting("disable_user_signup", false, model.SettingGroupUser)
)

@ -8,6 +8,7 @@ import (
"github.com/synctv-org/synctv/internal/op" "github.com/synctv-org/synctv/internal/op"
"github.com/synctv-org/synctv/internal/provider" "github.com/synctv-org/synctv/internal/provider"
"github.com/synctv-org/synctv/internal/provider/providers" "github.com/synctv-org/synctv/internal/provider/providers"
"github.com/synctv-org/synctv/internal/settings"
"github.com/synctv-org/synctv/server/middlewares" "github.com/synctv-org/synctv/server/middlewares"
"github.com/synctv-org/synctv/server/model" "github.com/synctv-org/synctv/server/model"
"github.com/synctv-org/synctv/utils" "github.com/synctv-org/synctv/utils"
@ -83,7 +84,17 @@ func OAuth2Callback(ctx *gin.Context) {
return return
} }
user, err := op.CreateOrLoadUser(ui.Username, p, ui.ProviderUserID) disable, err := settings.DisableUserSignup.Get()
if err != nil {
ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewApiErrorResp(err))
return
}
var user *op.User
if disable {
user, err = op.GetUserByProvider(p, ui.ProviderUserID)
} else {
user, err = op.CreateOrLoadUser(ui.Username, p, ui.ProviderUserID)
}
if err != nil { if err != nil {
ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewApiErrorResp(err)) ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewApiErrorResp(err))
return return
@ -130,7 +141,17 @@ func OAuth2CallbackApi(ctx *gin.Context) {
return return
} }
user, err := op.CreateOrLoadUser(ui.Username, p, ui.ProviderUserID) disable, err := settings.DisableUserSignup.Get()
if err != nil {
ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewApiErrorResp(err))
return
}
var user *op.User
if disable {
user, err = op.GetUserByProvider(p, ui.ProviderUserID)
} else {
user, err = op.CreateOrLoadUser(ui.Username, p, ui.ProviderUserID)
}
if err != nil { if err != nil {
ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewApiErrorResp(err)) ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewApiErrorResp(err))
return return

Loading…
Cancel
Save