From dab669708d0b4ac38821807b5d1a2cad797177df Mon Sep 17 00:00:00 2001 From: zijiren233 Date: Fri, 20 Oct 2023 18:01:32 +0800 Subject: [PATCH] Feat: use oauth2 --- go.mod | 5 ++ go.sum | 17 ++++- internal/conf/config.go | 6 ++ internal/conf/oauth2.go | 22 ++++++ internal/db/db.go | 2 +- internal/db/user.go | 44 +++++++++-- internal/model/model_test.go | 1 - internal/model/oauth2.go | 13 ++++ internal/model/user.go | 25 +++---- internal/op/user.go | 35 --------- internal/op/users.go | 44 ++++++----- internal/provider/github.go | 104 ++++++++++++++++++++++++++ internal/provider/gitlab.go | 46 ++++++++++++ internal/provider/google.go | 46 ++++++++++++ internal/provider/provider.go | 49 ++++++++++++ server/handlers/init.go | 16 ++-- server/handlers/user.go | 79 ------------------- server/middlewares/auth.go | 26 +------ server/oauth2/auth.go | 85 +++++++++++++++++++++ server/oauth2/init.go | 13 ++++ server/oauth2/render.go | 31 ++++++++ server/oauth2/templates/redirect.html | 16 ++++ server/router.go | 2 + utils/syncCache/cache.go | 81 ++++++++++++++++++++ utils/syncCache/item.go | 12 +++ 25 files changed, 624 insertions(+), 196 deletions(-) create mode 100644 internal/conf/oauth2.go create mode 100644 internal/model/oauth2.go create mode 100644 internal/provider/github.go create mode 100644 internal/provider/gitlab.go create mode 100644 internal/provider/google.go create mode 100644 internal/provider/provider.go create mode 100644 server/oauth2/auth.go create mode 100644 server/oauth2/init.go create mode 100644 server/oauth2/render.go create mode 100644 server/oauth2/templates/redirect.html create mode 100644 utils/syncCache/cache.go create mode 100644 utils/syncCache/item.go diff --git a/go.mod b/go.mod index 71296d7..ea7f254 100644 --- a/go.mod +++ b/go.mod @@ -28,6 +28,7 @@ require ( github.com/zijiren233/stream v0.5.1 github.com/zijiren233/yaml-comment v0.2.0 golang.org/x/crypto v0.14.0 + golang.org/x/oauth2 v0.13.0 google.golang.org/protobuf v1.31.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/mysql v1.4.7 @@ -36,6 +37,8 @@ require ( ) require ( + cloud.google.com/go/compute v1.20.1 // indirect + cloud.google.com/go/compute/metadata v0.2.3 // indirect github.com/BurntSushi/toml v1.3.2 // indirect github.com/bytedance/sonic v1.10.2 // indirect github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d // indirect @@ -50,6 +53,7 @@ require ( github.com/go-sql-driver/mysql v1.7.0 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/goccy/go-json v0.10.2 // indirect + github.com/golang/protobuf v1.5.3 // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/google/pprof v0.0.0-20230926050212-f7f687d19a98 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect @@ -79,6 +83,7 @@ require ( golang.org/x/sys v0.13.0 // indirect golang.org/x/text v0.13.0 // indirect golang.org/x/tools v0.14.0 // indirect + google.golang.org/appengine v1.6.7 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect modernc.org/libc v1.22.5 // indirect modernc.org/mathutil v1.5.0 // indirect diff --git a/go.sum b/go.sum index b673b53..2bdd6fb 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,7 @@ +cloud.google.com/go/compute v1.20.1 h1:6aKEtlUiwEpJzM001l0yFkpXmUVXaN8W+fbkb2AZNbg= +cloud.google.com/go/compute v1.20.1/go.mod h1:4tCnrn48xsqlwSAiLf1HXMQk8CONslYbdiEZc9FEIbM= +cloud.google.com/go/compute/metadata v0.2.3 h1:mg4jlk7mCAj6xXp9UJ4fjI9VUI5rubuGBW5aJ7UnBMY= +cloud.google.com/go/compute/metadata v0.2.3/go.mod h1:VAV5nSsACxMJvgaAuX6Pk2AawlZn8kiOGuCv6gTkwuA= github.com/BurntSushi/toml v1.3.2 h1:o7IhLm0Msx3BaB+n3Ag7L8EVlByGnpq14C4YWiu/gL8= github.com/BurntSushi/toml v1.3.2/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= github.com/bluele/gcache v0.0.2 h1:WcbfdXICg7G/DGBh1PFfcirkWOQV+v077yF1pSy3DGw= @@ -59,8 +63,10 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE= github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= +github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= @@ -172,10 +178,6 @@ github.com/zijiren233/gencontainer v0.0.0-20230930135658-e410015e13cc h1:qEYdClJ github.com/zijiren233/gencontainer v0.0.0-20230930135658-e410015e13cc/go.mod h1:V5oL7PrZxgisuLCblFWd89Jg99O8vM1n58llcxZ2hDY= github.com/zijiren233/go-colorable v0.0.0-20230930131441-997304c961cb h1:0DyOxf/TbbGodHhOVHNoPk+7v/YBJACs22gKpKlatWw= github.com/zijiren233/go-colorable v0.0.0-20230930131441-997304c961cb/go.mod h1:6TCzjDiQ8+5gWZiwsC3pnA5M0vUy2jV2Y7ciHJh729g= -github.com/zijiren233/livelib v0.1.2-0.20231010145337-1651f7b4be26 h1:h7cw3cPQX3VheviU0y0bUVV0CnQ8fJegJgZMBpb/tfw= -github.com/zijiren233/livelib v0.1.2-0.20231010145337-1651f7b4be26/go.mod h1:2wrAAqNIdMZjQrdbO7ERQfqK4VS5fzgUj2xXwrJ8/uo= -github.com/zijiren233/livelib v0.2.0 h1:o2YbXAA4v3WTq97hzIToBg6mvmGXLUHHJSBh7qSmXLE= -github.com/zijiren233/livelib v0.2.0/go.mod h1:2wrAAqNIdMZjQrdbO7ERQfqK4VS5fzgUj2xXwrJ8/uo= github.com/zijiren233/livelib v0.2.1 h1:7a+R/yiq3WJXM+1kwez9w//uWpRDrQN4hT+TC1hqkpI= github.com/zijiren233/livelib v0.2.1/go.mod h1:2wrAAqNIdMZjQrdbO7ERQfqK4VS5fzgUj2xXwrJ8/uo= github.com/zijiren233/stream v0.5.1 h1:9SUwM/fpET6frtBRT5WZBHnan0Hyzkezk/P8N78cgZQ= @@ -201,6 +203,7 @@ golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.13.0 h1:I/DsJXRlw/8l/0c24sM9yb0T4z9liZTduXvdAWYiysY= golang.org/x/mod v0.13.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201202161906-c7110b5ffcbb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= @@ -210,6 +213,8 @@ golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= +golang.org/x/oauth2 v0.13.0 h1:jDDenyj+WgFtmV3zYVoi8aE2BwtXFLWOA67ZfNWftiY= +golang.org/x/oauth2 v0.13.0/go.mod h1:/JMhi4ZRXAf4HG9LiNmxvk+45+96RUlVThiH8FzNBn0= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -235,6 +240,7 @@ golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= @@ -252,7 +258,10 @@ golang.org/x/tools v0.14.0 h1:jvNa2pY0M4r62jkRQ6RwEZZyPcymeL9XZMLBbV7U2nc= golang.org/x/tools v0.14.0/go.mod h1:uYBEerGOWcJyEORxN+Ek8+TT266gXkNlHdJBwexUsBg= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c= +google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= diff --git a/internal/conf/config.go b/internal/conf/config.go index 51c742e..15d10a0 100644 --- a/internal/conf/config.go +++ b/internal/conf/config.go @@ -25,6 +25,9 @@ type Config struct { // Database Database DatabaseConfig `yaml:"database"` + + // OAuth2 + OAuth2 OAuth2Config `yaml:"oauth2"` } func (c *Config) Save(file string) error { @@ -53,5 +56,8 @@ func DefaultConfig() *Config { // Database Database: DefaultDatabaseConfig(), + + // OAuth2 + OAuth2: DefaultOAuth2Config(), } } diff --git a/internal/conf/oauth2.go b/internal/conf/oauth2.go new file mode 100644 index 0000000..86107e2 --- /dev/null +++ b/internal/conf/oauth2.go @@ -0,0 +1,22 @@ +package conf + +import ( + "github.com/synctv-org/synctv/internal/provider" +) + +type OAuth2Config map[provider.OAuth2Provider]OAuth2ProviderConfig + +type OAuth2ProviderConfig struct { + ClientID string `yaml:"client_id" lc:"oauth2 client id"` + ClientSecret string `yaml:"client_secret" lc:"oauth2 client secret"` + // CustomRedirectURL string `yaml:"custom_redirect_url" lc:"oauth2 custom redirect url"` +} + +func DefaultOAuth2Config() OAuth2Config { + return OAuth2Config{ + provider.GithubProvider{}.Provider(): { + ClientID: "github_client_id", + ClientSecret: "github_client_secret", + }, + } +} diff --git a/internal/db/db.go b/internal/db/db.go index ae07f50..19c628b 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -12,7 +12,7 @@ var db *gorm.DB func Init(d *gorm.DB) error { db = d - return AutoMigrate(new(model.Movie), new(model.Room), new(model.User), new(model.RoomUserRelation)) + return AutoMigrate(new(model.Movie), new(model.Room), new(model.User), new(model.RoomUserRelation), new(model.UserProvider)) } func AutoMigrate(dst ...any) error { diff --git a/internal/db/user.go b/internal/db/user.go index 5640b42..f0a174f 100644 --- a/internal/db/user.go +++ b/internal/db/user.go @@ -4,6 +4,7 @@ import ( "errors" "github.com/synctv-org/synctv/internal/model" + "github.com/synctv-org/synctv/internal/provider" "github.com/zijiren233/stream" "golang.org/x/crypto/bcrypt" "gorm.io/gorm" @@ -18,18 +19,51 @@ func WithRole(role model.Role) CreateUserConfig { } } -func CreateUser(username string, hashedPassword []byte, conf ...CreateUserConfig) (*model.User, error) { +func CreateUser(username string, p provider.OAuth2Provider, puid uint, conf ...CreateUserConfig) (*model.User, error) { u := &model.User{ - Username: username, - HashedPassword: hashedPassword, - Role: model.RoleUser, + Username: username, + Role: model.RoleUser, + Providers: []model.UserProvider{ + { + Provider: p, + ProviderUserID: puid, + }, + }, } for _, c := range conf { c(u) } err := db.Create(u).Error if err != nil && errors.Is(err, gorm.ErrDuplicatedKey) { - return u, errors.New("username already exists") + return u, errors.New("user already exists") + } + return u, err +} + +func CreateOrLoadUser(username string, p provider.OAuth2Provider, puid uint, conf ...CreateUserConfig) (*model.User, error) { + u := &model.User{ + Username: username, + Role: model.RoleUser, + Providers: []model.UserProvider{ + { + Provider: p, + ProviderUserID: puid, + }, + }, + } + for _, c := range conf { + c(u) + } + return u, db.Preload("Providers", "provider = ? AND provider_user_id = ?", p, puid). + FirstOrCreate(u). + Error +} + +func GetUserByProvider(p provider.OAuth2Provider, puid uint) (*model.User, error) { + u := &model.User{} + err := db.Preload("Providers", "provider = ? AND provider_user_id = ?", p, puid).First(u).Error + if errors.Is(err, gorm.ErrRecordNotFound) { + return u, errors.New("user not found") } return u, err } diff --git a/internal/model/model_test.go b/internal/model/model_test.go index 4d83a37..6d5356f 100644 --- a/internal/model/model_test.go +++ b/internal/model/model_test.go @@ -42,7 +42,6 @@ func TestCreateUser(t *testing.T) { } user := model.User{ Username: "user1", - HashedPassword: nil, GroupUserRelations: []model.RoomUserRelation{}, } err = db.Create(&user).Error diff --git a/internal/model/oauth2.go b/internal/model/oauth2.go new file mode 100644 index 0000000..ddbac32 --- /dev/null +++ b/internal/model/oauth2.go @@ -0,0 +1,13 @@ +package model + +import ( + "github.com/synctv-org/synctv/internal/provider" + "gorm.io/gorm" +) + +type UserProvider struct { + gorm.Model + UserID uint `gorm:"not null"` + Provider provider.OAuth2Provider `gorm:"not null;uniqueIndex:provider_user_id"` + ProviderUserID uint `gorm:"not null;uniqueIndex:provider_user_id"` +} diff --git a/internal/model/user.go b/internal/model/user.go index 090bb21..da87b0b 100644 --- a/internal/model/user.go +++ b/internal/model/user.go @@ -1,8 +1,9 @@ package model import ( - "github.com/zijiren233/stream" - "golang.org/x/crypto/bcrypt" + "fmt" + "math/rand" + "gorm.io/gorm" ) @@ -17,23 +18,19 @@ const ( type User struct { gorm.Model - Username string `gorm:"not null;uniqueIndex"` - Role Role `gorm:"not null"` - HashedPassword []byte + Providers []UserProvider `gorm:"foreignKey:UserID;constraint:OnUpdate:CASCADE,OnDelete:CASCADE"` + Username string `gorm:"not null;uniqueIndex"` + Role Role `gorm:"not null"` GroupUserRelations []RoomUserRelation `gorm:"foreignKey:UserID;constraint:OnUpdate:CASCADE,OnDelete:CASCADE"` Rooms []Room `gorm:"foreignKey:CreatorID;constraint:OnUpdate:CASCADE,OnDelete:CASCADE"` Movies []Movie `gorm:"foreignKey:CreatorID;constraint:OnUpdate:CASCADE,OnDelete:SET NULL"` } -func (u *User) CheckPassword(password string) bool { - return bcrypt.CompareHashAndPassword(u.HashedPassword, stream.StringToBytes(password)) == nil -} - -func (u *User) SetPassword(password string) error { - hashedPassword, err := bcrypt.GenerateFromPassword(stream.StringToBytes(password), bcrypt.DefaultCost) - if err != nil { - return err +func (u *User) BeforeCreate(tx *gorm.DB) error { + var existingUser User + err := tx.Where("username = ?", u.Username).First(&existingUser).Error + if err == nil { + u.Username = fmt.Sprintf("%s#%d", u.Username, rand.Intn(9999)) } - u.HashedPassword = hashedPassword return nil } diff --git a/internal/op/user.go b/internal/op/user.go index 7068b10..afe5456 100644 --- a/internal/op/user.go +++ b/internal/op/user.go @@ -2,26 +2,13 @@ package op import ( "errors" - "hash/crc32" - "sync/atomic" "github.com/synctv-org/synctv/internal/db" "github.com/synctv-org/synctv/internal/model" - "github.com/zijiren233/stream" - "golang.org/x/crypto/bcrypt" ) type User struct { model.User - version uint32 -} - -func (u *User) Version() uint32 { - return atomic.LoadUint32(&u.version) -} - -func (u *User) CheckVersion(version uint32) bool { - return atomic.LoadUint32(&u.version) == version } func (u *User) CreateRoom(name, password string, conf ...db.CreateRoomConfig) (*model.Room, error) { @@ -45,25 +32,3 @@ func (u *User) DeleteRoom(room *Room) error { } return DeleteRoom(room) } - -func (u *User) NeedPassword() bool { - return len(u.HashedPassword) != 0 -} - -func (u *User) SetPassword(password string) error { - if u.CheckPassword(password) && u.NeedPassword() { - return errors.New("password is the same") - } - var hashedPassword []byte - if password != "" { - var err error - hashedPassword, err = bcrypt.GenerateFromPassword(stream.StringToBytes(password), bcrypt.DefaultCost) - if err != nil { - return err - } - } - u.HashedPassword = hashedPassword - - atomic.StoreUint32(&u.version, crc32.ChecksumIEEE(u.HashedPassword)) - return db.SetUserHashedPassword(u.ID, hashedPassword) -} diff --git a/internal/op/users.go b/internal/op/users.go index 7aa8fed..95cfba3 100644 --- a/internal/op/users.go +++ b/internal/op/users.go @@ -1,14 +1,13 @@ package op import ( - "hash/crc32" + "errors" "time" "github.com/bluele/gcache" "github.com/synctv-org/synctv/internal/db" "github.com/synctv-org/synctv/internal/model" - "github.com/zijiren233/stream" - "golang.org/x/crypto/bcrypt" + "github.com/synctv-org/synctv/internal/provider" ) var userCache gcache.Cache @@ -25,8 +24,7 @@ func GetUserById(id uint) (*User, error) { } u2 := &User{ - User: *u, - version: crc32.ChecksumIEEE(u.HashedPassword), + User: *u, } return u2, userCache.SetWithExpire(id, u2, time.Hour) @@ -40,42 +38,42 @@ func GetUserByUsername(username string) (*User, error) { } u2 := &User{ - User: *u, - version: crc32.ChecksumIEEE(u.HashedPassword), + User: *u, } return u2, userCache.SetWithExpire(u.ID, u2, time.Hour) } -var ErrInvalidUsernameOrPassword = bcrypt.ErrMismatchedHashAndPassword - -func CreateUser(username, password string, conf ...db.CreateUserConfig) (*User, error) { - if username == "" || password == "" { - return nil, ErrInvalidUsernameOrPassword - } - hashedPassword, err := bcrypt.GenerateFromPassword(stream.StringToBytes(password), bcrypt.DefaultCost) - if err != nil { - return nil, err +func CreateUser(username string, p provider.OAuth2Provider, pid uint, conf ...db.CreateUserConfig) (*User, error) { + if username == "" { + return nil, errors.New("username cannot be empty") } - u, err := db.CreateUser(username, hashedPassword, conf...) + u, err := db.CreateUser(username, p, pid, conf...) if err != nil { return nil, err } u2 := &User{ - User: *u, - version: crc32.ChecksumIEEE(u.HashedPassword), + User: *u, } return u2, userCache.SetWithExpire(u.ID, u2, time.Hour) } -func SetUserPassword(userID uint, password string) error { - u, err := GetUserById(userID) +func CreateOrLoadUser(username string, p provider.OAuth2Provider, pid uint, conf ...db.CreateUserConfig) (*User, error) { + if username == "" { + return nil, errors.New("username cannot be empty") + } + u, err := db.CreateOrLoadUser(username, p, pid, conf...) if err != nil { - return err + return nil, err + } + + u2 := &User{ + User: *u, } - return u.SetPassword(password) + + return u2, userCache.SetWithExpire(u.ID, u2, time.Hour) } func DeleteUserByID(userID uint) error { diff --git a/internal/provider/github.go b/internal/provider/github.go new file mode 100644 index 0000000..594890b --- /dev/null +++ b/internal/provider/github.go @@ -0,0 +1,104 @@ +package provider + +import ( + "context" + "net/http" + "time" + + json "github.com/json-iterator/go" + "golang.org/x/oauth2" + "golang.org/x/oauth2/github" +) + +type GithubProvider struct{} + +func (p GithubProvider) Provider() OAuth2Provider { + return "github" +} + +func (p GithubProvider) NewConfig(ClientID, ClientSecret string) *oauth2.Config { + return &oauth2.Config{ + ClientID: ClientID, + ClientSecret: ClientSecret, + Scopes: []string{"user"}, + Endpoint: github.Endpoint, + } +} + +func (p GithubProvider) GetUserInfo(ctx context.Context, config *oauth2.Config, code string) (*UserInfo, error) { + oauth2Token, err := config.Exchange(ctx, code) + if err != nil { + return nil, err + } + client := config.Client(ctx, oauth2Token) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://api.github.com/user", nil) + if err != nil { + return nil, err + } + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + ui := githubUserInfo{} + err = json.NewDecoder(resp.Body).Decode(&ui) + if err != nil { + return nil, err + } + return &UserInfo{ + Username: ui.Login, + ProviderUserID: ui.ID, + }, nil +} + +func init() { + RegisterProvider(GithubProvider{}) +} + +type githubUserInfo struct { + Login string `json:"login"` + ID uint `json:"id"` + NodeID string `json:"node_id"` + AvatarURL string `json:"avatar_url"` + GravatarID string `json:"gravatar_id"` + URL string `json:"url"` + HTMLURL string `json:"html_url"` + FollowersURL string `json:"followers_url"` + FollowingURL string `json:"following_url"` + GistsURL string `json:"gists_url"` + StarredURL string `json:"starred_url"` + SubscriptionsURL string `json:"subscriptions_url"` + OrganizationsURL string `json:"organizations_url"` + ReposURL string `json:"repos_url"` + EventsURL string `json:"events_url"` + ReceivedEventsURL string `json:"received_events_url"` + Type string `json:"type"` + SiteAdmin bool `json:"site_admin"` + Name string `json:"name"` + Company interface{} `json:"company"` + Blog string `json:"blog"` + Location string `json:"location"` + Email interface{} `json:"email"` + Hireable interface{} `json:"hireable"` + Bio string `json:"bio"` + TwitterUsername interface{} `json:"twitter_username"` + PublicRepos int `json:"public_repos"` + PublicGists int `json:"public_gists"` + Followers int `json:"followers"` + Following int `json:"following"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + PrivateGists int `json:"private_gists"` + TotalPrivateRepos int `json:"total_private_repos"` + OwnedPrivateRepos int `json:"owned_private_repos"` + DiskUsage int `json:"disk_usage"` + Collaborators int `json:"collaborators"` + TwoFactorAuthentication bool `json:"two_factor_authentication"` + Plan Plan `json:"plan"` +} +type Plan struct { + Name string `json:"name"` + Space int `json:"space"` + Collaborators int `json:"collaborators"` + PrivateRepos int `json:"private_repos"` +} diff --git a/internal/provider/gitlab.go b/internal/provider/gitlab.go new file mode 100644 index 0000000..3e2aff6 --- /dev/null +++ b/internal/provider/gitlab.go @@ -0,0 +1,46 @@ +package provider + +import ( + "context" + "net/http" + + "golang.org/x/oauth2" + "golang.org/x/oauth2/gitlab" +) + +type GitlabProvider struct{} + +func (g GitlabProvider) Provider() OAuth2Provider { + return "gitlab" +} + +func (g GitlabProvider) NewConfig(ClientID, ClientSecret string) *oauth2.Config { + return &oauth2.Config{ + ClientID: ClientID, + ClientSecret: ClientSecret, + Scopes: []string{"read_user"}, + Endpoint: gitlab.Endpoint, + } +} + +func (g GitlabProvider) GetUserInfo(ctx context.Context, config *oauth2.Config, code string) (*UserInfo, error) { + oauth2Token, err := config.Exchange(ctx, code) + if err != nil { + return nil, err + } + client := config.Client(ctx, oauth2Token) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://gitlab.com/api/v4/user", nil) + if err != nil { + return nil, err + } + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + return nil, FormatErrNotImplemented("gitlab") +} + +func init() { + RegisterProvider(GitlabProvider{}) +} diff --git a/internal/provider/google.go b/internal/provider/google.go new file mode 100644 index 0000000..7c2db84 --- /dev/null +++ b/internal/provider/google.go @@ -0,0 +1,46 @@ +package provider + +import ( + "context" + "net/http" + + "golang.org/x/oauth2" + "golang.org/x/oauth2/google" +) + +type GoogleProvider struct{} + +func (g GoogleProvider) Provider() OAuth2Provider { + return "google" +} + +func (g GoogleProvider) NewConfig(ClientID, ClientSecret string) *oauth2.Config { + return &oauth2.Config{ + ClientID: ClientID, + ClientSecret: ClientSecret, + Scopes: []string{"profile"}, + Endpoint: google.Endpoint, + } +} + +func (g GoogleProvider) GetUserInfo(ctx context.Context, config *oauth2.Config, code string) (*UserInfo, error) { + oauth2Token, err := config.Exchange(ctx, code) + if err != nil { + return nil, err + } + client := config.Client(ctx, oauth2Token) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v2/userinfo", nil) + if err != nil { + return nil, err + } + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + return nil, FormatErrNotImplemented("google") +} + +func init() { + RegisterProvider(GoogleProvider{}) +} diff --git a/internal/provider/provider.go b/internal/provider/provider.go new file mode 100644 index 0000000..f549c08 --- /dev/null +++ b/internal/provider/provider.go @@ -0,0 +1,49 @@ +package provider + +import ( + "context" + "fmt" + "sync" + + "golang.org/x/oauth2" +) + +type OAuth2Provider string + +var ( + providers = make(map[OAuth2Provider]ProviderInterface) + lock sync.Mutex +) + +type UserInfo struct { + Username string + ProviderUserID uint +} + +type ProviderInterface interface { + Provider() OAuth2Provider + NewConfig(ClientID, ClientSecret string) *oauth2.Config + GetUserInfo(ctx context.Context, config *oauth2.Config, code string) (*UserInfo, error) +} + +func RegisterProvider(provider ProviderInterface) { + lock.Lock() + defer lock.Unlock() + providers[provider.Provider()] = provider +} + +func (p OAuth2Provider) GetProvider() (ProviderInterface, error) { + lock.Lock() + defer lock.Unlock() + pi, ok := providers[p] + if !ok { + return nil, FormatErrNotImplemented(p) + } + return pi, nil +} + +type FormatErrNotImplemented string + +func (f FormatErrNotImplemented) Error() string { + return fmt.Sprintf("%s not implemented", string(f)) +} diff --git a/server/handlers/init.go b/server/handlers/init.go index adc7401..63dbe32 100644 --- a/server/handlers/init.go +++ b/server/handlers/init.go @@ -11,6 +11,10 @@ import ( func Init(e *gin.Engine) { { + e.GET("/", func(ctx *gin.Context) { + ctx.Redirect(http.StatusMovedPermanently, "/web/") + }) + web := e.Group("/web") web.Use(func(ctx *gin.Context) { @@ -103,22 +107,12 @@ func Init(e *gin.Engine) { } { - user := api.Group("/user") + // user := api.Group("/user") needAuthUser := needAuthUserApi.Group("/user") - user.POST("/login", LoginUser) - - user.POST("/signup", SignupUser) - needAuthUser.POST("/logout", LogoutUser) needAuthUser.GET("/me", Me) - - needAuthUser.POST("/pwd", SetUserPassword) } } - - e.NoRoute(func(c *gin.Context) { - c.Redirect(http.StatusFound, "/web/") - }) } diff --git a/server/handlers/user.go b/server/handlers/user.go index 098fb82..c929719 100644 --- a/server/handlers/user.go +++ b/server/handlers/user.go @@ -5,7 +5,6 @@ import ( "github.com/gin-gonic/gin" "github.com/synctv-org/synctv/internal/op" - "github.com/synctv-org/synctv/server/middlewares" "github.com/synctv-org/synctv/server/model" ) @@ -17,84 +16,6 @@ func Me(ctx *gin.Context) { })) } -func SetUserPassword(ctx *gin.Context) { - user := ctx.MustGet("user").(*op.User) - - req := model.SetUserPasswordReq{} - if err := model.Decode(ctx, &req); err != nil { - ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) - return - } - - if err := user.SetPassword(req.Password); err != nil { - ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) - return - } - - token, err := middlewares.NewAuthUserToken(user) - if err != nil { - ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewApiErrorResp(err)) - return - } - - ctx.JSON(http.StatusOK, model.NewApiDataResp(gin.H{ - "token": token, - })) -} - -func LoginUser(ctx *gin.Context) { - req := model.LoginUserReq{} - if err := model.Decode(ctx, &req); err != nil { - ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) - return - } - - user, err := middlewares.AuthUserWithPassword(req.Username, req.Password) - if err != nil { - ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) - return - } - - if !user.CheckPassword(req.Password) { - ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) - return - } - - token, err := middlewares.NewAuthUserToken(user) - if err != nil { - ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewApiErrorResp(err)) - return - } - - ctx.JSON(http.StatusOK, model.NewApiDataResp(gin.H{ - "token": token, - })) -} - -func SignupUser(ctx *gin.Context) { - req := model.SignupUserReq{} - if err := model.Decode(ctx, &req); err != nil { - ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) - return - } - - user, err := op.CreateUser(req.Username, req.Password) - if err != nil { - ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) - return - } - - token, err := middlewares.NewAuthUserToken(user) - if err != nil { - ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewApiErrorResp(err)) - return - } - - ctx.JSON(http.StatusOK, model.NewApiDataResp(gin.H{ - "token": token, - })) -} - func LogoutUser(ctx *gin.Context) { user := ctx.MustGet("user").(*op.User) diff --git a/server/middlewares/auth.go b/server/middlewares/auth.go index e40b07a..4c26051 100644 --- a/server/middlewares/auth.go +++ b/server/middlewares/auth.go @@ -19,8 +19,7 @@ var ( ) type AuthClaims struct { - UserId uint `json:"u"` - UserVersion uint32 `json:"uv"` + UserId uint `json:"u"` jwt.RegisteredClaims } @@ -76,9 +75,6 @@ func AuthRoom(Authorization string) (*op.User, *op.Room, error) { if err != nil { return nil, nil, err } - if !u.CheckVersion(claims.UserVersion) { - return nil, nil, ErrAuthExpired - } r, err := op.GetRoomByID(claims.RoomId) if err != nil { @@ -102,17 +98,6 @@ func AuthRoomWithPassword(u *op.User, roomId uint, password string) (*op.Room, e return r, nil } -func AuthUserWithPassword(username, password string) (*op.User, error) { - u, err := op.GetUserByUsername(username) - if err != nil { - return nil, err - } - if !u.CheckPassword(password) { - return nil, ErrAuthFailed - } - return u, nil -} - func AuthUser(Authorization string) (*op.User, error) { claims, err := authUser(Authorization) if err != nil { @@ -127,17 +112,13 @@ func AuthUser(Authorization string) (*op.User, error) { if err != nil { return nil, err } - if !u.CheckVersion(claims.UserVersion) { - return nil, ErrAuthExpired - } return u, nil } func NewAuthUserToken(user *op.User) (string, error) { claims := &AuthClaims{ - UserId: user.ID, - UserVersion: user.Version(), + UserId: user.ID, RegisteredClaims: jwt.RegisteredClaims{ NotBefore: jwt.NewNumericDate(time.Now()), ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour * time.Duration(conf.Conf.Jwt.Expire))), @@ -149,8 +130,7 @@ func NewAuthUserToken(user *op.User) (string, error) { func NewAuthRoomToken(user *op.User, room *op.Room) (string, error) { claims := &AuthRoomClaims{ AuthClaims: AuthClaims{ - UserId: user.ID, - UserVersion: user.Version(), + UserId: user.ID, RegisteredClaims: jwt.RegisteredClaims{ NotBefore: jwt.NewNumericDate(time.Now()), ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour * time.Duration(conf.Conf.Jwt.Expire))), diff --git a/server/oauth2/auth.go b/server/oauth2/auth.go new file mode 100644 index 0000000..896c1b8 --- /dev/null +++ b/server/oauth2/auth.go @@ -0,0 +1,85 @@ +package auth + +import ( + "net/http" + + "github.com/gin-gonic/gin" + "github.com/synctv-org/synctv/internal/conf" + "github.com/synctv-org/synctv/internal/op" + "github.com/synctv-org/synctv/internal/provider" + "github.com/synctv-org/synctv/server/middlewares" + "github.com/synctv-org/synctv/server/model" + "golang.org/x/oauth2" +) + +// /oauth2/login/:type +func OAuth2(ctx *gin.Context) { + t := ctx.Param("type") + p := provider.OAuth2Provider(t) + c, ok := conf.Conf.OAuth2[p] + if !ok { + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 provider")) + } + + pi, err := p.GetProvider() + if err != nil { + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) + } + + Render(ctx, pi.NewConfig(c.ClientID, c.ClientSecret), oauth2.AccessTypeOnline) +} + +// /oauth2/callback/:type +func OAuth2Callback(ctx *gin.Context) { + t := ctx.Param("type") + p := provider.OAuth2Provider(t) + c, ok := conf.Conf.OAuth2[p] + if !ok { + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 provider")) + } + + code := ctx.Query("code") + if code == "" { + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 code")) + return + } + + state := ctx.Query("state") + if state == "" { + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 state")) + return + } + + _, loaded := states.LoadAndDelete(state) + if !loaded { + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 state")) + return + } + + pi, err := p.GetProvider() + if err != nil { + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) + } + + ui, err := pi.GetUserInfo(ctx, pi.NewConfig(c.ClientID, c.ClientSecret), code) + if err != nil { + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) + return + } + + user, err := op.CreateOrLoadUser(ui.Username, p, ui.ProviderUserID) + if err != nil { + ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewApiErrorResp(err)) + return + } + + token, err := middlewares.NewAuthUserToken(user) + if err != nil { + ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewApiErrorResp(err)) + return + } + + ctx.JSON(http.StatusOK, model.NewApiDataResp(gin.H{ + "token": token, + })) +} diff --git a/server/oauth2/init.go b/server/oauth2/init.go new file mode 100644 index 0000000..883640c --- /dev/null +++ b/server/oauth2/init.go @@ -0,0 +1,13 @@ +package auth + +import "github.com/gin-gonic/gin" + +func Init(e *gin.Engine) { + { + auth := e.Group("/oauth2") + + auth.GET("/login/:type", OAuth2) + + auth.GET("/callback/:type", OAuth2Callback) + } +} diff --git a/server/oauth2/render.go b/server/oauth2/render.go new file mode 100644 index 0000000..27c3d1e --- /dev/null +++ b/server/oauth2/render.go @@ -0,0 +1,31 @@ +package auth + +import ( + "embed" + "html/template" + "time" + + "github.com/gin-gonic/gin" + "github.com/synctv-org/synctv/utils" + synccache "github.com/synctv-org/synctv/utils/syncCache" + "golang.org/x/oauth2" +) + +//go:embed templates/redirect.html +var temp embed.FS + +var ( + redirectTemplate *template.Template + states *synccache.SyncCache[string, struct{}] +) + +func Render(ctx *gin.Context, c *oauth2.Config, option ...oauth2.AuthCodeOption) error { + state := utils.RandString(16) + states.Store(state, struct{}{}, time.Minute*5) + return redirectTemplate.Execute(ctx.Writer, c.AuthCodeURL(state, option...)) +} + +func init() { + redirectTemplate = template.Must(template.ParseFS(temp, "templates/redirect.html")) + states = synccache.NewSyncCache[string, struct{}](time.Minute * 10) +} diff --git a/server/oauth2/templates/redirect.html b/server/oauth2/templates/redirect.html new file mode 100644 index 0000000..56a0cbb --- /dev/null +++ b/server/oauth2/templates/redirect.html @@ -0,0 +1,16 @@ + + + + + + + + Redirecting.. + + + +

If you are not redirected, please click here.

+ + + + \ No newline at end of file diff --git a/server/router.go b/server/router.go index 2bac368..b9d5904 100644 --- a/server/router.go +++ b/server/router.go @@ -4,10 +4,12 @@ import ( "github.com/gin-gonic/gin" "github.com/synctv-org/synctv/server/handlers" "github.com/synctv-org/synctv/server/middlewares" + auth "github.com/synctv-org/synctv/server/oauth2" ) func Init(e *gin.Engine) { middlewares.Init(e) + auth.Init(e) handlers.Init(e) } diff --git a/utils/syncCache/cache.go b/utils/syncCache/cache.go new file mode 100644 index 0000000..cfd769b --- /dev/null +++ b/utils/syncCache/cache.go @@ -0,0 +1,81 @@ +package synccache + +import ( + "time" + + "github.com/zijiren233/gencontainer/rwmap" +) + +type SyncCache[K comparable, V any] struct { + cache rwmap.RWMap[K, *entry[V]] + ticker *time.Ticker +} + +func NewSyncCache[K comparable, V any](trimTime time.Duration) *SyncCache[K, V] { + sc := &SyncCache[K, V]{ + ticker: time.NewTicker(trimTime), + } + go func() { + for range sc.ticker.C { + sc.trim() + } + }() + return sc +} + +func (sc *SyncCache[K, V]) Releases() { + sc.ticker.Stop() + sc.cache.Clear() +} + +func (sc *SyncCache[K, V]) trim() { + sc.cache.Range(func(key K, value *entry[V]) bool { + if value.IsExpired() { + sc.cache.Delete(key) + } + return true + }) +} + +func (sc *SyncCache[K, V]) Store(key K, value V, expire time.Duration) { + sc.LoadOrStore(key, value, expire) +} + +func (sc *SyncCache[K, V]) Load(key K) (value V, loaded bool) { + e, ok := sc.cache.Load(key) + if ok && !e.IsExpired() { + return e.value, ok + } + return +} + +func (sc *SyncCache[K, V]) LoadOrStore(key K, value V, expire time.Duration) (actual V, loaded bool) { + e, loaded := sc.cache.LoadOrStore(key, &entry[V]{ + expiration: time.Now().Add(expire), + value: value, + }) + if e.IsExpired() { + sc.cache.Store(key, &entry[V]{ + expiration: time.Now().Add(expire), + value: value, + }) + return value, false + } + return e.value, loaded +} + +func (sc *SyncCache[K, V]) Delete(key K) { + sc.LoadAndDelete(key) +} + +func (sc *SyncCache[K, V]) LoadAndDelete(key K) (value V, loaded bool) { + e, loaded := sc.cache.LoadAndDelete(key) + if loaded && !e.IsExpired() { + return e.value, loaded + } + return +} + +func (sc *SyncCache[K, V]) Clear() { + sc.cache.Clear() +} diff --git a/utils/syncCache/item.go b/utils/syncCache/item.go new file mode 100644 index 0000000..b74023e --- /dev/null +++ b/utils/syncCache/item.go @@ -0,0 +1,12 @@ +package synccache + +import "time" + +type entry[V any] struct { + expiration time.Time + value V +} + +func (e *entry[V]) IsExpired() bool { + return time.Now().After(e.expiration) +}