From e437dfff2528972f89ea0d1547c7c83d1186e439 Mon Sep 17 00:00:00 2001 From: Denys Vitali Date: Wed, 5 Jul 2023 01:12:42 +0200 Subject: [PATCH] feat: create post vote endpoint (#2) --- pkg/auth.go | 21 ++++++++ pkg/convert_response.go | 2 + pkg/error.go | 6 +++ pkg/models/post.go | 3 ++ pkg/models/post_action.go | 12 +++++ pkg/rc_comments.go | 14 +++++ pkg/repo_posts.go | 79 +++++++++++++++++++++++++++++ pkg/repo_posts_votes.go | 13 +++++ pkg/response/post.go | 2 + pkg/route_post.go | 36 +++++++++++++ pkg/routes.go | 7 ++- pkg/server.go | 104 ++++++++++++++++++++++++++++++++++---- pkg/server_post.go | 20 ++++++++ 13 files changed, 308 insertions(+), 11 deletions(-) create mode 100644 pkg/models/post_action.go create mode 100644 pkg/repo_posts_votes.go create mode 100644 pkg/route_post.go diff --git a/pkg/auth.go b/pkg/auth.go index 3aa7cd7..ecd25b9 100644 --- a/pkg/auth.go +++ b/pkg/auth.go @@ -29,3 +29,24 @@ func (s *Server) authenticatedUser(c *gin.Context) { c.Set("username", username) } + +// maybeAuthenticatedUser is like authenticatedUser, but it doesn't +// require the user to be authenticated. If the user is authenticated, it +// sets the "username" context variable. +func (s *Server) maybeAuthenticatedUser(c *gin.Context) { + if !s.hasIdToken(c) { + return + } + + idToken, done := s.idToken(c) + if done { + return + } + + username, err := s.usernameForIdToken(idToken) + if err != nil { + return + } + + c.Set("username", username) +} diff --git a/pkg/convert_response.go b/pkg/convert_response.go index 0d99f02..913be88 100644 --- a/pkg/convert_response.go +++ b/pkg/convert_response.go @@ -24,6 +24,8 @@ func convertResponsePosts(posts []models.Post, r *models.Ring) []response.Post { Ups: p.Ups, Downs: p.Downs, Nsfw: p.Nsfw, + VotedUp: p.VotedUp, + VotedDown: p.VotedDown, }) } return responsePosts diff --git a/pkg/error.go b/pkg/error.go index ac3082b..4d2dc90 100644 --- a/pkg/error.go +++ b/pkg/error.go @@ -2,3 +2,9 @@ package server const ErrInvalidPostRequest = "invalid post request" const ErrRingDoesNotExist = "ring does not exist" +const ErrUnableToGetPost = "unable to get post" +const ErrUnableToGetVote = "unable to get vote" +const ErrUnableToSaveVote = "unable to save vote" +const ErrUnableToIncreasePostScore = "unable to increase post score" +const ErrUnableToCreateVote = "unable to create vote" +const ErrUnableToVoteUserAlreadyVoted = "unable to vote, user already voted" diff --git a/pkg/models/post.go b/pkg/models/post.go index a3ea3e9..bd7dab5 100644 --- a/pkg/models/post.go +++ b/pkg/models/post.go @@ -16,4 +16,7 @@ type Post struct { Ups int `json:"ups"` Downs int `json:"downs"` Nsfw bool `json:"nsfw"` + + VotedUp bool `json:"votedUp" gorm:"-"` + VotedDown bool `json:"votedDown" gorm:"-"` } diff --git a/pkg/models/post_action.go b/pkg/models/post_action.go new file mode 100644 index 0000000..9551143 --- /dev/null +++ b/pkg/models/post_action.go @@ -0,0 +1,12 @@ +package models + +import "time" + +type PostAction struct { + Username string `json:"username" gorm:"primaryKey"` + User User `json:"user,omitempty" gorm:"foreignKey:Username"` + Post Post `json:"post,omitempty"` + PostId uint `json:"post_id" gorm:"primaryKey"` + Action string `json:"action" gorm:"type:post_action;index"` + CreatedAt time.Time +} diff --git a/pkg/rc_comments.go b/pkg/rc_comments.go index ce9db96..481ba2c 100644 --- a/pkg/rc_comments.go +++ b/pkg/rc_comments.go @@ -126,6 +126,20 @@ func setCommentActions(comments []models.Comment, actions map[uint]models.Commen return comments } +func setPostActions(posts []models.Post, actions map[uint]models.PostAction) { + for k, v := range posts { + action, ok := actions[v.ID] + if ok { + switch action.Action { + case models.ActionDownvote: + posts[k].VotedDown = true + case models.ActionUpvote: + posts[k].VotedUp = true + } + } + } +} + func setDepth(comments []models.Comment, i int) { for k, _ := range comments { comments[k].Depth = i diff --git a/pkg/repo_posts.go b/pkg/repo_posts.go index 7f8c8f8..015111e 100644 --- a/pkg/repo_posts.go +++ b/pkg/repo_posts.go @@ -3,7 +3,9 @@ package server import ( "backend/pkg/models" "backend/pkg/request" + "errors" "fmt" + "gorm.io/gorm" "net/url" ) @@ -62,3 +64,80 @@ func getDomain(s string) (string, error) { } return u.Hostname(), nil } + +func (s *Server) repoVoteAction(action models.VoteAction, username string, id int64) error { + // Check if post exists + post, err := s.repoPost(uint(id)) + if err != nil { + s.logger.Warnf("unable to get post: %v", err) + return fmt.Errorf(ErrUnableToGetPost) + } + + // Check if user has already voted + vote, err := s.repoGetVote(username, uint(id)) + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + s.logger.Warnf("unable to get vote: %v", err) + return fmt.Errorf(ErrUnableToGetVote) + } + + if vote != nil { + if vote.Action == action { + // User has already voted with this action + return fmt.Errorf(ErrUnableToVoteUserAlreadyVoted) + } + + vote.Action = action + + // Run in a transaction + tx := s.db.Begin() + err = tx.Save(&vote).Error + if err != nil { + s.logger.Warnf("unable to save vote: %v", err) + tx.Rollback() + return fmt.Errorf(ErrUnableToSaveVote) + } + + increase := 2 + if action == models.ActionDownvote { + increase = -2 + } + // Vote saved, update post, add 2 to score (1 for upvote, 1 for removing downvote) + err = s.repoIncreasePostScore(tx, post.ID, increase) + if err != nil { + s.logger.Warnf("unable to increase post score by %d: %v", increase, err) + tx.Rollback() + return fmt.Errorf(ErrUnableToIncreasePostScore) + } + tx.Commit() + } else { + // User has not voted yet + vote := models.PostAction{ + Username: username, + PostId: uint(id), + Action: action, + } + + // Run in a transaction + tx := s.db.Begin() + err = tx.Create(&vote).Error + if err != nil { + s.logger.Warnf("unable to create vote: %v", err) + tx.Rollback() + return fmt.Errorf(ErrUnableToCreateVote) + } + // Vote saved, update post, add +-1 to score + increase := 1 + if action == models.ActionDownvote { + increase = -1 + } + err = s.repoIncreasePostScore(tx, post.ID, increase) + if err != nil { + s.logger.Warnf("unable to increase post score by %d: %v", increase, err) + tx.Rollback() + return fmt.Errorf(ErrUnableToIncreasePostScore) + } + tx.Commit() + } + + return nil +} diff --git a/pkg/repo_posts_votes.go b/pkg/repo_posts_votes.go new file mode 100644 index 0000000..ae4e2e3 --- /dev/null +++ b/pkg/repo_posts_votes.go @@ -0,0 +1,13 @@ +package server + +import "backend/pkg/models" + +func (s *Server) repoGetVote(username string, postId uint) (*models.PostAction, error) { + var postAction models.PostAction + tx := s.db.First(&postAction, "username = ? AND post_id = ?", username, postId) + if tx.Error != nil { + return nil, tx.Error + } + + return &postAction, nil +} diff --git a/pkg/response/post.go b/pkg/response/post.go index 9e3dd3a..91e1d46 100644 --- a/pkg/response/post.go +++ b/pkg/response/post.go @@ -25,4 +25,6 @@ type Post struct { Ups int `json:"ups"` Downs int `json:"downs"` Nsfw bool `json:"nsfw"` + VotedUp bool `json:"votedUp"` + VotedDown bool `json:"votedDown"` } diff --git a/pkg/route_post.go b/pkg/route_post.go new file mode 100644 index 0000000..94e603e --- /dev/null +++ b/pkg/route_post.go @@ -0,0 +1,36 @@ +package server + +import ( + "backend/pkg/models" + "github.com/gin-gonic/gin" + "net/http" +) + +func (s *Server) innerRouteVotePost(c *gin.Context, action models.VoteAction) { + postId, done := parsePostId(c) + if done { + return + } + + username := c.GetString("username") + err := s.repoVoteAction(action, username, postId) + if err != nil { + if err.Error() == ErrUnableToVoteUserAlreadyVoted { + c.JSON(http.StatusOK, gin.H{"error": "user already voted"}) + return + } + s.logger.Errorf("unable to vote post: %v", err) + c.JSON(500, gin.H{"error": "unable to vote post"}) + return + } + + c.JSON(http.StatusAccepted, gin.H{"message": "voted"}) +} + +func (s *Server) routeUpvotePost(c *gin.Context) { + s.innerRouteVotePost(c, models.ActionUpvote) +} + +func (s *Server) routeDownvotePost(c *gin.Context) { + s.innerRouteVotePost(c, models.ActionDownvote) +} diff --git a/pkg/routes.go b/pkg/routes.go index e398468..9f86f80 100644 --- a/pkg/routes.go +++ b/pkg/routes.go @@ -28,16 +28,19 @@ func (s *Server) initRoutes() { // Rings g.GET("/rings", s.routeGetRings) g.GET("/r/:ring", s.routeGetRing) - g.GET("/r/:ring/posts", s.getRingPosts) + g.GET("/r/:ring/posts", s.maybeAuthenticatedUser, s.getRingPosts) g.POST("/r/:ring", s.authenticatedUser, s.routeCreateRing) // Posts g.POST("/posts", s.authenticatedUser, s.createPost) - g.GET("/posts/:id", s.getPost) + g.GET("/posts/:id", s.maybeAuthenticatedUser, s.getPost) g.GET("/posts/:id/comments", s.getComments) g.POST("/posts/:id/comments", s.postComment) g.DELETE("/posts/:id/comments/:commentId", s.deleteComment) + g.PUT("/posts/:id/upvote", s.authenticatedUser, s.routeUpvotePost) + g.PUT("/posts/:id/downvote", s.authenticatedUser, s.routeDownvotePost) + // Comments g.GET("/comments", s.routeGetRecentComments) g.PUT("/posts/:id/comments/:commentId/upvote", s.upvoteComment) diff --git a/pkg/server.go b/pkg/server.go index 611104b..35aef5a 100644 --- a/pkg/server.go +++ b/pkg/server.go @@ -116,17 +116,14 @@ func validateRingName(name string) bool { func (s *Server) initModels() error { // Auto-migrate all the models in `models` // Check if the comment_action enum exists - var commentActionExists bool - tx := s.db.Raw("SELECT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'comment_action');").Scan(&commentActionExists) - if tx.Error != nil { - return tx.Error + err := s.createCommentAction() + if err != nil { + return err } - if !commentActionExists { - tx := s.db.Exec("CREATE TYPE comment_action AS ENUM ('upvote', 'downvote');") - if tx.Error != nil { - return tx.Error - } + err = s.createPostAction() + if err != nil { + return err } return s.db.AutoMigrate( @@ -136,9 +133,45 @@ func (s *Server) initModels() error { &models.User{}, &models.SocialLink{}, &models.CommentAction{}, + &models.PostAction{}, ) } +func (s *Server) createCommentAction() error { + var commentActionExists bool + tx := s.db. + Raw("SELECT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'comment_action')"). + Scan(&commentActionExists) + if tx.Error != nil { + return tx.Error + } + + if !commentActionExists { + tx = s.db.Exec("CREATE TYPE comment_action AS ENUM ('upvote', 'downvote');") + if tx.Error != nil { + return tx.Error + } + } + return nil +} +func (s *Server) createPostAction() error { + var exists bool + tx := s.db. + Raw("SELECT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'post_action')"). + Scan(&exists) + if tx.Error != nil { + return tx.Error + } + + if !exists { + tx = s.db.Exec("CREATE TYPE post_action AS ENUM ('upvote', 'downvote');") + if tx.Error != nil { + return tx.Error + } + } + return nil +} + func (s *Server) getRingPosts(context *gin.Context) { // Gets the posts in ring, sorted by score ringName := context.Param("ring") @@ -171,6 +204,41 @@ func (s *Server) getRingPosts(context *gin.Context) { return } + // Get the user's vote on each post + if context.GetString("username") != "" { + // Get a list of IDs for the posts + var postIds []uint + for _, p := range posts { + postIds = append(postIds, p.ID) + } + + // Get votes for the posts + var votes []models.PostAction + tx = s.db. + Where("post_id IN ?", postIds). + Where("username = ?", context.GetString("username")). + Find(&votes) + if tx.Error != nil { + s.logger.Errorf("Unable to get votes for posts: %v", tx.Error) + internalServerError(context) + return + } + + // Add votes to map + votesMap := make(map[uint]models.PostAction) + for _, v := range votes { + votesMap[v.PostId] = v + } + + // Add votes to posts + for i, p := range posts { + if v, ok := votesMap[p.ID]; ok { + posts[i].VotedUp = v.Action == models.ActionUpvote + posts[i].VotedDown = v.Action == models.ActionDownvote + } + } + } + context.JSON(200, convertResponsePosts(posts, r)) } @@ -666,6 +734,24 @@ func (s *Server) repoRecentComments(after *uint64) ([]models.Comment, error) { return comments, tx.Error } +func (s *Server) repoIncreasePostScore(tx *gorm.DB, id uint, value int) error { + tx = tx.Model(&models.Post{}). + Where("id = ?", id). + UpdateColumn("score", gorm.Expr("score + ?", value)) + return tx.Error +} + +func (s *Server) repoPostAction(postId uint, username string) (*models.PostAction, error) { + var postAction models.PostAction + tx := s.db. + Where("post_id = ? AND username = ?", postId, username). + First(&postAction) + if tx.Error != nil { + return nil, tx.Error + } + return &postAction, nil +} + func parseNilAsEmpty[T any](element T) T { // Given a RedditPosts struct, parse the struct tag for the `json` key and check if it does // have the `nilasempty` key. If it does, then set the value to an empty array. diff --git a/pkg/server_post.go b/pkg/server_post.go index c27ed77..28a3139 100644 --- a/pkg/server_post.go +++ b/pkg/server_post.go @@ -1,6 +1,7 @@ package server import ( + "backend/pkg/models" "errors" "github.com/gin-gonic/gin" "gorm.io/gorm" @@ -25,6 +26,25 @@ func (s *Server) getPost(c *gin.Context) { return } + // If user is logged, check if user has upvoted this post + username := c.GetString("username") + if username != "" { + action, err := s.repoPostAction(uint(id), username) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + // User hasn't voted this post + c.JSON(200, post) + return + } else { + s.logger.Errorf("unable to check if user %s has upvoted post %d: %v", username, id, err) + internalServerError(c) + return + } + } + post.VotedUp = action.Action == models.ActionUpvote + post.VotedDown = action.Action == models.ActionDownvote + } + c.JSON(200, post) }