mirror of https://github.com/synctv-org/synctv
Feat: use oauth2
parent
92fdbc31be
commit
dab669708d
@ -0,0 +1,22 @@
|
|||||||
|
package conf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/synctv-org/synctv/internal/provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
type OAuth2Config map[provider.OAuth2Provider]OAuth2ProviderConfig
|
||||||
|
|
||||||
|
type OAuth2ProviderConfig struct {
|
||||||
|
ClientID string `yaml:"client_id" lc:"oauth2 client id"`
|
||||||
|
ClientSecret string `yaml:"client_secret" lc:"oauth2 client secret"`
|
||||||
|
// CustomRedirectURL string `yaml:"custom_redirect_url" lc:"oauth2 custom redirect url"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func DefaultOAuth2Config() OAuth2Config {
|
||||||
|
return OAuth2Config{
|
||||||
|
provider.GithubProvider{}.Provider(): {
|
||||||
|
ClientID: "github_client_id",
|
||||||
|
ClientSecret: "github_client_secret",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,13 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/synctv-org/synctv/internal/provider"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UserProvider struct {
|
||||||
|
gorm.Model
|
||||||
|
UserID uint `gorm:"not null"`
|
||||||
|
Provider provider.OAuth2Provider `gorm:"not null;uniqueIndex:provider_user_id"`
|
||||||
|
ProviderUserID uint `gorm:"not null;uniqueIndex:provider_user_id"`
|
||||||
|
}
|
@ -0,0 +1,104 @@
|
|||||||
|
package provider
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
json "github.com/json-iterator/go"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
"golang.org/x/oauth2/github"
|
||||||
|
)
|
||||||
|
|
||||||
|
type GithubProvider struct{}
|
||||||
|
|
||||||
|
func (p GithubProvider) Provider() OAuth2Provider {
|
||||||
|
return "github"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p GithubProvider) NewConfig(ClientID, ClientSecret string) *oauth2.Config {
|
||||||
|
return &oauth2.Config{
|
||||||
|
ClientID: ClientID,
|
||||||
|
ClientSecret: ClientSecret,
|
||||||
|
Scopes: []string{"user"},
|
||||||
|
Endpoint: github.Endpoint,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p GithubProvider) GetUserInfo(ctx context.Context, config *oauth2.Config, code string) (*UserInfo, error) {
|
||||||
|
oauth2Token, err := config.Exchange(ctx, code)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
client := config.Client(ctx, oauth2Token)
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://api.github.com/user", nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
ui := githubUserInfo{}
|
||||||
|
err = json.NewDecoder(resp.Body).Decode(&ui)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &UserInfo{
|
||||||
|
Username: ui.Login,
|
||||||
|
ProviderUserID: ui.ID,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
RegisterProvider(GithubProvider{})
|
||||||
|
}
|
||||||
|
|
||||||
|
type githubUserInfo struct {
|
||||||
|
Login string `json:"login"`
|
||||||
|
ID uint `json:"id"`
|
||||||
|
NodeID string `json:"node_id"`
|
||||||
|
AvatarURL string `json:"avatar_url"`
|
||||||
|
GravatarID string `json:"gravatar_id"`
|
||||||
|
URL string `json:"url"`
|
||||||
|
HTMLURL string `json:"html_url"`
|
||||||
|
FollowersURL string `json:"followers_url"`
|
||||||
|
FollowingURL string `json:"following_url"`
|
||||||
|
GistsURL string `json:"gists_url"`
|
||||||
|
StarredURL string `json:"starred_url"`
|
||||||
|
SubscriptionsURL string `json:"subscriptions_url"`
|
||||||
|
OrganizationsURL string `json:"organizations_url"`
|
||||||
|
ReposURL string `json:"repos_url"`
|
||||||
|
EventsURL string `json:"events_url"`
|
||||||
|
ReceivedEventsURL string `json:"received_events_url"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
SiteAdmin bool `json:"site_admin"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Company interface{} `json:"company"`
|
||||||
|
Blog string `json:"blog"`
|
||||||
|
Location string `json:"location"`
|
||||||
|
Email interface{} `json:"email"`
|
||||||
|
Hireable interface{} `json:"hireable"`
|
||||||
|
Bio string `json:"bio"`
|
||||||
|
TwitterUsername interface{} `json:"twitter_username"`
|
||||||
|
PublicRepos int `json:"public_repos"`
|
||||||
|
PublicGists int `json:"public_gists"`
|
||||||
|
Followers int `json:"followers"`
|
||||||
|
Following int `json:"following"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
PrivateGists int `json:"private_gists"`
|
||||||
|
TotalPrivateRepos int `json:"total_private_repos"`
|
||||||
|
OwnedPrivateRepos int `json:"owned_private_repos"`
|
||||||
|
DiskUsage int `json:"disk_usage"`
|
||||||
|
Collaborators int `json:"collaborators"`
|
||||||
|
TwoFactorAuthentication bool `json:"two_factor_authentication"`
|
||||||
|
Plan Plan `json:"plan"`
|
||||||
|
}
|
||||||
|
type Plan struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Space int `json:"space"`
|
||||||
|
Collaborators int `json:"collaborators"`
|
||||||
|
PrivateRepos int `json:"private_repos"`
|
||||||
|
}
|
@ -0,0 +1,46 @@
|
|||||||
|
package provider
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
"golang.org/x/oauth2/gitlab"
|
||||||
|
)
|
||||||
|
|
||||||
|
type GitlabProvider struct{}
|
||||||
|
|
||||||
|
func (g GitlabProvider) Provider() OAuth2Provider {
|
||||||
|
return "gitlab"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g GitlabProvider) NewConfig(ClientID, ClientSecret string) *oauth2.Config {
|
||||||
|
return &oauth2.Config{
|
||||||
|
ClientID: ClientID,
|
||||||
|
ClientSecret: ClientSecret,
|
||||||
|
Scopes: []string{"read_user"},
|
||||||
|
Endpoint: gitlab.Endpoint,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g GitlabProvider) GetUserInfo(ctx context.Context, config *oauth2.Config, code string) (*UserInfo, error) {
|
||||||
|
oauth2Token, err := config.Exchange(ctx, code)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
client := config.Client(ctx, oauth2Token)
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://gitlab.com/api/v4/user", nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
return nil, FormatErrNotImplemented("gitlab")
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
RegisterProvider(GitlabProvider{})
|
||||||
|
}
|
@ -0,0 +1,46 @@
|
|||||||
|
package provider
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
"golang.org/x/oauth2/google"
|
||||||
|
)
|
||||||
|
|
||||||
|
type GoogleProvider struct{}
|
||||||
|
|
||||||
|
func (g GoogleProvider) Provider() OAuth2Provider {
|
||||||
|
return "google"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g GoogleProvider) NewConfig(ClientID, ClientSecret string) *oauth2.Config {
|
||||||
|
return &oauth2.Config{
|
||||||
|
ClientID: ClientID,
|
||||||
|
ClientSecret: ClientSecret,
|
||||||
|
Scopes: []string{"profile"},
|
||||||
|
Endpoint: google.Endpoint,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g GoogleProvider) GetUserInfo(ctx context.Context, config *oauth2.Config, code string) (*UserInfo, error) {
|
||||||
|
oauth2Token, err := config.Exchange(ctx, code)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
client := config.Client(ctx, oauth2Token)
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v2/userinfo", nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
return nil, FormatErrNotImplemented("google")
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
RegisterProvider(GoogleProvider{})
|
||||||
|
}
|
@ -0,0 +1,49 @@
|
|||||||
|
package provider
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
)
|
||||||
|
|
||||||
|
type OAuth2Provider string
|
||||||
|
|
||||||
|
var (
|
||||||
|
providers = make(map[OAuth2Provider]ProviderInterface)
|
||||||
|
lock sync.Mutex
|
||||||
|
)
|
||||||
|
|
||||||
|
type UserInfo struct {
|
||||||
|
Username string
|
||||||
|
ProviderUserID uint
|
||||||
|
}
|
||||||
|
|
||||||
|
type ProviderInterface interface {
|
||||||
|
Provider() OAuth2Provider
|
||||||
|
NewConfig(ClientID, ClientSecret string) *oauth2.Config
|
||||||
|
GetUserInfo(ctx context.Context, config *oauth2.Config, code string) (*UserInfo, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterProvider(provider ProviderInterface) {
|
||||||
|
lock.Lock()
|
||||||
|
defer lock.Unlock()
|
||||||
|
providers[provider.Provider()] = provider
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p OAuth2Provider) GetProvider() (ProviderInterface, error) {
|
||||||
|
lock.Lock()
|
||||||
|
defer lock.Unlock()
|
||||||
|
pi, ok := providers[p]
|
||||||
|
if !ok {
|
||||||
|
return nil, FormatErrNotImplemented(p)
|
||||||
|
}
|
||||||
|
return pi, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type FormatErrNotImplemented string
|
||||||
|
|
||||||
|
func (f FormatErrNotImplemented) Error() string {
|
||||||
|
return fmt.Sprintf("%s not implemented", string(f))
|
||||||
|
}
|
@ -0,0 +1,85 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/synctv-org/synctv/internal/conf"
|
||||||
|
"github.com/synctv-org/synctv/internal/op"
|
||||||
|
"github.com/synctv-org/synctv/internal/provider"
|
||||||
|
"github.com/synctv-org/synctv/server/middlewares"
|
||||||
|
"github.com/synctv-org/synctv/server/model"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// /oauth2/login/:type
|
||||||
|
func OAuth2(ctx *gin.Context) {
|
||||||
|
t := ctx.Param("type")
|
||||||
|
p := provider.OAuth2Provider(t)
|
||||||
|
c, ok := conf.Conf.OAuth2[p]
|
||||||
|
if !ok {
|
||||||
|
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 provider"))
|
||||||
|
}
|
||||||
|
|
||||||
|
pi, err := p.GetProvider()
|
||||||
|
if err != nil {
|
||||||
|
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
Render(ctx, pi.NewConfig(c.ClientID, c.ClientSecret), oauth2.AccessTypeOnline)
|
||||||
|
}
|
||||||
|
|
||||||
|
// /oauth2/callback/:type
|
||||||
|
func OAuth2Callback(ctx *gin.Context) {
|
||||||
|
t := ctx.Param("type")
|
||||||
|
p := provider.OAuth2Provider(t)
|
||||||
|
c, ok := conf.Conf.OAuth2[p]
|
||||||
|
if !ok {
|
||||||
|
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 provider"))
|
||||||
|
}
|
||||||
|
|
||||||
|
code := ctx.Query("code")
|
||||||
|
if code == "" {
|
||||||
|
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 code"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
state := ctx.Query("state")
|
||||||
|
if state == "" {
|
||||||
|
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 state"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, loaded := states.LoadAndDelete(state)
|
||||||
|
if !loaded {
|
||||||
|
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 state"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
pi, err := p.GetProvider()
|
||||||
|
if err != nil {
|
||||||
|
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
ui, err := pi.GetUserInfo(ctx, pi.NewConfig(c.ClientID, c.ClientSecret), code)
|
||||||
|
if err != nil {
|
||||||
|
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := op.CreateOrLoadUser(ui.Username, p, ui.ProviderUserID)
|
||||||
|
if err != nil {
|
||||||
|
ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewApiErrorResp(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := middlewares.NewAuthUserToken(user)
|
||||||
|
if err != nil {
|
||||||
|
ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewApiErrorResp(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.JSON(http.StatusOK, model.NewApiDataResp(gin.H{
|
||||||
|
"token": token,
|
||||||
|
}))
|
||||||
|
}
|
@ -0,0 +1,13 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import "github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
func Init(e *gin.Engine) {
|
||||||
|
{
|
||||||
|
auth := e.Group("/oauth2")
|
||||||
|
|
||||||
|
auth.GET("/login/:type", OAuth2)
|
||||||
|
|
||||||
|
auth.GET("/callback/:type", OAuth2Callback)
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,31 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"embed"
|
||||||
|
"html/template"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/synctv-org/synctv/utils"
|
||||||
|
synccache "github.com/synctv-org/synctv/utils/syncCache"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed templates/redirect.html
|
||||||
|
var temp embed.FS
|
||||||
|
|
||||||
|
var (
|
||||||
|
redirectTemplate *template.Template
|
||||||
|
states *synccache.SyncCache[string, struct{}]
|
||||||
|
)
|
||||||
|
|
||||||
|
func Render(ctx *gin.Context, c *oauth2.Config, option ...oauth2.AuthCodeOption) error {
|
||||||
|
state := utils.RandString(16)
|
||||||
|
states.Store(state, struct{}{}, time.Minute*5)
|
||||||
|
return redirectTemplate.Execute(ctx.Writer, c.AuthCodeURL(state, option...))
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
redirectTemplate = template.Must(template.ParseFS(temp, "templates/redirect.html"))
|
||||||
|
states = synccache.NewSyncCache[string, struct{}](time.Minute * 10)
|
||||||
|
}
|
@ -0,0 +1,16 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="zh-CN">
|
||||||
|
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta http-equiv="X-UA-Compatible" content="IE=edge">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>Redirecting..</title>
|
||||||
|
</head>
|
||||||
|
|
||||||
|
<body>
|
||||||
|
<p>If you are not redirected, please click <a href="{{ . }}">here</a>.</p>
|
||||||
|
<script>window.location.href = "{{ . }}"</script>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
</html>
|
@ -0,0 +1,81 @@
|
|||||||
|
package synccache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/zijiren233/gencontainer/rwmap"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SyncCache[K comparable, V any] struct {
|
||||||
|
cache rwmap.RWMap[K, *entry[V]]
|
||||||
|
ticker *time.Ticker
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSyncCache[K comparable, V any](trimTime time.Duration) *SyncCache[K, V] {
|
||||||
|
sc := &SyncCache[K, V]{
|
||||||
|
ticker: time.NewTicker(trimTime),
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
for range sc.ticker.C {
|
||||||
|
sc.trim()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return sc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *SyncCache[K, V]) Releases() {
|
||||||
|
sc.ticker.Stop()
|
||||||
|
sc.cache.Clear()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *SyncCache[K, V]) trim() {
|
||||||
|
sc.cache.Range(func(key K, value *entry[V]) bool {
|
||||||
|
if value.IsExpired() {
|
||||||
|
sc.cache.Delete(key)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *SyncCache[K, V]) Store(key K, value V, expire time.Duration) {
|
||||||
|
sc.LoadOrStore(key, value, expire)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *SyncCache[K, V]) Load(key K) (value V, loaded bool) {
|
||||||
|
e, ok := sc.cache.Load(key)
|
||||||
|
if ok && !e.IsExpired() {
|
||||||
|
return e.value, ok
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *SyncCache[K, V]) LoadOrStore(key K, value V, expire time.Duration) (actual V, loaded bool) {
|
||||||
|
e, loaded := sc.cache.LoadOrStore(key, &entry[V]{
|
||||||
|
expiration: time.Now().Add(expire),
|
||||||
|
value: value,
|
||||||
|
})
|
||||||
|
if e.IsExpired() {
|
||||||
|
sc.cache.Store(key, &entry[V]{
|
||||||
|
expiration: time.Now().Add(expire),
|
||||||
|
value: value,
|
||||||
|
})
|
||||||
|
return value, false
|
||||||
|
}
|
||||||
|
return e.value, loaded
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *SyncCache[K, V]) Delete(key K) {
|
||||||
|
sc.LoadAndDelete(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *SyncCache[K, V]) LoadAndDelete(key K) (value V, loaded bool) {
|
||||||
|
e, loaded := sc.cache.LoadAndDelete(key)
|
||||||
|
if loaded && !e.IsExpired() {
|
||||||
|
return e.value, loaded
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *SyncCache[K, V]) Clear() {
|
||||||
|
sc.cache.Clear()
|
||||||
|
}
|
@ -0,0 +1,12 @@
|
|||||||
|
package synccache
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
type entry[V any] struct {
|
||||||
|
expiration time.Time
|
||||||
|
value V
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *entry[V]) IsExpired() bool {
|
||||||
|
return time.Now().After(e.expiration)
|
||||||
|
}
|
Loading…
Reference in New Issue