mirror of https://github.com/synctv-org/synctv
Feat: use oauth2
parent
92fdbc31be
commit
dab669708d
@ -0,0 +1,22 @@
|
||||
package conf
|
||||
|
||||
import (
|
||||
"github.com/synctv-org/synctv/internal/provider"
|
||||
)
|
||||
|
||||
type OAuth2Config map[provider.OAuth2Provider]OAuth2ProviderConfig
|
||||
|
||||
type OAuth2ProviderConfig struct {
|
||||
ClientID string `yaml:"client_id" lc:"oauth2 client id"`
|
||||
ClientSecret string `yaml:"client_secret" lc:"oauth2 client secret"`
|
||||
// CustomRedirectURL string `yaml:"custom_redirect_url" lc:"oauth2 custom redirect url"`
|
||||
}
|
||||
|
||||
func DefaultOAuth2Config() OAuth2Config {
|
||||
return OAuth2Config{
|
||||
provider.GithubProvider{}.Provider(): {
|
||||
ClientID: "github_client_id",
|
||||
ClientSecret: "github_client_secret",
|
||||
},
|
||||
}
|
||||
}
|
@ -0,0 +1,13 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"github.com/synctv-org/synctv/internal/provider"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type UserProvider struct {
|
||||
gorm.Model
|
||||
UserID uint `gorm:"not null"`
|
||||
Provider provider.OAuth2Provider `gorm:"not null;uniqueIndex:provider_user_id"`
|
||||
ProviderUserID uint `gorm:"not null;uniqueIndex:provider_user_id"`
|
||||
}
|
@ -0,0 +1,104 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
json "github.com/json-iterator/go"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/github"
|
||||
)
|
||||
|
||||
type GithubProvider struct{}
|
||||
|
||||
func (p GithubProvider) Provider() OAuth2Provider {
|
||||
return "github"
|
||||
}
|
||||
|
||||
func (p GithubProvider) NewConfig(ClientID, ClientSecret string) *oauth2.Config {
|
||||
return &oauth2.Config{
|
||||
ClientID: ClientID,
|
||||
ClientSecret: ClientSecret,
|
||||
Scopes: []string{"user"},
|
||||
Endpoint: github.Endpoint,
|
||||
}
|
||||
}
|
||||
|
||||
func (p GithubProvider) GetUserInfo(ctx context.Context, config *oauth2.Config, code string) (*UserInfo, error) {
|
||||
oauth2Token, err := config.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client := config.Client(ctx, oauth2Token)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://api.github.com/user", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
ui := githubUserInfo{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&ui)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &UserInfo{
|
||||
Username: ui.Login,
|
||||
ProviderUserID: ui.ID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterProvider(GithubProvider{})
|
||||
}
|
||||
|
||||
type githubUserInfo struct {
|
||||
Login string `json:"login"`
|
||||
ID uint `json:"id"`
|
||||
NodeID string `json:"node_id"`
|
||||
AvatarURL string `json:"avatar_url"`
|
||||
GravatarID string `json:"gravatar_id"`
|
||||
URL string `json:"url"`
|
||||
HTMLURL string `json:"html_url"`
|
||||
FollowersURL string `json:"followers_url"`
|
||||
FollowingURL string `json:"following_url"`
|
||||
GistsURL string `json:"gists_url"`
|
||||
StarredURL string `json:"starred_url"`
|
||||
SubscriptionsURL string `json:"subscriptions_url"`
|
||||
OrganizationsURL string `json:"organizations_url"`
|
||||
ReposURL string `json:"repos_url"`
|
||||
EventsURL string `json:"events_url"`
|
||||
ReceivedEventsURL string `json:"received_events_url"`
|
||||
Type string `json:"type"`
|
||||
SiteAdmin bool `json:"site_admin"`
|
||||
Name string `json:"name"`
|
||||
Company interface{} `json:"company"`
|
||||
Blog string `json:"blog"`
|
||||
Location string `json:"location"`
|
||||
Email interface{} `json:"email"`
|
||||
Hireable interface{} `json:"hireable"`
|
||||
Bio string `json:"bio"`
|
||||
TwitterUsername interface{} `json:"twitter_username"`
|
||||
PublicRepos int `json:"public_repos"`
|
||||
PublicGists int `json:"public_gists"`
|
||||
Followers int `json:"followers"`
|
||||
Following int `json:"following"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
PrivateGists int `json:"private_gists"`
|
||||
TotalPrivateRepos int `json:"total_private_repos"`
|
||||
OwnedPrivateRepos int `json:"owned_private_repos"`
|
||||
DiskUsage int `json:"disk_usage"`
|
||||
Collaborators int `json:"collaborators"`
|
||||
TwoFactorAuthentication bool `json:"two_factor_authentication"`
|
||||
Plan Plan `json:"plan"`
|
||||
}
|
||||
type Plan struct {
|
||||
Name string `json:"name"`
|
||||
Space int `json:"space"`
|
||||
Collaborators int `json:"collaborators"`
|
||||
PrivateRepos int `json:"private_repos"`
|
||||
}
|
@ -0,0 +1,46 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/gitlab"
|
||||
)
|
||||
|
||||
type GitlabProvider struct{}
|
||||
|
||||
func (g GitlabProvider) Provider() OAuth2Provider {
|
||||
return "gitlab"
|
||||
}
|
||||
|
||||
func (g GitlabProvider) NewConfig(ClientID, ClientSecret string) *oauth2.Config {
|
||||
return &oauth2.Config{
|
||||
ClientID: ClientID,
|
||||
ClientSecret: ClientSecret,
|
||||
Scopes: []string{"read_user"},
|
||||
Endpoint: gitlab.Endpoint,
|
||||
}
|
||||
}
|
||||
|
||||
func (g GitlabProvider) GetUserInfo(ctx context.Context, config *oauth2.Config, code string) (*UserInfo, error) {
|
||||
oauth2Token, err := config.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client := config.Client(ctx, oauth2Token)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://gitlab.com/api/v4/user", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
return nil, FormatErrNotImplemented("gitlab")
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterProvider(GitlabProvider{})
|
||||
}
|
@ -0,0 +1,46 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
type GoogleProvider struct{}
|
||||
|
||||
func (g GoogleProvider) Provider() OAuth2Provider {
|
||||
return "google"
|
||||
}
|
||||
|
||||
func (g GoogleProvider) NewConfig(ClientID, ClientSecret string) *oauth2.Config {
|
||||
return &oauth2.Config{
|
||||
ClientID: ClientID,
|
||||
ClientSecret: ClientSecret,
|
||||
Scopes: []string{"profile"},
|
||||
Endpoint: google.Endpoint,
|
||||
}
|
||||
}
|
||||
|
||||
func (g GoogleProvider) GetUserInfo(ctx context.Context, config *oauth2.Config, code string) (*UserInfo, error) {
|
||||
oauth2Token, err := config.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client := config.Client(ctx, oauth2Token)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v2/userinfo", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
return nil, FormatErrNotImplemented("google")
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterProvider(GoogleProvider{})
|
||||
}
|
@ -0,0 +1,49 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
type OAuth2Provider string
|
||||
|
||||
var (
|
||||
providers = make(map[OAuth2Provider]ProviderInterface)
|
||||
lock sync.Mutex
|
||||
)
|
||||
|
||||
type UserInfo struct {
|
||||
Username string
|
||||
ProviderUserID uint
|
||||
}
|
||||
|
||||
type ProviderInterface interface {
|
||||
Provider() OAuth2Provider
|
||||
NewConfig(ClientID, ClientSecret string) *oauth2.Config
|
||||
GetUserInfo(ctx context.Context, config *oauth2.Config, code string) (*UserInfo, error)
|
||||
}
|
||||
|
||||
func RegisterProvider(provider ProviderInterface) {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
providers[provider.Provider()] = provider
|
||||
}
|
||||
|
||||
func (p OAuth2Provider) GetProvider() (ProviderInterface, error) {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
pi, ok := providers[p]
|
||||
if !ok {
|
||||
return nil, FormatErrNotImplemented(p)
|
||||
}
|
||||
return pi, nil
|
||||
}
|
||||
|
||||
type FormatErrNotImplemented string
|
||||
|
||||
func (f FormatErrNotImplemented) Error() string {
|
||||
return fmt.Sprintf("%s not implemented", string(f))
|
||||
}
|
@ -0,0 +1,85 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/synctv-org/synctv/internal/conf"
|
||||
"github.com/synctv-org/synctv/internal/op"
|
||||
"github.com/synctv-org/synctv/internal/provider"
|
||||
"github.com/synctv-org/synctv/server/middlewares"
|
||||
"github.com/synctv-org/synctv/server/model"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// /oauth2/login/:type
|
||||
func OAuth2(ctx *gin.Context) {
|
||||
t := ctx.Param("type")
|
||||
p := provider.OAuth2Provider(t)
|
||||
c, ok := conf.Conf.OAuth2[p]
|
||||
if !ok {
|
||||
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 provider"))
|
||||
}
|
||||
|
||||
pi, err := p.GetProvider()
|
||||
if err != nil {
|
||||
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err))
|
||||
}
|
||||
|
||||
Render(ctx, pi.NewConfig(c.ClientID, c.ClientSecret), oauth2.AccessTypeOnline)
|
||||
}
|
||||
|
||||
// /oauth2/callback/:type
|
||||
func OAuth2Callback(ctx *gin.Context) {
|
||||
t := ctx.Param("type")
|
||||
p := provider.OAuth2Provider(t)
|
||||
c, ok := conf.Conf.OAuth2[p]
|
||||
if !ok {
|
||||
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 provider"))
|
||||
}
|
||||
|
||||
code := ctx.Query("code")
|
||||
if code == "" {
|
||||
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 code"))
|
||||
return
|
||||
}
|
||||
|
||||
state := ctx.Query("state")
|
||||
if state == "" {
|
||||
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 state"))
|
||||
return
|
||||
}
|
||||
|
||||
_, loaded := states.LoadAndDelete(state)
|
||||
if !loaded {
|
||||
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 state"))
|
||||
return
|
||||
}
|
||||
|
||||
pi, err := p.GetProvider()
|
||||
if err != nil {
|
||||
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err))
|
||||
}
|
||||
|
||||
ui, err := pi.GetUserInfo(ctx, pi.NewConfig(c.ClientID, c.ClientSecret), code)
|
||||
if err != nil {
|
||||
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err))
|
||||
return
|
||||
}
|
||||
|
||||
user, err := op.CreateOrLoadUser(ui.Username, p, ui.ProviderUserID)
|
||||
if err != nil {
|
||||
ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewApiErrorResp(err))
|
||||
return
|
||||
}
|
||||
|
||||
token, err := middlewares.NewAuthUserToken(user)
|
||||
if err != nil {
|
||||
ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewApiErrorResp(err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, model.NewApiDataResp(gin.H{
|
||||
"token": token,
|
||||
}))
|
||||
}
|
@ -0,0 +1,13 @@
|
||||
package auth
|
||||
|
||||
import "github.com/gin-gonic/gin"
|
||||
|
||||
func Init(e *gin.Engine) {
|
||||
{
|
||||
auth := e.Group("/oauth2")
|
||||
|
||||
auth.GET("/login/:type", OAuth2)
|
||||
|
||||
auth.GET("/callback/:type", OAuth2Callback)
|
||||
}
|
||||
}
|
@ -0,0 +1,31 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"html/template"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/synctv-org/synctv/utils"
|
||||
synccache "github.com/synctv-org/synctv/utils/syncCache"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
//go:embed templates/redirect.html
|
||||
var temp embed.FS
|
||||
|
||||
var (
|
||||
redirectTemplate *template.Template
|
||||
states *synccache.SyncCache[string, struct{}]
|
||||
)
|
||||
|
||||
func Render(ctx *gin.Context, c *oauth2.Config, option ...oauth2.AuthCodeOption) error {
|
||||
state := utils.RandString(16)
|
||||
states.Store(state, struct{}{}, time.Minute*5)
|
||||
return redirectTemplate.Execute(ctx.Writer, c.AuthCodeURL(state, option...))
|
||||
}
|
||||
|
||||
func init() {
|
||||
redirectTemplate = template.Must(template.ParseFS(temp, "templates/redirect.html"))
|
||||
states = synccache.NewSyncCache[string, struct{}](time.Minute * 10)
|
||||
}
|
@ -0,0 +1,16 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta http-equiv="X-UA-Compatible" content="IE=edge">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Redirecting..</title>
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<p>If you are not redirected, please click <a href="{{ . }}">here</a>.</p>
|
||||
<script>window.location.href = "{{ . }}"</script>
|
||||
</body>
|
||||
|
||||
</html>
|
@ -0,0 +1,81 @@
|
||||
package synccache
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/zijiren233/gencontainer/rwmap"
|
||||
)
|
||||
|
||||
type SyncCache[K comparable, V any] struct {
|
||||
cache rwmap.RWMap[K, *entry[V]]
|
||||
ticker *time.Ticker
|
||||
}
|
||||
|
||||
func NewSyncCache[K comparable, V any](trimTime time.Duration) *SyncCache[K, V] {
|
||||
sc := &SyncCache[K, V]{
|
||||
ticker: time.NewTicker(trimTime),
|
||||
}
|
||||
go func() {
|
||||
for range sc.ticker.C {
|
||||
sc.trim()
|
||||
}
|
||||
}()
|
||||
return sc
|
||||
}
|
||||
|
||||
func (sc *SyncCache[K, V]) Releases() {
|
||||
sc.ticker.Stop()
|
||||
sc.cache.Clear()
|
||||
}
|
||||
|
||||
func (sc *SyncCache[K, V]) trim() {
|
||||
sc.cache.Range(func(key K, value *entry[V]) bool {
|
||||
if value.IsExpired() {
|
||||
sc.cache.Delete(key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func (sc *SyncCache[K, V]) Store(key K, value V, expire time.Duration) {
|
||||
sc.LoadOrStore(key, value, expire)
|
||||
}
|
||||
|
||||
func (sc *SyncCache[K, V]) Load(key K) (value V, loaded bool) {
|
||||
e, ok := sc.cache.Load(key)
|
||||
if ok && !e.IsExpired() {
|
||||
return e.value, ok
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (sc *SyncCache[K, V]) LoadOrStore(key K, value V, expire time.Duration) (actual V, loaded bool) {
|
||||
e, loaded := sc.cache.LoadOrStore(key, &entry[V]{
|
||||
expiration: time.Now().Add(expire),
|
||||
value: value,
|
||||
})
|
||||
if e.IsExpired() {
|
||||
sc.cache.Store(key, &entry[V]{
|
||||
expiration: time.Now().Add(expire),
|
||||
value: value,
|
||||
})
|
||||
return value, false
|
||||
}
|
||||
return e.value, loaded
|
||||
}
|
||||
|
||||
func (sc *SyncCache[K, V]) Delete(key K) {
|
||||
sc.LoadAndDelete(key)
|
||||
}
|
||||
|
||||
func (sc *SyncCache[K, V]) LoadAndDelete(key K) (value V, loaded bool) {
|
||||
e, loaded := sc.cache.LoadAndDelete(key)
|
||||
if loaded && !e.IsExpired() {
|
||||
return e.value, loaded
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (sc *SyncCache[K, V]) Clear() {
|
||||
sc.cache.Clear()
|
||||
}
|
@ -0,0 +1,12 @@
|
||||
package synccache
|
||||
|
||||
import "time"
|
||||
|
||||
type entry[V any] struct {
|
||||
expiration time.Time
|
||||
value V
|
||||
}
|
||||
|
||||
func (e *entry[V]) IsExpired() bool {
|
||||
return time.Now().After(e.expiration)
|
||||
}
|
Loading…
Reference in New Issue