mirror of https://github.com/usememos/memos
feat: implement oauth2 plugin (#1110)
parent
37f9c7c8d6
commit
69726c3925
@ -0,0 +1,7 @@
|
||||
package idp
|
||||
|
||||
type IdentityProviderUserInfo struct {
|
||||
Identifier string
|
||||
DisplayName string
|
||||
Email string
|
||||
}
|
@ -0,0 +1,115 @@
|
||||
// Package oauth2 is the plugin for OAuth2 Identity Provider.
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/usememos/memos/plugin/idp"
|
||||
"github.com/usememos/memos/store"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// IdentityProvider represents an OAuth2 Identity Provider.
|
||||
type IdentityProvider struct {
|
||||
config *store.IdentityProviderOAuth2Config
|
||||
}
|
||||
|
||||
// NewIdentityProvider initializes a new OAuth2 Identity Provider with the given configuration.
|
||||
func NewIdentityProvider(config *store.IdentityProviderOAuth2Config) (*IdentityProvider, error) {
|
||||
for v, field := range map[string]string{
|
||||
config.ClientID: "clientId",
|
||||
config.ClientSecret: "clientSecret",
|
||||
config.TokenURL: "tokenUrl",
|
||||
config.UserInfoURL: "userInfoUrl",
|
||||
config.FieldMapping.Identifier: "fieldMapping.identifier",
|
||||
} {
|
||||
if v == "" {
|
||||
return nil, errors.Errorf(`the field "%s" is empty but required`, field)
|
||||
}
|
||||
}
|
||||
|
||||
return &IdentityProvider{
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ExchangeToken returns the exchanged OAuth2 token using the given authorization code.
|
||||
func (p *IdentityProvider) ExchangeToken(ctx context.Context, redirectURL, code string) (string, error) {
|
||||
conf := &oauth2.Config{
|
||||
ClientID: p.config.ClientID,
|
||||
ClientSecret: p.config.ClientSecret,
|
||||
RedirectURL: redirectURL,
|
||||
Scopes: p.config.Scopes,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: p.config.AuthURL,
|
||||
TokenURL: p.config.TokenURL,
|
||||
AuthStyle: oauth2.AuthStyleInParams,
|
||||
},
|
||||
}
|
||||
|
||||
token, err := conf.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to exchange access token")
|
||||
}
|
||||
|
||||
accessToken, ok := token.Extra("access_token").(string)
|
||||
if !ok {
|
||||
return "", errors.New(`missing "access_token" from authorization response`)
|
||||
}
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
// UserInfo returns the parsed user information using the given OAuth2 token.
|
||||
func (p *IdentityProvider) UserInfo(token string) (*idp.IdentityProviderUserInfo, error) {
|
||||
client := &http.Client{}
|
||||
req, err := http.NewRequest(http.MethodGet, p.config.UserInfoURL, nil)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to new http request")
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get user information")
|
||||
}
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to read response body")
|
||||
}
|
||||
|
||||
var claims map[string]any
|
||||
err = json.Unmarshal(body, &claims)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to unmarshal response body")
|
||||
}
|
||||
|
||||
userInfo := &idp.IdentityProviderUserInfo{}
|
||||
if v, ok := claims[p.config.FieldMapping.Identifier].(string); ok {
|
||||
userInfo.Identifier = v
|
||||
}
|
||||
if userInfo.Identifier == "" {
|
||||
return nil, errors.Errorf("the field %q is not found in claims or has empty value", p.config.FieldMapping.Identifier)
|
||||
}
|
||||
|
||||
// Best effort to map optional fields
|
||||
if p.config.FieldMapping.DisplayName != "" {
|
||||
if v, ok := claims[p.config.FieldMapping.DisplayName].(string); ok {
|
||||
userInfo.DisplayName = v
|
||||
}
|
||||
}
|
||||
if userInfo.DisplayName == "" {
|
||||
userInfo.DisplayName = userInfo.Identifier
|
||||
}
|
||||
if p.config.FieldMapping.Email != "" {
|
||||
if v, ok := claims[p.config.FieldMapping.Email].(string); ok {
|
||||
userInfo.Email = v
|
||||
}
|
||||
}
|
||||
return userInfo, nil
|
||||
}
|
@ -0,0 +1,163 @@
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/usememos/memos/plugin/idp"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func TestNewIdentityProvider(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *store.IdentityProviderOAuth2Config
|
||||
containsErr string
|
||||
}{
|
||||
{
|
||||
name: "no tokenUrl",
|
||||
config: &store.IdentityProviderOAuth2Config{
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
AuthURL: "",
|
||||
TokenURL: "",
|
||||
UserInfoURL: "https://example.com/api/user",
|
||||
FieldMapping: &store.FieldMapping{
|
||||
Identifier: "login",
|
||||
},
|
||||
},
|
||||
containsErr: `the field "tokenUrl" is empty but required`,
|
||||
},
|
||||
{
|
||||
name: "no userInfoUrl",
|
||||
config: &store.IdentityProviderOAuth2Config{
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
AuthURL: "",
|
||||
TokenURL: "https://example.com/token",
|
||||
UserInfoURL: "",
|
||||
FieldMapping: &store.FieldMapping{
|
||||
Identifier: "login",
|
||||
},
|
||||
},
|
||||
containsErr: `the field "userInfoUrl" is empty but required`,
|
||||
},
|
||||
{
|
||||
name: "no field mapping identifier",
|
||||
config: &store.IdentityProviderOAuth2Config{
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
AuthURL: "",
|
||||
TokenURL: "https://example.com/token",
|
||||
UserInfoURL: "https://example.com/api/user",
|
||||
FieldMapping: &store.FieldMapping{
|
||||
Identifier: "",
|
||||
},
|
||||
},
|
||||
containsErr: `the field "fieldMapping.identifier" is empty but required`,
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
_, err := NewIdentityProvider(test.config)
|
||||
assert.ErrorContains(t, err, test.containsErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newMockServer(t *testing.T, code, accessToken string, userinfo []byte) *httptest.Server {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
var rawIDToken string
|
||||
mux.HandleFunc("/oauth2/token", func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, http.MethodPost, r.Method)
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
vals, err := url.ParseQuery(string(body))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, code, vals.Get("code"))
|
||||
require.Equal(t, "authorization_code", vals.Get("grant_type"))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w).Encode(map[string]any{
|
||||
"access_token": accessToken,
|
||||
"token_type": "Bearer",
|
||||
"refresh_token": "test-refresh-token",
|
||||
"expires_in": 3600,
|
||||
"id_token": rawIDToken,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
mux.HandleFunc("/oauth2/userinfo", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, err := w.Write(userinfo)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
s := httptest.NewServer(mux)
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func TestIdentityProvider(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
const (
|
||||
testClientID = "test-client-id"
|
||||
testCode = "test-code"
|
||||
testAccessToken = "test-access-token"
|
||||
testSubject = "123456789"
|
||||
testName = "John Doe"
|
||||
testEmail = "john.doe@example.com"
|
||||
)
|
||||
userInfo, err := json.Marshal(
|
||||
map[string]any{
|
||||
"sub": testSubject,
|
||||
"name": testName,
|
||||
"email": testEmail,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
s := newMockServer(t, testCode, testAccessToken, userInfo)
|
||||
|
||||
oauth2, err := NewIdentityProvider(
|
||||
&store.IdentityProviderOAuth2Config{
|
||||
ClientID: testClientID,
|
||||
ClientSecret: "test-client-secret",
|
||||
TokenURL: fmt.Sprintf("%s/oauth2/token", s.URL),
|
||||
UserInfoURL: fmt.Sprintf("%s/oauth2/userinfo", s.URL),
|
||||
FieldMapping: &store.FieldMapping{
|
||||
Identifier: "sub",
|
||||
DisplayName: "name",
|
||||
Email: "email",
|
||||
},
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
redirectURL := "https://example.com/oauth/callback"
|
||||
oauthToken, err := oauth2.ExchangeToken(ctx, redirectURL, testCode)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, testAccessToken, oauthToken)
|
||||
|
||||
userInfoResult, err := oauth2.UserInfo(oauthToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
wantUserInfo := &idp.IdentityProviderUserInfo{
|
||||
Identifier: testSubject,
|
||||
DisplayName: testName,
|
||||
Email: testEmail,
|
||||
}
|
||||
assert.Equal(t, wantUserInfo, userInfoResult)
|
||||
}
|
Loading…
Reference in New Issue