diff --git a/internal/app/admin_server/controller/admin/list.go b/internal/app/admin_server/controller/admin/list.go index b25a223cc..ecaae9a0d 100644 --- a/internal/app/admin_server/controller/admin/list.go +++ b/internal/app/admin_server/controller/admin/list.go @@ -15,10 +15,10 @@ import ( type Query struct { schema.Query - Status *model.AdminStatus `json:"status"` // 管理员状态 + Status *model.AdminStatus `json:"status" url:"status" validate:"omitempty,number" comment:"管理员状态"` // 管理员状态 } -func GetList(c helper.Context, input Query) (res schema.Response) { +func GetList(c helper.Context, query Query) (res schema.Response) { var ( err error data = make([]schema.AdminProfile, 0) @@ -40,15 +40,17 @@ func GetList(c helper.Context, input Query) (res schema.Response) { helper.Response(&res, data, meta, err) }() - query := input.Query - query.Normalize() + if err = query.Validate(); err != nil { + return + } + list := make([]model.Admin, 0) filter := map[string]interface{}{} - if input.Status != nil { - filter["status"] = *input.Status + if query.Status != nil { + filter["status"] = *query.Status } if err = query.Order(database.Db.Limit(query.Limit).Offset(query.Limit * query.Page)).Where(filter).Find(&list).Error; err != nil { diff --git a/internal/app/admin_server/controller/banner/create.go b/internal/app/admin_server/controller/banner/create.go index a63ed4079..5a3444146 100644 --- a/internal/app/admin_server/controller/banner/create.go +++ b/internal/app/admin_server/controller/banner/create.go @@ -18,7 +18,7 @@ import ( type CreateParams struct { Image string `json:"image" validate:"required,url,max=255" comment:"图片地址"` // 图片 URL Href string `json:"href" validate:"required,url,max=255" comment:"图片跳转的地址"` // 图片跳转的 URL - Platform model.BannerPlatform `json:"platform" validate:"required,max=32,oneof=web app" comment:"平台"` // 用于哪个平台, web/app + Platform model.BannerPlatform `json:"platform" validate:"required,max=32,oneof=pc app" comment:"平台"` // 用于哪个平台, web/app Description *string `json:"description" validate:"omitempty,max=255" comment:"描述"` // Banner 描述 Priority *int `json:"priority" validate:"omitempty,gt=0" comment:"优先级"` // 优先级,用于排序 Identifier *string `json:"identifier" validate:"omitempty,max=32" comment:"APP 标识符"` // APP 跳转标识符 diff --git a/internal/app/admin_server/controller/banner/list.go b/internal/app/admin_server/controller/banner/list.go index feb42f524..acb6eefda 100644 --- a/internal/app/admin_server/controller/banner/list.go +++ b/internal/app/admin_server/controller/banner/list.go @@ -15,11 +15,11 @@ import ( type Query struct { schema.Query - Platform *model.BannerPlatform `json:"platform" form:"platform"` // 根据平台筛选 - Active *bool `json:"active" form:"active"` // 是否激活 + Platform *model.BannerPlatform `json:"platform" url:"platform" validate:"omitempty,oneof=pc app" comment:"平台"` // 根据平台筛选 + Active *bool `json:"active" url:"active" validate:"omitempty" comment:"是否激活"` // 是否激活 } -func GetBannerList(c helper.Context, q Query) (res schema.Response) { +func GetBannerList(c helper.Context, query Query) (res schema.Response) { var ( err error data = make([]schema.Banner, 0) @@ -41,20 +41,26 @@ func GetBannerList(c helper.Context, q Query) (res schema.Response) { helper.Response(&res, data, meta, err) }() - query := q.Query - query.Normalize() + if err = query.Validate(); err != nil { + return + } + + if err = query.Validate(); err != nil { + return + } + list := make([]model.Banner, 0) filter := map[string]interface{}{} - if q.Platform != nil { - filter["platform"] = *q.Platform + if query.Platform != nil { + filter["platform"] = *query.Platform } - if q.Active != nil { - filter["active"] = *q.Active + if query.Active != nil { + filter["active"] = *query.Active } else { filter["active"] = true } diff --git a/internal/app/admin_server/controller/banner/update.go b/internal/app/admin_server/controller/banner/update.go index ec7a21842..7b8ff29d3 100644 --- a/internal/app/admin_server/controller/banner/update.go +++ b/internal/app/admin_server/controller/banner/update.go @@ -18,7 +18,7 @@ import ( type UpdateParams struct { Image *string `json:"image" validate:"omitempty,url,max=255" comment:"图片地址"` // 图片 URL Href *string `json:"href" validate:"omitempty,url,max=255" comment:"图片跳转的地址"` // 图片跳转的 URL - Platform *model.BannerPlatform `json:"platform" validate:"omitempty,max=32,oneof=web app" comment:"平台"` // 用于哪个平台, web/app + Platform *model.BannerPlatform `json:"platform" validate:"omitempty,max=32,oneof=pc app" comment:"平台"` // 用于哪个平台, web/app Description *string `json:"description" validate:"omitempty,max=255" comment:"描述"` // Banner 描述 Priority *int `json:"priority" validate:"omitempty,gt=0" comment:"优先级"` // 优先级,用于排序 Identifier *string `json:"identifier" validate:"omitempty,max=32" comment:"APP 标识符"` // APP 跳转标识符 diff --git a/internal/app/admin_server/controller/config/list.go b/internal/app/admin_server/controller/config/list.go index 82012b17b..7fc6a41d4 100644 --- a/internal/app/admin_server/controller/config/list.go +++ b/internal/app/admin_server/controller/config/list.go @@ -18,7 +18,7 @@ type Query struct { schema.Query } -func GetList(c helper.Context, q Query) (res schema.Response) { +func GetList(c helper.Context, query Query) (res schema.Response) { var ( err error data = make([]schema.Config, 0) @@ -49,6 +49,12 @@ func GetList(c helper.Context, q Query) (res schema.Response) { helper.Response(&res, data, meta, err) }() + query.Normalize() + + if err = query.Validate(); err != nil { + return + } + tx = database.Db.Begin() adminInfo := model.Admin{ @@ -68,10 +74,6 @@ func GetList(c helper.Context, q Query) (res schema.Response) { return } - query := q.Query - - query.Normalize() - list := make([]model.Config, 0) filter := map[string]interface{}{} diff --git a/internal/app/admin_server/controller/help/list.go b/internal/app/admin_server/controller/help/list.go index 70d82fa5d..383f4a276 100644 --- a/internal/app/admin_server/controller/help/list.go +++ b/internal/app/admin_server/controller/help/list.go @@ -15,11 +15,11 @@ import ( type Query struct { schema.Query - Status *model.HelpStatus `json:"status" form:"status"` // 根据状态筛选 - Type *model.HelpType `json:"type" form:"type"` // 根据类型筛选 + Status *model.HelpStatus `json:"status" url:"status" validate:"omitempty,number" comment:"状态"` // 根据状态筛选 + Type *model.HelpType `json:"type" url:"type" validate:"omitempty" comment:"类型"` // 根据类型筛选 } -func GetHelpList(c helper.Context, q Query) (res schema.Response) { +func GetHelpList(c helper.Context, query Query) (res schema.Response) { var ( err error data = make([]schema.Help, 0) @@ -41,20 +41,22 @@ func GetHelpList(c helper.Context, q Query) (res schema.Response) { helper.Response(&res, data, meta, err) }() - query := q.Query - query.Normalize() + if err = query.Validate(); err != nil { + return + } + list := make([]model.Help, 0) filter := map[string]interface{}{} - if q.Status != nil { - filter["status"] = *q.Status + if query.Status != nil { + filter["status"] = *query.Status } - if q.Type != nil { - filter["type"] = *q.Type + if query.Type != nil { + filter["type"] = *query.Type } var total int64 diff --git a/internal/app/admin_server/controller/logger/login/list.go b/internal/app/admin_server/controller/logger/login/list.go index c98281684..b2a52e6c6 100644 --- a/internal/app/admin_server/controller/logger/login/list.go +++ b/internal/app/admin_server/controller/logger/login/list.go @@ -15,13 +15,13 @@ import ( type Query struct { schema.Query - Uid *string `json:"uid" form:"uid"` // 根据用户 ID 筛选 - Type *int `json:"type" form:"type"` // 根据类型筛选 - Command *int `json:"command" form:"command"` // 根据登陆命令筛选 - Ip *string `json:"ip"` // 根据 IP 筛选 + Uid *string `json:"uid" url:"uid" validate:"omitempty" comment:"用户ID"` // 根据用户 ID 筛选 + Type *int `json:"type" url:"type" validate:"omitempty,number" comment:"类型"` // 根据类型筛选 + Command *int `json:"command" url:"command" validate:"omitempty,number" comment:"登陆命令"` // 根据登陆命令筛选 + Ip *string `json:"ip" url:"ip" validate:"omitempty,ip" comment:"IP"` // 根据 IP 筛选 } -func GetLoginLogs(c helper.Context, q Query) (res schema.Response) { +func GetLoginLogs(c helper.Context, query Query) (res schema.Response) { var ( err error data = make([]schema.LogLogin, 0) @@ -43,28 +43,30 @@ func GetLoginLogs(c helper.Context, q Query) (res schema.Response) { helper.Response(&res, data, meta, err) }() - query := q.Query - query.Normalize() + if err = query.Validate(); err != nil { + return + } + list := make([]model.LoginLog, 0) filter := map[string]interface{}{} - if q.Uid != nil { - filter["uid"] = *q.Uid + if query.Uid != nil { + filter["uid"] = *query.Uid } - if q.Type != nil { - filter["type"] = *q.Type + if query.Type != nil { + filter["type"] = *query.Type } - if q.Command != nil { - filter["command"] = *q.Command + if query.Command != nil { + filter["command"] = *query.Command } - if q.Ip != nil { - filter["last_ip"] = *q.Ip + if query.Ip != nil { + filter["last_ip"] = *query.Ip } var total int64 diff --git a/internal/app/admin_server/controller/message/list.go b/internal/app/admin_server/controller/message/list.go index 8e60cbd0e..2aaf17114 100644 --- a/internal/app/admin_server/controller/message/list.go +++ b/internal/app/admin_server/controller/message/list.go @@ -15,21 +15,17 @@ import ( type Query struct { schema.Query - Status *model.MessageStatus `json:"status" form:"status"` - Read *bool `json:"read" form:"read"` -} - -type QueryAdmin struct { - Query - Uid *string `json:"uid" form:"uid"` // 指定某个用户ID + Status *model.MessageStatus `json:"status" url:"status" validate:"omitempty,number" comment:"状态"` + Read *bool `json:"read" url:"read" validate:"omitempty" comment:"是否已读"` + Uid *string `json:"uid" url:"uid" validate:"omitempty" comment:"用户ID"` } // 用户获取自己的消息列表 -func GetMessageListByUser(c helper.Context, input Query) (res schema.Response) { +func GetMessageListByAdmin(c helper.Context, query Query) (res schema.Response) { var ( err error - data = make([]schema.Message, 0) // 接口输出的数据 - list = make([]model.Message, 0) // 数据库查询出的原始数据 + data = make([]schema.MessageAdmin, 0) // 接口输出的数据 + list = make([]model.Message, 0) // 数据库查询出的原始数据 meta = &schema.Meta{} ) @@ -48,101 +44,33 @@ func GetMessageListByUser(c helper.Context, input Query) (res schema.Response) { helper.Response(&res, data, meta, err) }() - query := input.Query - query.Normalize() - var total int64 - - filter := map[string]interface{}{} - - filter["uid"] = c.Uid - - if input.Read != nil { - filter["read"] = *input.Read - } - - if input.Status != nil { - filter["status"] = *input.Status - } - - if err = query.Order(database.Db.Limit(query.Limit).Offset(query.Limit * query.Page)).Where(filter).Find(&list).Error; err != nil { + if err = query.Validate(); err != nil { return } - if err = database.Db.Model(model.Message{}).Where(filter).Count(&total).Error; err != nil { - return - } - - for _, v := range list { - d := schema.Message{} - if er := mapstructure.Decode(v, &d.MessagePure); er != nil { - err = er - return - } - d.CreatedAt = v.CreatedAt.Format(time.RFC3339Nano) - d.UpdatedAt = v.UpdatedAt.Format(time.RFC3339Nano) - data = append(data, d) - } - - meta.Total = total - meta.Num = len(data) - meta.Page = query.Page - meta.Limit = query.Limit - meta.Sort = query.Sort - - return -} - -// 用户获取自己的消息列表 -func GetMessageListByAdmin(c helper.Context, input QueryAdmin) (res schema.Response) { - var ( - err error - data = make([]schema.MessageAdmin, 0) // 接口输出的数据 - list = make([]model.Message, 0) // 数据库查询出的原始数据 - meta = &schema.Meta{} - ) - - defer func() { - if r := recover(); r != nil { - switch t := r.(type) { - case string: - err = errors.New(t) - case error: - err = t - default: - err = exception.Unknown - } - } - - helper.Response(&res, data, meta, err) - }() - - query := input.Query - - query.Normalize() - - filter := model.Message{} + filter := map[string]interface{}{} - if input.Uid != nil { - filter.Uid = *input.Uid + if query.Uid != nil { + filter["uid"] = *query.Uid } - if input.Read != nil { - filter.Read = *input.Read + if query.Read != nil { + filter["read"] = *query.Read } - if input.Status != nil { - filter.Status = *input.Status + if query.Status != nil { + filter["status"] = *query.Status } var total int64 - if err = query.Order(database.Db.Limit(query.Limit).Offset(query.Limit * query.Page)).Where(&filter).Find(&list).Error; err != nil { + if err = query.Order(database.Db.Limit(query.Limit).Offset(query.Limit * query.Page)).Where(filter).Find(&list).Error; err != nil { return } - if err = database.Db.Model(&filter).Where(&filter).Count(&total).Error; err != nil { + if err = database.Db.Model(model.Message{}).Where(filter).Count(&total).Error; err != nil { return } @@ -168,7 +96,7 @@ func GetMessageListByAdmin(c helper.Context, input QueryAdmin) (res schema.Respo var GetMessageListByAdminRouter = router.Handler(func(c router.Context) { var ( - input QueryAdmin + input Query ) c.ResponseFunc(c.ShouldBindQuery(&input), func() schema.Response { diff --git a/internal/app/admin_server/controller/message/list_test.go b/internal/app/admin_server/controller/message/list_test.go index 573e3245e..9a0a60014 100644 --- a/internal/app/admin_server/controller/message/list_test.go +++ b/internal/app/admin_server/controller/message/list_test.go @@ -3,6 +3,7 @@ package message_test import ( "encoding/json" + "fmt" "github.com/axetroy/go-server/internal/app/admin_server/controller/message" "github.com/axetroy/go-server/internal/library/helper" "github.com/axetroy/go-server/internal/schema" @@ -14,88 +15,6 @@ import ( "testing" ) -func TestGetMessageListByUser(t *testing.T) { - - { - var ( - data = make([]schema.Message, 0) - ) - query := schema.Query{ - Limit: 20, - } - r := message.GetMessageListByUser(helper.Context{ - Uid: "123123", - }, message.Query{ - Query: query, - }) - - assert.Equal(t, schema.StatusSuccess, r.Status) - assert.Equal(t, "", r.Message) - - assert.Nil(t, r.Decode(&data)) - assert.Equal(t, query.Limit, r.Meta.Limit) - assert.Equal(t, schema.DefaultPage, r.Meta.Page) - assert.Equal(t, 0, r.Meta.Num) - assert.Equal(t, int64(0), r.Meta.Total) - } - - adminInfo, _ := tester.LoginAdmin() - - userInfo, _ := tester.CreateUser() - - defer tester.DeleteUserByUserName(userInfo.Username) - - { - var ( - title = "test" - content = "test" - ) - - r := message.Create(helper.Context{ - Uid: adminInfo.Id, - }, message.CreateMessageParams{ - Uid: userInfo.Id, - Title: title, - Content: content, - }) - - assert.Equal(t, schema.StatusSuccess, r.Status) - assert.Equal(t, "", r.Message) - - n := schema.Message{} - - assert.Nil(t, r.Decode(&n)) - - defer message.DeleteMessageById(n.Id) - } - - // 3. 获取列表 - { - data := make([]schema.Message, 0) - - query := schema.Query{ - Limit: 20, - } - r := message.GetMessageListByUser(helper.Context{ - Uid: userInfo.Id, - }, message.Query{ - Query: query, - }) - - assert.Equal(t, schema.StatusSuccess, r.Status) - assert.Equal(t, "", r.Message) - - assert.Nil(t, r.Decode(&data)) - - assert.Equal(t, query.Limit, r.Meta.Limit) - assert.Equal(t, schema.DefaultPage, r.Meta.Page) - assert.Equal(t, 1, r.Meta.Num) - assert.Equal(t, int64(1), r.Meta.Total) - - assert.Len(t, data, 1) - } -} - func TestGetMessageListByAdmin(t *testing.T) { { @@ -105,20 +24,21 @@ func TestGetMessageListByAdmin(t *testing.T) { query := schema.Query{ Limit: 20, } - r := message.GetMessageListByUser(helper.Context{ + r := message.GetMessageListByAdmin(helper.Context{ Uid: "123123", }, message.Query{ Query: query, }) + fmt.Printf("%+v\n", r) + assert.Equal(t, schema.StatusSuccess, r.Status) assert.Equal(t, "", r.Message) assert.Nil(t, r.Decode(&data)) + assert.NotNil(t, r.Meta.Limit) assert.Equal(t, query.Limit, r.Meta.Limit) assert.Equal(t, schema.DefaultPage, r.Meta.Page) - assert.Equal(t, 0, r.Meta.Num) - assert.Equal(t, int64(0), r.Meta.Total) } adminInfo, _ := tester.LoginAdmin() @@ -162,9 +82,7 @@ func TestGetMessageListByAdmin(t *testing.T) { } r := message.GetMessageListByAdmin(helper.Context{ Uid: adminInfo.Id, - }, message.QueryAdmin{ - Query: query, - }) + }, query) assert.Equal(t, schema.StatusSuccess, r.Status) assert.Equal(t, "", r.Message) @@ -179,75 +97,6 @@ func TestGetMessageListByAdmin(t *testing.T) { } } -func TestGetMessageListByUserRouter(t *testing.T) { - adminInfo, _ := tester.LoginAdmin() - - userInfo, _ := tester.CreateUser() - - defer tester.DeleteUserByUserName(userInfo.Username) - - { - var ( - title = "test" - content = "test" - ) - - r := message.Create(helper.Context{ - Uid: adminInfo.Id, - }, message.CreateMessageParams{ - Uid: userInfo.Id, - Title: title, - Content: content, - }) - - assert.Equal(t, schema.StatusSuccess, r.Status) - assert.Equal(t, "", r.Message) - - n := schema.Message{} - - assert.Nil(t, r.Decode(&n)) - - //defer message.DeleteMessageById(n.Id) - } - - header := mocker.Header{ - "Authorization": token.Prefix + " " + userInfo.Token, - } - - { - r := tester.HttpUser.Get("/v1/message", nil, &header) - - res := schema.Response{} - - assert.Equal(t, http.StatusOK, r.Code) - - if !assert.Nil(t, json.Unmarshal(r.Body.Bytes(), &res)) { - return - } - - if !assert.Equal(t, schema.StatusSuccess, res.Status) { - return - } - - if !assert.Equal(t, "", res.Message) { - return - } - - messages := make([]schema.Message, 0) - - assert.Nil(t, res.Decode(&messages)) - - assert.True(t, len(messages) > 0) - - for _, b := range messages { - assert.IsType(t, "string", b.Title) - assert.IsType(t, "string", b.Content) - assert.IsType(t, "string", b.CreatedAt) - assert.IsType(t, "string", b.UpdatedAt) - } - } -} - func TestGetMessageListByAdminRouter(t *testing.T) { adminInfo, _ := tester.LoginAdmin() diff --git a/internal/app/admin_server/controller/news/list.go b/internal/app/admin_server/controller/news/list.go index ccd5830ca..b0c3999e9 100644 --- a/internal/app/admin_server/controller/news/list.go +++ b/internal/app/admin_server/controller/news/list.go @@ -15,11 +15,11 @@ import ( type Query struct { schema.Query - Status *model.NewsStatus `json:"status" form:"status"` - Type *model.NewsType `json:"type" form:"type"` + Status *model.NewsStatus `json:"status" url:"status" validate:"omitempty,number" comment:"状态"` + Type *model.NewsType `json:"type" url:"type" validate:"omitempty" comment:"状态"` } -func GetNewsList(input Query) (res schema.Response) { +func GetNewsList(query Query) (res schema.Response) { var ( err error data = make([]schema.News, 0) // 接口输出的数据 @@ -42,18 +42,20 @@ func GetNewsList(input Query) (res schema.Response) { helper.Response(&res, data, meta, err) }() - query := input.Query - query.Normalize() + if err = query.Validate(); err != nil { + return + } + filter := map[string]interface{}{} - if input.Status != nil { - filter["status"] = *input.Status + if query.Status != nil { + filter["status"] = *query.Status } - if input.Type != nil { - filter["type"] = *input.Type + if query.Type != nil { + filter["type"] = *query.Type } if err = query.Order(database.Db.Limit(query.Limit).Offset(query.Limit * query.Page)).Where(filter).Find(&list).Error; err != nil { diff --git a/internal/app/admin_server/controller/notification/list.go b/internal/app/admin_server/controller/notification/list.go index 40bb7e752..bf138afac 100644 --- a/internal/app/admin_server/controller/notification/list.go +++ b/internal/app/admin_server/controller/notification/list.go @@ -20,7 +20,7 @@ type Query struct { } // GetList get notification list -func GetNotificationListByAdmin(c helper.Context, input Query) (res schema.Response) { +func GetNotificationListByAdmin(c helper.Context, query Query) (res schema.Response) { var ( err error data = make([]schema.NotificationAdmin, 0) @@ -51,10 +51,12 @@ func GetNotificationListByAdmin(c helper.Context, input Query) (res schema.Respo helper.Response(&res, data, meta, err) }() - query := input.Query - query.Normalize() + if err = query.Validate(); err != nil { + return + } + tx = database.Db.Begin() list := make([]model.Notification, 0) diff --git a/internal/app/admin_server/controller/report/list.go b/internal/app/admin_server/controller/report/list.go index 4b7b57efa..7585aecf0 100644 --- a/internal/app/admin_server/controller/report/list.go +++ b/internal/app/admin_server/controller/report/list.go @@ -15,16 +15,12 @@ import ( type Query struct { schema.Query - Type *model.ReportType `json:"type" form:"type"` // 类型 - Status *model.ReportStatus `json:"status" form:"status"` // 状态 + Uid *string `json:"uid" url:"uid" validate:"omitempty,max=32" comment:"用户ID"` // 用户ID + Type *model.ReportType `json:"type" url:"type" validate:"omitempty" comment:"类型"` // 类型 + Status *model.ReportStatus `json:"status" url:"status" validate:"omitempty,number" comment:"状态"` // 状态 } -type QueryAdmin struct { - Query - Uid string `json:"uid"` -} - -func GetListByAdmin(c helper.Context, input QueryAdmin) (res schema.Response) { +func GetListByAdmin(c helper.Context, query Query) (res schema.Response) { var ( err error data = make([]schema.Report, 0) @@ -46,20 +42,26 @@ func GetListByAdmin(c helper.Context, input QueryAdmin) (res schema.Response) { helper.Response(&res, data, meta, err) }() - query := input.Query - query.Normalize() + if err = query.Validate(); err != nil { + return + } + list := make([]model.Report, 0) filter := map[string]interface{}{} - if input.Type != nil { - filter["type"] = *input.Type + if query.Uid != nil { + filter["uid"] = *query.Uid + } + + if query.Type != nil { + filter["type"] = *query.Type } - if input.Status != nil { - filter["status"] = *input.Status + if query.Status != nil { + filter["status"] = *query.Status } if err = query.Order(database.Db.Limit(query.Limit).Offset(query.Limit * query.Page)).Where(filter).Find(&list).Error; err != nil { @@ -90,7 +92,7 @@ func GetListByAdmin(c helper.Context, input QueryAdmin) (res schema.Response) { var GetListByAdminRouter = router.Handler(func(c router.Context) { var ( - input QueryAdmin + input Query ) c.ResponseFunc(c.ShouldBindQuery(&input), func() schema.Response { diff --git a/internal/app/admin_server/controller/report/list_test.go b/internal/app/admin_server/controller/report/list_test.go index fcdbb6c7a..4c0722fab 100644 --- a/internal/app/admin_server/controller/report/list_test.go +++ b/internal/app/admin_server/controller/report/list_test.go @@ -53,7 +53,7 @@ func TestGetListByAdmin(t *testing.T) { // 获取列表 { - r := report.GetListByAdmin(helper.Context{Uid: adminInfo.Id}, report.QueryAdmin{}) + r := report.GetListByAdmin(helper.Context{Uid: adminInfo.Id}, report.Query{}) assert.Equal(t, schema.StatusSuccess, r.Status) assert.Equal(t, "", r.Message) diff --git a/internal/app/user_server/controller/address/list.go b/internal/app/user_server/controller/address/list.go index 49cbc7ea5..78fb30c05 100644 --- a/internal/app/user_server/controller/address/list.go +++ b/internal/app/user_server/controller/address/list.go @@ -19,7 +19,7 @@ type Query struct { //Status model.NewsStatus `json:"status" form:"status"` } -func GetAddressListByUser(c helper.Context, input Query) (res schema.Response) { +func GetAddressListByUser(c helper.Context, query Query) (res schema.Response) { var ( err error data = make([]schema.Address, 0) // 输出到外部的结果 @@ -43,10 +43,12 @@ func GetAddressListByUser(c helper.Context, input Query) (res schema.Response) { helper.Response(&res, data, meta, err) }() - query := input.Query - query.Normalize() + if err = query.Validate(); err != nil { + return + } + tx = database.Db.Begin() filter := map[string]interface{}{} diff --git a/internal/app/user_server/controller/banner/list.go b/internal/app/user_server/controller/banner/list.go index feb42f524..439ebb201 100644 --- a/internal/app/user_server/controller/banner/list.go +++ b/internal/app/user_server/controller/banner/list.go @@ -15,11 +15,11 @@ import ( type Query struct { schema.Query - Platform *model.BannerPlatform `json:"platform" form:"platform"` // 根据平台筛选 - Active *bool `json:"active" form:"active"` // 是否激活 + Platform *model.BannerPlatform `json:"platform" url:"platform" validate:"omitempty,oneof=pc app" comment:"平台"` // 根据平台筛选 + Active *bool `json:"active" url:"active" validate:"omitempty" comment:"是否激活"` // 是否激活 } -func GetBannerList(c helper.Context, q Query) (res schema.Response) { +func GetBannerList(c helper.Context, query Query) (res schema.Response) { var ( err error data = make([]schema.Banner, 0) @@ -41,20 +41,22 @@ func GetBannerList(c helper.Context, q Query) (res schema.Response) { helper.Response(&res, data, meta, err) }() - query := q.Query - query.Normalize() + if err = query.Validate(); err != nil { + return + } + list := make([]model.Banner, 0) filter := map[string]interface{}{} - if q.Platform != nil { - filter["platform"] = *q.Platform + if query.Platform != nil { + filter["platform"] = *query.Platform } - if q.Active != nil { - filter["active"] = *q.Active + if query.Active != nil { + filter["active"] = *query.Active } else { filter["active"] = true } diff --git a/internal/app/user_server/controller/invite/list.go b/internal/app/user_server/controller/invite/list.go index f4ddc2f77..454557211 100644 --- a/internal/app/user_server/controller/invite/list.go +++ b/internal/app/user_server/controller/invite/list.go @@ -15,7 +15,7 @@ type Query struct { schema.Query } -func GetInviteListByUser(input Query) (res schema.Response) { +func GetInviteListByUser(query Query) (res schema.Response) { var ( err error data = make([]model.InviteHistory, 0) @@ -37,10 +37,12 @@ func GetInviteListByUser(input Query) (res schema.Response) { helper.Response(&res, data, meta, err) }() - query := input.Query - query.Normalize() + if err = query.Validate(); err != nil { + return + } + filter := map[string]interface{}{} if err = query.Order(database.Db.Limit(query.Limit).Offset(query.Limit * query.Page).Where(filter)).Find(&data).Error; err != nil { diff --git a/internal/app/user_server/controller/message/list.go b/internal/app/user_server/controller/message/list.go index b18dbd739..6eccd494a 100644 --- a/internal/app/user_server/controller/message/list.go +++ b/internal/app/user_server/controller/message/list.go @@ -15,12 +15,12 @@ import ( type Query struct { schema.Query - Status *model.MessageStatus `json:"status" form:"status"` - Read *bool `json:"read" form:"read"` + Status *model.MessageStatus `json:"status" url:"status" validate:"omitempty,number" comment:"状态"` + Read *bool `json:"read" url:"read" validate:"omitempty" comment:"是否已读"` } // 用户获取自己的消息列表 -func GetMessageListByUser(c helper.Context, input Query) (res schema.Response) { +func GetMessageListByUser(c helper.Context, query Query) (res schema.Response) { var ( err error data = make([]schema.Message, 0) // 接口输出的数据 @@ -43,22 +43,24 @@ func GetMessageListByUser(c helper.Context, input Query) (res schema.Response) { helper.Response(&res, data, meta, err) }() - query := input.Query - query.Normalize() + if err = query.Validate(); err != nil { + return + } + var total int64 filter := map[string]interface{}{} filter["uid"] = c.Uid - if input.Read != nil { - filter["read"] = *input.Read + if query.Read != nil { + filter["read"] = *query.Read } - if input.Status != nil { - filter["status"] = *input.Status + if query.Status != nil { + filter["status"] = *query.Status } if err = query.Order(database.Db.Limit(query.Limit).Offset(query.Limit * query.Page)).Where(filter).Find(&list).Error; err != nil { diff --git a/internal/app/user_server/controller/news/list.go b/internal/app/user_server/controller/news/list.go index ccd5830ca..98c296eb5 100644 --- a/internal/app/user_server/controller/news/list.go +++ b/internal/app/user_server/controller/news/list.go @@ -15,11 +15,11 @@ import ( type Query struct { schema.Query - Status *model.NewsStatus `json:"status" form:"status"` - Type *model.NewsType `json:"type" form:"type"` + Status *model.NewsStatus `json:"status" url:"status" validate:"omitempty,number" comment:"状态"` + Type *model.NewsType `json:"type" url:"type" validate:"omitempty" comment:"类型"` } -func GetNewsList(input Query) (res schema.Response) { +func GetNewsList(query Query) (res schema.Response) { var ( err error data = make([]schema.News, 0) // 接口输出的数据 @@ -42,18 +42,20 @@ func GetNewsList(input Query) (res schema.Response) { helper.Response(&res, data, meta, err) }() - query := input.Query - query.Normalize() + if err = query.Validate(); err != nil { + return + } + filter := map[string]interface{}{} - if input.Status != nil { - filter["status"] = *input.Status + if query.Status != nil { + filter["status"] = *query.Status } - if input.Type != nil { - filter["type"] = *input.Type + if query.Type != nil { + filter["type"] = *query.Type } if err = query.Order(database.Db.Limit(query.Limit).Offset(query.Limit * query.Page)).Where(filter).Find(&list).Error; err != nil { diff --git a/internal/app/user_server/controller/notification/list.go b/internal/app/user_server/controller/notification/list.go index 2bc2b5cdf..c4db6b8de 100644 --- a/internal/app/user_server/controller/notification/list.go +++ b/internal/app/user_server/controller/notification/list.go @@ -20,7 +20,7 @@ type Query struct { } // GetList get notification list -func GetNotificationListByUser(c helper.Context, input Query) (res schema.Response) { +func GetNotificationListByUser(c helper.Context, query Query) (res schema.Response) { var ( err error data = make([]schema.Notification, 0) @@ -51,10 +51,12 @@ func GetNotificationListByUser(c helper.Context, input Query) (res schema.Respon helper.Response(&res, data, meta, err) }() - query := input.Query - query.Normalize() + if err = query.Validate(); err != nil { + return + } + tx = database.Db.Begin() list := make([]model.Notification, 0) diff --git a/internal/app/user_server/controller/report/list.go b/internal/app/user_server/controller/report/list.go index 7a21b7967..f7b613a4a 100644 --- a/internal/app/user_server/controller/report/list.go +++ b/internal/app/user_server/controller/report/list.go @@ -15,16 +15,11 @@ import ( type Query struct { schema.Query - Type *model.ReportType `json:"type" form:"type"` // 类型 - Status *model.ReportStatus `json:"status" form:"status"` // 状态 + Type *model.ReportType `json:"type" url:"type" validate:"omitempty,max=16" comment:"类型"` // 类型 + Status *model.ReportStatus `json:"status" url:"status" validate:"omitempty" comment:"状态"` // 状态 } -type QueryAdmin struct { - Query - Uid string `json:"uid"` -} - -func GetList(c helper.Context, input Query) (res schema.Response) { +func GetList(c helper.Context, query Query) (res schema.Response) { var ( err error data = make([]schema.Report, 0) @@ -46,22 +41,24 @@ func GetList(c helper.Context, input Query) (res schema.Response) { helper.Response(&res, data, meta, err) }() - query := input.Query - query.Normalize() + if err = query.Validate(); err != nil { + return + } + list := make([]model.Report, 0) filter := map[string]interface{}{} filter["uid"] = c.Uid - if input.Type != nil { - filter["type"] = *input.Type + if query.Type != nil { + filter["type"] = *query.Type } - if input.Status != nil { - filter["status"] = *input.Status + if query.Status != nil { + filter["status"] = *query.Status } if err = query.Order(database.Db.Limit(query.Limit).Offset(query.Limit * query.Page)).Where(filter).Find(&list).Error; err != nil { diff --git a/internal/app/user_server/controller/transfer/history.go b/internal/app/user_server/controller/transfer/history.go index 2b2fc630b..087218bb8 100644 --- a/internal/app/user_server/controller/transfer/history.go +++ b/internal/app/user_server/controller/transfer/history.go @@ -18,7 +18,7 @@ type Query struct { schema.Query } -func GetHistory(c helper.Context, input Query) (res schema.Response) { +func GetHistory(c helper.Context, query Query) (res schema.Response) { var ( err error tx *gorm.DB @@ -49,6 +49,12 @@ func GetHistory(c helper.Context, input Query) (res schema.Response) { helper.Response(&res, data, meta, err) }() + query.Normalize() + + if err = query.Validate(); err != nil { + return + } + tx = database.Db.Begin() userInfo := model.User{Id: c.Uid} @@ -60,10 +66,6 @@ func GetHistory(c helper.Context, input Query) (res schema.Response) { return } - query := input.Query - - query.Normalize() - list := make([]model.TransferLog, 0) condition := QueryParams{ diff --git a/internal/app/user_server/controller/user/profile.go b/internal/app/user_server/controller/user/profile.go index a5459264c..66ba0d1b4 100644 --- a/internal/app/user_server/controller/user/profile.go +++ b/internal/app/user_server/controller/user/profile.go @@ -16,22 +16,22 @@ import ( ) type UpdateProfileParams struct { - Username *string `json:"username"` // 用户名,部分用户有机会修改自己的用户名,比如微信注册的帐号 - Nickname *string `json:"nickname" valid:"length(1|36)~昵称长度为1-36位"` - Gender *model.Gender `json:"gender"` - Avatar *string `json:"avatar"` - Wechat *UpdateWechatProfileParams `json:"wechat"` // 更新微信绑定的帐号相关 + Username *string `json:"username" validate:"omitempty,max=32" comment:"用户名"` // 用户名,部分用户有机会修改自己的用户名,比如微信注册的帐号 + Nickname *string `json:"nickname" validate:"omitempty,max=32" comment:"昵称"` + Gender *model.Gender `json:"gender" validate:"omitempty,number,oneof=0 1 2" comment:"性别"` + Avatar *string `json:"avatar" validate:"omitempty,url,max=255" comment:"头像"` + Wechat *UpdateWechatProfileParams `json:"wechat" validate:"omitempty" comment:"微信绑定信息"` // 更新微信绑定的帐号相关 } // 绑定的微信信息帐号相关 type UpdateWechatProfileParams struct { - Nickname *string `json:"nickname"` // 用户昵称 - AvatarUrl *string `json:"avatar_url"` // 用户头像 - Gender *int `json:"gender"` // 性别 - Country *string `json:"country"` // 国家 - Province *string `json:"province"` // 省份 - City *string `json:"city"` // 城市 - Language *string `json:"language"` // 语言 + Nickname *string `json:"nickname" validate:"omitempty,max=32" comment:"微信昵称"` // 用户昵称 + AvatarUrl *string `json:"avatar_url" validate:"omitempty,url,max=255" comment:"微信头像"` // 用户头像 + Gender *int `json:"gender" validate:"omitempty,number" comment:"性别"` // 性别 + Country *string `json:"country" validate:"omitempty,max=32" comment:"国家"` // 国家 + Province *string `json:"province" validate:"omitempty,max=32" comment:"省份"` // 省份 + City *string `json:"city" validate:"omitempty,max=32" comment:"城市"` // 城市 + Language *string `json:"language" validate:"omitempty,max=32" comment:"语言"` // 语言 } func GetProfile(c helper.Context) (res schema.Response) { diff --git a/internal/schema/request.go b/internal/schema/request.go index f7829bbd4..4fe0d5e20 100644 --- a/internal/schema/request.go +++ b/internal/schema/request.go @@ -3,6 +3,7 @@ package schema import ( "fmt" + "github.com/axetroy/go-server/internal/library/validator" "github.com/jinzhu/gorm" "regexp" "strings" @@ -11,10 +12,10 @@ import ( type Order string type Query struct { - Limit int `json:"limit" form:"limit"` - Page int `json:"page" form:"page"` - Sort string `json:"sort" form:"sort"` - Platform *string `json:"platform" form:"platform"` + Limit int `json:"limit" url:"limit" validate:"omitempty,number,gte=1" comment:"每页数量"` + Page int `json:"page" url:"page" validate:"omitempty,number,gte=0" comment:"页数"` + Sort string `json:"sort" url:"sort" validate:"omitempty,max=255" comment:"排序"` + Platform *string `json:"platform" url:"platform" validate:"omitempty,max=16" comment:"平台"` } type Sort struct { @@ -93,3 +94,7 @@ func (q *Query) Normalize() *Query { return q } + +func (q *Query) Validate() error { + return validator.ValidateStruct(q) +}