Skip to content

Commit

Permalink
feat: create post vote endpoint (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
denysvitali committed Jul 4, 2023
1 parent d1f3548 commit e437dff
Show file tree
Hide file tree
Showing 13 changed files with 308 additions and 11 deletions.
21 changes: 21 additions & 0 deletions pkg/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,24 @@ func (s *Server) authenticatedUser(c *gin.Context) {

c.Set("username", username)
}

// maybeAuthenticatedUser is like authenticatedUser, but it doesn't
// require the user to be authenticated. If the user is authenticated, it
// sets the "username" context variable.
func (s *Server) maybeAuthenticatedUser(c *gin.Context) {
if !s.hasIdToken(c) {
return
}

idToken, done := s.idToken(c)
if done {
return
}

username, err := s.usernameForIdToken(idToken)
if err != nil {
return
}

c.Set("username", username)
}
2 changes: 2 additions & 0 deletions pkg/convert_response.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ func convertResponsePosts(posts []models.Post, r *models.Ring) []response.Post {
Ups: p.Ups,
Downs: p.Downs,
Nsfw: p.Nsfw,
VotedUp: p.VotedUp,
VotedDown: p.VotedDown,
})
}
return responsePosts
Expand Down
6 changes: 6 additions & 0 deletions pkg/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,9 @@ package server

const ErrInvalidPostRequest = "invalid post request"
const ErrRingDoesNotExist = "ring does not exist"
const ErrUnableToGetPost = "unable to get post"
const ErrUnableToGetVote = "unable to get vote"
const ErrUnableToSaveVote = "unable to save vote"
const ErrUnableToIncreasePostScore = "unable to increase post score"
const ErrUnableToCreateVote = "unable to create vote"
const ErrUnableToVoteUserAlreadyVoted = "unable to vote, user already voted"
3 changes: 3 additions & 0 deletions pkg/models/post.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,7 @@ type Post struct {
Ups int `json:"ups"`
Downs int `json:"downs"`
Nsfw bool `json:"nsfw"`

VotedUp bool `json:"votedUp" gorm:"-"`
VotedDown bool `json:"votedDown" gorm:"-"`
}
12 changes: 12 additions & 0 deletions pkg/models/post_action.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package models

import "time"

type PostAction struct {
Username string `json:"username" gorm:"primaryKey"`
User User `json:"user,omitempty" gorm:"foreignKey:Username"`
Post Post `json:"post,omitempty"`
PostId uint `json:"post_id" gorm:"primaryKey"`
Action string `json:"action" gorm:"type:post_action;index"`
CreatedAt time.Time
}
14 changes: 14 additions & 0 deletions pkg/rc_comments.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,20 @@ func setCommentActions(comments []models.Comment, actions map[uint]models.Commen
return comments
}

func setPostActions(posts []models.Post, actions map[uint]models.PostAction) {
for k, v := range posts {
action, ok := actions[v.ID]
if ok {
switch action.Action {
case models.ActionDownvote:
posts[k].VotedDown = true
case models.ActionUpvote:
posts[k].VotedUp = true
}
}
}
}

