diff --git a/server/router/api/v1/logger_interceptor.go b/server/router/api/v1/logger_interceptor.go index b5dd2eac3..c8f97d6e8 100644 --- a/server/router/api/v1/logger_interceptor.go +++ b/server/router/api/v1/logger_interceptor.go @@ -2,6 +2,7 @@ package v1 import ( "context" + "fmt" "log/slog" "google.golang.org/grpc" @@ -10,10 +11,11 @@ import ( ) type LoggerInterceptor struct { + logStacktrace bool } -func NewLoggerInterceptor() *LoggerInterceptor { - return &LoggerInterceptor{} +func NewLoggerInterceptor(logStacktrace bool) *LoggerInterceptor { + return &LoggerInterceptor{logStacktrace: logStacktrace} } func (in *LoggerInterceptor) LoggerInterceptor(ctx context.Context, request any, serverInfo *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { @@ -22,7 +24,7 @@ func (in *LoggerInterceptor) LoggerInterceptor(ctx context.Context, request any, return resp, err } -func (*LoggerInterceptor) loggerInterceptorDo(ctx context.Context, fullMethod string, err error) { +func (in *LoggerInterceptor) loggerInterceptorDo(ctx context.Context, fullMethod string, err error) { st := status.Convert(err) var logLevel slog.Level var logMsg string @@ -43,6 +45,9 @@ func (*LoggerInterceptor) loggerInterceptorDo(ctx context.Context, fullMethod st logAttrs := []slog.Attr{slog.String("method", fullMethod)} if err != nil { logAttrs = append(logAttrs, slog.String("error", err.Error())) + if in.logStacktrace { + logAttrs = append(logAttrs, slog.String("stacktrace", fmt.Sprintf("%v", err))) + } } slog.LogAttrs(ctx, logLevel, logMsg, logAttrs...) } diff --git a/server/server.go b/server/server.go index c15c8c161..d4e58bd82 100644 --- a/server/server.go +++ b/server/server.go @@ -8,6 +8,7 @@ import ( "net" "net/http" "runtime" + "runtime/debug" "time" "github.com/google/uuid" @@ -83,12 +84,15 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store // Create and register RSS routes. rss.NewRSSService(s.Profile, s.Store).RegisterRoutes(rootGroup) + // Log full stacktraces if we're in dev + logStacktraces := profile.IsDev() + grpcServer := grpc.NewServer( // Override the maximum receiving message size to math.MaxInt32 for uploading large attachments. grpc.MaxRecvMsgSize(math.MaxInt32), grpc.ChainUnaryInterceptor( - apiv1.NewLoggerInterceptor().LoggerInterceptor, - grpcrecovery.UnaryServerInterceptor(), + apiv1.NewLoggerInterceptor(logStacktraces).LoggerInterceptor, + newRecoveryInterceptor(logStacktraces), apiv1.NewGRPCAuthInterceptor(store, secret).AuthenticationInterceptor, )) s.grpcServer = grpcServer @@ -102,6 +106,26 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store return s, nil } +func newRecoveryInterceptor(logStacktraces bool) grpc.UnaryServerInterceptor { + var recoveryOptions []grpcrecovery.Option + if logStacktraces { + recoveryOptions = append(recoveryOptions, grpcrecovery.WithRecoveryHandler(func(p any) error { + if p == nil { + return nil + } + + switch val := p.(type) { + case runtime.Error: + return &stacktraceError{err: val, stacktrace: debug.Stack()} + default: + return nil + } + })) + } + + return grpcrecovery.UnaryServerInterceptor(recoveryOptions...) +} + func (s *Server) Start(ctx context.Context) error { var address, network string if len(s.Profile.UNIXSock) == 0 { @@ -227,3 +251,23 @@ func (s *Server) getOrUpsertWorkspaceBasicSetting(ctx context.Context) (*storepb } return workspaceBasicSetting, nil } + +// stacktraceError wraps an underlying error and captures the stacktrace. It +// implements fmt.Formatter, so it'll be rendered when invoked by something like +// `fmt.Sprint("%v", err)`. +type stacktraceError struct { + err error + stacktrace []byte +} + +func (e *stacktraceError) Error() string { + return e.err.Error() +} + +func (e *stacktraceError) Unwrap() error { + return e.err +} + +func (e *stacktraceError) Format(f fmt.State, _ rune) { + f.Write(e.stacktrace) +}