From cc3a47fc65ec23b9b24cc5b9d4709b9b1755f722 Mon Sep 17 00:00:00 2001 From: boojack Date: Sun, 30 Jul 2023 23:49:10 +0800 Subject: [PATCH] feat: impl auth interceptor (#2055) * feat: impl auth interceptor * chore: update * chore: update * chore: update --- api/v1/auth/auth.go | 11 +- api/v1/idp.go | 11 +- api/v1/jwt.go | 18 +-- api/v1/memo.go | 15 +- api/v1/memo_organizer.go | 3 +- api/v1/memo_resource.go | 5 +- api/v1/resource.go | 13 +- api/v1/shortcut.go | 9 +- api/v1/storage.go | 9 +- api/v1/system.go | 3 +- api/v1/system_setting.go | 5 +- api/v1/tag.go | 9 +- api/v1/user.go | 9 +- api/v1/user_setting.go | 3 +- api/v2/auth/auth.go | 297 +++++++++++++++++++++++++++++++++++++++ api/v2/auth/config.go | 15 ++ api/v2/user_service.go | 10 +- api/v2/v2.go | 37 ++++- server/server.go | 19 ++- 19 files changed, 422 insertions(+), 79 deletions(-) create mode 100644 api/v2/auth/auth.go create mode 100644 api/v2/auth/config.go diff --git a/api/v1/auth/auth.go b/api/v1/auth/auth.go index 68d21d91..47c4178a 100644 --- a/api/v1/auth/auth.go +++ b/api/v1/auth/auth.go @@ -12,6 +12,10 @@ import ( ) const ( + // The key name used to store user id in the context + // user id is extracted from the jwt token subject field. + UserIDContextKey = "user-id" + // issuer is the issuer of the jwt token. issuer = "memos" // Signing key section. For now, this is only used for signing, not for verifying since we only // have 1 version. But it will be used to maintain backward compatibility if we change the signing mechanism. @@ -23,14 +27,11 @@ const ( apiTokenDuration = 2 * time.Hour accessTokenDuration = 24 * time.Hour refreshTokenDuration = 7 * 24 * time.Hour - // RefreshThresholdDuration is the threshold duration for refreshing token. - RefreshThresholdDuration = 1 * time.Hour // CookieExpDuration expires slightly earlier than the jwt expiration. Client would be logged out if the user // cookie expires, thus the client would always logout first before attempting to make a request with the expired jwt. - // Suppose we have a valid refresh token, we will refresh the token in 2 cases: - // 1. The access token is about to expire in <> - // 2. The access token has already expired, we refresh the token so that the ongoing request can pass through. + // Suppose we have a valid refresh token, we will refresh the token in cases: + // 1. The access token has already expired, we refresh the token so that the ongoing request can pass through. CookieExpDuration = refreshTokenDuration - 1*time.Minute // AccessTokenCookieName is the cookie name of access token. AccessTokenCookieName = "memos.access-token" diff --git a/api/v1/idp.go b/api/v1/idp.go index 4ceb9678..5889c499 100644 --- a/api/v1/idp.go +++ b/api/v1/idp.go @@ -7,6 +7,7 @@ import ( "strconv" "github.com/labstack/echo/v4" + "github.com/usememos/memos/api/v1/auth" "github.com/usememos/memos/store" ) @@ -66,7 +67,7 @@ type UpdateIdentityProviderRequest struct { func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) { g.POST("/idp", func(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } @@ -100,7 +101,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) { g.PATCH("/idp/:idpId", func(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } @@ -147,7 +148,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find identity provider list").SetInternal(err) } - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) isHostUser := false if ok { user, err := s.Store.GetUser(ctx, &store.FindUser{ @@ -175,7 +176,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) { g.GET("/idp/:idpId", func(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } @@ -209,7 +210,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) { g.DELETE("/idp/:idpId", func(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } diff --git a/api/v1/jwt.go b/api/v1/jwt.go index 64de6da3..0e3f646b 100644 --- a/api/v1/jwt.go +++ b/api/v1/jwt.go @@ -5,7 +5,6 @@ import ( "net/http" "strconv" "strings" - "time" "github.com/golang-jwt/jwt/v4" "github.com/labstack/echo/v4" @@ -15,17 +14,6 @@ import ( "github.com/usememos/memos/store" ) -const ( - // Context section - // The key name used to store user id in the context - // user id is extracted from the jwt token subject field. - userIDContextKey = "user-id" -) - -func getUserIDContextKey() string { - return userIDContextKey -} - // Claims creates a struct that will be encoded to a JWT. // We add jwt.RegisteredClaims as an embedded type, to provide fields such as name. type Claims struct { @@ -112,7 +100,7 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"]) }) - generateToken := time.Until(claims.ExpiresAt.Time) < auth.RefreshThresholdDuration + generateToken := false if err != nil { var ve *jwt.ValidationError if errors.As(err, &ve) { @@ -203,7 +191,7 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e } // Stores userID into context. - c.Set(getUserIDContextKey(), userID) + c.Set(auth.UserIDContextKey, userID) return next(c) } } @@ -228,7 +216,7 @@ func (s *APIV1Service) defaultAuthSkipper(c echo.Context) bool { } if user != nil { // Stores userID into context. - c.Set(getUserIDContextKey(), user.ID) + c.Set(auth.UserIDContextKey, user.ID) return true } } diff --git a/api/v1/memo.go b/api/v1/memo.go index 4945df77..b70f561e 100644 --- a/api/v1/memo.go +++ b/api/v1/memo.go @@ -10,6 +10,7 @@ import ( "github.com/labstack/echo/v4" "github.com/pkg/errors" + "github.com/usememos/memos/api/v1/auth" "github.com/usememos/memos/store" ) @@ -113,7 +114,7 @@ const maxContentLength = 1 << 30 func (s *APIV1Service) registerMemoRoutes(g *echo.Group) { g.POST("/memo", func(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } @@ -224,7 +225,7 @@ func (s *APIV1Service) registerMemoRoutes(g *echo.Group) { g.PATCH("/memo/:memoId", func(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } @@ -362,7 +363,7 @@ func (s *APIV1Service) registerMemoRoutes(g *echo.Group) { } } - currentUserID, ok := c.Get(getUserIDContextKey()).(int) + currentUserID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { // Anonymous use should only fetch PUBLIC memos with specified user if findMemoMessage.CreatorID == nil { @@ -449,7 +450,7 @@ func (s *APIV1Service) registerMemoRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Memo not found: %d", memoID)) } - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if memo.Visibility == store.Private { if !ok || memo.CreatorID != userID { return echo.NewHTTPError(http.StatusForbidden, "this memo is private only") @@ -487,7 +488,7 @@ func (s *APIV1Service) registerMemoRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find memo") } - currentUserID, ok := c.Get(getUserIDContextKey()).(int) + currentUserID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { findMemoMessage.VisibilityList = []store.Visibility{store.Public} } else { @@ -529,7 +530,7 @@ func (s *APIV1Service) registerMemoRoutes(g *echo.Group) { g.GET("/memo/all", func(c echo.Context) error { ctx := c.Request().Context() findMemoMessage := &store.FindMemo{} - _, ok := c.Get(getUserIDContextKey()).(int) + _, ok := c.Get(auth.UserIDContextKey).(int) if !ok { findMemoMessage.VisibilityList = []store.Visibility{store.Public} } else { @@ -589,7 +590,7 @@ func (s *APIV1Service) registerMemoRoutes(g *echo.Group) { g.DELETE("/memo/:memoId", func(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } diff --git a/api/v1/memo_organizer.go b/api/v1/memo_organizer.go index 9adc9f66..f88d4712 100644 --- a/api/v1/memo_organizer.go +++ b/api/v1/memo_organizer.go @@ -7,6 +7,7 @@ import ( "strconv" "github.com/labstack/echo/v4" + "github.com/usememos/memos/api/v1/auth" "github.com/usememos/memos/store" ) @@ -28,7 +29,7 @@ func (s *APIV1Service) registerMemoOrganizerRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err) } - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } diff --git a/api/v1/memo_resource.go b/api/v1/memo_resource.go index eaba21ff..2a197c23 100644 --- a/api/v1/memo_resource.go +++ b/api/v1/memo_resource.go @@ -8,6 +8,7 @@ import ( "time" "github.com/labstack/echo/v4" + "github.com/usememos/memos/api/v1/auth" "github.com/usememos/memos/store" ) @@ -41,7 +42,7 @@ func (s *APIV1Service) registerMemoResourceRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("memoId"))).SetInternal(err) } - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } @@ -97,7 +98,7 @@ func (s *APIV1Service) registerMemoResourceRoutes(g *echo.Group) { g.DELETE("/memo/:memoId/resource/:resourceId", func(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } diff --git a/api/v1/resource.go b/api/v1/resource.go index 0fdb34c0..f004caf5 100644 --- a/api/v1/resource.go +++ b/api/v1/resource.go @@ -21,6 +21,7 @@ import ( "github.com/disintegration/imaging" "github.com/labstack/echo/v4" "github.com/pkg/errors" + "github.com/usememos/memos/api/v1/auth" "github.com/usememos/memos/common/log" "github.com/usememos/memos/common/util" "github.com/usememos/memos/plugin/storage/s3" @@ -82,7 +83,7 @@ var fileKeyPattern = regexp.MustCompile(`\{[a-z]{1,9}\}`) func (s *APIV1Service) registerResourceRoutes(g *echo.Group) { g.POST("/resource", func(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } @@ -156,7 +157,7 @@ func (s *APIV1Service) registerResourceRoutes(g *echo.Group) { g.POST("/resource/blob", func(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } @@ -216,7 +217,7 @@ func (s *APIV1Service) registerResourceRoutes(g *echo.Group) { g.GET("/resource", func(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } @@ -243,7 +244,7 @@ func (s *APIV1Service) registerResourceRoutes(g *echo.Group) { g.PATCH("/resource/:resourceId", func(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } @@ -289,7 +290,7 @@ func (s *APIV1Service) registerResourceRoutes(g *echo.Group) { g.DELETE("/resource/:resourceId", func(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } @@ -345,7 +346,7 @@ func (s *APIV1Service) registerResourcePublicRoutes(g *echo.Group) { } // Protected resource require a logined user - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if resourceVisibility == store.Protected && (!ok || userID <= 0) { return echo.NewHTTPError(http.StatusUnauthorized, "Resource visibility not match").SetInternal(err) } diff --git a/api/v1/shortcut.go b/api/v1/shortcut.go index 8e99b690..f804db58 100644 --- a/api/v1/shortcut.go +++ b/api/v1/shortcut.go @@ -9,6 +9,7 @@ import ( "github.com/labstack/echo/v4" "github.com/pkg/errors" + "github.com/usememos/memos/api/v1/auth" "github.com/usememos/memos/store" ) @@ -57,7 +58,7 @@ type ShortcutDelete struct { func (s *APIV1Service) registerShortcutRoutes(g *echo.Group) { g.POST("/shortcut", func(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } @@ -84,7 +85,7 @@ func (s *APIV1Service) registerShortcutRoutes(g *echo.Group) { g.PATCH("/shortcut/:shortcutId", func(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } @@ -136,7 +137,7 @@ func (s *APIV1Service) registerShortcutRoutes(g *echo.Group) { g.GET("/shortcut", func(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find shortcut") } @@ -175,7 +176,7 @@ func (s *APIV1Service) registerShortcutRoutes(g *echo.Group) { g.DELETE("/shortcut/:shortcutId", func(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } diff --git a/api/v1/storage.go b/api/v1/storage.go index 186d5687..31500721 100644 --- a/api/v1/storage.go +++ b/api/v1/storage.go @@ -7,6 +7,7 @@ import ( "strconv" "github.com/labstack/echo/v4" + "github.com/usememos/memos/api/v1/auth" "github.com/usememos/memos/store" ) @@ -64,7 +65,7 @@ type UpdateStorageRequest struct { func (s *APIV1Service) registerStorageRoutes(g *echo.Group) { g.POST("/storage", func(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } @@ -110,7 +111,7 @@ func (s *APIV1Service) registerStorageRoutes(g *echo.Group) { g.PATCH("/storage/:storageId", func(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } @@ -164,7 +165,7 @@ func (s *APIV1Service) registerStorageRoutes(g *echo.Group) { g.GET("/storage", func(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } @@ -198,7 +199,7 @@ func (s *APIV1Service) registerStorageRoutes(g *echo.Group) { g.DELETE("/storage/:storageId", func(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } diff --git a/api/v1/system.go b/api/v1/system.go index 1808c76c..861384b9 100644 --- a/api/v1/system.go +++ b/api/v1/system.go @@ -5,6 +5,7 @@ import ( "net/http" "github.com/labstack/echo/v4" + "github.com/usememos/memos/api/v1/auth" "github.com/usememos/memos/common/log" "github.com/usememos/memos/server/profile" "github.com/usememos/memos/store" @@ -140,7 +141,7 @@ func (s *APIV1Service) registerSystemRoutes(g *echo.Group) { g.POST("/system/vacuum", func(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } diff --git a/api/v1/system_setting.go b/api/v1/system_setting.go index 116f06a5..188c7241 100644 --- a/api/v1/system_setting.go +++ b/api/v1/system_setting.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/labstack/echo/v4" + "github.com/usememos/memos/api/v1/auth" "github.com/usememos/memos/store" ) @@ -186,7 +187,7 @@ func (upsert UpsertSystemSettingRequest) Validate() error { func (s *APIV1Service) registerSystemSettingRoutes(g *echo.Group) { g.POST("/system/setting", func(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } @@ -236,7 +237,7 @@ func (s *APIV1Service) registerSystemSettingRoutes(g *echo.Group) { g.GET("/system/setting", func(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } diff --git a/api/v1/tag.go b/api/v1/tag.go index acd8407b..8c3eb6ef 100644 --- a/api/v1/tag.go +++ b/api/v1/tag.go @@ -9,6 +9,7 @@ import ( "github.com/labstack/echo/v4" "github.com/pkg/errors" + "github.com/usememos/memos/api/v1/auth" "github.com/usememos/memos/store" "golang.org/x/exp/slices" ) @@ -29,7 +30,7 @@ type DeleteTagRequest struct { func (s *APIV1Service) registerTagRoutes(g *echo.Group) { g.POST("/tag", func(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } @@ -58,7 +59,7 @@ func (s *APIV1Service) registerTagRoutes(g *echo.Group) { g.GET("/tag", func(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find tag") } @@ -79,7 +80,7 @@ func (s *APIV1Service) registerTagRoutes(g *echo.Group) { g.GET("/tag/suggestion", func(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusBadRequest, "Missing user session") } @@ -124,7 +125,7 @@ func (s *APIV1Service) registerTagRoutes(g *echo.Group) { g.POST("/tag/delete", func(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } diff --git a/api/v1/user.go b/api/v1/user.go index 3a49044f..dbd89a17 100644 --- a/api/v1/user.go +++ b/api/v1/user.go @@ -9,6 +9,7 @@ import ( "github.com/labstack/echo/v4" "github.com/pkg/errors" + "github.com/usememos/memos/api/v1/auth" "github.com/usememos/memos/common/util" "github.com/usememos/memos/store" "golang.org/x/crypto/bcrypt" @@ -132,7 +133,7 @@ func (s *APIV1Service) registerUserRoutes(g *echo.Group) { // POST /user - Create a new user. g.POST("/user", func(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing auth session") } @@ -207,7 +208,7 @@ func (s *APIV1Service) registerUserRoutes(g *echo.Group) { // GET /user/me - Get current user. g.GET("/user/me", func(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing auth session") } @@ -286,7 +287,7 @@ func (s *APIV1Service) registerUserRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("id"))).SetInternal(err) } - currentUserID, ok := c.Get(getUserIDContextKey()).(int) + currentUserID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } @@ -366,7 +367,7 @@ func (s *APIV1Service) registerUserRoutes(g *echo.Group) { // DELETE /user/:id - Delete user by id. g.DELETE("/user/:id", func(c echo.Context) error { ctx := c.Request().Context() - currentUserID, ok := c.Get(getUserIDContextKey()).(int) + currentUserID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") } diff --git a/api/v1/user_setting.go b/api/v1/user_setting.go index b158b8aa..dedf04b7 100644 --- a/api/v1/user_setting.go +++ b/api/v1/user_setting.go @@ -6,6 +6,7 @@ import ( "net/http" "github.com/labstack/echo/v4" + "github.com/usememos/memos/api/v1/auth" "github.com/usememos/memos/store" "golang.org/x/exp/slices" ) @@ -121,7 +122,7 @@ func (upsert UpsertUserSettingRequest) Validate() error { func (s *APIV1Service) registerUserSettingRoutes(g *echo.Group) { g.POST("/user/setting", func(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing auth session") } diff --git a/api/v2/auth/auth.go b/api/v2/auth/auth.go new file mode 100644 index 00000000..b7bc1ef4 --- /dev/null +++ b/api/v2/auth/auth.go @@ -0,0 +1,297 @@ +// Package auth handles the auth of gRPC server. +package auth + +import ( + "context" + "errors" + "net/http" + "strconv" + "strings" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + + "github.com/golang-jwt/jwt/v4" + errs "github.com/pkg/errors" + + "github.com/usememos/memos/store" +) + +// ContextKey is the key type of context value. +type ContextKey int + +const ( + // The key name used to store user id in the context + // user id is extracted from the jwt token subject field. + UserIDContextKey ContextKey = iota + // issuer is the issuer of the jwt token. + issuer = "memos" + // Signing key section. For now, this is only used for signing, not for verifying since we only + // have 1 version. But it will be used to maintain backward compatibility if we change the signing mechanism. + keyID = "v1" + // AccessTokenAudienceName is the audience name of the access token. + AccessTokenAudienceName = "user.access-token" + // RefreshTokenAudienceName is the audience name of the refresh token. + RefreshTokenAudienceName = "user.refresh-token" + apiTokenDuration = 2 * time.Hour + accessTokenDuration = 24 * time.Hour + refreshTokenDuration = 7 * 24 * time.Hour + + // CookieExpDuration expires slightly earlier than the jwt expiration. Client would be logged out if the user + // cookie expires, thus the client would always logout first before attempting to make a request with the expired jwt. + // Suppose we have a valid refresh token, we will refresh the token in cases: + // 1. The access token has already expired, we refresh the token so that the ongoing request can pass through. + CookieExpDuration = refreshTokenDuration - 1*time.Minute + // AccessTokenCookieName is the cookie name of access token. + AccessTokenCookieName = "memos.access-token" + // RefreshTokenCookieName is the cookie name of refresh token. + RefreshTokenCookieName = "memos.refresh-token" +) + +// GRPCAuthInterceptor is the auth interceptor for gRPC server. +type GRPCAuthInterceptor struct { + store *store.Store + secret string +} + +// NewGRPCAuthInterceptor returns a new API auth interceptor. +func NewGRPCAuthInterceptor(store *store.Store, secret string) *GRPCAuthInterceptor { + return &GRPCAuthInterceptor{ + store: store, + secret: secret, + } +} + +// AuthenticationInterceptor is the unary interceptor for gRPC API. +func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, request any, serverInfo *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil, status.Errorf(codes.Unauthenticated, "failed to parse metadata from incoming context") + } + accessTokenStr, refreshTokenStr, err := getTokenFromMetadata(md) + if err != nil { + return nil, status.Errorf(codes.Unauthenticated, err.Error()) + } + + userID, err := in.authenticate(ctx, accessTokenStr, refreshTokenStr) + if err != nil { + if IsAuthenticationAllowed(serverInfo.FullMethod) { + return handler(ctx, request) + } + return nil, err + } + + // Stores userID into context. + childCtx := context.WithValue(ctx, UserIDContextKey, userID) + return handler(childCtx, request) +} + +func (in *GRPCAuthInterceptor) authenticate(ctx context.Context, accessTokenStr, refreshTokenStr string) (int, error) { + if accessTokenStr == "" { + return 0, status.Errorf(codes.Unauthenticated, "access token not found") + } + claims := &claimsMessage{} + generateToken := false + accessToken, err := jwt.ParseWithClaims(accessTokenStr, claims, func(t *jwt.Token) (any, error) { + if t.Method.Alg() != jwt.SigningMethodHS256.Name { + return nil, status.Errorf(codes.Unauthenticated, "unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256) + } + if kid, ok := t.Header["kid"].(string); ok { + if kid == "v1" { + return []byte(in.secret), nil + } + } + return nil, status.Errorf(codes.Unauthenticated, "unexpected access token kid=%v", t.Header["kid"]) + }) + if err != nil { + var ve *jwt.ValidationError + if errors.As(err, &ve) && ve.Errors == jwt.ValidationErrorExpired { + // If expiration error is the only error, we will clear the err + // and generate new access token and refresh token + if refreshTokenStr == "" { + return 0, status.Errorf(codes.Unauthenticated, "access token is expired") + } + generateToken = true + } else { + return 0, status.Errorf(codes.Unauthenticated, "failed to parse claim") + } + } + if !audienceContains(claims.Audience, AccessTokenAudienceName) { + return 0, status.Errorf(codes.Unauthenticated, + "invalid access token, audience mismatch, got %q, expected %q. you may send request to the wrong environment", + claims.Audience, + AccessTokenAudienceName, + ) + } + + userID, err := strconv.Atoi(claims.Subject) + if err != nil { + return 0, status.Errorf(codes.Unauthenticated, "malformed ID %q in the access token", claims.Subject) + } + user, err := in.store.GetUser(ctx, &store.FindUser{ + ID: &userID, + }) + if err != nil { + return 0, status.Errorf(codes.Unauthenticated, "failed to find user ID %q in the access token", userID) + } + if user == nil { + return 0, status.Errorf(codes.Unauthenticated, "user ID %q not exists in the access token", userID) + } + if user.RowStatus == store.Archived { + return 0, status.Errorf(codes.Unauthenticated, "user ID %q has been deactivated by administrators", userID) + } + + if generateToken { + generateTokenFunc := func() error { + // Parses token and checks if it's valid. + refreshTokenClaims := &claimsMessage{} + refreshToken, err := jwt.ParseWithClaims(refreshTokenStr, refreshTokenClaims, func(t *jwt.Token) (any, error) { + if t.Method.Alg() != jwt.SigningMethodHS256.Name { + return nil, status.Errorf(codes.Unauthenticated, "unexpected refresh token signing method=%v, expected %v", t.Header["alg"], jwt.SigningMethodHS256) + } + + if kid, ok := t.Header["kid"].(string); ok { + if kid == "v1" { + return []byte(in.secret), nil + } + } + return nil, errs.Errorf("unexpected refresh token kid=%v", t.Header["kid"]) + }) + if err != nil { + if err == jwt.ErrSignatureInvalid { + return errs.Errorf("failed to generate access token: invalid refresh token signature") + } + return errs.Errorf("Server error to refresh expired token, user ID %d", userID) + } + + if !audienceContains(refreshTokenClaims.Audience, RefreshTokenAudienceName) { + return errs.Errorf("Invalid refresh token, audience mismatch, got %q, expected %q. you may send request to the wrong environment", + refreshTokenClaims.Audience, + RefreshTokenAudienceName, + ) + } + + // If we have a valid refresh token, we will generate new access token and refresh token + if refreshToken != nil && refreshToken.Valid { + if err := generateTokensAndSetCookies(ctx, user.Username, user.ID, in.secret); err != nil { + return errs.Wrapf(err, "failed to regenerate token") + } + } + + return nil + } + + // It may happen that we still have a valid access token, but we encounter issue when trying to generate new token + // In such case, we won't return the error. + if err := generateTokenFunc(); err != nil && !accessToken.Valid { + return 0, status.Errorf(codes.Unauthenticated, err.Error()) + } + } + return userID, nil +} + +func getTokenFromMetadata(md metadata.MD) (string, string, error) { + authorizationHeaders := md.Get("Authorization") + if len(md.Get("Authorization")) > 0 { + authHeaderParts := strings.Fields(authorizationHeaders[0]) + if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { + return "", "", errs.Errorf("authorization header format must be Bearer {token}") + } + return authHeaderParts[1], "", nil + } + // check the HTTP cookie + var accessToken, refreshToken string + for _, t := range append(md.Get("grpcgateway-cookie"), md.Get("cookie")...) { + header := http.Header{} + header.Add("Cookie", t) + request := http.Request{Header: header} + if v, _ := request.Cookie(AccessTokenCookieName); v != nil { + accessToken = v.Value + } + if v, _ := request.Cookie(RefreshTokenCookieName); v != nil { + refreshToken = v.Value + } + } + if accessToken != "" && refreshToken != "" { + return accessToken, refreshToken, nil + } + return "", "", nil +} + +func audienceContains(audience jwt.ClaimStrings, token string) bool { + for _, v := range audience { + if v == token { + return true + } + } + return false +} + +type claimsMessage struct { + Name string `json:"name"` + jwt.RegisteredClaims +} + +// generateTokensAndSetCookies generates jwt token and saves it to the http-only cookie. +func generateTokensAndSetCookies(ctx context.Context, username string, userID int, secret string) error { + accessToken, err := GenerateAccessToken(username, userID, secret) + if err != nil { + return errs.Wrap(err, "failed to generate access token") + } + // We generate here a new refresh token and saving it to the cookie. + refreshToken, err := GenerateRefreshToken(username, userID, secret) + if err != nil { + return errs.Wrap(err, "failed to generate refresh token") + } + + if err := grpc.SetHeader(ctx, metadata.New(map[string]string{ + AccessTokenCookieName: accessToken, + RefreshTokenCookieName: refreshToken, + })); err != nil { + return errs.Wrapf(err, "failed to set grpc header") + } + return nil +} + +// GenerateAccessToken generates an access token for web. +func GenerateAccessToken(username string, userID int, secret string) (string, error) { + expirationTime := time.Now().Add(accessTokenDuration) + return generateToken(username, userID, AccessTokenAudienceName, expirationTime, []byte(secret)) +} + +// GenerateRefreshToken generates a refresh token for web. +func GenerateRefreshToken(username string, userID int, secret string) (string, error) { + expirationTime := time.Now().Add(refreshTokenDuration) + return generateToken(username, userID, RefreshTokenAudienceName, expirationTime, []byte(secret)) +} + +// Pay attention to this function. It holds the main JWT token generation logic. +func generateToken(username string, userID int, aud string, expirationTime time.Time, secret []byte) (string, error) { + // Create the JWT claims, which includes the username and expiry time. + claims := &claimsMessage{ + Name: username, + RegisteredClaims: jwt.RegisteredClaims{ + Audience: jwt.ClaimStrings{aud}, + // In JWT, the expiry time is expressed as unix milliseconds. + ExpiresAt: jwt.NewNumericDate(expirationTime), + IssuedAt: jwt.NewNumericDate(time.Now()), + Issuer: issuer, + Subject: strconv.Itoa(userID), + }, + } + + // Declare the token with the HS256 algorithm used for signing, and the claims. + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + token.Header["kid"] = keyID + + // Create the JWT string. + tokenString, err := token.SignedString(secret) + if err != nil { + return "", err + } + + return tokenString, nil +} diff --git a/api/v2/auth/config.go b/api/v2/auth/config.go new file mode 100644 index 00000000..6f52b1b4 --- /dev/null +++ b/api/v2/auth/config.go @@ -0,0 +1,15 @@ +package auth + +import "strings" + +var authenticationAllowlistMethods = map[string]bool{ + "/memos.api.v2.UserService/GetUser": true, +} + +// IsAuthenticationAllowed returns whether the method is exempted from authentication. +func IsAuthenticationAllowed(fullMethodName string) bool { + if strings.HasPrefix(fullMethodName, "/grpc.reflection") { + return true + } + return authenticationAllowlistMethods[fullMethodName] +} diff --git a/api/v2/user_service.go b/api/v2/user_service.go index f62e3748..e3f2c3e8 100644 --- a/api/v2/user_service.go +++ b/api/v2/user_service.go @@ -3,6 +3,7 @@ package v2 import ( "context" + "github.com/usememos/memos/api/v2/auth" apiv2pb "github.com/usememos/memos/proto/gen/api/v2" "github.com/usememos/memos/store" "google.golang.org/grpc/codes" @@ -44,9 +45,12 @@ func (s *UserService) GetUser(ctx context.Context, request *apiv2pb.GetUserReque if err != nil { return nil, status.Errorf(codes.Internal, "failed to list user settings: %v", err) } - // TODO: check the access permission for user settings. - for _, userSetting := range userSettings { - userMessage.Settings = append(userMessage.Settings, convertUserSettingFromStore(userSetting)) + + userID, ok := ctx.Value(auth.UserIDContextKey).(int) + if ok && userID == int(userMessage.Id) { + for _, userSetting := range userSettings { + userMessage.Settings = append(userMessage.Settings, convertUserSettingFromStore(userSetting)) + } } response := &apiv2pb.GetUserResponse{ diff --git a/api/v2/v2.go b/api/v2/v2.go index 525cbbb4..c5130694 100644 --- a/api/v2/v2.go +++ b/api/v2/v2.go @@ -6,26 +6,53 @@ import ( grpcRuntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/labstack/echo/v4" + "github.com/usememos/memos/api/v2/auth" apiv2pb "github.com/usememos/memos/proto/gen/api/v2" + "github.com/usememos/memos/server/profile" "github.com/usememos/memos/store" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" ) -func NewGRPCServer(store *store.Store) *grpc.Server { - grpcServer := grpc.NewServer() +type APIV2Service struct { + Secret string + Profile *profile.Profile + Store *store.Store + + grpcServer *grpc.Server + grpcServerPort int +} + +func NewAPIV2Service(secret string, profile *profile.Profile, store *store.Store, grpcServerPort int) *APIV2Service { + authProvider := auth.NewGRPCAuthInterceptor(store, secret) + grpcServer := grpc.NewServer( + grpc.ChainUnaryInterceptor( + authProvider.AuthenticationInterceptor, + ), + ) apiv2pb.RegisterUserServiceServer(grpcServer, NewUserService(store)) apiv2pb.RegisterTagServiceServer(grpcServer, NewTagService(store)) - return grpcServer + + return &APIV2Service{ + Secret: secret, + Profile: profile, + Store: store, + grpcServer: grpcServer, + grpcServerPort: grpcServerPort, + } +} + +func (s *APIV2Service) GetGRPCServer() *grpc.Server { + return s.grpcServer } // RegisterGateway registers the gRPC-Gateway with the given Echo instance. -func RegisterGateway(ctx context.Context, e *echo.Echo, grpcServerPort int) error { +func (s *APIV2Service) RegisterGateway(ctx context.Context, e *echo.Echo) error { // Create a client connection to the gRPC Server we just started. // This is where the gRPC-Gateway proxies the requests. conn, err := grpc.DialContext( ctx, - fmt.Sprintf(":%d", grpcServerPort), + fmt.Sprintf(":%d", s.grpcServerPort), grpc.WithTransportCredentials(insecure.NewCredentials()), ) if err != nil { diff --git a/server/server.go b/server/server.go index d1d93293..a1da098d 100644 --- a/server/server.go +++ b/server/server.go @@ -20,18 +20,19 @@ import ( "github.com/usememos/memos/server/profile" "github.com/usememos/memos/store" "go.uber.org/zap" - "google.golang.org/grpc" ) type Server struct { - e *echo.Echo - grpcServer *grpc.Server + e *echo.Echo ID string Secret string Profile *profile.Profile Store *store.Store + // API services. + apiV2Service *apiv2.APIV2Service + // Asynchronous runners. backupRunner *BackupRunner telegramBot *telegram.Bot @@ -102,11 +103,9 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store apiV1Service := apiv1.NewAPIV1Service(s.Secret, profile, store) apiV1Service.Register(rootGroup) - // Register gPRC server services. - s.grpcServer = apiv2.NewGRPCServer(store) - + s.apiV2Service = apiv2.NewAPIV2Service(s.Secret, profile, store, s.Profile.Port+1) // Register gRPC gateway as api v2. - if err := apiv2.RegisterGateway(ctx, e, s.Profile.Port+1); err != nil { + if err := s.apiV2Service.RegisterGateway(ctx, e); err != nil { return nil, fmt.Errorf("failed to register gRPC gateway: %w", err) } @@ -127,7 +126,7 @@ func (s *Server) Start(ctx context.Context) error { return err } go func() { - if err := s.grpcServer.Serve(listen); err != nil { + if err := s.apiV2Service.GetGRPCServer().Serve(listen); err != nil { log.Error("grpc server listen error", zap.Error(err)) } }() @@ -220,6 +219,6 @@ func defaultGetRequestSkipper(c echo.Context) bool { } func defaultAPIRequestSkipper(c echo.Context) bool { - path := c.Path() - return util.HasPrefixes(path, "/api", "/api/v1") + path := c.Request().URL.Path + return util.HasPrefixes(path, "/api", "/api/v1", "api/v2") }