mirror of https://github.com/synctv-org/synctv
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
312 lines
9.9 KiB
Go
312 lines
9.9 KiB
Go
package bootstrap
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"slices"
|
|
"strings"
|
|
|
|
"github.com/hashicorp/go-hclog"
|
|
"github.com/maruel/natural"
|
|
log "github.com/sirupsen/logrus"
|
|
"github.com/synctv-org/synctv/cmd/flags"
|
|
"github.com/synctv-org/synctv/internal/conf"
|
|
"github.com/synctv-org/synctv/internal/model"
|
|
"github.com/synctv-org/synctv/internal/provider"
|
|
"github.com/synctv-org/synctv/internal/provider/aggregations"
|
|
"github.com/synctv-org/synctv/internal/provider/plugins"
|
|
"github.com/synctv-org/synctv/internal/provider/providers"
|
|
"github.com/synctv-org/synctv/internal/settings"
|
|
"github.com/zijiren233/gencontainer/refreshcache0"
|
|
)
|
|
|
|
var ProviderGroupSettings = make(map[model.SettingGroup]*ProviderGroupSetting)
|
|
|
|
type ProviderGroupSetting struct {
|
|
Enabled settings.BoolSetting
|
|
ClientID settings.StringSetting
|
|
ClientSecret settings.StringSetting
|
|
RedirectURL settings.StringSetting
|
|
DisableUserSignup settings.BoolSetting
|
|
SignupNeedReview settings.BoolSetting
|
|
}
|
|
|
|
var Oauth2EnabledCache = refreshcache0.NewRefreshCache[[]provider.OAuth2Provider](func(context.Context) ([]provider.OAuth2Provider, error) {
|
|
ps := providers.EnabledProvider()
|
|
r := make([]provider.OAuth2Provider, 0, ps.Len())
|
|
ps.Range(func(p provider.OAuth2Provider, value struct{}) bool {
|
|
r = append(r, p)
|
|
return true
|
|
})
|
|
slices.SortStableFunc(r, func(a, b provider.OAuth2Provider) int {
|
|
if a == b {
|
|
return 0
|
|
} else if natural.Less(a, b) {
|
|
return -1
|
|
} else {
|
|
return 1
|
|
}
|
|
})
|
|
return r, nil
|
|
}, 0)
|
|
|
|
var Oauth2SignupEnabledCache = refreshcache0.NewRefreshCache[[]provider.OAuth2Provider](func(ctx context.Context) ([]provider.OAuth2Provider, error) {
|
|
ps := providers.EnabledProvider()
|
|
r := make([]provider.OAuth2Provider, 0, ps.Len())
|
|
ps.Range(func(p provider.OAuth2Provider, value struct{}) bool {
|
|
group := model.SettingGroup(fmt.Sprintf("%s_%s", model.SettingGroupOauth2, p))
|
|
groupSettings := ProviderGroupSettings[group]
|
|
if groupSettings.Enabled.Get() && !groupSettings.DisableUserSignup.Get() {
|
|
r = append(r, p)
|
|
}
|
|
return true
|
|
})
|
|
slices.SortStableFunc(r, func(a, b provider.OAuth2Provider) int {
|
|
if a == b {
|
|
return 0
|
|
} else if natural.Less(a, b) {
|
|
return -1
|
|
} else {
|
|
return 1
|
|
}
|
|
})
|
|
return r, nil
|
|
}, 0)
|
|
|
|
func InitProvider(ctx context.Context) (err error) {
|
|
logOur := log.StandardLogger().Writer()
|
|
logLevle := hclog.Info
|
|
if flags.Global.Dev {
|
|
logLevle = hclog.Debug
|
|
}
|
|
for _, op := range conf.Conf.Oauth2Plugins {
|
|
log.Infof("load oauth2 plugin: %s", op.PluginFile)
|
|
err := os.MkdirAll(filepath.Dir(op.PluginFile), 0o755)
|
|
if err != nil {
|
|
log.Fatalf("create plugin dir: %s failed: %s", filepath.Dir(op.PluginFile), err)
|
|
return err
|
|
}
|
|
err = plugins.InitProviderPlugins(op.PluginFile, op.Args, hclog.New(&hclog.LoggerOptions{
|
|
Name: op.PluginFile,
|
|
Level: logLevle,
|
|
Output: logOur,
|
|
Color: hclog.ForceColor,
|
|
}))
|
|
if err != nil {
|
|
log.Fatalf("load oauth2 plugin: %s failed: %s", op.PluginFile, err)
|
|
return err
|
|
}
|
|
}
|
|
|
|
for _, pi := range providers.AllProvider() {
|
|
InitProviderSetting(pi)
|
|
}
|
|
|
|
for _, api := range aggregations.AllAggregation() {
|
|
InitAggregationSetting(api)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func InitProviderSetting(pi provider.Provider) {
|
|
group := model.SettingGroup(fmt.Sprintf("%s_%s", model.SettingGroupOauth2, pi.Provider()))
|
|
groupSettings := &ProviderGroupSetting{}
|
|
ProviderGroupSettings[group] = groupSettings
|
|
|
|
groupSettings.Enabled = settings.NewBoolSetting(fmt.Sprintf("%s_enabled", group), false, group,
|
|
settings.WithBeforeInitBool(func(bs settings.BoolSetting, b bool) (bool, error) {
|
|
defer Oauth2EnabledCache.Refresh(context.Background())
|
|
if b {
|
|
return b, providers.EnableProvider(pi.Provider())
|
|
} else {
|
|
providers.DisableProvider(pi.Provider())
|
|
return b, nil
|
|
}
|
|
}),
|
|
settings.WithInitPriorityBool(1),
|
|
settings.WithBeforeSetBool(func(bs settings.BoolSetting, b bool) (bool, error) {
|
|
defer Oauth2EnabledCache.Refresh(context.Background())
|
|
if b {
|
|
return b, providers.EnableProvider(pi.Provider())
|
|
} else {
|
|
providers.DisableProvider(pi.Provider())
|
|
return b, nil
|
|
}
|
|
}),
|
|
)
|
|
|
|
opt := provider.Oauth2Option{}
|
|
|
|
groupSettings.ClientID = settings.NewStringSetting(fmt.Sprintf("%s_client_id", group), opt.ClientID, group,
|
|
settings.WithBeforeInitString(func(ss settings.StringSetting, s string) (string, error) {
|
|
opt.ClientID = s
|
|
pi.Init(opt)
|
|
return s, nil
|
|
}),
|
|
settings.WithInitPriorityString(1),
|
|
settings.WithBeforeSetString(func(ss settings.StringSetting, s string) (string, error) {
|
|
opt.ClientID = s
|
|
pi.Init(opt)
|
|
return s, nil
|
|
}))
|
|
|
|
groupSettings.ClientSecret = settings.NewStringSetting(fmt.Sprintf("%s_client_secret", group), opt.ClientSecret, group,
|
|
settings.WithBeforeInitString(func(ss settings.StringSetting, s string) (string, error) {
|
|
opt.ClientSecret = s
|
|
pi.Init(opt)
|
|
return s, nil
|
|
}),
|
|
settings.WithInitPriorityString(1),
|
|
settings.WithBeforeSetString(func(ss settings.StringSetting, s string) (string, error) {
|
|
opt.ClientSecret = s
|
|
pi.Init(opt)
|
|
return s, nil
|
|
}))
|
|
|
|
groupSettings.RedirectURL = settings.NewStringSetting(fmt.Sprintf("%s_redirect_url", group), opt.RedirectURL, group,
|
|
settings.WithBeforeInitString(func(ss settings.StringSetting, s string) (string, error) {
|
|
opt.RedirectURL = s
|
|
pi.Init(opt)
|
|
return s, nil
|
|
}),
|
|
settings.WithInitPriorityString(1),
|
|
settings.WithBeforeSetString(func(ss settings.StringSetting, s string) (string, error) {
|
|
opt.RedirectURL = s
|
|
pi.Init(opt)
|
|
return s, nil
|
|
}))
|
|
|
|
groupSettings.DisableUserSignup = settings.NewBoolSetting(fmt.Sprintf("%s_disable_user_signup", group), false, group)
|
|
|
|
groupSettings.SignupNeedReview = settings.NewBoolSetting(fmt.Sprintf("%s_signup_need_review", group), false, group)
|
|
}
|
|
|
|
func InitAggregationProviderSetting(pi provider.Provider) {
|
|
group := model.SettingGroup(fmt.Sprintf("%s_%s", model.SettingGroupOauth2, pi.Provider()))
|
|
groupSettings := &ProviderGroupSetting{}
|
|
ProviderGroupSettings[group] = groupSettings
|
|
|
|
groupSettings.Enabled = settings.LoadOrNewBoolSetting(fmt.Sprintf("%s_enabled", group), false, group,
|
|
settings.WithBeforeSetBool(func(bs settings.BoolSetting, b bool) (bool, error) {
|
|
defer Oauth2EnabledCache.Refresh(context.Background())
|
|
if b {
|
|
return b, providers.EnableProvider(pi.Provider())
|
|
} else {
|
|
providers.DisableProvider(pi.Provider())
|
|
return b, nil
|
|
}
|
|
}),
|
|
)
|
|
|
|
opt := provider.Oauth2Option{}
|
|
|
|
groupSettings.ClientID = settings.LoadOrNewStringSetting(fmt.Sprintf("%s_client_id", group), opt.ClientID, group)
|
|
opt.ClientID = groupSettings.ClientID.Get()
|
|
groupSettings.ClientID.SetBeforeSet(func(ss settings.StringSetting, s string) (string, error) {
|
|
opt.ClientID = s
|
|
pi.Init(opt)
|
|
return s, nil
|
|
})
|
|
|
|
groupSettings.ClientSecret = settings.LoadOrNewStringSetting(fmt.Sprintf("%s_client_secret", group), opt.ClientSecret, group)
|
|
opt.ClientSecret = groupSettings.ClientSecret.Get()
|
|
groupSettings.ClientSecret.SetBeforeSet(func(ss settings.StringSetting, s string) (string, error) {
|
|
opt.ClientSecret = s
|
|
pi.Init(opt)
|
|
return s, nil
|
|
})
|
|
|
|
groupSettings.RedirectURL = settings.LoadOrNewStringSetting(fmt.Sprintf("%s_redirect_url", group), opt.RedirectURL, group)
|
|
opt.RedirectURL = groupSettings.RedirectURL.Get()
|
|
groupSettings.RedirectURL.SetBeforeSet(func(ss settings.StringSetting, s string) (string, error) {
|
|
opt.RedirectURL = s
|
|
pi.Init(opt)
|
|
return s, nil
|
|
})
|
|
|
|
pi.Init(opt)
|
|
|
|
groupSettings.DisableUserSignup = settings.LoadOrNewBoolSetting(fmt.Sprintf("%s_disable_user_signup", group), false, group)
|
|
|
|
groupSettings.SignupNeedReview = settings.LoadOrNewBoolSetting(fmt.Sprintf("%s_signup_need_review", group), false, group)
|
|
}
|
|
|
|
func InitAggregationSetting(pi provider.AggregationProviderInterface) {
|
|
group := model.SettingGroup(fmt.Sprintf("%s_%s", model.SettingGroupOauth2, pi.Provider()))
|
|
|
|
switch pi := pi.(type) {
|
|
case *aggregations.Rainbow:
|
|
settings.NewStringSetting(fmt.Sprintf("%s_api", group), aggregations.DefaultRainbowApi, group,
|
|
settings.WithBeforeInitString(func(ss settings.StringSetting, s string) (string, error) {
|
|
pi.SetAPI(s)
|
|
return s, nil
|
|
},
|
|
),
|
|
settings.WithBeforeSetString(func(ss settings.StringSetting, s string) (string, error) {
|
|
pi.SetAPI(s)
|
|
return s, nil
|
|
},
|
|
),
|
|
)
|
|
}
|
|
|
|
list := settings.NewStringSetting(fmt.Sprintf("%s_enabled_list", group), "", group,
|
|
settings.WithBeforeInitString(func(ss settings.StringSetting, s string) (string, error) {
|
|
return s, nil
|
|
}),
|
|
settings.WithInitPriorityString(1),
|
|
settings.WithBeforeSetString(func(ss settings.StringSetting, s string) (string, error) {
|
|
if s == "" {
|
|
return s, nil
|
|
}
|
|
list := strings.Split(s, ",")
|
|
for _, p := range list {
|
|
if slices.Index(pi.Providers(), p) == -1 {
|
|
return s, fmt.Errorf("provider %s not found", p)
|
|
}
|
|
}
|
|
return s, nil
|
|
}),
|
|
)
|
|
|
|
settings.NewBoolSetting(fmt.Sprintf("%s_enabled", group), false, group,
|
|
settings.WithBeforeInitBool(func(bs settings.BoolSetting, b bool) (bool, error) {
|
|
if b {
|
|
s := list.Get()
|
|
if s == "" {
|
|
log.Warnf("aggregation provider %s enabled, but no provider enabled", pi.Provider())
|
|
}
|
|
all := pi.Providers()
|
|
list := strings.Split(s, ",")
|
|
enabled := make([]provider.OAuth2Provider, 0, len(list))
|
|
for _, p := range list {
|
|
if slices.Index(all, p) != -1 {
|
|
enabled = append(enabled, p)
|
|
} else {
|
|
log.Warnf("aggregation provider %s enabled, but provider %s not found", pi.Provider(), p)
|
|
}
|
|
}
|
|
|
|
pi2, err := provider.ExtractProviders(pi, enabled...)
|
|
if err != nil {
|
|
log.Errorf("aggregation provider %s enabled, but extract provider failed: %s", pi.Provider(), err)
|
|
return b, nil
|
|
}
|
|
for _, pi2 := range pi2 {
|
|
providers.RegisterProvider(pi2)
|
|
InitAggregationProviderSetting(pi2)
|
|
}
|
|
}
|
|
return b, nil
|
|
}),
|
|
settings.WithBeforeSetBool(func(bs settings.BoolSetting, b bool) (bool, error) {
|
|
if len(list.Get()) == 0 {
|
|
return b, fmt.Errorf("enabled provider list is empty")
|
|
}
|
|
return b, nil
|
|
}),
|
|
)
|
|
}
|