From c7a57191bd7e7fb73540dcb1d1053d825ede32b8 Mon Sep 17 00:00:00 2001 From: boojack Date: Sun, 2 Apr 2023 09:28:02 +0800 Subject: [PATCH] feat: add jwt auth (#1441) * feat: add jwt auth * chore: update --- go.mod | 5 +- go.sum | 10 +- plugin/http-getter/html_meta_test.go | 11 +- server/acl.go | 95 ---------- server/auth.go | 23 +-- server/auth/auth.go | 88 +++++++++ server/jwt.go | 256 +++++++++++++++++++++++++++ server/server.go | 7 +- 8 files changed, 359 insertions(+), 136 deletions(-) delete mode 100644 server/acl.go create mode 100644 server/auth/auth.go create mode 100644 server/jwt.go diff --git a/go.mod b/go.mod index 2df8995d..357e61e6 100644 --- a/go.mod +++ b/go.mod @@ -10,8 +10,6 @@ require ( github.com/aws/aws-sdk-go-v2/service/s3 v1.30.3 github.com/google/uuid v1.3.0 github.com/gorilla/feeds v1.1.1 - github.com/gorilla/sessions v1.2.1 - github.com/labstack/echo-contrib v0.13.0 github.com/labstack/echo/v4 v4.9.0 github.com/mattn/go-sqlite3 v1.14.9 github.com/pkg/errors v0.9.1 @@ -44,9 +42,8 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect github.com/golang-jwt/jwt v3.2.2+incompatible // indirect + github.com/golang-jwt/jwt/v4 v4.5.0 github.com/golang/protobuf v1.5.2 // indirect - github.com/gorilla/context v1.1.1 // indirect - github.com/gorilla/securecookie v1.1.1 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.0.1 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect diff --git a/go.sum b/go.sum index 64caa9ed..eaa894e1 100644 --- a/go.sum +++ b/go.sum @@ -106,6 +106,8 @@ github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2 github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= +github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= +github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= @@ -168,14 +170,8 @@ github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+ github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g= -github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8= -github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= github.com/gorilla/feeds v1.1.1 h1:HwKXxqzcRNg9to+BbvJog4+f3s/xzvtZXICcQGutYfY= github.com/gorilla/feeds v1.1.1/go.mod h1:Nk0jZrvPFZX1OBe5NPiddPw7CfwF6Q9eqzaBbaightA= -github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= -github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= -github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI= -github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= @@ -200,8 +196,6 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/labstack/echo-contrib v0.13.0 h1:bzSG0SpuZZd7BmJLvsWtPfU23W0Enh3K0tok3aENVKA= -github.com/labstack/echo-contrib v0.13.0/go.mod h1:IF9+MJu22ADOZEHD+bAV67XMIO3vNXUy7Naz/ABPHEs= github.com/labstack/echo/v4 v4.9.0 h1:wPOF1CE6gvt/kmbMR4dGzWvHMPT+sAEUJOwOTtvITVY= github.com/labstack/echo/v4 v4.9.0/go.mod h1:xkCDAdFCIf8jsFQ5NnbK7oqaF/yU1A1X20Ltm0OvSks= github.com/labstack/gommon v0.3.1 h1:OomWaJXm7xR6L1HmEtGyQf26TEn7V6X88mktX9kee9o= diff --git a/plugin/http-getter/html_meta_test.go b/plugin/http-getter/html_meta_test.go index 31920eac..310e64e4 100644 --- a/plugin/http-getter/html_meta_test.go +++ b/plugin/http-getter/html_meta_test.go @@ -10,16 +10,7 @@ func TestGetHTMLMeta(t *testing.T) { tests := []struct { urlStr string htmlMeta HTMLMeta - }{ - { - urlStr: "https://www.bytebase.com/blog/sql-review-tool-for-devs", - htmlMeta: HTMLMeta{ - Title: "The SQL Review Tool for Developers", - Description: "Reviewing SQL can be somewhat tedious, yet is essential to keep your database fleet reliable. At Bytebase, we are building a developer-first SQL review tool to empower the DevOps system.", - Image: "https://www.bytebase.com/static/blog/sql-review-tool-for-devs/dev-fighting-dba.webp", - }, - }, - } + }{} for _, test := range tests { metadata, err := GetHTMLMeta(test.urlStr) require.NoError(t, err) diff --git a/server/acl.go b/server/acl.go deleted file mode 100644 index 7c2a1a3b..00000000 --- a/server/acl.go +++ /dev/null @@ -1,95 +0,0 @@ -package server - -import ( - "fmt" - "net/http" - "strconv" - - "github.com/usememos/memos/api" - "github.com/usememos/memos/common" - - "github.com/gorilla/sessions" - "github.com/labstack/echo-contrib/session" - "github.com/labstack/echo/v4" -) - -var ( - userIDContextKey = "user-id" - sessionName = "memos_session" -) - -func getUserIDContextKey() string { - return userIDContextKey -} - -func setUserSession(ctx echo.Context, user *api.User) error { - sess, _ := session.Get(sessionName, ctx) - sess.Options = &sessions.Options{ - Path: "/", - MaxAge: 3600 * 24 * 30, - HttpOnly: true, - SameSite: http.SameSiteStrictMode, - } - sess.Values[userIDContextKey] = user.ID - err := sess.Save(ctx.Request(), ctx.Response()) - if err != nil { - return fmt.Errorf("failed to set session, err: %w", err) - } - return nil -} - -func removeUserSession(ctx echo.Context) error { - sess, _ := session.Get(sessionName, ctx) - sess.Options = &sessions.Options{ - Path: "/", - MaxAge: 0, - HttpOnly: true, - } - sess.Values[userIDContextKey] = nil - err := sess.Save(ctx.Request(), ctx.Response()) - if err != nil { - return fmt.Errorf("failed to set session, err: %w", err) - } - return nil -} - -func aclMiddleware(s *Server, next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - ctx := c.Request().Context() - path := c.Path() - - if s.defaultAuthSkipper(c) { - return next(c) - } - - sess, _ := session.Get(sessionName, c) - userIDValue := sess.Values[userIDContextKey] - if userIDValue != nil { - userID, _ := strconv.Atoi(fmt.Sprintf("%v", userIDValue)) - userFind := &api.UserFind{ - ID: &userID, - } - user, err := s.Store.FindUser(ctx, userFind) - if err != nil && common.ErrorCode(err) != common.NotFound { - return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find user by ID: %d", userID)).SetInternal(err) - } - if user != nil { - if user.RowStatus == api.Archived { - return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with username %s", user.Username)) - } - c.Set(getUserIDContextKey(), userID) - } - } - - if common.HasPrefixes(path, "/api/ping", "/api/status", "/api/idp", "/api/user/:id", "/api/memo") && c.Request().Method == http.MethodGet { - return next(c) - } - - userID := c.Get(getUserIDContextKey()) - if userID == nil { - return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") - } - - return next(c) - } -} diff --git a/server/auth.go b/server/auth.go index 982db801..7f238774 100644 --- a/server/auth.go +++ b/server/auth.go @@ -17,7 +17,7 @@ import ( "golang.org/x/crypto/bcrypt" ) -func (s *Server) registerAuthRoutes(g *echo.Group) { +func (s *Server) registerAuthRoutes(g *echo.Group, secret string) { g.POST("/auth/signin", func(c echo.Context) error { ctx := c.Request().Context() signin := &api.SignIn{} @@ -44,8 +44,8 @@ func (s *Server) registerAuthRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusUnauthorized, "Incorrect login credentials, please try again") } - if err = setUserSession(c, user); err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to set signin session").SetInternal(err) + if err := GenerateTokensAndSetCookies(c, user, s.Profile.Mode, secret); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate tokens").SetInternal(err) } if err := s.createUserAuthSignInActivity(c, user); err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err) @@ -128,8 +128,8 @@ func (s *Server) registerAuthRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with username %s", userInfo.Identifier)) } - if err = setUserSession(c, user); err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to set signin session").SetInternal(err) + if err := GenerateTokensAndSetCookies(c, user, s.Profile.Mode, secret); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate tokens").SetInternal(err) } if err := s.createUserAuthSignInActivity(c, user); err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err) @@ -196,23 +196,18 @@ func (s *Server) registerAuthRoutes(g *echo.Group) { if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create user").SetInternal(err) } + if err := GenerateTokensAndSetCookies(c, user, s.Profile.Mode, secret); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate tokens").SetInternal(err) + } if err := s.createUserAuthSignUpActivity(c, user); err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err) } - err = setUserSession(c, user) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to set signup session").SetInternal(err) - } return c.JSON(http.StatusOK, composeResponse(user)) }) g.POST("/auth/signout", func(c echo.Context) error { - err := removeUserSession(c) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to set sign out session").SetInternal(err) - } - + RemoveTokensAndCookies(c) return c.JSON(http.StatusOK, true) }) } diff --git a/server/auth/auth.go b/server/auth/auth.go new file mode 100644 index 00000000..f481795f --- /dev/null +++ b/server/auth/auth.go @@ -0,0 +1,88 @@ +package auth + +import ( + "fmt" + "strconv" + "time" + + "github.com/golang-jwt/jwt/v4" +) + +const ( + issuer = "memos" + // Signing key section. For now, this is only used for signing, not for verifying since we only + // have 1 version. But it will be used to maintain backward compatibility if we change the signing mechanism. + keyID = "v1" + // AccessTokenAudienceFmt is the format of the acccess token audience. + AccessTokenAudienceFmt = "user.access.%s" + // RefreshTokenAudienceFmt is the format of the refresh token audience. + RefreshTokenAudienceFmt = "user.refresh.%s" + apiTokenDuration = 2 * time.Hour + accessTokenDuration = 24 * time.Hour + refreshTokenDuration = 7 * 24 * time.Hour + // RefreshThresholdDuration is the threshold duration for refreshing token. + RefreshThresholdDuration = 1 * time.Hour + + // CookieExpDuration expires slightly earlier than the jwt expiration. Client would be logged out if the user + // cookie expires, thus the client would always logout first before attempting to make a request with the expired jwt. + // Suppose we have a valid refresh token, we will refresh the token in 2 cases: + // 1. The access token is about to expire in <> + // 2. The access token has already expired, we refresh the token so that the ongoing request can pass through. + CookieExpDuration = refreshTokenDuration - 1*time.Minute + // AccessTokenCookieName is the cookie name of access token. + AccessTokenCookieName = "access-token" + // RefreshTokenCookieName is the cookie name of refresh token. + RefreshTokenCookieName = "refresh-token" + // UserIDCookieName is the cookie name of user ID. + UserIDCookieName = "user" +) + +type claimsMessage struct { + Name string `json:"name"` + jwt.RegisteredClaims +} + +// GenerateAPIToken generates an API token. +func GenerateAPIToken(userName string, userID int, mode string, secret string) (string, error) { + expirationTime := time.Now().Add(apiTokenDuration) + return generateToken(userName, userID, fmt.Sprintf(AccessTokenAudienceFmt, mode), expirationTime, []byte(secret)) +} + +// GenerateAccessToken generates an access token for web. +func GenerateAccessToken(userName string, userID int, mode string, secret string) (string, error) { + expirationTime := time.Now().Add(accessTokenDuration) + return generateToken(userName, userID, fmt.Sprintf(AccessTokenAudienceFmt, mode), expirationTime, []byte(secret)) +} + +// GenerateRefreshToken generates a refresh token for web. +func GenerateRefreshToken(userName string, userID int, mode string, secret string) (string, error) { + expirationTime := time.Now().Add(refreshTokenDuration) + return generateToken(userName, userID, fmt.Sprintf(RefreshTokenAudienceFmt, mode), expirationTime, []byte(secret)) +} + +func generateToken(username string, userID int, aud string, expirationTime time.Time, secret []byte) (string, error) { + // Create the JWT claims, which includes the username and expiry time. + claims := &claimsMessage{ + Name: username, + RegisteredClaims: jwt.RegisteredClaims{ + Audience: jwt.ClaimStrings{aud}, + // In JWT, the expiry time is expressed as unix milliseconds. + ExpiresAt: jwt.NewNumericDate(expirationTime), + IssuedAt: jwt.NewNumericDate(time.Now()), + Issuer: issuer, + Subject: strconv.Itoa(userID), + }, + } + + // Declare the token with the HS256 algorithm used for signing, and the claims. + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + token.Header["kid"] = keyID + + // Create the JWT string. + tokenString, err := token.SignedString(secret) + if err != nil { + return "", err + } + + return tokenString, nil +} diff --git a/server/jwt.go b/server/jwt.go new file mode 100644 index 00000000..613bc819 --- /dev/null +++ b/server/jwt.go @@ -0,0 +1,256 @@ +package server + +import ( + "errors" + "fmt" + "net/http" + "strconv" + "strings" + "time" + + "github.com/golang-jwt/jwt/v4" + "github.com/labstack/echo/v4" + pkgerrors "github.com/pkg/errors" + "github.com/usememos/memos/api" + "github.com/usememos/memos/common" + "github.com/usememos/memos/server/auth" +) + +const ( + // Context section + // The key name used to store user id in the context + // user id is extracted from the jwt token subject field. + userIDContextKey = "user-id" +) + +// Claims creates a struct that will be encoded to a JWT. +// We add jwt.RegisteredClaims as an embedded type, to provide fields such as name. +type Claims struct { + Name string `json:"name"` + jwt.RegisteredClaims +} + +func getUserIDContextKey() string { + return userIDContextKey +} + +// GenerateTokensAndSetCookies generates jwt token and saves it to the http-only cookie. +func GenerateTokensAndSetCookies(c echo.Context, user *api.User, mode string, secret string) error { + accessToken, err := auth.GenerateAccessToken(user.Username, user.ID, mode, secret) + if err != nil { + return pkgerrors.Wrap(err, "failed to generate access token") + } + + cookieExp := time.Now().Add(auth.CookieExpDuration) + setTokenCookie(c, auth.AccessTokenCookieName, accessToken, cookieExp) + + // We generate here a new refresh token and saving it to the cookie. + refreshToken, err := auth.GenerateRefreshToken(user.Username, user.ID, mode, secret) + if err != nil { + return pkgerrors.Wrap(err, "failed to generate refresh token") + } + setTokenCookie(c, auth.RefreshTokenCookieName, refreshToken, cookieExp) + + return nil +} + +// RemoveTokensAndCookies removes the jwt token and refresh token from the cookies. +func RemoveTokensAndCookies(c echo.Context) { + // We set the expiration time to the past, so that the cookie will be removed. + cookieExp := time.Now().Add(-1 * time.Hour) + setTokenCookie(c, auth.AccessTokenCookieName, "", cookieExp) + setTokenCookie(c, auth.RefreshTokenCookieName, "", cookieExp) +} + +// Here we are creating a new cookie, which will store the valid JWT token. +func setTokenCookie(c echo.Context, name, token string, expiration time.Time) { + cookie := new(http.Cookie) + cookie.Name = name + cookie.Value = token + cookie.Expires = expiration + cookie.Path = "/" + // Http-only helps mitigate the risk of client side script accessing the protected cookie. + cookie.HttpOnly = true + cookie.SameSite = http.SameSiteStrictMode + c.SetCookie(cookie) +} + +func extractTokenFromHeader(c echo.Context) (string, error) { + authHeader := c.Request().Header.Get("Authorization") + if authHeader == "" { + return "", nil + } + + authHeaderParts := strings.Fields(authHeader) + if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { + return "", errors.New("Authorization header format must be Bearer {token}") + } + + return authHeaderParts[1], nil +} + +func findAccessToken(c echo.Context) string { + accessToken := "" + cookie, _ := c.Cookie(auth.AccessTokenCookieName) + if cookie != nil { + accessToken = cookie.Value + } + if accessToken == "" { + accessToken, _ = extractTokenFromHeader(c) + } + + return accessToken +} + +// JWTMiddleware validates the access token. +// If the access token is about to expire or has expired and the request has a valid refresh token, it +// will try to generate new access token and refresh token. +func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.HandlerFunc { + return func(c echo.Context) error { + path := c.Request().URL.Path + method := c.Request().Method + mode := server.Profile.Mode + + if server.defaultAuthSkipper(c) { + return next(c) + } + + // Skip validation for server status endpoints. + if common.HasPrefixes(path, "/api/ping", "/api/status", "/api/idp", "/api/user/:id") && method == http.MethodGet { + return next(c) + } + + token := findAccessToken(c) + if token == "" { + // When the request is not authenticated, we allow the user to access the memo endpoints for those public memos. + if common.HasPrefixes(path, "/api/memo") && method == http.MethodGet { + return next(c) + } + return echo.NewHTTPError(http.StatusUnauthorized, "Missing access token") + } + + claims := &Claims{} + accessToken, err := jwt.ParseWithClaims(token, claims, func(t *jwt.Token) (any, error) { + if t.Method.Alg() != jwt.SigningMethodHS256.Name { + return nil, pkgerrors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256) + } + if kid, ok := t.Header["kid"].(string); ok { + if kid == "v1" { + return []byte(secret), nil + } + } + return nil, pkgerrors.Errorf("unexpected access token kid=%v", t.Header["kid"]) + }) + + if !audienceContains(claims.Audience, fmt.Sprintf(auth.AccessTokenAudienceFmt, mode)) { + return echo.NewHTTPError(http.StatusUnauthorized, + fmt.Sprintf("Invalid access token, audience mismatch, got %q, expected %q. you may send request to the wrong environment", + claims.Audience, + fmt.Sprintf(auth.AccessTokenAudienceFmt, mode), + )) + } + + generateToken := time.Until(claims.ExpiresAt.Time) < auth.RefreshThresholdDuration + if err != nil { + var ve *jwt.ValidationError + if errors.As(err, &ve) { + // If expiration error is the only error, we will clear the err + // and generate new access token and refresh token + if ve.Errors == jwt.ValidationErrorExpired { + generateToken = true + } + } else { + return &echo.HTTPError{ + Code: http.StatusUnauthorized, + Message: "Invalid or expired access token", + Internal: err, + } + } + } + + // We either have a valid access token or we will attempt to generate new access token and refresh token + ctx := c.Request().Context() + userID, err := strconv.Atoi(claims.Subject) + if err != nil { + return echo.NewHTTPError(http.StatusUnauthorized, "Malformed ID in the token.") + } + + // Even if there is no error, we still need to make sure the user still exists. + user, err := server.Store.FindUser(ctx, &api.UserFind{ + ID: &userID, + }) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Server error to find user ID: %d", userID)).SetInternal(err) + } + if user == nil { + return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("Failed to find user ID: %d", userID)) + } + + if generateToken { + generateTokenFunc := func() error { + rc, err := c.Cookie(auth.RefreshTokenCookieName) + + if err != nil { + return echo.NewHTTPError(http.StatusUnauthorized, "Failed to generate access token. Missing refresh token.") + } + + // Parses token and checks if it's valid. + refreshTokenClaims := &Claims{} + refreshToken, err := jwt.ParseWithClaims(rc.Value, refreshTokenClaims, func(t *jwt.Token) (any, error) { + if t.Method.Alg() != jwt.SigningMethodHS256.Name { + return nil, pkgerrors.Errorf("unexpected refresh token signing method=%v, expected %v", t.Header["alg"], jwt.SigningMethodHS256) + } + + if kid, ok := t.Header["kid"].(string); ok { + if kid == "v1" { + return []byte(secret), nil + } + } + return nil, pkgerrors.Errorf("unexpected refresh token kid=%v", t.Header["kid"]) + }) + if err != nil { + if err == jwt.ErrSignatureInvalid { + return echo.NewHTTPError(http.StatusUnauthorized, "Failed to generate access token. Invalid refresh token signature.") + } + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Server error to refresh expired token. User Id %d", userID)).SetInternal(err) + } + + if !audienceContains(refreshTokenClaims.Audience, fmt.Sprintf(auth.RefreshTokenAudienceFmt, mode)) { + return echo.NewHTTPError(http.StatusUnauthorized, + fmt.Sprintf("Invalid refresh token, audience mismatch, got %q, expected %q. you may send request to the wrong environment", + refreshTokenClaims.Audience, + fmt.Sprintf(auth.RefreshTokenAudienceFmt, mode), + )) + } + + // If we have a valid refresh token, we will generate new access token and refresh token + if refreshToken != nil && refreshToken.Valid { + if err := GenerateTokensAndSetCookies(c, user, mode, secret); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Server error to refresh expired token. User Id %d", userID)).SetInternal(err) + } + } + + return nil + } + + // It may happen that we still have a valid access token, but we encounter issue when trying to generate new token + // In such case, we won't return the error. + if err := generateTokenFunc(); err != nil && !accessToken.Valid { + return err + } + } + + // Stores userID into context. + c.Set(getUserIDContextKey(), userID) + return next(c) + } +} + +func audienceContains(audience jwt.ClaimStrings, token string) bool { + for _, v := range audience { + if v == token { + return true + } + } + return false +} diff --git a/server/server.go b/server/server.go index 3ed3b8db..e853aadd 100644 --- a/server/server.go +++ b/server/server.go @@ -13,8 +13,6 @@ import ( "github.com/usememos/memos/store" "github.com/usememos/memos/store/db" - "github.com/gorilla/sessions" - "github.com/labstack/echo-contrib/session" "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" ) @@ -88,7 +86,6 @@ func NewServer(ctx context.Context, profile *profile.Profile) (*Server, error) { return nil, err } } - e.Use(session.Middleware(sessions.NewCookieStore([]byte(secretSessionName)))) embedFrontend(e) @@ -101,10 +98,10 @@ func NewServer(ctx context.Context, profile *profile.Profile) (*Server, error) { apiGroup := e.Group("/api") apiGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc { - return aclMiddleware(s, next) + return JWTMiddleware(s, next, secretSessionName) }) s.registerSystemRoutes(apiGroup) - s.registerAuthRoutes(apiGroup) + s.registerAuthRoutes(apiGroup, secretSessionName) s.registerUserRoutes(apiGroup) s.registerMemoRoutes(apiGroup) s.registerShortcutRoutes(apiGroup)