diff --git a/proto/api/v1/user_service.proto b/proto/api/v1/user_service.proto index c43bba6b6..44bbe6aaf 100644 --- a/proto/api/v1/user_service.proto +++ b/proto/api/v1/user_service.proto @@ -511,9 +511,6 @@ message UserSession { // Optional. Browser name and version (e.g., "Chrome 119.0"). string browser = 5 [(google.api.field_behavior) = OPTIONAL]; - - // Optional. Geographic location (country code, e.g., "US"). - string country = 6 [(google.api.field_behavior) = OPTIONAL]; } } diff --git a/proto/gen/api/v1/user_service.pb.go b/proto/gen/api/v1/user_service.pb.go index 7b9782254..28de68eb8 100644 --- a/proto/gen/api/v1/user_service.pb.go +++ b/proto/gen/api/v1/user_service.pb.go @@ -1868,9 +1868,7 @@ type UserSession_ClientInfo struct { // Optional. Operating system (e.g., "iOS 17.0", "Windows 11"). Os string `protobuf:"bytes,4,opt,name=os,proto3" json:"os,omitempty"` // Optional. Browser name and version (e.g., "Chrome 119.0"). - Browser string `protobuf:"bytes,5,opt,name=browser,proto3" json:"browser,omitempty"` - // Optional. Geographic location (country code, e.g., "US"). - Country string `protobuf:"bytes,6,opt,name=country,proto3" json:"country,omitempty"` + Browser string `protobuf:"bytes,5,opt,name=browser,proto3" json:"browser,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -1940,13 +1938,6 @@ func (x *UserSession_ClientInfo) GetBrowser() string { return "" } -func (x *UserSession_ClientInfo) GetCountry() string { - if x != nil { - return x.Country - } - return "" -} - var File_api_v1_user_service_proto protoreflect.FileDescriptor const file_api_v1_user_service_proto_rawDesc = "" + @@ -2084,7 +2075,7 @@ const file_api_v1_user_service_proto_rawDesc = "" + "\x0faccess_token_id\x18\x03 \x01(\tB\x03\xe0A\x01R\raccessTokenId\"X\n" + "\x1cDeleteUserAccessTokenRequest\x128\n" + "\x04name\x18\x01 \x01(\tB$\xe0A\x02\xfaA\x1e\n" + - "\x1cmemos.api.v1/UserAccessTokenR\x04name\"\xf5\x04\n" + + "\x1cmemos.api.v1/UserAccessTokenR\x04name\"\xd6\x04\n" + "\vUserSession\x12\x17\n" + "\x04name\x18\x01 \x01(\tB\x03\xe0A\bR\x04name\x12\"\n" + "\n" + @@ -2095,7 +2086,7 @@ const file_api_v1_user_service_proto_rawDesc = "" + "expireTime\x12M\n" + "\x12last_accessed_time\x18\x05 \x01(\v2\x1a.google.protobuf.TimestampB\x03\xe0A\x03R\x10lastAccessedTime\x12J\n" + "\vclient_info\x18\x06 \x01(\v2$.memos.api.v1.UserSession.ClientInfoB\x03\xe0A\x03R\n" + - "clientInfo\x1a\xc3\x01\n" + + "clientInfo\x1a\xa4\x01\n" + "\n" + "ClientInfo\x12\x1d\n" + "\n" + @@ -2105,8 +2096,7 @@ const file_api_v1_user_service_proto_rawDesc = "" + "\vdevice_type\x18\x03 \x01(\tB\x03\xe0A\x01R\n" + "deviceType\x12\x13\n" + "\x02os\x18\x04 \x01(\tB\x03\xe0A\x01R\x02os\x12\x1d\n" + - "\abrowser\x18\x05 \x01(\tB\x03\xe0A\x01R\abrowser\x12\x1d\n" + - "\acountry\x18\x06 \x01(\tB\x03\xe0A\x01R\acountry:D\xeaAA\n" + + "\abrowser\x18\x05 \x01(\tB\x03\xe0A\x01R\abrowser:D\xeaAA\n" + "\x18memos.api.v1/UserSession\x12\x1fusers/{user}/sessions/{session}\x1a\x04name\"L\n" + "\x17ListUserSessionsRequest\x121\n" + "\x06parent\x18\x01 \x01(\tB\x19\xe0A\x02\xfaA\x13\n" + diff --git a/proto/gen/apidocs.swagger.yaml b/proto/gen/apidocs.swagger.yaml index 169ae5b5a..9c90346c1 100644 --- a/proto/gen/apidocs.swagger.yaml +++ b/proto/gen/apidocs.swagger.yaml @@ -4340,9 +4340,6 @@ definitions: browser: type: string description: Optional. Browser name and version (e.g., "Chrome 119.0"). - country: - type: string - description: Optional. Geographic location (country code, e.g., "US"). v1UserStats: type: object properties: diff --git a/proto/gen/store/user_setting.pb.go b/proto/gen/store/user_setting.pb.go index 93c3af7d1..8246a42d6 100644 --- a/proto/gen/store/user_setting.pb.go +++ b/proto/gen/store/user_setting.pb.go @@ -590,9 +590,7 @@ type SessionsUserSetting_ClientInfo struct { // Optional. Operating system (e.g., "iOS 17.0", "Windows 11"). Os string `protobuf:"bytes,4,opt,name=os,proto3" json:"os,omitempty"` // Optional. Browser name and version (e.g., "Chrome 119.0"). - Browser string `protobuf:"bytes,5,opt,name=browser,proto3" json:"browser,omitempty"` - // Optional. Geographic location (country code, e.g., "US"). - Country string `protobuf:"bytes,6,opt,name=country,proto3" json:"country,omitempty"` + Browser string `protobuf:"bytes,5,opt,name=browser,proto3" json:"browser,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -662,13 +660,6 @@ func (x *SessionsUserSetting_ClientInfo) GetBrowser() string { return "" } -func (x *SessionsUserSetting_ClientInfo) GetCountry() string { - if x != nil { - return x.Country - } - return "" -} - var File_store_user_setting_proto protoreflect.FileDescriptor const file_store_user_setting_proto_rawDesc = "" + @@ -696,7 +687,7 @@ const file_store_user_setting_proto_rawDesc = "" + "\bShortcut\x12\x0e\n" + "\x02id\x18\x01 \x01(\tR\x02id\x12\x14\n" + "\x05title\x18\x02 \x01(\tR\x05title\x12\x16\n" + - "\x06filter\x18\x03 \x01(\tR\x06filter\"\xca\x04\n" + + "\x06filter\x18\x03 \x01(\tR\x06filter\"\xb0\x04\n" + "\x13SessionsUserSetting\x12D\n" + "\bsessions\x18\x01 \x03(\v2(.memos.store.SessionsUserSetting.SessionR\bsessions\x1a\xba\x02\n" + "\aSession\x12\x1d\n" + @@ -708,7 +699,7 @@ const file_store_user_setting_proto_rawDesc = "" + "expireTime\x12H\n" + "\x12last_accessed_time\x18\x04 \x01(\v2\x1a.google.protobuf.TimestampR\x10lastAccessedTime\x12L\n" + "\vclient_info\x18\x05 \x01(\v2+.memos.store.SessionsUserSetting.ClientInfoR\n" + - "clientInfo\x1a\xaf\x01\n" + + "clientInfo\x1a\x95\x01\n" + "\n" + "ClientInfo\x12\x1d\n" + "\n" + @@ -718,8 +709,7 @@ const file_store_user_setting_proto_rawDesc = "" + "\vdevice_type\x18\x03 \x01(\tR\n" + "deviceType\x12\x0e\n" + "\x02os\x18\x04 \x01(\tR\x02os\x12\x18\n" + - "\abrowser\x18\x05 \x01(\tR\abrowser\x12\x18\n" + - "\acountry\x18\x06 \x01(\tR\acountry*\x93\x01\n" + + "\abrowser\x18\x05 \x01(\tR\abrowser*\x93\x01\n" + "\x0eUserSettingKey\x12 \n" + "\x1cUSER_SETTING_KEY_UNSPECIFIED\x10\x00\x12\x11\n" + "\rACCESS_TOKENS\x10\x01\x12\n" + diff --git a/proto/store/user_setting.proto b/proto/store/user_setting.proto index 9f2ae877d..0deb56cd8 100644 --- a/proto/store/user_setting.proto +++ b/proto/store/user_setting.proto @@ -80,8 +80,6 @@ message SessionsUserSetting { string os = 4; // Optional. Browser name and version (e.g., "Chrome 119.0"). string browser = 5; - // Optional. Geographic location (country code, e.g., "US"). - string country = 6; } repeated Session sessions = 1; diff --git a/server/router/api/v1/acl.go b/server/router/api/v1/acl.go index bf8b06538..182c8bdb4 100644 --- a/server/router/api/v1/acl.go +++ b/server/router/api/v1/acl.go @@ -52,22 +52,38 @@ func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, re return nil, status.Errorf(codes.Unauthenticated, "failed to parse metadata from incoming context") } - // Try to get access token from either Authorization header or cookie - accessToken, err := getTokenFromMetadata(md) - if err != nil { - return nil, status.Errorf(codes.Unauthenticated, "failed to get access token: %v", err) + // Try to authenticate via session ID (from cookie) first + if sessionCookieValue, err := getSessionIDFromMetadata(md); err == nil && sessionCookieValue != "" { + user, err := in.authenticateBySession(ctx, sessionCookieValue) + if err == nil && user != nil { + // Extract just the sessionID part for context storage + _, sessionID, parseErr := ParseSessionCookieValue(sessionCookieValue) + if parseErr != nil { + return nil, status.Errorf(codes.Internal, "failed to parse session cookie: %v", parseErr) + } + return in.handleAuthenticatedRequest(ctx, request, serverInfo, handler, user, sessionID, "") + } } - // Authenticate using access token (which also validates sessions when it's from cookie) - user, err := in.authenticateByAccessToken(ctx, accessToken) - if err != nil { - // Check if this method is in the allowlist first - if isUnauthorizeAllowedMethod(serverInfo.FullMethod) { - return handler(ctx, request) + // Try to authenticate via JWT access token (from Authorization header) + if accessToken, err := getAccessTokenFromMetadata(md); err == nil && accessToken != "" { + user, err := in.authenticateByJWT(ctx, accessToken) + if err == nil && user != nil { + return in.handleAuthenticatedRequest(ctx, request, serverInfo, handler, user, "", accessToken) } - return nil, err } + // If no valid authentication found, check if this method is in the allowlist (public endpoints) + if isUnauthorizeAllowedMethod(serverInfo.FullMethod) { + return handler(ctx, request) + } + + // If authentication is required but not found, reject the request + return nil, status.Errorf(codes.Unauthenticated, "authentication required") +} + +// handleAuthenticatedRequest processes an authenticated request with the given user and auth info. +func (in *GRPCAuthInterceptor) handleAuthenticatedRequest(ctx context.Context, request any, serverInfo *grpc.UnaryServerInfo, handler grpc.UnaryHandler, user *store.User, sessionID, accessToken string) (any, error) { // Check user status if user.RowStatus == store.Archived { return nil, errors.Errorf("user %q is archived", user.Username) @@ -79,22 +95,21 @@ func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, re // Set context values ctx = context.WithValue(ctx, userIDContextKey, user.ID) - // Determine if this came from cookie (session) or header (API token) - if _, headerErr := getAccessTokenFromMetadata(md); headerErr != nil { - // Came from cookie, treat as session - ctx = context.WithValue(ctx, sessionIDContextKey, accessToken) + if sessionID != "" { + // Session-based authentication + ctx = context.WithValue(ctx, sessionIDContextKey, sessionID) // Update session last accessed time - _ = in.updateSessionLastAccessed(ctx, user.ID, accessToken) - } else { - // Came from Authorization header, treat as API token + _ = in.updateSessionLastAccessed(ctx, user.ID, sessionID) + } else if accessToken != "" { + // JWT access token-based authentication ctx = context.WithValue(ctx, accessTokenContextKey, accessToken) } return handler(ctx, request) } -// authenticateByAccessToken authenticates a user using access token from Authorization header or cookie. -func (in *GRPCAuthInterceptor) authenticateByAccessToken(ctx context.Context, accessToken string) (*store.User, error) { +// authenticateByJWT authenticates a user using JWT access token from Authorization header. +func (in *GRPCAuthInterceptor) authenticateByJWT(ctx context.Context, accessToken string) (*store.User, error) { if accessToken == "" { return nil, status.Errorf(codes.Unauthenticated, "access token not found") } @@ -114,7 +129,7 @@ func (in *GRPCAuthInterceptor) authenticateByAccessToken(ctx context.Context, ac return nil, status.Errorf(codes.Unauthenticated, "Invalid or expired access token") } - // We either have a valid access token or we will attempt to generate new access token. + // Get user from JWT claims userID, err := util.ConvertStringToInt32(claims.Subject) if err != nil { return nil, errors.Wrap(err, "malformed ID in the token") @@ -132,6 +147,7 @@ func (in *GRPCAuthInterceptor) authenticateByAccessToken(ctx context.Context, ac return nil, errors.Errorf("user %q is archived", userID) } + // Validate that this access token exists in the user's access tokens accessTokens, err := in.Store.GetUserAccessTokens(ctx, user.ID) if err != nil { return nil, errors.Wrapf(err, "failed to get user access tokens") @@ -140,10 +156,43 @@ func (in *GRPCAuthInterceptor) authenticateByAccessToken(ctx context.Context, ac return nil, status.Errorf(codes.Unauthenticated, "invalid access token") } - // For tokens that might be used as session IDs (from cookies), also validate session existence - // This is a best-effort check - if sessions can't be retrieved or token isn't a session, that's ok - if sessions, err := in.Store.GetUserSessions(ctx, user.ID); err == nil { - validateUserSession(accessToken, sessions) // Result doesn't matter for API tokens + return user, nil +} + +// authenticateBySession authenticates a user using session ID from cookie. +func (in *GRPCAuthInterceptor) authenticateBySession(ctx context.Context, sessionCookieValue string) (*store.User, error) { + if sessionCookieValue == "" { + return nil, status.Errorf(codes.Unauthenticated, "session cookie value not found") + } + + // Parse the cookie value to extract userID and sessionID + userID, sessionID, err := ParseSessionCookieValue(sessionCookieValue) + if err != nil { + return nil, status.Errorf(codes.Unauthenticated, "invalid session cookie format: %v", err) + } + + // Get the user directly using the userID from the cookie + user, err := in.Store.GetUser(ctx, &store.FindUser{ + ID: &userID, + }) + if err != nil { + return nil, errors.Wrap(err, "failed to get user") + } + if user == nil { + return nil, status.Errorf(codes.Unauthenticated, "user not found") + } + if user.RowStatus == store.Archived { + return nil, status.Errorf(codes.Unauthenticated, "user is archived") + } + + // Get user sessions and validate the sessionID + sessions, err := in.Store.GetUserSessions(ctx, userID) + if err != nil { + return nil, errors.Wrap(err, "failed to get user sessions") + } + + if !validateUserSession(sessionID, sessions) { + return nil, status.Errorf(codes.Unauthenticated, "invalid or expired session") } return user, nil @@ -168,6 +217,24 @@ func validateUserSession(sessionID string, userSessions []*storepb.SessionsUserS return false } +// getSessionIDFromMetadata extracts session cookie value from cookie. +func getSessionIDFromMetadata(md metadata.MD) (string, error) { + // Check the cookie header for session cookie value + var sessionCookieValue string + for _, t := range append(md.Get("grpcgateway-cookie"), md.Get("cookie")...) { + header := http.Header{} + header.Add("Cookie", t) + request := http.Request{Header: header} + if v, _ := request.Cookie(SessionCookieName); v != nil { + sessionCookieValue = v.Value + } + } + if sessionCookieValue == "" { + return "", errors.New("session cookie not found") + } + return sessionCookieValue, nil +} + // getAccessTokenFromMetadata extracts access token from Authorization header. func getAccessTokenFromMetadata(md metadata.MD) (string, error) { // Check the HTTP request Authorization header. @@ -182,29 +249,6 @@ func getAccessTokenFromMetadata(md metadata.MD) (string, error) { return authHeaderParts[1], nil } -func getTokenFromMetadata(md metadata.MD) (string, error) { - // Check the HTTP request header first. - authorizationHeaders := md.Get("Authorization") - if len(authorizationHeaders) > 0 { - authHeaderParts := strings.Fields(authorizationHeaders[0]) - if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { - return "", errors.New("authorization header format must be Bearer {token}") - } - return authHeaderParts[1], nil - } - // Check the cookie header. - var accessToken string - for _, t := range append(md.Get("grpcgateway-cookie"), md.Get("cookie")...) { - header := http.Header{} - header.Add("Cookie", t) - request := http.Request{Header: header} - if v, _ := request.Cookie(AccessTokenCookieName); v != nil { - accessToken = v.Value - } - } - return accessToken, nil -} - func validateAccessToken(accessTokenString string, userAccessTokens []*storepb.AccessTokensUserSetting_AccessToken) bool { for _, userAccessToken := range userAccessTokens { if accessTokenString == userAccessToken.AccessToken { diff --git a/server/router/api/v1/auth.go b/server/router/api/v1/auth.go index 78868f427..b07f76361 100644 --- a/server/router/api/v1/auth.go +++ b/server/router/api/v1/auth.go @@ -2,9 +2,12 @@ package v1 import ( "fmt" + "strings" "time" "github.com/golang-jwt/jwt/v5" + + "github.com/usememos/memos/internal/util" ) const ( @@ -20,8 +23,8 @@ const ( // CookieExpDuration expires slightly earlier than the jwt expiration. Client would be logged out if the user // cookie expires, thus the client would always logout first before attempting to make a request with the expired jwt. CookieExpDuration = AccessTokenDuration - 1*time.Minute - // AccessTokenCookieName is the cookie name of access token. - AccessTokenCookieName = "memos.access-token" + // SessionCookieName is the cookie name of user session ID. + SessionCookieName = "user_session" ) type ClaimsMessage struct { @@ -61,3 +64,28 @@ func generateToken(username string, userID int32, audience string, expirationTim return tokenString, nil } + +// GenerateSessionID generates a unique session ID using UUIDv4. +func GenerateSessionID() (string, error) { + return util.GenUUID(), nil +} + +// BuildSessionCookieValue builds the session cookie value in format {userID}-{sessionID}. +func BuildSessionCookieValue(userID int32, sessionID string) string { + return fmt.Sprintf("%d-%s", userID, sessionID) +} + +// ParseSessionCookieValue parses the session cookie value to extract userID and sessionID. +func ParseSessionCookieValue(cookieValue string) (int32, string, error) { + parts := strings.SplitN(cookieValue, "-", 2) + if len(parts) != 2 { + return 0, "", fmt.Errorf("invalid session cookie format") + } + + userID, err := util.ConvertStringToInt32(parts[0]) + if err != nil { + return 0, "", fmt.Errorf("invalid user ID in session cookie: %v", err) + } + + return userID, parts[1], nil +} diff --git a/server/router/api/v1/auth_service.go b/server/router/api/v1/auth_service.go index 96f1ba5c0..aede7269e 100644 --- a/server/router/api/v1/auth_service.go +++ b/server/router/api/v1/auth_service.go @@ -36,9 +36,9 @@ func (s *APIV1Service) GetCurrentSession(ctx context.Context, _ *v1pb.GetCurrent return nil, status.Errorf(codes.Unauthenticated, "failed to get current user: %v", err) } if user == nil { - // Set the cookie header to expire access token. - if err := s.clearAccessTokenCookie(ctx); err != nil { - return nil, status.Errorf(codes.Internal, "failed to set grpc header: %v", err) + // Clear auth cookies + if err := s.clearAuthCookies(ctx); err != nil { + return nil, status.Errorf(codes.Internal, "failed to clear auth cookies: %v", err) } return nil, status.Errorf(codes.Unauthenticated, "user not found") } @@ -178,6 +178,7 @@ func (s *APIV1Service) CreateSession(ctx context.Context, request *v1pb.CreateSe } func (s *APIV1Service) doSignIn(ctx context.Context, user *store.User, expireTime time.Time) error { + // Generate JWT access token for API use accessToken, err := GenerateAccessToken(user.Email, user.ID, expireTime, []byte(s.Secret)) if err != nil { return status.Errorf(codes.Internal, "failed to generate access token, error: %v", err) @@ -186,19 +187,27 @@ func (s *APIV1Service) doSignIn(ctx context.Context, user *store.User, expireTim return status.Errorf(codes.Internal, "failed to upsert access token to store, error: %v", err) } + // Generate unique session ID for web use + sessionID, err := GenerateSessionID() + if err != nil { + return status.Errorf(codes.Internal, "failed to generate session ID, error: %v", err) + } + // Track session in user settings - if err := s.trackUserSession(ctx, user.ID, accessToken, expireTime); err != nil { + if err := s.trackUserSession(ctx, user.ID, sessionID, expireTime); err != nil { // Log the error but don't fail the login if session tracking fails // This ensures backward compatibility slog.Error("failed to track user session", "error", err) } - cookie, err := s.buildAccessTokenCookie(ctx, accessToken, expireTime) + // Set session cookie for web use (format: userID-sessionID) + sessionCookieValue := BuildSessionCookieValue(user.ID, sessionID) + sessionCookie, err := s.buildSessionCookie(ctx, sessionCookieValue, expireTime) if err != nil { - return status.Errorf(codes.Internal, "failed to build access token cookie, error: %v", err) + return status.Errorf(codes.Internal, "failed to build session cookie, error: %v", err) } if err := grpc.SetHeader(ctx, metadata.New(map[string]string{ - "Set-Cookie": cookie, + "Set-Cookie": sessionCookie, })); err != nil { return status.Errorf(codes.Internal, "failed to set grpc header, error: %v", err) } @@ -281,28 +290,31 @@ func (s *APIV1Service) DeleteSession(ctx context.Context, _ *v1pb.DeleteSessionR } } - if err := s.clearAccessTokenCookie(ctx); err != nil { - return nil, status.Errorf(codes.Internal, "failed to set grpc header, error: %v", err) + if err := s.clearAuthCookies(ctx); err != nil { + return nil, status.Errorf(codes.Internal, "failed to clear auth cookies, error: %v", err) } return &emptypb.Empty{}, nil } -func (s *APIV1Service) clearAccessTokenCookie(ctx context.Context) error { - cookie, err := s.buildAccessTokenCookie(ctx, "", time.Time{}) +func (s *APIV1Service) clearAuthCookies(ctx context.Context) error { + // Clear session cookie + sessionCookie, err := s.buildSessionCookie(ctx, "", time.Time{}) if err != nil { - return errors.Wrap(err, "failed to build access token cookie") + return errors.Wrap(err, "failed to build session cookie") } + + // Set both cookies in the response if err := grpc.SetHeader(ctx, metadata.New(map[string]string{ - "Set-Cookie": cookie, + "Set-Cookie": sessionCookie, })); err != nil { return errors.Wrap(err, "failed to set grpc header") } return nil } -func (*APIV1Service) buildAccessTokenCookie(ctx context.Context, accessToken string, expireTime time.Time) (string, error) { +func (*APIV1Service) buildSessionCookie(ctx context.Context, sessionCookieValue string, expireTime time.Time) (string, error) { attrs := []string{ - fmt.Sprintf("%s=%s", AccessTokenCookieName, accessToken), + fmt.Sprintf("%s=%s", SessionCookieName, sessionCookieValue), "Path=/", "HttpOnly", } @@ -364,23 +376,189 @@ func (s *APIV1Service) trackUserSession(ctx context.Context, userID int32, sessi } // Helper function to extract client information from the gRPC context. -func (*APIV1Service) extractClientInfo(ctx context.Context) *storepb.SessionsUserSetting_ClientInfo { +// extractClientInfo extracts comprehensive client information from the request context. +// This includes user agent parsing to determine device type, operating system, browser, +// and IP address extraction. This information is used to provide detailed session +// tracking and management capabilities in the web UI. +// +// Fields populated: +// - UserAgent: Raw user agent string +// - IpAddress: Client IP (from X-Forwarded-For or X-Real-IP headers) +// - DeviceType: "mobile", "tablet", or "desktop" +// - Os: Operating system name and version (e.g., "iOS 17.1", "Windows 10/11") +// - Browser: Browser name and version (e.g., "Chrome 120.0.0.0") +// - Country: Geographic location (TODO: implement with GeoIP service) +func (s *APIV1Service) extractClientInfo(ctx context.Context) *storepb.SessionsUserSetting_ClientInfo { clientInfo := &storepb.SessionsUserSetting_ClientInfo{} // Extract user agent from metadata if available if md, ok := metadata.FromIncomingContext(ctx); ok { if userAgents := md.Get("user-agent"); len(userAgents) > 0 { - clientInfo.UserAgent = userAgents[0] + userAgent := userAgents[0] + clientInfo.UserAgent = userAgent + + // Parse user agent to extract device type, OS, browser info + s.parseUserAgent(userAgent, clientInfo) } if forwardedFor := md.Get("x-forwarded-for"); len(forwardedFor) > 0 { - clientInfo.IpAddress = forwardedFor[0] + ipAddress := strings.Split(forwardedFor[0], ",")[0] // Get the first IP in case of multiple + ipAddress = strings.TrimSpace(ipAddress) + clientInfo.IpAddress = ipAddress } else if realIP := md.Get("x-real-ip"); len(realIP) > 0 { clientInfo.IpAddress = realIP[0] } } - // TODO: Parse user agent to extract device type, OS, browser info - // This could be done using a user agent parsing library - return clientInfo } + +// parseUserAgent extracts device type, OS, and browser information from user agent string +func (s *APIV1Service) parseUserAgent(userAgent string, clientInfo *storepb.SessionsUserSetting_ClientInfo) { + if userAgent == "" { + return + } + + userAgent = strings.ToLower(userAgent) + + // Detect device type + if strings.Contains(userAgent, "ipad") { + clientInfo.DeviceType = "tablet" + } else if strings.Contains(userAgent, "mobile") || strings.Contains(userAgent, "android") || + strings.Contains(userAgent, "iphone") || strings.Contains(userAgent, "ipod") || + strings.Contains(userAgent, "windows phone") || strings.Contains(userAgent, "blackberry") { + clientInfo.DeviceType = "mobile" + } else if strings.Contains(userAgent, "tablet") { + clientInfo.DeviceType = "tablet" + } else { + clientInfo.DeviceType = "desktop" + } + + // Detect operating system + if strings.Contains(userAgent, "iphone os") || strings.Contains(userAgent, "cpu os") { + // Extract iOS version + if idx := strings.Index(userAgent, "cpu os "); idx != -1 { + versionStart := idx + 7 + versionEnd := strings.Index(userAgent[versionStart:], " ") + if versionEnd != -1 { + version := strings.Replace(userAgent[versionStart:versionStart+versionEnd], "_", ".", -1) + clientInfo.Os = "iOS " + version + } else { + clientInfo.Os = "iOS" + } + } else if idx := strings.Index(userAgent, "iphone os "); idx != -1 { + versionStart := idx + 10 + versionEnd := strings.Index(userAgent[versionStart:], " ") + if versionEnd != -1 { + version := strings.Replace(userAgent[versionStart:versionStart+versionEnd], "_", ".", -1) + clientInfo.Os = "iOS " + version + } else { + clientInfo.Os = "iOS" + } + } else { + clientInfo.Os = "iOS" + } + } else if strings.Contains(userAgent, "android") { + // Extract Android version + if idx := strings.Index(userAgent, "android "); idx != -1 { + versionStart := idx + 8 + versionEnd := strings.Index(userAgent[versionStart:], ";") + if versionEnd == -1 { + versionEnd = strings.Index(userAgent[versionStart:], ")") + } + if versionEnd != -1 { + version := userAgent[versionStart : versionStart+versionEnd] + clientInfo.Os = "Android " + version + } else { + clientInfo.Os = "Android" + } + } else { + clientInfo.Os = "Android" + } + } else if strings.Contains(userAgent, "windows nt 10.0") { + clientInfo.Os = "Windows 10/11" + } else if strings.Contains(userAgent, "windows nt 6.3") { + clientInfo.Os = "Windows 8.1" + } else if strings.Contains(userAgent, "windows nt 6.1") { + clientInfo.Os = "Windows 7" + } else if strings.Contains(userAgent, "windows") { + clientInfo.Os = "Windows" + } else if strings.Contains(userAgent, "mac os x") { + // Extract macOS version + if idx := strings.Index(userAgent, "mac os x "); idx != -1 { + versionStart := idx + 9 + versionEnd := strings.Index(userAgent[versionStart:], ";") + if versionEnd == -1 { + versionEnd = strings.Index(userAgent[versionStart:], ")") + } + if versionEnd != -1 { + version := strings.Replace(userAgent[versionStart:versionStart+versionEnd], "_", ".", -1) + clientInfo.Os = "macOS " + version + } else { + clientInfo.Os = "macOS" + } + } else { + clientInfo.Os = "macOS" + } + } else if strings.Contains(userAgent, "linux") { + clientInfo.Os = "Linux" + } else if strings.Contains(userAgent, "cros") { + clientInfo.Os = "Chrome OS" + } + + // Detect browser + if strings.Contains(userAgent, "edg/") { + // Extract Edge version + if idx := strings.Index(userAgent, "edg/"); idx != -1 { + versionStart := idx + 4 + versionEnd := strings.Index(userAgent[versionStart:], " ") + if versionEnd == -1 { + versionEnd = len(userAgent) - versionStart + } + version := userAgent[versionStart : versionStart+versionEnd] + clientInfo.Browser = "Edge " + version + } else { + clientInfo.Browser = "Edge" + } + } else if strings.Contains(userAgent, "chrome/") && !strings.Contains(userAgent, "edg") { + // Extract Chrome version + if idx := strings.Index(userAgent, "chrome/"); idx != -1 { + versionStart := idx + 7 + versionEnd := strings.Index(userAgent[versionStart:], " ") + if versionEnd == -1 { + versionEnd = len(userAgent) - versionStart + } + version := userAgent[versionStart : versionStart+versionEnd] + clientInfo.Browser = "Chrome " + version + } else { + clientInfo.Browser = "Chrome" + } + } else if strings.Contains(userAgent, "firefox/") { + // Extract Firefox version + if idx := strings.Index(userAgent, "firefox/"); idx != -1 { + versionStart := idx + 8 + versionEnd := strings.Index(userAgent[versionStart:], " ") + if versionEnd == -1 { + versionEnd = len(userAgent) - versionStart + } + version := userAgent[versionStart : versionStart+versionEnd] + clientInfo.Browser = "Firefox " + version + } else { + clientInfo.Browser = "Firefox" + } + } else if strings.Contains(userAgent, "safari/") && !strings.Contains(userAgent, "chrome") && !strings.Contains(userAgent, "edg") { + // Extract Safari version + if idx := strings.Index(userAgent, "version/"); idx != -1 { + versionStart := idx + 8 + versionEnd := strings.Index(userAgent[versionStart:], " ") + if versionEnd == -1 { + versionEnd = len(userAgent) - versionStart + } + version := userAgent[versionStart : versionStart+versionEnd] + clientInfo.Browser = "Safari " + version + } else { + clientInfo.Browser = "Safari" + } + } else if strings.Contains(userAgent, "opera/") || strings.Contains(userAgent, "opr/") { + clientInfo.Browser = "Opera" + } +} diff --git a/server/router/api/v1/auth_service_client_info_test.go b/server/router/api/v1/auth_service_client_info_test.go new file mode 100644 index 000000000..7f754733f --- /dev/null +++ b/server/router/api/v1/auth_service_client_info_test.go @@ -0,0 +1,179 @@ +package v1 + +import ( + "context" + "testing" + + "google.golang.org/grpc/metadata" + + storepb "github.com/usememos/memos/proto/gen/store" +) + +func TestParseUserAgent(t *testing.T) { + service := &APIV1Service{} + + tests := []struct { + name string + userAgent string + expectedDevice string + expectedOS string + expectedBrowser string + }{ + { + name: "Chrome on Windows", + userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36", + expectedDevice: "desktop", + expectedOS: "Windows 10/11", + expectedBrowser: "Chrome 119.0.0.0", + }, + { + name: "Safari on macOS", + userAgent: "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.0 Safari/605.1.15", + expectedDevice: "desktop", + expectedOS: "macOS 10.15.7", + expectedBrowser: "Safari 17.0", + }, + { + name: "Chrome on Android Mobile", + userAgent: "Mozilla/5.0 (Linux; Android 13; SM-G998B) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Mobile Safari/537.36", + expectedDevice: "mobile", + expectedOS: "Android 13", + expectedBrowser: "Chrome 119.0.0.0", + }, + { + name: "Safari on iPhone", + userAgent: "Mozilla/5.0 (iPhone; CPU iPhone OS 17_0 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.0 Mobile/15E148 Safari/604.1", + expectedDevice: "mobile", + expectedOS: "iOS 17.0", + expectedBrowser: "Safari 17.0", + }, + { + name: "Firefox on Windows", + userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/119.0", + expectedDevice: "desktop", + expectedOS: "Windows 10/11", + expectedBrowser: "Firefox 119.0", + }, + { + name: "Edge on Windows", + userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0", + expectedDevice: "desktop", + expectedOS: "Windows 10/11", + expectedBrowser: "Edge 119.0.0.0", + }, + { + name: "iPad Safari", + userAgent: "Mozilla/5.0 (iPad; CPU OS 17_0 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.0 Mobile/15E148 Safari/604.1", + expectedDevice: "tablet", + expectedOS: "iOS 17.0", + expectedBrowser: "Safari 17.0", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + clientInfo := &storepb.SessionsUserSetting_ClientInfo{} + service.parseUserAgent(tt.userAgent, clientInfo) + + if clientInfo.DeviceType != tt.expectedDevice { + t.Errorf("Expected device type %s, got %s", tt.expectedDevice, clientInfo.DeviceType) + } + if clientInfo.Os != tt.expectedOS { + t.Errorf("Expected OS %s, got %s", tt.expectedOS, clientInfo.Os) + } + if clientInfo.Browser != tt.expectedBrowser { + t.Errorf("Expected browser %s, got %s", tt.expectedBrowser, clientInfo.Browser) + } + }) + } +} + +func TestExtractClientInfo(t *testing.T) { + service := &APIV1Service{} + + // Test with metadata containing user agent and IP + md := metadata.New(map[string]string{ + "user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36", + "x-forwarded-for": "203.0.113.1, 198.51.100.1", + "x-real-ip": "203.0.113.1", + }) + + ctx := metadata.NewIncomingContext(context.Background(), md) + + clientInfo := service.extractClientInfo(ctx) + + if clientInfo.UserAgent == "" { + t.Error("Expected user agent to be set") + } + if clientInfo.IpAddress != "203.0.113.1" { + t.Errorf("Expected IP address to be 203.0.113.1, got %s", clientInfo.IpAddress) + } + if clientInfo.DeviceType != "desktop" { + t.Errorf("Expected device type to be desktop, got %s", clientInfo.DeviceType) + } + if clientInfo.Os != "Windows 10/11" { + t.Errorf("Expected OS to be Windows 10/11, got %s", clientInfo.Os) + } + if clientInfo.Browser != "Chrome 119.0.0.0" { + t.Errorf("Expected browser to be Chrome 119.0.0.0, got %s", clientInfo.Browser) + } +} + +// TestClientInfoExamples demonstrates the enhanced client info extraction with various user agents +func TestClientInfoExamples(t *testing.T) { + service := &APIV1Service{} + + examples := []struct { + description string + userAgent string + }{ + { + description: "Modern Chrome on Windows 11", + userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", + }, + { + description: "Safari on iPhone 15 Pro", + userAgent: "Mozilla/5.0 (iPhone; CPU iPhone OS 17_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.1 Mobile/15E148 Safari/604.1", + }, + { + description: "Chrome on Samsung Galaxy", + userAgent: "Mozilla/5.0 (Linux; Android 14; SM-S918B) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Mobile Safari/537.36", + }, + { + description: "Firefox on Ubuntu", + userAgent: "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/120.0", + }, + { + description: "Edge on Windows 10", + userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0", + }, + { + description: "Safari on iPad Air", + userAgent: "Mozilla/5.0 (iPad; CPU OS 17_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.1 Mobile/15E148 Safari/604.1", + }, + } + + for _, example := range examples { + t.Run(example.description, func(t *testing.T) { + clientInfo := &storepb.SessionsUserSetting_ClientInfo{} + service.parseUserAgent(example.userAgent, clientInfo) + + t.Logf("User Agent: %s", example.userAgent) + t.Logf("Device Type: %s", clientInfo.DeviceType) + t.Logf("Operating System: %s", clientInfo.Os) + t.Logf("Browser: %s", clientInfo.Browser) + t.Logf("---") + + // Ensure all fields are populated + if clientInfo.DeviceType == "" { + t.Error("Device type should not be empty") + } + if clientInfo.Os == "" { + t.Error("OS should not be empty") + } + if clientInfo.Browser == "" { + t.Error("Browser should not be empty") + } + }) + } +} diff --git a/server/router/api/v1/user_service.go b/server/router/api/v1/user_service.go index c52f80b8e..8e2c489e8 100644 --- a/server/router/api/v1/user_service.go +++ b/server/router/api/v1/user_service.go @@ -627,7 +627,6 @@ func (s *APIV1Service) ListUserSessions(ctx context.Context, request *v1pb.ListU DeviceType: userSession.ClientInfo.DeviceType, Os: userSession.ClientInfo.Os, Browser: userSession.ClientInfo.Browser, - Country: userSession.ClientInfo.Country, } } diff --git a/web/src/components/Settings/MyAccountSection.tsx b/web/src/components/Settings/MyAccountSection.tsx index 2b48e1055..1ce983fec 100644 --- a/web/src/components/Settings/MyAccountSection.tsx +++ b/web/src/components/Settings/MyAccountSection.tsx @@ -7,6 +7,7 @@ import showUpdateAccountDialog from "../UpdateAccountDialog"; import UserAvatar from "../UserAvatar"; import { Popover, PopoverContent, PopoverTrigger } from "../ui/Popover"; import AccessTokenSection from "./AccessTokenSection"; +import UserSessionsSection from "./UserSessionsSection"; const MyAccountSection = () => { const t = useTranslate(); @@ -48,6 +49,7 @@ const MyAccountSection = () => { + ); }; diff --git a/web/src/components/Settings/UserSessionsSection.tsx b/web/src/components/Settings/UserSessionsSection.tsx new file mode 100644 index 000000000..e95fb2a16 --- /dev/null +++ b/web/src/components/Settings/UserSessionsSection.tsx @@ -0,0 +1,177 @@ +import { Button } from "@usememos/mui"; +import { ClockIcon, MonitorIcon, SmartphoneIcon, TabletIcon, TrashIcon, WifiIcon } from "lucide-react"; +import { useEffect, useState } from "react"; +import { toast } from "react-hot-toast"; +import { userServiceClient } from "@/grpcweb"; +import useCurrentUser from "@/hooks/useCurrentUser"; +import { UserSession } from "@/types/proto/api/v1/user_service"; +import { useTranslate } from "@/utils/i18n"; +import LearnMore from "../LearnMore"; + +const listUserSessions = async (parent: string) => { + const { sessions } = await userServiceClient.listUserSessions({ parent }); + return sessions.sort((a, b) => (b.lastAccessedTime?.getTime() ?? 0) - (a.lastAccessedTime?.getTime() ?? 0)); +}; + +const UserSessionsSection = () => { + const t = useTranslate(); + const currentUser = useCurrentUser(); + const [userSessions, setUserSessions] = useState([]); + + useEffect(() => { + listUserSessions(currentUser.name).then((sessions) => { + setUserSessions(sessions); + }); + }, []); + + const handleRevokeSession = async (userSession: UserSession) => { + const formattedSessionId = getFormattedSessionId(userSession.sessionId); + const confirmed = window.confirm(t("setting.user-sessions-section.session-revocation", { sessionId: formattedSessionId })); + if (confirmed) { + await userServiceClient.revokeUserSession({ name: userSession.name }); + setUserSessions(userSessions.filter((session) => session.sessionId !== userSession.sessionId)); + toast.success(t("setting.user-sessions-section.session-revoked")); + } + }; + + const getFormattedSessionId = (sessionId: string) => { + return `${sessionId.slice(0, 8)}...${sessionId.slice(-8)}`; + }; + + const getDeviceIcon = (deviceType: string) => { + switch (deviceType?.toLowerCase()) { + case "mobile": + return ; + case "tablet": + return ; + case "desktop": + default: + return ; + } + }; + + const formatLocation = (clientInfo: UserSession["clientInfo"]) => { + if (!clientInfo) return "Unknown"; + + const parts = []; + if (clientInfo.ipAddress) parts.push(clientInfo.ipAddress); + + return parts.length > 0 ? parts.join(" • ") : "Unknown"; + }; + + const formatDeviceInfo = (clientInfo: UserSession["clientInfo"]) => { + if (!clientInfo) return "Unknown Device"; + + const parts = []; + if (clientInfo.os) parts.push(clientInfo.os); + if (clientInfo.browser) parts.push(clientInfo.browser); + + return parts.length > 0 ? parts.join(" • ") : "Unknown Device"; + }; + + const isCurrentSession = (session: UserSession) => { + // A simple heuristic: the most recently accessed session is likely the current one + if (userSessions.length === 0) return false; + const mostRecent = userSessions[0]; + return session.sessionId === mostRecent.sessionId; + }; + + return ( +
+
+
+
+

+ {t("setting.user-sessions-section.title")} + +

+

{t("setting.user-sessions-section.description")}

+
+
+
+
+
+ + + + + + + + + + + + {userSessions.map((userSession) => ( + + + + + + + + ))} + +
+ {t("setting.user-sessions-section.device")} + + {t("setting.user-sessions-section.location")} + + {t("setting.user-sessions-section.last-active")} + + {t("setting.user-sessions-section.expires")} + + {t("common.delete")} +
+
+ {getDeviceIcon(userSession.clientInfo?.deviceType || "")} +
+ + {formatDeviceInfo(userSession.clientInfo)} + {isCurrentSession(userSession) && ( + + + {t("setting.user-sessions-section.current")} + + )} + + {getFormattedSessionId(userSession.sessionId)} +
+
+
+ {formatLocation(userSession.clientInfo)} + +
+ + {userSession.lastAccessedTime?.toLocaleString()} +
+
+ {userSession.expireTime?.toLocaleString() ?? t("setting.user-sessions-section.never")} + + +
+ {userSessions.length === 0 && ( +
{t("setting.user-sessions-section.no-sessions")}
+ )} +
+
+
+
+
+ ); +}; + +export default UserSessionsSection; diff --git a/web/src/locales/en.json b/web/src/locales/en.json index f02c89ca6..c0868605b 100644 --- a/web/src/locales/en.json +++ b/web/src/locales/en.json @@ -251,6 +251,21 @@ "title": "Access Tokens", "token": "Token" }, + "user-sessions-section": { + "title": "Active Sessions", + "description": "A list of all active sessions for your account. You can revoke any session except the current one.", + "device": "Device", + "location": "Location", + "last-active": "Last Active", + "expires": "Expires", + "current": "Current", + "never": "Never", + "session-revocation": "Are you sure to revoke session {{sessionId}}? You will need to sign in again on that device.", + "session-revoked": "Session revoked successfully", + "revoke-session": "Revoke session", + "cannot-revoke-current": "Cannot revoke current session", + "no-sessions": "No active sessions found" + }, "account-section": { "change-password": "Change password", "email-note": "Optional", diff --git a/web/src/types/proto/api/v1/user_service.ts b/web/src/types/proto/api/v1/user_service.ts index e899cb6f4..fe9b0d5bd 100644 --- a/web/src/types/proto/api/v1/user_service.ts +++ b/web/src/types/proto/api/v1/user_service.ts @@ -394,8 +394,6 @@ export interface UserSession_ClientInfo { os: string; /** Optional. Browser name and version (e.g., "Chrome 119.0"). */ browser: string; - /** Optional. Geographic location (country code, e.g., "US"). */ - country: string; } export interface ListUserSessionsRequest { @@ -2222,7 +2220,7 @@ export const UserSession: MessageFns = { }; function createBaseUserSession_ClientInfo(): UserSession_ClientInfo { - return { userAgent: "", ipAddress: "", deviceType: "", os: "", browser: "", country: "" }; + return { userAgent: "", ipAddress: "", deviceType: "", os: "", browser: "" }; } export const UserSession_ClientInfo: MessageFns = { @@ -2242,9 +2240,6 @@ export const UserSession_ClientInfo: MessageFns = { if (message.browser !== "") { writer.uint32(42).string(message.browser); } - if (message.country !== "") { - writer.uint32(50).string(message.country); - } return writer; }, @@ -2295,14 +2290,6 @@ export const UserSession_ClientInfo: MessageFns = { message.browser = reader.string(); continue; } - case 6: { - if (tag !== 50) { - break; - } - - message.country = reader.string(); - continue; - } } if ((tag & 7) === 4 || tag === 0) { break; @@ -2322,7 +2309,6 @@ export const UserSession_ClientInfo: MessageFns = { message.deviceType = object.deviceType ?? ""; message.os = object.os ?? ""; message.browser = object.browser ?? ""; - message.country = object.country ?? ""; return message; }, };