diff --git a/internal/db/user.go b/internal/db/user.go index dda2cba..a86f3c6 100644 --- a/internal/db/user.go +++ b/internal/db/user.go @@ -41,23 +41,24 @@ func CreateUser(username string, p provider.OAuth2Provider, puid uint, conf ...C return u, err } +// 只有当provider和puid没有找到对应的user时才会创建 func CreateOrLoadUser(username string, p provider.OAuth2Provider, puid uint, conf ...CreateUserConfig) (*model.User, error) { - u := &model.User{ - Username: username, - Role: model.RoleUser, - Providers: []model.UserProvider{ - { - Provider: p, - ProviderUserID: puid, - }, - }, - } - for _, c := range conf { - c(u) + var user model.User + var userProvider model.UserProvider + + if err := db.Where("provider = ? AND provider_user_id = ?", p, puid).First(&userProvider).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return CreateUser(username, p, puid, conf...) + } else { + return nil, err + } + } else { + if err := db.First(&user, userProvider.UserID).Error; err != nil { + return nil, err + } } - return u, db.Preload("Providers", "provider = ? AND provider_user_id = ?", p, puid). - FirstOrCreate(u). - Error + + return &user, nil } func GetUserByProvider(p provider.OAuth2Provider, puid uint) (*model.User, error) {