@ -13,6 +13,10 @@ import (
"gorm.io/gorm/clause"
)
const (
ErrUserNotFound = "user"
)
type CreateUserConfig func ( u * model . User )
func WithID ( id string ) CreateUserConfig {
@ -140,13 +144,13 @@ func GetUserByProvider(p provider.OAuth2Provider, puid string) (*model.User, err
err := db . Joins ( "JOIN user_providers ON users.id = user_providers.user_id" ) .
Where ( "user_providers.provider = ? AND user_providers.provider_user_id = ?" , p , puid ) .
First ( & user ) . Error
return & user , HandleNotFound ( err , "user" )
return & user , HandleNotFound ( err , ErrUserNotFound )
}
func GetUserByEmail ( email string ) ( * model . User , error ) {
var user model . User
err := db . Where ( "email = ?" , email ) . First ( & user ) . Error
return & user , HandleNotFound ( err , "user" )
return & user , HandleNotFound ( err , ErrUserNotFound )
}
func GetProviderUserID ( p provider . OAuth2Provider , puid string ) ( string , error ) {
@ -155,7 +159,7 @@ func GetProviderUserID(p provider.OAuth2Provider, puid string) (string, error) {
Where ( "provider = ? AND provider_user_id = ?" , p , puid ) .
Select ( "user_id" ) .
First ( & userID ) . Error
return userID , HandleNotFound ( err , "user" )
return userID , HandleNotFound ( err , ErrUserNotFound )
}
func BindProvider ( uid string , p provider . OAuth2Provider , puid string ) error {
@ -177,7 +181,7 @@ func UnBindProvider(uid string, p provider.OAuth2Provider) error {
return Transactional ( func ( tx * gorm . DB ) error {
var user model . User
if err := tx . Preload ( "UserProviders" ) . Where ( "id = ?" , uid ) . First ( & user ) . Error ; err != nil {
return HandleNotFound ( err , "user" )
return HandleNotFound ( err , ErrUserNotFound )
}
if user . RegisteredByProvider && len ( user . UserProviders ) <= 1 {
return errors . New ( "user must have at least one provider" )
@ -189,14 +193,14 @@ func UnBindProvider(uid string, p provider.OAuth2Provider) error {
func BindEmail ( id string , email string ) error {
result := db . Model ( & model . User { } ) . Where ( "id = ?" , id ) . Update ( "email" , model . EmptyNullString ( email ) )
return HandleUpdateResult ( result , "user" )
return HandleUpdateResult ( result , ErrUserNotFound )
}
func UnbindEmail ( uid string ) error {
return Transactional ( func ( tx * gorm . DB ) error {
var user model . User
if err := tx . Select ( "email" , "registered_by_email" ) . Where ( "id = ?" , uid ) . First ( & user ) . Error ; err != nil {
return HandleNotFound ( err , "user" )
return HandleNotFound ( err , ErrUserNotFound )
}
if user . RegisteredByEmail {
return errors . New ( "user must have one email" )
@ -205,7 +209,7 @@ func UnbindEmail(uid string) error {
return nil
}
result := tx . Model ( & model . User { } ) . Where ( "id = ?" , uid ) . Update ( "email" , model . EmptyNullString ( "" ) )
return HandleUpdateResult ( result , "user" )
return HandleUpdateResult ( result , ErrUserNotFound )
} )
}
@ -221,7 +225,7 @@ func GetBindProviders(uid string) ([]*model.UserProvider, error) {
func GetUserByUsername ( username string ) ( * model . User , error ) {
var user model . User
err := db . Where ( "username = ?" , username ) . First ( & user ) . Error
return & user , HandleNotFound ( err , "user" )
return & user , HandleNotFound ( err , ErrUserNotFound )
}
func GetUserByUsernameLike ( username string , scopes ... func ( * gorm . DB ) * gorm . DB ) ( [ ] * model . User , error ) {
@ -266,7 +270,7 @@ func GetUserByID(id string) (*model.User, error) {
}
var user model . User
err := db . Where ( "id = ?" , id ) . First ( & user ) . Error
return & user , HandleNotFound ( err , "user" )
return & user , HandleNotFound ( err , ErrUserNotFound )
}
func BanUser ( u * model . User ) error {
@ -279,7 +283,7 @@ func BanUser(u *model.User) error {
func BanUserByID ( userID string ) error {
result := db . Model ( & model . User { } ) . Where ( "id = ?" , userID ) . Update ( "role" , model . RoleBanned )
return HandleUpdateResult ( result , "user" )
return HandleUpdateResult ( result , ErrUserNotFound )
}
func UnbanUser ( u * model . User ) error {
@ -292,12 +296,12 @@ func UnbanUser(u *model.User) error {
func UnbanUserByID ( userID string ) error {
result := db . Model ( & model . User { } ) . Where ( "id = ?" , userID ) . Update ( "role" , model . RoleUser )
return HandleUpdateResult ( result , "user" )
return HandleUpdateResult ( result , ErrUserNotFound )
}
func DeleteUserByID ( userID string ) error {
result := db . Unscoped ( ) . Select ( clause . Associations ) . Delete ( & model . User { ID : userID } )
return HandleUpdateResult ( result , "user" )
return HandleUpdateResult ( result , ErrUserNotFound )
}
func LoadAndDeleteUserByID ( userID string , columns ... clause . Column ) ( * model . User , error ) {
@ -307,12 +311,12 @@ func LoadAndDeleteUserByID(userID string, columns ...clause.Column) (*model.User
Select ( clause . Associations ) .
Where ( "id = ?" , userID ) .
Delete ( & user )
return & user , Handle NotFound( result . Error , "user" )
return & user , Handle UpdateResult( result , ErrUserNotFound )
}
func SaveUser ( u * model . User ) error {
result := db . Omit ( "created_at" ) . Save ( u )
return HandleUpdateResult ( result , "user" )
return HandleUpdateResult ( result , ErrUserNotFound )
}
func AddAdmin ( u * model . User ) error {
@ -342,12 +346,12 @@ func GetAdmins() ([]*model.User, error) {
func AddAdminByID ( userID string ) error {
result := db . Model ( & model . User { } ) . Where ( "id = ?" , userID ) . Update ( "role" , model . RoleAdmin )
return HandleUpdateResult ( result , "user" )
return HandleUpdateResult ( result , ErrUserNotFound )
}
func RemoveAdminByID ( userID string ) error {
result := db . Model ( & model . User { } ) . Where ( "id = ?" , userID ) . Update ( "role" , model . RoleUser )
return HandleUpdateResult ( result , "user" )
return HandleUpdateResult ( result , ErrUserNotFound )
}
func AddRoot ( u * model . User ) error {
@ -368,12 +372,12 @@ func RemoveRoot(u *model.User) error {
func AddRootByID ( userID string ) error {
result := db . Model ( & model . User { } ) . Where ( "id = ?" , userID ) . Update ( "role" , model . RoleRoot )
return HandleUpdateResult ( result , "user" )
return HandleUpdateResult ( result , ErrUserNotFound )
}
func RemoveRootByID ( userID string ) error {
result := db . Model ( & model . User { } ) . Where ( "id = ?" , userID ) . Update ( "role" , model . RoleUser )
return HandleUpdateResult ( result , "user" )
return HandleUpdateResult ( result , ErrUserNotFound )
}
func GetRoots ( ) [ ] * model . User {
@ -384,22 +388,22 @@ func GetRoots() []*model.User {
func SetAdminRoleByID ( userID string ) error {
result := db . Model ( & model . User { } ) . Where ( "id = ?" , userID ) . Update ( "role" , model . RoleAdmin )
return HandleUpdateResult ( result , "user" )
return HandleUpdateResult ( result , ErrUserNotFound )
}
func SetRootRoleByID ( userID string ) error {
result := db . Model ( & model . User { } ) . Where ( "id = ?" , userID ) . Update ( "role" , model . RoleRoot )
return HandleUpdateResult ( result , "user" )
return HandleUpdateResult ( result , ErrUserNotFound )
}
func SetUserRoleByID ( userID string ) error {
result := db . Model ( & model . User { } ) . Where ( "id = ?" , userID ) . Update ( "role" , model . RoleUser )
return HandleUpdateResult ( result , "user" )
return HandleUpdateResult ( result , ErrUserNotFound )
}
func SetUsernameByID ( userID string , username string ) error {
result := db . Model ( & model . User { } ) . Where ( "id = ?" , userID ) . Update ( "username" , username )
return HandleUpdateResult ( result , "user" )
return HandleUpdateResult ( result , ErrUserNotFound )
}
func GetUserCount ( scopes ... func ( * gorm . DB ) * gorm . DB ) ( int64 , error ) {
@ -422,5 +426,5 @@ func GetUsers(scopes ...func(*gorm.DB) *gorm.DB) ([]*model.User, error) {
func SetUserHashedPassword ( id string , hashedPassword [ ] byte ) error {
result := db . Model ( & model . User { } ) . Where ( "id = ?" , id ) . Update ( "hashed_password" , hashedPassword )
return HandleUpdateResult ( result , "user" )
return HandleUpdateResult ( result , ErrUserNotFound )
}