From e23ade1f8b34922512b388916f0693a5957582bf Mon Sep 17 00:00:00 2001 From: Sergey Gorbunov Date: Wed, 7 May 2025 17:12:05 +0300 Subject: [PATCH] feat: support listening on a UNIX socket (#4654) --- bin/memos/main.go | 18 ++++++++++++++---- server/profile/profile.go | 2 ++ server/router/api/v1/v1.go | 8 +++++++- server/server.go | 11 +++++++++-- 4 files changed, 32 insertions(+), 7 deletions(-) diff --git a/bin/memos/main.go b/bin/memos/main.go index 6450a73ad..37e621e08 100644 --- a/bin/memos/main.go +++ b/bin/memos/main.go @@ -39,6 +39,7 @@ var ( Mode: viper.GetString("mode"), Addr: viper.GetString("addr"), Port: viper.GetInt("port"), + UNIXSock: viper.GetString("unix-sock"), Data: viper.GetString("data"), Driver: viper.GetString("driver"), DSN: viper.GetString("dsn"), @@ -106,6 +107,7 @@ func init() { rootCmd.PersistentFlags().String("mode", "dev", `mode of server, can be "prod" or "dev" or "demo"`) rootCmd.PersistentFlags().String("addr", "", "address of server") rootCmd.PersistentFlags().Int("port", 8081, "port of server") + rootCmd.PersistentFlags().String("unix-sock", "", "path to the unix socket, overrides --addr and --port") rootCmd.PersistentFlags().String("data", "", "data directory") rootCmd.PersistentFlags().String("driver", "sqlite", "database driver") rootCmd.PersistentFlags().String("dsn", "", "database source name(aka. DSN)") @@ -120,6 +122,9 @@ func init() { if err := viper.BindPFlag("port", rootCmd.PersistentFlags().Lookup("port")); err != nil { panic(err) } + if err := viper.BindPFlag("unix-sock", rootCmd.PersistentFlags().Lookup("unix-sock")); err != nil { + panic(err) + } if err := viper.BindPFlag("data", rootCmd.PersistentFlags().Lookup("data")); err != nil { panic(err) } @@ -151,16 +156,21 @@ version: %s data: %s addr: %s port: %d +unix-sock: %s mode: %s driver: %s --- -`, profile.Version, profile.Data, profile.Addr, profile.Port, profile.Mode, profile.Driver) +`, profile.Version, profile.Data, profile.Addr, profile.Port, profile.UNIXSock, profile.Mode, profile.Driver) print(greetingBanner) - if len(profile.Addr) == 0 { - fmt.Printf("Version %s has been started on port %d\n", profile.Version, profile.Port) + if len(profile.UNIXSock) == 0 { + if len(profile.Addr) == 0 { + fmt.Printf("Version %s has been started on port %d\n", profile.Version, profile.Port) + } else { + fmt.Printf("Version %s has been started on address '%s' and port %d\n", profile.Version, profile.Addr, profile.Port) + } } else { - fmt.Printf("Version %s has been started on address '%s' and port %d\n", profile.Version, profile.Addr, profile.Port) + fmt.Printf("Version %s has been started on unix socket %s\n", profile.Version, profile.UNIXSock) } fmt.Printf(`--- See more in: diff --git a/server/profile/profile.go b/server/profile/profile.go index e6bbaf28b..8d551d669 100644 --- a/server/profile/profile.go +++ b/server/profile/profile.go @@ -19,6 +19,8 @@ type Profile struct { Addr string // Port is the binding port for server Port int + // UNIXSock is the IPC binding path. Overrides Addr and Port + UNIXSock string // Data is the data directory Data string // DSN points to where memos stores its own data diff --git a/server/router/api/v1/v1.go b/server/router/api/v1/v1.go index 901d141ec..fcbbd5d0f 100644 --- a/server/router/api/v1/v1.go +++ b/server/router/api/v1/v1.go @@ -67,8 +67,14 @@ func NewAPIV1Service(secret string, profile *profile.Profile, store *store.Store // RegisterGateway registers the gRPC-Gateway with the given Echo instance. func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Echo) error { + var target string + if len(s.Profile.UNIXSock) == 0 { + target = fmt.Sprintf("%s:%d", s.Profile.Addr, s.Profile.Port) + } else { + target = fmt.Sprintf("unix:%s", s.Profile.UNIXSock) + } conn, err := grpc.NewClient( - fmt.Sprintf("%s:%d", s.Profile.Addr, s.Profile.Port), + target, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(math.MaxInt32)), ) diff --git a/server/server.go b/server/server.go index 69c3e8ef3..fd5834ca7 100644 --- a/server/server.go +++ b/server/server.go @@ -93,8 +93,15 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store } func (s *Server) Start(ctx context.Context) error { - address := fmt.Sprintf("%s:%d", s.Profile.Addr, s.Profile.Port) - listener, err := net.Listen("tcp", address) + var address, network string + if len(s.Profile.UNIXSock) == 0 { + address = fmt.Sprintf("%s:%d", s.Profile.Addr, s.Profile.Port) + network = "tcp" + } else { + address = s.Profile.UNIXSock + network = "unix" + } + listener, err := net.Listen(network, address) if err != nil { return errors.Wrap(err, "failed to listen") }