diff --git a/proto/api/v1/auth_service.proto b/proto/api/v1/auth_service.proto index 7c814e7bb..1d81ed12e 100644 --- a/proto/api/v1/auth_service.proto +++ b/proto/api/v1/auth_service.proto @@ -13,14 +13,10 @@ service AuthService { rpc GetAuthStatus(GetAuthStatusRequest) returns (User) { option (google.api.http) = {post: "/api/v1/auth/status"}; } - // SignIn signs in the user with the given username and password. + // SignIn signs in the user. rpc SignIn(SignInRequest) returns (User) { option (google.api.http) = {post: "/api/v1/auth/signin"}; } - // SignInWithSSO signs in the user with the given SSO code. - rpc SignInWithSSO(SignInWithSSORequest) returns (User) { - option (google.api.http) = {post: "/api/v1/auth/signin/sso"}; - } // SignUp signs up the user with the given username and password. rpc SignUp(SignUpRequest) returns (User) { option (google.api.http) = {post: "/api/v1/auth/signup"}; @@ -38,15 +34,26 @@ message GetAuthStatusResponse { } message SignInRequest { + // Provide one authentication method (username/password or SSO). + oneof method { + // Username and password authentication method. + PasswordCredentials password_credentials = 1; + + // SSO provider authentication method. + SSOCredentials sso_credentials = 2; + } + // Whether the session should never expire. + bool never_expire = 3; +} + +message PasswordCredentials { // The username to sign in with. string username = 1; // The password to sign in with. string password = 2; - // Whether the session should never expire. - bool never_expire = 3; } -message SignInWithSSORequest { +message SSOCredentials { // The ID of the SSO provider. int32 idp_id = 1; // The code to sign in with. diff --git a/proto/gen/api/v1/auth_service.pb.go b/proto/gen/api/v1/auth_service.pb.go index 5dab9f08d..cd08f0bf3 100644 --- a/proto/gen/api/v1/auth_service.pb.go +++ b/proto/gen/api/v1/auth_service.pb.go @@ -105,10 +105,13 @@ func (x *GetAuthStatusResponse) GetUser() *User { type SignInRequest struct { state protoimpl.MessageState `protogen:"open.v1"` - // The username to sign in with. - Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"` - // The password to sign in with. - Password string `protobuf:"bytes,2,opt,name=password,proto3" json:"password,omitempty"` + // Provide one authentication method (username/password or SSO). + // + // Types that are valid to be assigned to Method: + // + // *SignInRequest_PasswordCredentials + // *SignInRequest_SsoCredentials + Method isSignInRequest_Method `protobuf_oneof:"method"` // Whether the session should never expire. NeverExpire bool `protobuf:"varint,3,opt,name=never_expire,json=neverExpire,proto3" json:"never_expire,omitempty"` unknownFields protoimpl.UnknownFields @@ -145,18 +148,29 @@ func (*SignInRequest) Descriptor() ([]byte, []int) { return file_api_v1_auth_service_proto_rawDescGZIP(), []int{2} } -func (x *SignInRequest) GetUsername() string { +func (x *SignInRequest) GetMethod() isSignInRequest_Method { if x != nil { - return x.Username + return x.Method } - return "" + return nil } -func (x *SignInRequest) GetPassword() string { +func (x *SignInRequest) GetPasswordCredentials() *PasswordCredentials { if x != nil { - return x.Password + if x, ok := x.Method.(*SignInRequest_PasswordCredentials); ok { + return x.PasswordCredentials + } } - return "" + return nil +} + +func (x *SignInRequest) GetSsoCredentials() *SSOCredentials { + if x != nil { + if x, ok := x.Method.(*SignInRequest_SsoCredentials); ok { + return x.SsoCredentials + } + } + return nil } func (x *SignInRequest) GetNeverExpire() bool { @@ -166,7 +180,79 @@ func (x *SignInRequest) GetNeverExpire() bool { return false } -type SignInWithSSORequest struct { +type isSignInRequest_Method interface { + isSignInRequest_Method() +} + +type SignInRequest_PasswordCredentials struct { + // Username and password authentication method. + PasswordCredentials *PasswordCredentials `protobuf:"bytes,1,opt,name=password_credentials,json=passwordCredentials,proto3,oneof"` +} + +type SignInRequest_SsoCredentials struct { + // SSO provider authentication method. + SsoCredentials *SSOCredentials `protobuf:"bytes,2,opt,name=sso_credentials,json=ssoCredentials,proto3,oneof"` +} + +func (*SignInRequest_PasswordCredentials) isSignInRequest_Method() {} + +func (*SignInRequest_SsoCredentials) isSignInRequest_Method() {} + +type PasswordCredentials struct { + state protoimpl.MessageState `protogen:"open.v1"` + // The username to sign in with. + Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"` + // The password to sign in with. + Password string `protobuf:"bytes,2,opt,name=password,proto3" json:"password,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *PasswordCredentials) Reset() { + *x = PasswordCredentials{} + mi := &file_api_v1_auth_service_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *PasswordCredentials) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PasswordCredentials) ProtoMessage() {} + +func (x *PasswordCredentials) ProtoReflect() protoreflect.Message { + mi := &file_api_v1_auth_service_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PasswordCredentials.ProtoReflect.Descriptor instead. +func (*PasswordCredentials) Descriptor() ([]byte, []int) { + return file_api_v1_auth_service_proto_rawDescGZIP(), []int{3} +} + +func (x *PasswordCredentials) GetUsername() string { + if x != nil { + return x.Username + } + return "" +} + +func (x *PasswordCredentials) GetPassword() string { + if x != nil { + return x.Password + } + return "" +} + +type SSOCredentials struct { state protoimpl.MessageState `protogen:"open.v1"` // The ID of the SSO provider. IdpId int32 `protobuf:"varint,1,opt,name=idp_id,json=idpId,proto3" json:"idp_id,omitempty"` @@ -178,21 +264,21 @@ type SignInWithSSORequest struct { sizeCache protoimpl.SizeCache } -func (x *SignInWithSSORequest) Reset() { - *x = SignInWithSSORequest{} - mi := &file_api_v1_auth_service_proto_msgTypes[3] +func (x *SSOCredentials) Reset() { + *x = SSOCredentials{} + mi := &file_api_v1_auth_service_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } -func (x *SignInWithSSORequest) String() string { +func (x *SSOCredentials) String() string { return protoimpl.X.MessageStringOf(x) } -func (*SignInWithSSORequest) ProtoMessage() {} +func (*SSOCredentials) ProtoMessage() {} -func (x *SignInWithSSORequest) ProtoReflect() protoreflect.Message { - mi := &file_api_v1_auth_service_proto_msgTypes[3] +func (x *SSOCredentials) ProtoReflect() protoreflect.Message { + mi := &file_api_v1_auth_service_proto_msgTypes[4] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -203,26 +289,26 @@ func (x *SignInWithSSORequest) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use SignInWithSSORequest.ProtoReflect.Descriptor instead. -func (*SignInWithSSORequest) Descriptor() ([]byte, []int) { - return file_api_v1_auth_service_proto_rawDescGZIP(), []int{3} +// Deprecated: Use SSOCredentials.ProtoReflect.Descriptor instead. +func (*SSOCredentials) Descriptor() ([]byte, []int) { + return file_api_v1_auth_service_proto_rawDescGZIP(), []int{4} } -func (x *SignInWithSSORequest) GetIdpId() int32 { +func (x *SSOCredentials) GetIdpId() int32 { if x != nil { return x.IdpId } return 0 } -func (x *SignInWithSSORequest) GetCode() string { +func (x *SSOCredentials) GetCode() string { if x != nil { return x.Code } return "" } -func (x *SignInWithSSORequest) GetRedirectUri() string { +func (x *SSOCredentials) GetRedirectUri() string { if x != nil { return x.RedirectUri } @@ -241,7 +327,7 @@ type SignUpRequest struct { func (x *SignUpRequest) Reset() { *x = SignUpRequest{} - mi := &file_api_v1_auth_service_proto_msgTypes[4] + mi := &file_api_v1_auth_service_proto_msgTypes[5] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -253,7 +339,7 @@ func (x *SignUpRequest) String() string { func (*SignUpRequest) ProtoMessage() {} func (x *SignUpRequest) ProtoReflect() protoreflect.Message { - mi := &file_api_v1_auth_service_proto_msgTypes[4] + mi := &file_api_v1_auth_service_proto_msgTypes[5] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -266,7 +352,7 @@ func (x *SignUpRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SignUpRequest.ProtoReflect.Descriptor instead. func (*SignUpRequest) Descriptor() ([]byte, []int) { - return file_api_v1_auth_service_proto_rawDescGZIP(), []int{4} + return file_api_v1_auth_service_proto_rawDescGZIP(), []int{5} } func (x *SignUpRequest) GetUsername() string { @@ -291,7 +377,7 @@ type SignOutRequest struct { func (x *SignOutRequest) Reset() { *x = SignOutRequest{} - mi := &file_api_v1_auth_service_proto_msgTypes[5] + mi := &file_api_v1_auth_service_proto_msgTypes[6] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -303,7 +389,7 @@ func (x *SignOutRequest) String() string { func (*SignOutRequest) ProtoMessage() {} func (x *SignOutRequest) ProtoReflect() protoreflect.Message { - mi := &file_api_v1_auth_service_proto_msgTypes[5] + mi := &file_api_v1_auth_service_proto_msgTypes[6] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -316,7 +402,7 @@ func (x *SignOutRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SignOutRequest.ProtoReflect.Descriptor instead. func (*SignOutRequest) Descriptor() ([]byte, []int) { - return file_api_v1_auth_service_proto_rawDescGZIP(), []int{5} + return file_api_v1_auth_service_proto_rawDescGZIP(), []int{6} } var File_api_v1_auth_service_proto protoreflect.FileDescriptor @@ -326,23 +412,26 @@ const file_api_v1_auth_service_proto_rawDesc = "" + "\x19api/v1/auth_service.proto\x12\fmemos.api.v1\x1a\x19api/v1/user_service.proto\x1a\x1cgoogle/api/annotations.proto\x1a\x1bgoogle/protobuf/empty.proto\"\x16\n" + "\x14GetAuthStatusRequest\"?\n" + "\x15GetAuthStatusResponse\x12&\n" + - "\x04user\x18\x01 \x01(\v2\x12.memos.api.v1.UserR\x04user\"j\n" + - "\rSignInRequest\x12\x1a\n" + + "\x04user\x18\x01 \x01(\v2\x12.memos.api.v1.UserR\x04user\"\xdd\x01\n" + + "\rSignInRequest\x12V\n" + + "\x14password_credentials\x18\x01 \x01(\v2!.memos.api.v1.PasswordCredentialsH\x00R\x13passwordCredentials\x12G\n" + + "\x0fsso_credentials\x18\x02 \x01(\v2\x1c.memos.api.v1.SSOCredentialsH\x00R\x0essoCredentials\x12!\n" + + "\fnever_expire\x18\x03 \x01(\bR\vneverExpireB\b\n" + + "\x06method\"M\n" + + "\x13PasswordCredentials\x12\x1a\n" + "\busername\x18\x01 \x01(\tR\busername\x12\x1a\n" + - "\bpassword\x18\x02 \x01(\tR\bpassword\x12!\n" + - "\fnever_expire\x18\x03 \x01(\bR\vneverExpire\"d\n" + - "\x14SignInWithSSORequest\x12\x15\n" + + "\bpassword\x18\x02 \x01(\tR\bpassword\"^\n" + + "\x0eSSOCredentials\x12\x15\n" + "\x06idp_id\x18\x01 \x01(\x05R\x05idpId\x12\x12\n" + "\x04code\x18\x02 \x01(\tR\x04code\x12!\n" + "\fredirect_uri\x18\x03 \x01(\tR\vredirectUri\"G\n" + "\rSignUpRequest\x12\x1a\n" + "\busername\x18\x01 \x01(\tR\busername\x12\x1a\n" + "\bpassword\x18\x02 \x01(\tR\bpassword\"\x10\n" + - "\x0eSignOutRequest2\xec\x03\n" + + "\x0eSignOutRequest2\x82\x03\n" + "\vAuthService\x12d\n" + "\rGetAuthStatus\x12\".memos.api.v1.GetAuthStatusRequest\x1a\x12.memos.api.v1.User\"\x1b\x82\xd3\xe4\x93\x02\x15\"\x13/api/v1/auth/status\x12V\n" + - "\x06SignIn\x12\x1b.memos.api.v1.SignInRequest\x1a\x12.memos.api.v1.User\"\x1b\x82\xd3\xe4\x93\x02\x15\"\x13/api/v1/auth/signin\x12h\n" + - "\rSignInWithSSO\x12\".memos.api.v1.SignInWithSSORequest\x1a\x12.memos.api.v1.User\"\x1f\x82\xd3\xe4\x93\x02\x19\"\x17/api/v1/auth/signin/sso\x12V\n" + + "\x06SignIn\x12\x1b.memos.api.v1.SignInRequest\x1a\x12.memos.api.v1.User\"\x1b\x82\xd3\xe4\x93\x02\x15\"\x13/api/v1/auth/signin\x12V\n" + "\x06SignUp\x12\x1b.memos.api.v1.SignUpRequest\x1a\x12.memos.api.v1.User\"\x1b\x82\xd3\xe4\x93\x02\x15\"\x13/api/v1/auth/signup\x12]\n" + "\aSignOut\x12\x1c.memos.api.v1.SignOutRequest\x1a\x16.google.protobuf.Empty\"\x1c\x82\xd3\xe4\x93\x02\x16\"\x14/api/v1/auth/signoutB\xa8\x01\n" + "\x10com.memos.api.v1B\x10AuthServiceProtoP\x01Z0github.com/usememos/memos/proto/gen/api/v1;apiv1\xa2\x02\x03MAX\xaa\x02\fMemos.Api.V1\xca\x02\fMemos\\Api\\V1\xe2\x02\x18Memos\\Api\\V1\\GPBMetadata\xea\x02\x0eMemos::Api::V1b\x06proto3" @@ -359,34 +448,35 @@ func file_api_v1_auth_service_proto_rawDescGZIP() []byte { return file_api_v1_auth_service_proto_rawDescData } -var file_api_v1_auth_service_proto_msgTypes = make([]protoimpl.MessageInfo, 6) +var file_api_v1_auth_service_proto_msgTypes = make([]protoimpl.MessageInfo, 7) var file_api_v1_auth_service_proto_goTypes = []any{ (*GetAuthStatusRequest)(nil), // 0: memos.api.v1.GetAuthStatusRequest (*GetAuthStatusResponse)(nil), // 1: memos.api.v1.GetAuthStatusResponse (*SignInRequest)(nil), // 2: memos.api.v1.SignInRequest - (*SignInWithSSORequest)(nil), // 3: memos.api.v1.SignInWithSSORequest - (*SignUpRequest)(nil), // 4: memos.api.v1.SignUpRequest - (*SignOutRequest)(nil), // 5: memos.api.v1.SignOutRequest - (*User)(nil), // 6: memos.api.v1.User - (*emptypb.Empty)(nil), // 7: google.protobuf.Empty + (*PasswordCredentials)(nil), // 3: memos.api.v1.PasswordCredentials + (*SSOCredentials)(nil), // 4: memos.api.v1.SSOCredentials + (*SignUpRequest)(nil), // 5: memos.api.v1.SignUpRequest + (*SignOutRequest)(nil), // 6: memos.api.v1.SignOutRequest + (*User)(nil), // 7: memos.api.v1.User + (*emptypb.Empty)(nil), // 8: google.protobuf.Empty } var file_api_v1_auth_service_proto_depIdxs = []int32{ - 6, // 0: memos.api.v1.GetAuthStatusResponse.user:type_name -> memos.api.v1.User - 0, // 1: memos.api.v1.AuthService.GetAuthStatus:input_type -> memos.api.v1.GetAuthStatusRequest - 2, // 2: memos.api.v1.AuthService.SignIn:input_type -> memos.api.v1.SignInRequest - 3, // 3: memos.api.v1.AuthService.SignInWithSSO:input_type -> memos.api.v1.SignInWithSSORequest - 4, // 4: memos.api.v1.AuthService.SignUp:input_type -> memos.api.v1.SignUpRequest - 5, // 5: memos.api.v1.AuthService.SignOut:input_type -> memos.api.v1.SignOutRequest - 6, // 6: memos.api.v1.AuthService.GetAuthStatus:output_type -> memos.api.v1.User - 6, // 7: memos.api.v1.AuthService.SignIn:output_type -> memos.api.v1.User - 6, // 8: memos.api.v1.AuthService.SignInWithSSO:output_type -> memos.api.v1.User - 6, // 9: memos.api.v1.AuthService.SignUp:output_type -> memos.api.v1.User - 7, // 10: memos.api.v1.AuthService.SignOut:output_type -> google.protobuf.Empty - 6, // [6:11] is the sub-list for method output_type - 1, // [1:6] is the sub-list for method input_type - 1, // [1:1] is the sub-list for extension type_name - 1, // [1:1] is the sub-list for extension extendee - 0, // [0:1] is the sub-list for field type_name + 7, // 0: memos.api.v1.GetAuthStatusResponse.user:type_name -> memos.api.v1.User + 3, // 1: memos.api.v1.SignInRequest.password_credentials:type_name -> memos.api.v1.PasswordCredentials + 4, // 2: memos.api.v1.SignInRequest.sso_credentials:type_name -> memos.api.v1.SSOCredentials + 0, // 3: memos.api.v1.AuthService.GetAuthStatus:input_type -> memos.api.v1.GetAuthStatusRequest + 2, // 4: memos.api.v1.AuthService.SignIn:input_type -> memos.api.v1.SignInRequest + 5, // 5: memos.api.v1.AuthService.SignUp:input_type -> memos.api.v1.SignUpRequest + 6, // 6: memos.api.v1.AuthService.SignOut:input_type -> memos.api.v1.SignOutRequest + 7, // 7: memos.api.v1.AuthService.GetAuthStatus:output_type -> memos.api.v1.User + 7, // 8: memos.api.v1.AuthService.SignIn:output_type -> memos.api.v1.User + 7, // 9: memos.api.v1.AuthService.SignUp:output_type -> memos.api.v1.User + 8, // 10: memos.api.v1.AuthService.SignOut:output_type -> google.protobuf.Empty + 7, // [7:11] is the sub-list for method output_type + 3, // [3:7] is the sub-list for method input_type + 3, // [3:3] is the sub-list for extension type_name + 3, // [3:3] is the sub-list for extension extendee + 0, // [0:3] is the sub-list for field type_name } func init() { file_api_v1_auth_service_proto_init() } @@ -395,13 +485,17 @@ func file_api_v1_auth_service_proto_init() { return } file_api_v1_user_service_proto_init() + file_api_v1_auth_service_proto_msgTypes[2].OneofWrappers = []any{ + (*SignInRequest_PasswordCredentials)(nil), + (*SignInRequest_SsoCredentials)(nil), + } type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_api_v1_auth_service_proto_rawDesc), len(file_api_v1_auth_service_proto_rawDesc)), NumEnums: 0, - NumMessages: 6, + NumMessages: 7, NumExtensions: 0, NumServices: 1, }, diff --git a/proto/gen/api/v1/auth_service.pb.gw.go b/proto/gen/api/v1/auth_service.pb.gw.go index 904050d53..cbb3c9c14 100644 --- a/proto/gen/api/v1/auth_service.pb.gw.go +++ b/proto/gen/api/v1/auth_service.pb.gw.go @@ -87,39 +87,6 @@ func local_request_AuthService_SignIn_0(ctx context.Context, marshaler runtime.M return msg, metadata, err } -var filter_AuthService_SignInWithSSO_0 = &utilities.DoubleArray{Encoding: map[string]int{}, Base: []int(nil), Check: []int(nil)} - -func request_AuthService_SignInWithSSO_0(ctx context.Context, marshaler runtime.Marshaler, client AuthServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var ( - protoReq SignInWithSSORequest - metadata runtime.ServerMetadata - ) - io.Copy(io.Discard, req.Body) - if err := req.ParseForm(); err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } - if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_AuthService_SignInWithSSO_0); err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } - msg, err := client.SignInWithSSO(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) - return msg, metadata, err -} - -func local_request_AuthService_SignInWithSSO_0(ctx context.Context, marshaler runtime.Marshaler, server AuthServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var ( - protoReq SignInWithSSORequest - metadata runtime.ServerMetadata - ) - if err := req.ParseForm(); err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } - if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_AuthService_SignInWithSSO_0); err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } - msg, err := server.SignInWithSSO(ctx, &protoReq) - return msg, metadata, err -} - var filter_AuthService_SignUp_0 = &utilities.DoubleArray{Encoding: map[string]int{}, Base: []int(nil), Check: []int(nil)} func request_AuthService_SignUp_0(ctx context.Context, marshaler runtime.Marshaler, client AuthServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { @@ -218,26 +185,6 @@ func RegisterAuthServiceHandlerServer(ctx context.Context, mux *runtime.ServeMux } forward_AuthService_SignIn_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) }) - mux.Handle(http.MethodPost, pattern_AuthService_SignInWithSSO_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - ctx, cancel := context.WithCancel(req.Context()) - defer cancel() - var stream runtime.ServerTransportStream - ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) - inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/memos.api.v1.AuthService/SignInWithSSO", runtime.WithHTTPPathPattern("/api/v1/auth/signin/sso")) - if err != nil { - runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) - return - } - resp, md, err := local_request_AuthService_SignInWithSSO_0(annotatedContext, inboundMarshaler, server, req, pathParams) - md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer()) - annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) - if err != nil { - runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) - return - } - forward_AuthService_SignInWithSSO_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - }) mux.Handle(http.MethodPost, pattern_AuthService_SignUp_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() @@ -352,23 +299,6 @@ func RegisterAuthServiceHandlerClient(ctx context.Context, mux *runtime.ServeMux } forward_AuthService_SignIn_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) }) - mux.Handle(http.MethodPost, pattern_AuthService_SignInWithSSO_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - ctx, cancel := context.WithCancel(req.Context()) - defer cancel() - inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/memos.api.v1.AuthService/SignInWithSSO", runtime.WithHTTPPathPattern("/api/v1/auth/signin/sso")) - if err != nil { - runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) - return - } - resp, md, err := request_AuthService_SignInWithSSO_0(annotatedContext, inboundMarshaler, client, req, pathParams) - annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) - if err != nil { - runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) - return - } - forward_AuthService_SignInWithSSO_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - }) mux.Handle(http.MethodPost, pattern_AuthService_SignUp_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() @@ -409,7 +339,6 @@ func RegisterAuthServiceHandlerClient(ctx context.Context, mux *runtime.ServeMux var ( pattern_AuthService_GetAuthStatus_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"api", "v1", "auth", "status"}, "")) pattern_AuthService_SignIn_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"api", "v1", "auth", "signin"}, "")) - pattern_AuthService_SignInWithSSO_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3, 2, 4}, []string{"api", "v1", "auth", "signin", "sso"}, "")) pattern_AuthService_SignUp_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"api", "v1", "auth", "signup"}, "")) pattern_AuthService_SignOut_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"api", "v1", "auth", "signout"}, "")) ) @@ -417,7 +346,6 @@ var ( var ( forward_AuthService_GetAuthStatus_0 = runtime.ForwardResponseMessage forward_AuthService_SignIn_0 = runtime.ForwardResponseMessage - forward_AuthService_SignInWithSSO_0 = runtime.ForwardResponseMessage forward_AuthService_SignUp_0 = runtime.ForwardResponseMessage forward_AuthService_SignOut_0 = runtime.ForwardResponseMessage ) diff --git a/proto/gen/api/v1/auth_service_grpc.pb.go b/proto/gen/api/v1/auth_service_grpc.pb.go index f270198ec..2004e2039 100644 --- a/proto/gen/api/v1/auth_service_grpc.pb.go +++ b/proto/gen/api/v1/auth_service_grpc.pb.go @@ -22,7 +22,6 @@ const _ = grpc.SupportPackageIsVersion9 const ( AuthService_GetAuthStatus_FullMethodName = "/memos.api.v1.AuthService/GetAuthStatus" AuthService_SignIn_FullMethodName = "/memos.api.v1.AuthService/SignIn" - AuthService_SignInWithSSO_FullMethodName = "/memos.api.v1.AuthService/SignInWithSSO" AuthService_SignUp_FullMethodName = "/memos.api.v1.AuthService/SignUp" AuthService_SignOut_FullMethodName = "/memos.api.v1.AuthService/SignOut" ) @@ -33,10 +32,8 @@ const ( type AuthServiceClient interface { // GetAuthStatus returns the current auth status of the user. GetAuthStatus(ctx context.Context, in *GetAuthStatusRequest, opts ...grpc.CallOption) (*User, error) - // SignIn signs in the user with the given username and password. + // SignIn signs in the user. SignIn(ctx context.Context, in *SignInRequest, opts ...grpc.CallOption) (*User, error) - // SignInWithSSO signs in the user with the given SSO code. - SignInWithSSO(ctx context.Context, in *SignInWithSSORequest, opts ...grpc.CallOption) (*User, error) // SignUp signs up the user with the given username and password. SignUp(ctx context.Context, in *SignUpRequest, opts ...grpc.CallOption) (*User, error) // SignOut signs out the user. @@ -71,16 +68,6 @@ func (c *authServiceClient) SignIn(ctx context.Context, in *SignInRequest, opts return out, nil } -func (c *authServiceClient) SignInWithSSO(ctx context.Context, in *SignInWithSSORequest, opts ...grpc.CallOption) (*User, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) - out := new(User) - err := c.cc.Invoke(ctx, AuthService_SignInWithSSO_FullMethodName, in, out, cOpts...) - if err != nil { - return nil, err - } - return out, nil -} - func (c *authServiceClient) SignUp(ctx context.Context, in *SignUpRequest, opts ...grpc.CallOption) (*User, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(User) @@ -107,10 +94,8 @@ func (c *authServiceClient) SignOut(ctx context.Context, in *SignOutRequest, opt type AuthServiceServer interface { // GetAuthStatus returns the current auth status of the user. GetAuthStatus(context.Context, *GetAuthStatusRequest) (*User, error) - // SignIn signs in the user with the given username and password. + // SignIn signs in the user. SignIn(context.Context, *SignInRequest) (*User, error) - // SignInWithSSO signs in the user with the given SSO code. - SignInWithSSO(context.Context, *SignInWithSSORequest) (*User, error) // SignUp signs up the user with the given username and password. SignUp(context.Context, *SignUpRequest) (*User, error) // SignOut signs out the user. @@ -131,9 +116,6 @@ func (UnimplementedAuthServiceServer) GetAuthStatus(context.Context, *GetAuthSta func (UnimplementedAuthServiceServer) SignIn(context.Context, *SignInRequest) (*User, error) { return nil, status.Errorf(codes.Unimplemented, "method SignIn not implemented") } -func (UnimplementedAuthServiceServer) SignInWithSSO(context.Context, *SignInWithSSORequest) (*User, error) { - return nil, status.Errorf(codes.Unimplemented, "method SignInWithSSO not implemented") -} func (UnimplementedAuthServiceServer) SignUp(context.Context, *SignUpRequest) (*User, error) { return nil, status.Errorf(codes.Unimplemented, "method SignUp not implemented") } @@ -197,24 +179,6 @@ func _AuthService_SignIn_Handler(srv interface{}, ctx context.Context, dec func( return interceptor(ctx, in, info, handler) } -func _AuthService_SignInWithSSO_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(SignInWithSSORequest) - if err := dec(in); err != nil { - return nil, err - } - if interceptor == nil { - return srv.(AuthServiceServer).SignInWithSSO(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: AuthService_SignInWithSSO_FullMethodName, - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(AuthServiceServer).SignInWithSSO(ctx, req.(*SignInWithSSORequest)) - } - return interceptor(ctx, in, info, handler) -} - func _AuthService_SignUp_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(SignUpRequest) if err := dec(in); err != nil { @@ -266,10 +230,6 @@ var AuthService_ServiceDesc = grpc.ServiceDesc{ MethodName: "SignIn", Handler: _AuthService_SignIn_Handler, }, - { - MethodName: "SignInWithSSO", - Handler: _AuthService_SignInWithSSO_Handler, - }, { MethodName: "SignUp", Handler: _AuthService_SignUp_Handler, diff --git a/proto/gen/apidocs.swagger.yaml b/proto/gen/apidocs.swagger.yaml index 8c0caa938..a6589e3ab 100644 --- a/proto/gen/apidocs.swagger.yaml +++ b/proto/gen/apidocs.swagger.yaml @@ -21,7 +21,7 @@ produces: paths: /api/v1/auth/signin: post: - summary: SignIn signs in the user with the given username and password. + summary: SignIn signs in the user. operationId: AuthService_SignIn responses: "200": @@ -33,53 +33,37 @@ paths: schema: $ref: '#/definitions/googlerpcStatus' parameters: - - name: username + - name: passwordCredentials.username description: The username to sign in with. in: query required: false type: string - - name: password + - name: passwordCredentials.password description: The password to sign in with. in: query required: false type: string - - name: neverExpire - description: Whether the session should never expire. - in: query - required: false - type: boolean - tags: - - AuthService - /api/v1/auth/signin/sso: - post: - summary: SignInWithSSO signs in the user with the given SSO code. - operationId: AuthService_SignInWithSSO - responses: - "200": - description: A successful response. - schema: - $ref: '#/definitions/v1User' - default: - description: An unexpected error response. - schema: - $ref: '#/definitions/googlerpcStatus' - parameters: - - name: idpId + - name: ssoCredentials.idpId description: The ID of the SSO provider. in: query required: false type: integer format: int32 - - name: code + - name: ssoCredentials.code description: The code to sign in with. in: query required: false type: string - - name: redirectUri + - name: ssoCredentials.redirectUri description: The redirect URI. in: query required: false type: string + - name: neverExpire + description: Whether the session should never expire. + in: query + required: false + type: boolean tags: - AuthService /api/v1/auth/signout: @@ -305,7 +289,9 @@ paths: $ref: '#/definitions/googlerpcStatus' parameters: - name: parent - description: "The parent is the owner of the memos.\r\nIf not specified or `users/-`, it will list all memos." + description: |- + The parent is the owner of the memos. + If not specified or `users/-`, it will list all memos. in: query required: false type: string @@ -316,12 +302,16 @@ paths: type: integer format: int32 - name: pageToken - description: "A page token, received from a previous `ListMemos` call.\r\nProvide this to retrieve the subsequent page." + description: |- + A page token, received from a previous `ListMemos` call. + Provide this to retrieve the subsequent page. in: query required: false type: string - name: state - description: "The state of the memos to list.\r\nDefault to `NORMAL`. Set to `ARCHIVED` to list archived memos." + description: |- + The state of the memos to list. + Default to `NORMAL`. Set to `ARCHIVED` to list archived memos. in: query required: false type: string @@ -331,12 +321,16 @@ paths: - ARCHIVED default: STATE_UNSPECIFIED - name: sort - description: "What field to sort the results by.\r\nDefault to display_time." + description: |- + What field to sort the results by. + Default to display_time. in: query required: false type: string - name: direction - description: "The direction to sort the results by.\r\nDefault to DESC." + description: |- + The direction to sort the results by. + Default to DESC. in: query required: false type: string @@ -346,12 +340,16 @@ paths: - DESC default: DIRECTION_UNSPECIFIED - name: filter - description: "Filter is a CEL expression to filter memos.\r\nRefer to `Shortcut.filter`." + description: |- + Filter is a CEL expression to filter memos. + Refer to `Shortcut.filter`. in: query required: false type: string - name: oldFilter - description: "[Deprecated] Old filter contains some specific conditions to filter memos.\r\nFormat: \"creator == 'users/{user}' && visibilities == ['PUBLIC', 'PROTECTED']\"" + description: |- + [Deprecated] Old filter contains some specific conditions to filter memos. + Format: "creator == 'users/{user}' && visibilities == ['PUBLIC', 'PROTECTED']" in: query required: false type: string @@ -396,7 +394,9 @@ paths: $ref: '#/definitions/googlerpcStatus' parameters: - name: id - description: "The id of the reaction.\r\nRefer to the `Reaction.id`." + description: |- + The id of the reaction. + Refer to the `Reaction.id`. in: path required: true type: integer @@ -662,7 +662,9 @@ paths: $ref: '#/definitions/googlerpcStatus' parameters: - name: name - description: "The resource name of the workspace setting.\r\nFormat: settings/{setting}" + description: |- + The resource name of the workspace setting. + Format: settings/{setting} in: path required: true type: string @@ -684,7 +686,9 @@ paths: $ref: '#/definitions/googlerpcStatus' parameters: - name: setting.name - description: "name is the name of the setting.\r\nFormat: settings/{setting}" + description: |- + name is the name of the setting. + Format: settings/{setting} in: path required: true type: string @@ -806,13 +810,17 @@ paths: $ref: '#/definitions/googlerpcStatus' parameters: - name: memo.name - description: "The name of the memo.\r\nFormat: memos/{memo}, memo is the user defined id or uuid." + description: |- + The name of the memo. + Format: memos/{memo}, memo is the user defined id or uuid. in: path required: true type: string pattern: memos/[^/]+ - name: memo - description: "The memo to update.\r\nThe `name` field is required." + description: |- + The memo to update. + The `name` field is required. in: body required: true schema: @@ -822,7 +830,9 @@ paths: $ref: '#/definitions/v1State' creator: type: string - title: "The name of the creator.\r\nFormat: users/{user}" + title: |- + The name of the creator. + Format: users/{user} createTime: type: string format: date-time @@ -870,7 +880,9 @@ paths: readOnly: true parent: type: string - title: "The name of the parent memo.\r\nFormat: memos/{id}" + title: |- + The name of the parent memo. + Format: memos/{id} readOnly: true snippet: type: string @@ -879,7 +891,9 @@ paths: location: $ref: '#/definitions/apiv1Location' description: The location of the memo. - title: "The memo to update.\r\nThe `name` field is required." + title: |- + The memo to update. + The `name` field is required. required: - memo tags: @@ -1075,7 +1089,9 @@ paths: $ref: '#/definitions/googlerpcStatus' parameters: - name: name - description: "The name of the activity.\r\nFormat: activities/{id}, id is the system generated auto-incremented id." + description: |- + The name of the activity. + Format: activities/{id}, id is the system generated auto-incremented id. in: path required: true type: string @@ -1434,7 +1450,9 @@ paths: $ref: '#/definitions/googlerpcStatus' parameters: - name: parent - description: "The parent is the owner of the memos.\r\nIf not specified or `users/-`, it will list all memos." + description: |- + The parent is the owner of the memos. + If not specified or `users/-`, it will list all memos. in: path required: true type: string @@ -1446,12 +1464,16 @@ paths: type: integer format: int32 - name: pageToken - description: "A page token, received from a previous `ListMemos` call.\r\nProvide this to retrieve the subsequent page." + description: |- + A page token, received from a previous `ListMemos` call. + Provide this to retrieve the subsequent page. in: query required: false type: string - name: state - description: "The state of the memos to list.\r\nDefault to `NORMAL`. Set to `ARCHIVED` to list archived memos." + description: |- + The state of the memos to list. + Default to `NORMAL`. Set to `ARCHIVED` to list archived memos. in: query required: false type: string @@ -1461,12 +1483,16 @@ paths: - ARCHIVED default: STATE_UNSPECIFIED - name: sort - description: "What field to sort the results by.\r\nDefault to display_time." + description: |- + What field to sort the results by. + Default to display_time. in: query required: false type: string - name: direction - description: "The direction to sort the results by.\r\nDefault to DESC." + description: |- + The direction to sort the results by. + Default to DESC. in: query required: false type: string @@ -1476,12 +1502,16 @@ paths: - DESC default: DIRECTION_UNSPECIFIED - name: filter - description: "Filter is a CEL expression to filter memos.\r\nRefer to `Shortcut.filter`." + description: |- + Filter is a CEL expression to filter memos. + Refer to `Shortcut.filter`. in: query required: false type: string - name: oldFilter - description: "[Deprecated] Old filter contains some specific conditions to filter memos.\r\nFormat: \"creator == 'users/{user}' && visibilities == ['PUBLIC', 'PROTECTED']\"" + description: |- + [Deprecated] Old filter contains some specific conditions to filter memos. + Format: "creator == 'users/{user}' && visibilities == ['PUBLIC', 'PROTECTED']" in: query required: false type: string @@ -1619,7 +1649,9 @@ paths: $ref: '#/definitions/googlerpcStatus' parameters: - name: parent - description: "The parent, who owns the tags.\r\nFormat: memos/{id}. Use \"memos/-\" to delete all tags." + description: |- + The parent, who owns the tags. + Format: memos/{id}. Use "memos/-" to delete all tags. in: path required: true type: string @@ -1650,7 +1682,9 @@ paths: $ref: '#/definitions/googlerpcStatus' parameters: - name: parent - description: "The parent, who owns the tags.\r\nFormat: memos/{id}. Use \"memos/-\" to rename all tags." + description: |- + The parent, who owns the tags. + Format: memos/{id}. Use "memos/-" to rename all tags. in: path required: true type: string @@ -1677,7 +1711,9 @@ paths: $ref: '#/definitions/googlerpcStatus' parameters: - name: resource.name - description: "The name of the resource.\r\nFormat: resources/{resource}, resource is the user defined if or uuid." + description: |- + The name of the resource. + Format: resources/{resource}, resource is the user defined if or uuid. in: path required: true type: string @@ -1763,7 +1799,9 @@ paths: $ref: '#/definitions/googlerpcStatus' parameters: - name: user.name - description: "The name of the user.\r\nFormat: users/{id}, id is the system generated auto-incremented id." + description: |- + The name of the user. + Format: users/{id}, id is the system generated auto-incremented id. in: path required: true type: string @@ -2026,7 +2064,9 @@ definitions: properties: memo: type: string - description: "The memo name of comment.\r\nRefer to `Memo.name`." + description: |- + The memo name of comment. + Refer to `Memo.name`. relatedMemo: type: string description: The name of related memo. @@ -2090,13 +2130,17 @@ definitions: properties: name: type: string - description: "The name of the memo.\r\nFormat: memos/{memo}, memo is the user defined id or uuid." + description: |- + The name of the memo. + Format: memos/{memo}, memo is the user defined id or uuid. readOnly: true state: $ref: '#/definitions/v1State' creator: type: string - title: "The name of the creator.\r\nFormat: users/{user}" + title: |- + The name of the creator. + Format: users/{user} createTime: type: string format: date-time @@ -2144,7 +2188,9 @@ definitions: readOnly: true parent: type: string - title: "The name of the parent memo.\r\nFormat: memos/{id}" + title: |- + The name of the parent memo. + Format: memos/{id} readOnly: true snippet: type: string @@ -2230,7 +2276,10 @@ definitions: weekStartDayOffset: type: integer format: int32 - description: "week_start_day_offset is the week start day offset from Sunday.\r\n0: Sunday, 1: Monday, 2: Tuesday, 3: Wednesday, 4: Thursday, 5: Friday, 6: Saturday\r\nDefault is Sunday." + description: |- + week_start_day_offset is the week start day offset from Sunday. + 0: Sunday, 1: Monday, 2: Tuesday, 3: Wednesday, 4: Thursday, 5: Friday, 6: Saturday + Default is Sunday. disallowChangeUsername: type: boolean description: disallow_change_username disallows changing username. @@ -2283,7 +2332,9 @@ definitions: properties: name: type: string - title: "name is the name of the setting.\r\nFormat: settings/{setting}" + title: |- + name is the name of the setting. + Format: settings/{setting} generalSetting: $ref: '#/definitions/apiv1WorkspaceGeneralSetting' storageSetting: @@ -2298,7 +2349,9 @@ definitions: description: storage_type is the storage type. filepathTemplate: type: string - title: "The template of file path.\r\ne.g. assets/{timestamp}_{filename}" + title: |- + The template of file path. + e.g. assets/{timestamp}_{filename} uploadSizeLimitMb: type: string format: int64 @@ -2457,11 +2510,15 @@ definitions: properties: name: type: string - title: "The name of the activity.\r\nFormat: activities/{id}" + title: |- + The name of the activity. + Format: activities/{id} readOnly: true creator: type: string - title: "The name of the creator.\r\nFormat: users/{user}" + title: |- + The name of the creator. + Format: users/{user} type: type: string description: The type of the activity. @@ -2723,7 +2780,9 @@ definitions: $ref: '#/definitions/apiv1Memo' nextPageToken: type: string - description: "A token, which can be sent as `page_token` to retrieve the next page.\r\nIf this field is omitted, there are no subsequent pages." + description: |- + A token, which can be sent as `page_token` to retrieve the next page. + If this field is omitted, there are no subsequent pages. v1ListNode: type: object properties: @@ -2812,7 +2871,9 @@ definitions: properties: name: type: string - title: "The name of the memo.\r\nFormat: memos/{id}" + title: |- + The name of the memo. + Format: memos/{id} uid: type: string snippet: @@ -2968,6 +3029,15 @@ definitions: items: type: object $ref: '#/definitions/v1Node' + v1PasswordCredentials: + type: object + properties: + username: + type: string + description: The username to sign in with. + password: + type: string + description: The password to sign in with. v1Reaction: type: object properties: @@ -2998,7 +3068,9 @@ definitions: properties: name: type: string - description: "The name of the resource.\r\nFormat: resources/{resource}, resource is the user defined if or uuid." + description: |- + The name of the resource. + Format: resources/{resource}, resource is the user defined if or uuid. readOnly: true createTime: type: string @@ -3032,6 +3104,19 @@ definitions: properties: markdown: type: string + v1SSOCredentials: + type: object + properties: + idpId: + type: integer + format: int32 + description: The ID of the SSO provider. + code: + type: string + description: The code to sign in with. + redirectUri: + type: string + description: The redirect URI. v1SpoilerNode: type: object properties: @@ -3132,7 +3217,9 @@ definitions: properties: name: type: string - description: "The name of the user.\r\nFormat: users/{id}, id is the system generated auto-incremented id." + description: |- + The name of the user. + Format: users/{id}, id is the system generated auto-incremented id. readOnly: true role: $ref: '#/definitions/UserRole' @@ -3182,7 +3269,9 @@ definitions: items: type: string format: date-time - description: "The timestamps when the memos were displayed.\r\nWe should return raw data to the client, and let the client format the data with the user's timezone." + description: |- + The timestamps when the memos were displayed. + We should return raw data to the client, and let the client format the data with the user's timezone. memoTypeStats: $ref: '#/definitions/UserStatsMemoTypeStats' description: The stats of memo types. @@ -3191,7 +3280,9 @@ definitions: additionalProperties: type: integer format: int32 - title: "The count of tags.\r\nFormat: \"tag1\": 1, \"tag2\": 2" + title: |- + The count of tags. + Format: "tag1": 1, "tag2": 2 pinnedMemos: type: array items: diff --git a/server/router/api/v1/auth_service.go b/server/router/api/v1/auth_service.go index 8f3f11845..c02a6dbd2 100644 --- a/server/router/api/v1/auth_service.go +++ b/server/router/api/v1/auth_service.go @@ -44,128 +44,126 @@ func (s *APIV1Service) GetAuthStatus(ctx context.Context, _ *v1pb.GetAuthStatusR } func (s *APIV1Service) SignIn(ctx context.Context, request *v1pb.SignInRequest) (*v1pb.User, error) { - user, err := s.Store.GetUser(ctx, &store.FindUser{ - Username: &request.Username, - }) - if err != nil { - return nil, status.Errorf(codes.Internal, "failed to get user, error: %v", err) - } - if user == nil { - return nil, status.Errorf(codes.InvalidArgument, unmatchedUsernameAndPasswordError) - } - // Compare the stored hashed password, with the hashed version of the password that was received. - if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(request.Password)); err != nil { - return nil, status.Errorf(codes.InvalidArgument, unmatchedUsernameAndPasswordError) - } - - workspaceGeneralSetting, err := s.Store.GetWorkspaceGeneralSetting(ctx) - if err != nil { - return nil, status.Errorf(codes.Internal, "failed to get workspace general setting, error: %v", err) - } - // Check if the password auth in is allowed. - if workspaceGeneralSetting.DisallowPasswordAuth && user.Role == store.RoleUser { - return nil, status.Errorf(codes.PermissionDenied, "password signin is not allowed") - } - if user.RowStatus == store.Archived { - return nil, status.Errorf(codes.PermissionDenied, "user has been archived with username %s", request.Username) - } - - expireTime := time.Now().Add(AccessTokenDuration) - if request.NeverExpire { - // Set the expire time to 100 years. - expireTime = time.Now().Add(100 * 365 * 24 * time.Hour) - } - if err := s.doSignIn(ctx, user, expireTime); err != nil { - return nil, status.Errorf(codes.Internal, "failed to sign in, error: %v", err) - } - return convertUserFromStore(user), nil -} - -func (s *APIV1Service) SignInWithSSO(ctx context.Context, request *v1pb.SignInWithSSORequest) (*v1pb.User, error) { - identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{ - ID: &request.IdpId, - }) - if err != nil { - return nil, status.Errorf(codes.Internal, "failed to get identity provider, error: %v", err) - } - if identityProvider == nil { - return nil, status.Errorf(codes.InvalidArgument, "identity provider not found") - } - - var userInfo *idp.IdentityProviderUserInfo - if identityProvider.Type == storepb.IdentityProvider_OAUTH2 { - oauth2IdentityProvider, err := oauth2.NewIdentityProvider(identityProvider.Config.GetOauth2Config()) + var existingUser *store.User + if passwordCredentials := request.GetPasswordCredentials(); passwordCredentials != nil { + user, err := s.Store.GetUser(ctx, &store.FindUser{ + Username: &passwordCredentials.Username, + }) if err != nil { - return nil, status.Errorf(codes.Internal, "failed to create oauth2 identity provider, error: %v", err) + return nil, status.Errorf(codes.Internal, "failed to get user, error: %v", err) } - token, err := oauth2IdentityProvider.ExchangeToken(ctx, request.RedirectUri, request.Code) - if err != nil { - return nil, status.Errorf(codes.Internal, "failed to exchange token, error: %v", err) + if user == nil { + return nil, status.Errorf(codes.InvalidArgument, unmatchedUsernameAndPasswordError) } - userInfo, err = oauth2IdentityProvider.UserInfo(token) - if err != nil { - return nil, status.Errorf(codes.Internal, "failed to get user info, error: %v", err) + // Compare the stored hashed password, with the hashed version of the password that was received. + if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(passwordCredentials.Password)); err != nil { + return nil, status.Errorf(codes.InvalidArgument, unmatchedUsernameAndPasswordError) } - } - - identifierFilter := identityProvider.IdentifierFilter - if identifierFilter != "" { - identifierFilterRegex, err := regexp.Compile(identifierFilter) + workspaceGeneralSetting, err := s.Store.GetWorkspaceGeneralSetting(ctx) if err != nil { - return nil, status.Errorf(codes.Internal, "failed to compile identifier filter regex, error: %v", err) + return nil, status.Errorf(codes.Internal, "failed to get workspace general setting, error: %v", err) } - if !identifierFilterRegex.MatchString(userInfo.Identifier) { - return nil, status.Errorf(codes.PermissionDenied, "identifier %s is not allowed", userInfo.Identifier) + // Check if the password auth in is allowed. + if workspaceGeneralSetting.DisallowPasswordAuth && user.Role == store.RoleUser { + return nil, status.Errorf(codes.PermissionDenied, "password signin is not allowed") } - } - - user, err := s.Store.GetUser(ctx, &store.FindUser{ - Username: &userInfo.Identifier, - }) - if err != nil { - return nil, status.Errorf(codes.Internal, "failed to get user, error: %v", err) - } - if user == nil { - // Check if the user is allowed to sign up. - workspaceGeneralSetting, err := s.Store.GetWorkspaceGeneralSetting(ctx) + existingUser = user + } else if ssoCredentials := request.GetSsoCredentials(); ssoCredentials != nil { + identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{ + ID: &ssoCredentials.IdpId, + }) if err != nil { - return nil, status.Errorf(codes.Internal, "failed to get workspace general setting, error: %v", err) + return nil, status.Errorf(codes.Internal, "failed to get identity provider, error: %v", err) } - if workspaceGeneralSetting.DisallowUserRegistration { - return nil, status.Errorf(codes.PermissionDenied, "user registration is not allowed") + if identityProvider == nil { + return nil, status.Errorf(codes.InvalidArgument, "identity provider not found") } - // Create a new user with the user info from the identity provider. - userCreate := &store.User{ - Username: userInfo.Identifier, - // The new signup user should be normal user by default. - Role: store.RoleUser, - Nickname: userInfo.DisplayName, - Email: userInfo.Email, - AvatarURL: userInfo.AvatarURL, + var userInfo *idp.IdentityProviderUserInfo + if identityProvider.Type == storepb.IdentityProvider_OAUTH2 { + oauth2IdentityProvider, err := oauth2.NewIdentityProvider(identityProvider.Config.GetOauth2Config()) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to create oauth2 identity provider, error: %v", err) + } + token, err := oauth2IdentityProvider.ExchangeToken(ctx, ssoCredentials.RedirectUri, ssoCredentials.Code) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to exchange token, error: %v", err) + } + userInfo, err = oauth2IdentityProvider.UserInfo(token) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get user info, error: %v", err) + } } - password, err := util.RandomString(20) - if err != nil { - return nil, status.Errorf(codes.Internal, "failed to generate random password, error: %v", err) + + identifierFilter := identityProvider.IdentifierFilter + if identifierFilter != "" { + identifierFilterRegex, err := regexp.Compile(identifierFilter) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to compile identifier filter regex, error: %v", err) + } + if !identifierFilterRegex.MatchString(userInfo.Identifier) { + return nil, status.Errorf(codes.PermissionDenied, "identifier %s is not allowed", userInfo.Identifier) + } } - passwordHash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + + user, err := s.Store.GetUser(ctx, &store.FindUser{ + Username: &userInfo.Identifier, + }) if err != nil { - return nil, status.Errorf(codes.Internal, "failed to generate password hash, error: %v", err) + return nil, status.Errorf(codes.Internal, "failed to get user, error: %v", err) } - userCreate.PasswordHash = string(passwordHash) - user, err = s.Store.CreateUser(ctx, userCreate) - if err != nil { - return nil, status.Errorf(codes.Internal, "failed to create user, error: %v", err) + if user == nil { + // Check if the user is allowed to sign up. + workspaceGeneralSetting, err := s.Store.GetWorkspaceGeneralSetting(ctx) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get workspace general setting, error: %v", err) + } + if workspaceGeneralSetting.DisallowUserRegistration { + return nil, status.Errorf(codes.PermissionDenied, "user registration is not allowed") + } + + // Create a new user with the user info from the identity provider. + userCreate := &store.User{ + Username: userInfo.Identifier, + // The new signup user should be normal user by default. + Role: store.RoleUser, + Nickname: userInfo.DisplayName, + Email: userInfo.Email, + AvatarURL: userInfo.AvatarURL, + } + password, err := util.RandomString(20) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to generate random password, error: %v", err) + } + passwordHash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to generate password hash, error: %v", err) + } + userCreate.PasswordHash = string(passwordHash) + user, err = s.Store.CreateUser(ctx, userCreate) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to create user, error: %v", err) + } } + existingUser = user } - if user.RowStatus == store.Archived { - return nil, status.Errorf(codes.PermissionDenied, "user has been archived with username %s", userInfo.Identifier) + + if existingUser == nil { + return nil, status.Errorf(codes.InvalidArgument, "invalid credentials") + } + if existingUser.RowStatus == store.Archived { + return nil, status.Errorf(codes.PermissionDenied, "user has been archived with username %s", existingUser.Username) } - if err := s.doSignIn(ctx, user, time.Now().Add(AccessTokenDuration)); err != nil { + expireTime := time.Now().Add(AccessTokenDuration) + if request.NeverExpire { + // Set the expire time to 100 years. + expireTime = time.Now().Add(100 * 365 * 24 * time.Hour) + } + if err := s.doSignIn(ctx, existingUser, expireTime); err != nil { return nil, status.Errorf(codes.Internal, "failed to sign in, error: %v", err) } - return convertUserFromStore(user), nil + return convertUserFromStore(existingUser), nil } func (s *APIV1Service) doSignIn(ctx context.Context, user *store.User, expireTime time.Time) error { diff --git a/web/src/components/PasswordSignInForm.tsx b/web/src/components/PasswordSignInForm.tsx index eb36ef942..82634ea99 100644 --- a/web/src/components/PasswordSignInForm.tsx +++ b/web/src/components/PasswordSignInForm.tsx @@ -45,7 +45,7 @@ const PasswordSignInForm = observer(() => { try { actionBtnLoadingState.setLoading(); - await authServiceClient.signIn({ username, password, neverExpire: remember }); + await authServiceClient.signIn({ passwordCredentials: { username, password }, neverExpire: remember }); await initialUserStore(); navigateTo("/"); } catch (error: any) { diff --git a/web/src/pages/AuthCallback.tsx b/web/src/pages/AuthCallback.tsx index 8488066d5..61ade8071 100644 --- a/web/src/pages/AuthCallback.tsx +++ b/web/src/pages/AuthCallback.tsx @@ -45,10 +45,12 @@ const AuthCallback = () => { const redirectUri = absolutifyLink("/auth/callback"); (async () => { try { - await authServiceClient.signInWithSSO({ - idpId: identityProviderId, - code, - redirectUri, + await authServiceClient.signIn({ + ssoCredentials: { + idpId: identityProviderId, + code, + redirectUri, + }, }); setState({ loading: false, diff --git a/web/src/types/proto/api/v1/auth_service.ts b/web/src/types/proto/api/v1/auth_service.ts index f35f8b621..86e1ade67 100644 --- a/web/src/types/proto/api/v1/auth_service.ts +++ b/web/src/types/proto/api/v1/auth_service.ts @@ -19,15 +19,26 @@ export interface GetAuthStatusResponse { } export interface SignInRequest { + /** Username and password authentication method. */ + passwordCredentials?: + | PasswordCredentials + | undefined; + /** SSO provider authentication method. */ + ssoCredentials?: + | SSOCredentials + | undefined; + /** Whether the session should never expire. */ + neverExpire: boolean; +} + +export interface PasswordCredentials { /** The username to sign in with. */ username: string; /** The password to sign in with. */ password: string; - /** Whether the session should never expire. */ - neverExpire: boolean; } -export interface SignInWithSSORequest { +export interface SSOCredentials { /** The ID of the SSO provider. */ idpId: number; /** The code to sign in with. */ @@ -127,16 +138,16 @@ export const GetAuthStatusResponse: MessageFns = { }; function createBaseSignInRequest(): SignInRequest { - return { username: "", password: "", neverExpire: false }; + return { passwordCredentials: undefined, ssoCredentials: undefined, neverExpire: false }; } export const SignInRequest: MessageFns = { encode(message: SignInRequest, writer: BinaryWriter = new BinaryWriter()): BinaryWriter { - if (message.username !== "") { - writer.uint32(10).string(message.username); + if (message.passwordCredentials !== undefined) { + PasswordCredentials.encode(message.passwordCredentials, writer.uint32(10).fork()).join(); } - if (message.password !== "") { - writer.uint32(18).string(message.password); + if (message.ssoCredentials !== undefined) { + SSOCredentials.encode(message.ssoCredentials, writer.uint32(18).fork()).join(); } if (message.neverExpire !== false) { writer.uint32(24).bool(message.neverExpire); @@ -156,7 +167,7 @@ export const SignInRequest: MessageFns = { break; } - message.username = reader.string(); + message.passwordCredentials = PasswordCredentials.decode(reader, reader.uint32()); continue; } case 2: { @@ -164,7 +175,7 @@ export const SignInRequest: MessageFns = { break; } - message.password = reader.string(); + message.ssoCredentials = SSOCredentials.decode(reader, reader.uint32()); continue; } case 3: { @@ -189,19 +200,81 @@ export const SignInRequest: MessageFns = { }, fromPartial(object: DeepPartial): SignInRequest { const message = createBaseSignInRequest(); + message.passwordCredentials = (object.passwordCredentials !== undefined && object.passwordCredentials !== null) + ? PasswordCredentials.fromPartial(object.passwordCredentials) + : undefined; + message.ssoCredentials = (object.ssoCredentials !== undefined && object.ssoCredentials !== null) + ? SSOCredentials.fromPartial(object.ssoCredentials) + : undefined; + message.neverExpire = object.neverExpire ?? false; + return message; + }, +}; + +function createBasePasswordCredentials(): PasswordCredentials { + return { username: "", password: "" }; +} + +export const PasswordCredentials: MessageFns = { + encode(message: PasswordCredentials, writer: BinaryWriter = new BinaryWriter()): BinaryWriter { + if (message.username !== "") { + writer.uint32(10).string(message.username); + } + if (message.password !== "") { + writer.uint32(18).string(message.password); + } + return writer; + }, + + decode(input: BinaryReader | Uint8Array, length?: number): PasswordCredentials { + const reader = input instanceof BinaryReader ? input : new BinaryReader(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBasePasswordCredentials(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: { + if (tag !== 10) { + break; + } + + message.username = reader.string(); + continue; + } + case 2: { + if (tag !== 18) { + break; + } + + message.password = reader.string(); + continue; + } + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skip(tag & 7); + } + return message; + }, + + create(base?: DeepPartial): PasswordCredentials { + return PasswordCredentials.fromPartial(base ?? {}); + }, + fromPartial(object: DeepPartial): PasswordCredentials { + const message = createBasePasswordCredentials(); message.username = object.username ?? ""; message.password = object.password ?? ""; - message.neverExpire = object.neverExpire ?? false; return message; }, }; -function createBaseSignInWithSSORequest(): SignInWithSSORequest { +function createBaseSSOCredentials(): SSOCredentials { return { idpId: 0, code: "", redirectUri: "" }; } -export const SignInWithSSORequest: MessageFns = { - encode(message: SignInWithSSORequest, writer: BinaryWriter = new BinaryWriter()): BinaryWriter { +export const SSOCredentials: MessageFns = { + encode(message: SSOCredentials, writer: BinaryWriter = new BinaryWriter()): BinaryWriter { if (message.idpId !== 0) { writer.uint32(8).int32(message.idpId); } @@ -214,10 +287,10 @@ export const SignInWithSSORequest: MessageFns = { return writer; }, - decode(input: BinaryReader | Uint8Array, length?: number): SignInWithSSORequest { + decode(input: BinaryReader | Uint8Array, length?: number): SSOCredentials { const reader = input instanceof BinaryReader ? input : new BinaryReader(input); let end = length === undefined ? reader.len : reader.pos + length; - const message = createBaseSignInWithSSORequest(); + const message = createBaseSSOCredentials(); while (reader.pos < end) { const tag = reader.uint32(); switch (tag >>> 3) { @@ -254,11 +327,11 @@ export const SignInWithSSORequest: MessageFns = { return message; }, - create(base?: DeepPartial): SignInWithSSORequest { - return SignInWithSSORequest.fromPartial(base ?? {}); + create(base?: DeepPartial): SSOCredentials { + return SSOCredentials.fromPartial(base ?? {}); }, - fromPartial(object: DeepPartial): SignInWithSSORequest { - const message = createBaseSignInWithSSORequest(); + fromPartial(object: DeepPartial): SSOCredentials { + const message = createBaseSSOCredentials(); message.idpId = object.idpId ?? 0; message.code = object.code ?? ""; message.redirectUri = object.redirectUri ?? ""; @@ -401,7 +474,7 @@ export const AuthServiceDefinition = { }, }, }, - /** SignIn signs in the user with the given username and password. */ + /** SignIn signs in the user. */ signIn: { name: "SignIn", requestType: SignInRequest, @@ -439,48 +512,6 @@ export const AuthServiceDefinition = { }, }, }, - /** SignInWithSSO signs in the user with the given SSO code. */ - signInWithSSO: { - name: "SignInWithSSO", - requestType: SignInWithSSORequest, - requestStream: false, - responseType: User, - responseStream: false, - options: { - _unknownFields: { - 578365826: [ - new Uint8Array([ - 25, - 34, - 23, - 47, - 97, - 112, - 105, - 47, - 118, - 49, - 47, - 97, - 117, - 116, - 104, - 47, - 115, - 105, - 103, - 110, - 105, - 110, - 47, - 115, - 115, - 111, - ]), - ], - }, - }, - }, /** SignUp signs up the user with the given username and password. */ signUp: { name: "SignUp",