Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: resource visibility #1777

Merged
merged 4 commits into from
May 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions server/resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,18 @@ func (s *Server) registerResourcePublicRoutes(g *echo.Group) {
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("resourceId"))).SetInternal(err)
}

resourceVisibility, err := CheckResourceVisibility(ctx, s.Store, resourceID)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Failed to get resource visibility").SetInternal(err)
}

// Protected resource require a logined user
userID, ok := c.Get(getUserIDContextKey()).(int)
if resourceVisibility == store.Protected && (!ok || userID <= 0) {
return echo.NewHTTPError(http.StatusUnauthorized, "Resource visibility not match").SetInternal(err)
}

publicID, err := url.QueryUnescape(c.Param("publicId"))
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("publicID is invalid: %s", c.Param("publicId"))).SetInternal(err)
Expand All @@ -370,6 +382,11 @@ func (s *Server) registerResourcePublicRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find resource by ID: %v", resourceID)).SetInternal(err)
}

// Private resource require logined user is the creator
if resourceVisibility == store.Private && (!ok || userID != resource.CreatorID) {
return echo.NewHTTPError(http.StatusUnauthorized, "Resource visibility not match").SetInternal(err)
}

blob := resource.Blob
if resource.InternalPath != "" {
src, err := os.Open(resource.InternalPath)
Expand Down Expand Up @@ -401,6 +418,18 @@ func (s *Server) registerResourcePublicRoutes(g *echo.Group) {
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("resourceId"))).SetInternal(err)
}

resourceVisibility, err := CheckResourceVisibility(ctx, s.Store, resourceID)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Failed to get resource visibility").SetInternal(err)
}

// Protected resource require a logined user
userID, ok := c.Get(getUserIDContextKey()).(int)
if resourceVisibility == store.Protected && (!ok || userID <= 0) {
return echo.NewHTTPError(http.StatusUnauthorized, "Resource visibility not match").SetInternal(err)
}

publicID, err := url.QueryUnescape(c.Param("publicId"))
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("publicID is invalid: %s", c.Param("publicId"))).SetInternal(err)
Expand All @@ -420,6 +449,11 @@ func (s *Server) registerResourcePublicRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find resource by ID: %v", resourceID)).SetInternal(err)
}

// Private resource require logined user is the creator
if resourceVisibility == store.Private && (!ok || userID != resource.CreatorID) {
return echo.NewHTTPError(http.StatusUnauthorized, "Resource visibility not match").SetInternal(err)
}

blob := resource.Blob
if resource.InternalPath != "" {
resourcePath := resource.InternalPath
Expand Down Expand Up @@ -552,3 +586,48 @@ func getOrGenerateThumbnailImage(srcBlob []byte, dstPath string) ([]byte, error)
}
return dstBlob, nil
}

func CheckResourceVisibility(ctx context.Context, s *store.Store, resourceID int) (store.Visibility, error) {
memoResourceFind := &api.MemoResourceFind{
ResourceID: &resourceID,
}

memoResources, err := s.FindMemoResourceList(ctx, memoResourceFind)
if err != nil {
return store.Private, err
}

// If resource is belongs to no memo, it'll always PRIVATE
if len(memoResources) == 0 {
return store.Private, nil
}

memoIDs := make([]int, 0, len(memoResources))
for _, memoResource := range memoResources {
memoIDs = append(memoIDs, memoResource.MemoID)
}
visibilityList, err := s.FindMemosVisibilityList(ctx, memoIDs)
if err != nil {
return store.Private, err
}

var isProtected bool
for _, visibility := range visibilityList {
// If any memo is PUBLIC, resource do
if visibility == store.Public {
return store.Public, nil
}

if visibility == store.Protected {
isProtected = true
}
}

// If no memo is PUBLIC, but any memo is PROTECTED, resource do
if isProtected {
return store.Protected, nil
}

// If all memo is PRIVATE, the resource do
return store.Private, nil
}
40 changes: 40 additions & 0 deletions store/memo.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,46 @@ func (s *Store) DeleteMemo(ctx context.Context, delete *DeleteMemoMessage) error
return err
}

func (s *Store) FindMemosVisibilityList(ctx context.Context, memoIDs []int) ([]Visibility, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()

args := make([]any, 0, len(memoIDs))
list := make([]string, 0, len(memoIDs))
for _, memoID := range memoIDs {
args = append(args, memoID)
list = append(list, "?")
}

where := fmt.Sprintf("id in (%s)", strings.Join(list, ","))

query := `SELECT DISTINCT(visibility) FROM memo WHERE ` + where

rows, err := tx.QueryContext(ctx, query, args...)
if err != nil {
return nil, FormatError(err)
}
defer rows.Close()

visibilityList := make([]Visibility, 0)
for rows.Next() {
var visibility Visibility
if err := rows.Scan(&visibility); err != nil {
return nil, FormatError(err)
}
visibilityList = append(visibilityList, visibility)
}

if err := rows.Err(); err != nil {
return nil, FormatError(err)
}

return visibilityList, nil
}

func listMemos(ctx context.Context, tx *sql.Tx, find *FindMemoMessage) ([]*MemoMessage, error) {
where, args := []string{"1 = 1"}, []any{}

Expand Down