func setDepth(comments []models.Comment, i int) {
for k, _ := range comments {
comments[k].Depth = i
Expand Down
79 changes: 79 additions & 0 deletions pkg/repo_posts.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package server
import (
"backend/pkg/models"
"backend/pkg/request"
"errors"
"fmt"
"gorm.io/gorm"
"net/url"
)

Expand Down Expand Up @@ -62,3 +64,80 @@ func getDomain(s string) (string, error) {
}
return u.Hostname(), nil
}

func (s *Server) repoVoteAction(action models.VoteAction, username string, id int64) error {
// Check if post exists
post, err := s.repoPost(uint(id))
if err != nil {
s.logger.Warnf("unable to get post: %v", err)
return fmt.Errorf(ErrUnableToGetPost)
}

// Check if user has already voted
vote, err := s.repoGetVote(username, uint(id))
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
s.logger.Warnf("unable to get vote: %v", err)
return fmt.Errorf(ErrUnableToGetVote)
}

if vote != nil {
if vote.Action == action {
// User has already voted with this action
return fmt.Errorf(ErrUnableToVoteUserAlreadyVoted)
}

vote.Action = action

// Run in a transaction
tx := s.db.Begin()
err = tx.Save(&vote).Error
if err != nil {
s.logger.Warnf("unable to save vote: %v", err)
tx.Rollback()
return fmt.Errorf(ErrUnableToSaveVote)
}

increase := 2
if action == models.ActionDownvote {
increase = -2
}
// Vote saved, update post, add 2 to score (1 for upvote, 1 for removing downvote)
err = s.repoIncreasePostScore(tx, post.ID, increase)
if err != nil {
s.logger.Warnf("unable to increase post score by %d: %v", increase, err)
tx.Rollback()
return fmt.Errorf(ErrUnableToIncreasePostScore)
}
tx.Commit()
} else {
// User has not voted yet
vote := models.PostAction{
Username: username,
PostId: uint(id),
Action: action,
}

// Run in a transaction
tx := s.db.Begin()
err = tx.Create(&vote).Error
if err != nil {
s.logger.Warnf("unable to create vote: %v", err)
tx.Rollback()
return fmt.Errorf(ErrUnableToCreateVote)
}
// Vote saved, update post, add +-1 to score
increase := 1
if action == models.ActionDownvote {
increase = -1
}
err = s.repoIncreasePostScore(tx, post.ID, increase)
if err != nil {
s.logger.Warnf("unable to increase post score by %d: %v", increase, err)
tx.Rollback()
return fmt.Errorf(ErrUnableToIncreasePostScore)
}
tx.Commit()
}

return nil
}
13 changes: 13 additions & 0 deletions pkg/repo_posts_votes.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package server

import "backend/pkg/models"

func (s *Server) repoGetVote(username string, postId uint) (*models.PostAction, error) {
var postAction models.PostAction
tx := s.db.First(&postAction, "username = ? AND post_id = ?", username, postId)
if tx.Error != nil {
return nil, tx.Error
}

return &postAction, nil
}
2 changes: 2 additions & 0 deletions pkg/response/post.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,6 @@ type Post struct {
Ups int `json:"ups"`
Downs int `json:"downs"`
Nsfw bool `json:"nsfw"`
VotedUp bool `json:"votedUp"`
VotedDown bool `json:"votedDown"`
}
36 changes: 36 additions & 0 deletions pkg/route_post.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package server

import (
"backend/pkg/models"
"github.com/gin-gonic/gin"
"net/http"
)

func (s *Server) innerRouteVotePost(c *gin.Context, action models.VoteAction) {
postId, done := parsePostId(c)
if done {
return
}

username := c.GetString("username")
err := s.repoVoteAction(action, username, postId)
if err != nil {
if err.Error() == ErrUnableToVoteUserAlreadyVoted {
c.JSON(http.StatusOK, gin.H{"error": "user already voted"})
return
}
s.logger.Errorf("unable to vote post: %v", err)
c.JSON(500, gin.H{"error": "unable to vote post"})
return
}

c.JSON(http.StatusAccepted, gin.H{"message": "voted"})
}

func (s *Server) routeUpvotePost(c *gin.Context) {
s.innerRouteVotePost(c, models.ActionUpvote)
}

