chore: update acl middleware

pull/135/head
boojack 3 years ago
parent 873973a088
commit d83f204d8c

@ -51,11 +51,10 @@ func removeUserSession(ctx echo.Context) error {
return nil return nil
} }
// Use session to store user.id. func aclMiddleware(s *Server, next echo.HandlerFunc) echo.HandlerFunc {
func BasicAuthMiddleware(s *Server, next echo.HandlerFunc) echo.HandlerFunc {
return func(ctx echo.Context) error { return func(ctx echo.Context) error {
// Skip auth for some paths. // Skip auth for some paths.
if common.HasPrefixes(ctx.Path(), "/api/auth", "/api/ping", "/api/status", "/api/user/:userId") { if common.HasPrefixes(ctx.Path(), "/api/auth", "/api/ping", "/api/status", "/api/user/:id") {
return next(ctx) return next(ctx)
} }
@ -76,42 +75,36 @@ func BasicAuthMiddleware(s *Server, next echo.HandlerFunc) echo.HandlerFunc {
} }
} }
needAuth := true
if common.HasPrefixes(ctx.Path(), "/api/memo", "/api/tag", "/api/shortcut") && ctx.Request().Method == http.MethodGet {
if _, err := strconv.Atoi(ctx.QueryParam("creatorId")); err == nil {
needAuth = false
}
}
{ {
sess, _ := session.Get("session", ctx) sess, _ := session.Get("session", ctx)
userIDValue := sess.Values[userIDContextKey] userIDValue := sess.Values[userIDContextKey]
if userIDValue == nil && needAuth { if userIDValue != nil {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing userID in session") userID, _ := strconv.Atoi(fmt.Sprintf("%v", userIDValue))
} userFind := &api.UserFind{
ID: &userID,
userID, err := strconv.Atoi(fmt.Sprintf("%v", userIDValue)) }
if err != nil && needAuth { user, err := s.Store.FindUser(userFind)
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to malformatted user id in the session.").SetInternal(err) if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find user by ID: %d", userID)).SetInternal(err)
}
if user != nil {
if user.RowStatus == api.Archived {
return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with email %s", user.Email))
}
ctx.Set(getUserIDContextKey(), userID)
}
} }
}
userFind := &api.UserFind{ if common.HasPrefixes(ctx.Path(), "/api/memo", "/api/tag", "/api/shortcut") && ctx.Request().Method == http.MethodGet {
ID: &userID, if _, err := strconv.Atoi(ctx.QueryParam("creatorId")); err == nil {
} return next(ctx)
user, err := s.Store.FindUser(userFind)
if err != nil && needAuth {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find user by ID: %d", userID)).SetInternal(err)
}
if needAuth {
if user == nil {
return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("Not found user ID: %d", userID))
} else if user.RowStatus == api.Archived {
return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with email %s", user.Email))
}
} }
}
// Save userID into context. userID := ctx.Get(getUserIDContextKey())
ctx.Set(getUserIDContextKey(), userID) if userID == nil {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing userID in session")
} }
return next(ctx) return next(ctx)

@ -72,8 +72,8 @@ func (s *Server) registerMemoRoutes(g *echo.Group) {
memoFind.CreatorID = &userID memoFind.CreatorID = &userID
} }
currentUserID := c.Get(getUserIDContextKey()).(int) currentUserID, ok := c.Get(getUserIDContextKey()).(int)
if currentUserID == api.UNKNOWN_ID { if !ok {
if memoFind.CreatorID == nil { if memoFind.CreatorID == nil {
return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find memo") return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find memo")
} }

@ -58,7 +58,7 @@ func NewServer(profile *profile.Profile) *Server {
apiGroup := e.Group("/api") apiGroup := e.Group("/api")
apiGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc { apiGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return BasicAuthMiddleware(s, next) return aclMiddleware(s, next)
}) })
s.registerSystemRoutes(apiGroup) s.registerSystemRoutes(apiGroup)
s.registerAuthRoutes(apiGroup) s.registerAuthRoutes(apiGroup)

@ -25,8 +25,8 @@ func (s *Server) registerTagRoutes(g *echo.Group) {
memoFind.CreatorID = &userID memoFind.CreatorID = &userID
} }
currentUserID := c.Get(getUserIDContextKey()).(int) currentUserID, ok := c.Get(getUserIDContextKey()).(int)
if currentUserID == api.UNKNOWN_ID { if !ok {
if memoFind.CreatorID == nil { if memoFind.CreatorID == nil {
return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find memo") return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find memo")
} }

@ -83,12 +83,11 @@ func (s *Server) registerUserRoutes(g *echo.Group) {
// GET /api/user/me is used to check if the user is logged in. // GET /api/user/me is used to check if the user is logged in.
g.GET("/user/me", func(c echo.Context) error { g.GET("/user/me", func(c echo.Context) error {
userSessionID := c.Get(getUserIDContextKey()) userID, ok := c.Get(getUserIDContextKey()).(int)
if userSessionID == nil { if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing auth session") return echo.NewHTTPError(http.StatusUnauthorized, "Missing auth session")
} }
userID := userSessionID.(int)
userFind := &api.UserFind{ userFind := &api.UserFind{
ID: &userID, ID: &userID,
} }

@ -255,7 +255,6 @@ func findUserList(db *sql.DB, find *api.UserFind) ([]*userRaw, error) {
&userRaw.UpdatedTs, &userRaw.UpdatedTs,
&userRaw.RowStatus, &userRaw.RowStatus,
); err != nil { ); err != nil {
fmt.Println(err)
return nil, FormatError(err) return nil, FormatError(err)
} }

Loading…
Cancel
Save