func (s *Server) routeDownvotePost(c *gin.Context) {
s.innerRouteVotePost(c, models.ActionDownvote)
}
7 changes: 5 additions & 2 deletions pkg/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,19 @@ func (s *Server) initRoutes() {
// Rings
g.GET("/rings", s.routeGetRings)
g.GET("/r/:ring", s.routeGetRing)
g.GET("/r/:ring/posts", s.getRingPosts)
g.GET("/r/:ring/posts", s.maybeAuthenticatedUser, s.getRingPosts)
g.POST("/r/:ring", s.authenticatedUser, s.routeCreateRing)

// Posts
g.POST("/posts", s.authenticatedUser, s.createPost)
g.GET("/posts/:id", s.getPost)
g.GET("/posts/:id", s.maybeAuthenticatedUser, s.getPost)
g.GET("/posts/:id/comments", s.getComments)
g.POST("/posts/:id/comments", s.postComment)
g.DELETE("/posts/:id/comments/:commentId", s.deleteComment)

g.PUT("/posts/:id/upvote", s.authenticatedUser, s.routeUpvotePost)
g.PUT("/posts/:id/downvote", s.authenticatedUser, s.routeDownvotePost)

// Comments
g.GET("/comments", s.routeGetRecentComments)
g.PUT("/posts/:id/comments/:commentId/upvote", s.upvoteComment)
Expand Down
104 changes: 95 additions & 9 deletions pkg/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,17 +116,14 @@ func validateRingName(name string) bool {
func (s *Server) initModels() error {
// Auto-migrate all the models in `models`
// Check if the comment_action enum exists
var commentActionExists bool
tx := s.db.Raw("SELECT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'comment_action');").Scan(&commentActionExists)
if tx.Error != nil {
return tx.Error
err := s.createCommentAction()
if err != nil {
return err
}

if !commentActionExists {
tx := s.db.Exec("CREATE TYPE comment_action AS ENUM ('upvote', 'downvote');")
if tx.Error != nil {
return tx.Error
}
err = s.createPostAction()
if err != nil {
return err
}

return s.db.AutoMigrate(
Expand All @@ -136,9 +133,45 @@ func (s *Server) initModels() error {
&models.User{},
&models.SocialLink{},
&models.CommentAction{},
&models.PostAction{},
)
}

func (s *Server) createCommentAction() error {
var commentActionExists bool
tx := s.db.
Raw("SELECT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'comment_action')").
Scan(&commentActionExists)
if tx.Error != nil {
return tx.Error
}

if !commentActionExists {
tx = s.db.Exec("CREATE TYPE comment_action AS ENUM ('upvote', 'downvote');")
if tx.Error != nil {
return tx.Error
}
}
return nil
}
func (s *Server) createPostAction() error {
var exists bool
tx := s.db.
Raw("SELECT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'post_action')").
Scan(&exists)
if tx.Error != nil {
return tx.Error
}

if !exists {
tx = s.db.Exec("CREATE TYPE post_action AS ENUM ('upvote', 'downvote');")
if tx.Error != nil {
return tx.Error
}
}
return nil
}

func (s *Server) getRingPosts(context *gin.Context) {
// Gets the posts in ring, sorted by score
ringName := context.Param("ring")
Expand Down Expand Up @@ -171,6 +204,41 @@ func (s *Server) getRingPosts(context *gin.Context) {
return
}

// Get the user's vote on each post
if context.GetString("username") != "" {
// Get a list of IDs for the posts
var postIds []uint
for _, p := range posts {
postIds = append(postIds, p.ID)
}

// Get votes for the posts
var votes []models.PostAction
tx = s.db.
Where("post_id IN ?", postIds).
Where("username = ?", context.GetString("username")).
Find(&votes)
if tx.Error != nil {
s.logger.Errorf("Unable to get votes for posts: %v", tx.Error)
internalServerError(context)
return
}

// Add votes to map
votesMap := make(map[uint]models.PostAction)
for _, v := range votes {
votesMap[v.PostId] = v
}

// Add votes to posts
for i, p := range posts {
if v, ok := votesMap[p.ID]; ok {
posts[i].VotedUp = v.Action == models.ActionUpvote
posts[i].VotedDown = v.Action == models.ActionDownvote
}
}
}

context.JSON(200, convertResponsePosts(posts, r))
}

Expand Down Expand Up @@ -666,6 +734,24 @@ func (s *Server) repoRecentComments(after *uint64) ([]models.Comment, error) {
return comments, tx.Error
}

func (s *Server) repoIncreasePostScore(tx *gorm.DB, id uint, value int) error {
tx = tx.Model(&models.Post{}).
Where("id = ?", id).
UpdateColumn("score", gorm.Expr("score + ?", value))
return tx.Error
}

func (s *Server) repoPostAction(postId uint, username string) (*models.PostAction, error) {
var postAction models.PostAction
tx := s.db.
Where("post_id = ? AND username = ?", postId, username).
First(&postAction)
if tx.Error != nil {
return nil, tx.Error
}
return &postAction, nil
}

func parseNilAsEmpty[T any](element T) T {
// Given a RedditPosts struct, parse the struct tag for the `json` key and check if it does
// have the `nilasempty` key. If it does, then set the value to an empty array.
Expand Down
Loading

0 comments on commit e437dff

Please sign in to comment.