diff --git a/README.md b/README.md index e9fe8a17..ef65c73e 100644 --- a/README.md +++ b/README.md @@ -58,6 +58,7 @@ The CLI currently supports the following products: - [api](./docs/api.md) - [configure](./docs/configure.md) - [event](docs/event.md) +- [mock-api](docs/mock-api.md) - [token](docs/token.md) - [version](docs/version.md) diff --git a/internal/database/_schema.sql b/internal/database/_schema.sql index b10b2e0c..b79aeabc 100644 --- a/internal/database/_schema.sql +++ b/internal/database/_schema.sql @@ -204,7 +204,8 @@ create table drops_entitlements( benefit_id text not null, timestamp text not null, user_id text not null, - game_id text not null, + game_id text not null, + status text not null default 'CLAIMED', foreign key (user_id) references users(id), foreign key (game_id) references categories(id) ); @@ -289,3 +290,17 @@ create table clips ( foreign key (broadcaster_id) references users(id), foreign key (creator_id) references users(id) ); +create table stream_schedule( + id text not null primary key, + broadcaster_id text not null, + starttime text not null, + endtime text not null, + timezone text not null, + is_vacation boolean not null default false, + is_recurring boolean not null default false, + is_canceled boolean not null default false, + title text, + category_id text, + foreign key (broadcaster_id) references users(id), + foreign key (category_id) references categories(id) +); \ No newline at end of file diff --git a/internal/database/drops.go b/internal/database/drops.go index dffbd5ef..fab3c70a 100644 --- a/internal/database/drops.go +++ b/internal/database/drops.go @@ -10,6 +10,7 @@ type DropsEntitlement struct { BenefitID string `db:"benefit_id" json:"benefit_id"` GameID string `db:"game_id" json:"game_id"` Timestamp string `db:"timestamp" json:"timestamp"` + Status string `db:"status" json:"fulfillment_status"` } func (q *Query) GetDropsEntitlements(de DropsEntitlement) (*DBResponse, error) { @@ -51,3 +52,8 @@ func (q *Query) InsertDropsEntitlement(d DropsEntitlement) error { _, err := q.DB.NamedExec(stmt, d) return err } + +func (q *Query) UpdateDropsEntitlement(d DropsEntitlement) error { + _, err := q.DB.NamedExec(generateUpdateSQL("drops_entitlements", []string{"id"}, d), d) + return err +} diff --git a/internal/database/init.go b/internal/database/init.go index 04967c2d..174f040b 100644 --- a/internal/database/init.go +++ b/internal/database/init.go @@ -30,6 +30,10 @@ var migrateSQL = map[int]migrateMap{ SQL: `create table categories( id text not null primary key, category_name text not null ); create table users( id text not null primary key, user_login text not null, display_name text not null, email text not null, user_type text, broadcaster_type text, user_description text, created_at text not null, category_id text, modified_at text, stream_language text not null default 'en', title text not null default '', delay int not null default 0, foreign key (category_id) references categories(id) ); create table follows ( broadcaster_id text not null, user_id text not null, created_at text not null, primary key (broadcaster_id, user_id), foreign key (broadcaster_id) references users(id), foreign key (user_id) references users(id) ); create table blocks ( broadcaster_id text not null, user_id text not null, created_at text not null, primary key (broadcaster_id, user_id), foreign key (broadcaster_id) references users(id), foreign key (user_id) references users(id) ); create table bans ( broadcaster_id text not null, user_id text not null, created_at text not null, expires_at text, primary key (broadcaster_id, user_id), foreign key (broadcaster_id) references users(id), foreign key (user_id) references users(id) ); create table ban_events ( id text not null primary key, event_timestamp text not null, event_type text not null, event_version text not null default '1.0', broadcaster_id text not null, user_id text not null, expires_at text, foreign key (broadcaster_id) references users(id), foreign key (user_id) references users(id) ); create table moderators ( broadcaster_id text not null, user_id text not null, created_at text not null, primary key (broadcaster_id, user_id), foreign key (broadcaster_id) references users(id), foreign key (user_id) references users(id) ); create table moderator_actions ( id text not null primary key, event_timestamp text not null, event_type text not null, event_version text not null default '1.0', broadcaster_id text not null, user_id text not null, foreign key (broadcaster_id) references users(id), foreign key (user_id) references users(id) ); create table editors ( broadcaster_id text not null, user_id text not null, created_at text not null, primary key (broadcaster_id, user_id), foreign key (broadcaster_id) references users(id), foreign key (user_id) references users(id) ); create table channel_points_rewards( id text not null primary key, broadcaster_id text not null, reward_image text, background_color text, is_enabled boolean not null default false, cost int not null default 0, title text not null, reward_prompt text, is_user_input_required boolean default false, stream_max_enabled boolean default false, stream_max_count int default 0, stream_user_max_enabled boolean default false, stream_user_max_count int default 0, global_cooldown_enabled boolean default false, global_cooldown_seconds int default 0, is_paused boolean default false, is_in_stock boolean default true, should_redemptions_skip_queue boolean default false, redemptions_redeemed_current_stream int, cooldown_expires_at text, foreign key (broadcaster_id) references users(id) ); create table channel_points_redemptions( id text not null primary key, reward_id text not null, broadcaster_id text not null, user_id text not null, user_input text, redemption_status text not null, redeemed_at text, foreign key (reward_id) references channel_points_rewards(id), foreign key (broadcaster_id) references users(id), foreign key (user_id) references users(id) ); create table streams( id text not null primary key, broadcaster_id id text not null, stream_type text not null default 'live', viewer_count int not null, started_at text not null, is_mature boolean not null default false, foreign key (broadcaster_id) references users(id) ); create table tags( id text not null primary key, is_auto boolean not null default false, tag_name text not null ); create table stream_tags( user_id text not null, tag_id text not null, primary key(user_id, tag_id), foreign key(user_id) references users(id), foreign key(tag_id) references tags(id) ); create table teams( id text not null primary key, background_image_url text, banner text, created_at text not null, updated_at text, info text, thumbnail_url text, team_name text, team_display_name text ); create table team_members( team_id text not null, user_id text not null, primary key (team_id, user_id) foreign key (team_id) references teams(id), foreign key (user_id) references users(id) ); create table videos( id text not null primary key, stream_id text, broadcaster_id text not null, title text not null, video_description text not null, created_at text not null, published_at text, viewable text not null, view_count int not null default 0, duration text not null, video_language text not null default 'en', category_id text, type text default 'archive', foreign key (stream_id) references streams(id), foreign key (broadcaster_id) references users(id), foreign key (category_id) references categories(id) ); create table stream_markers( id text not null primary key, video_id text not null, position_seconds int not null, created_at text not null, description text not null, broadcaster_id text not null, foreign key (broadcaster_id) references users(id), foreign key (video_id) references videos(id) ); create table video_muted_segments ( video_id text not null, video_offset int not null, duration int not null, primary key (video_id, video_offset), foreign key (video_id) references videos(id) ); create table subscriptions ( broadcaster_id text not null, user_id text not null, is_gift boolean not null default false, gifter_id text, tier text not null default '1000', created_at text not null, primary key (broadcaster_id, user_id), foreign key (broadcaster_id) references users(id), foreign key (user_id) references users(id), foreign key (gifter_id) references users(id) ); create table drops_entitlements( id text not null primary key, benefit_id text not null, timestamp text not null, user_id text not null, game_id text not null, foreign key (user_id) references users(id), foreign key (game_id) references categories(id) ); create table clients ( id text not null primary key, secret text not null, is_extension boolean default false, name text not null ); create table authorizations ( id integer not null primary key AUTOINCREMENT, client_id text not null, user_id text, token text not null unique, expires_at text not null, scopes text, foreign key (client_id) references clients(id) ); create table polls ( id text not null primary key, broadcaster_id text not null, title text not null, bits_voting_enabled boolean default false, bits_per_vote int default 10, channel_points_voting_enabled boolean default false, channel_points_per_vote int default 10, status text not null, duration int not null, started_at text not null, ended_at text, foreign key (broadcaster_id) references users(id) ); create table poll_choices ( id text not null primary key, title text not null, votes int not null default 0, channel_points_votes int not null default 0, bits_votes int not null default 0, poll_id text not null, foreign key (poll_id) references polls(id) ); create table predictions ( id text not null primary key, broadcaster_id text not null, title text not null, winning_outcome_id text, prediction_window int, status text not null, created_at text not null, ended_at text, locked_at text, foreign key (broadcaster_id) references users(id) ); create table prediction_outcomes ( id text not null primary key, title text not null, users int not null default 0, channel_points int not null default 0, color text not null, prediction_id text not null, foreign key (prediction_id) references predictions(id) ); create table prediction_predictions ( prediction_id text not null, user_id text not null, amount int not null, outcome_id text not null, primary key(prediction_id, user_id), foreign key(user_id) references users(id), foreign key(prediction_id) references predictions(id), foreign key(outcome_id) references prediction_outcomes(id) ); create table clips ( id text not null primary key, broadcaster_id text not null, creator_id text not null, video_id text not null, game_id text not null, title text not null, view_count int default 0, created_at text not null, duration real not null, foreign key (broadcaster_id) references users(id), foreign key (creator_id) references users(id) ); `, Message: "Adding mock API tables.", }, + 3: { + SQL: `alter table drops_entitlements add column status text not null default 'CLAIMED'; create table stream_schedule( id text not null primary key, broadcaster_id text not null, starttime text not null, endtime text not null, timezone text not null, is_vacation boolean not null default false, is_recurring boolean not null default false, is_canceled boolean not null default false, title text, category_id text, foreign key(broadcaster_id) references users(id), foreign key (category_id) references categories(id));`, + Message: ``, + }, } func checkAndUpdate(db sqlx.DB) error { @@ -65,7 +69,7 @@ func checkAndUpdate(db sqlx.DB) error { } func initDatabase(db sqlx.DB) error { - createSQL := `create table events( id text not null primary key, event text not null, json text not null, from_user text not null, to_user text not null, transport text not null, timestamp text not null); create table categories( id text not null primary key, category_name text not null ); create table users( id text not null primary key, user_login text not null, display_name text not null, email text not null, user_type text, broadcaster_type text, user_description text, created_at text not null, category_id text, modified_at text, stream_language text not null default 'en', title text not null default '', delay int not null default 0, foreign key (category_id) references categories(id) ); create table follows ( broadcaster_id text not null, user_id text not null, created_at text not null, primary key (broadcaster_id, user_id), foreign key (broadcaster_id) references users(id), foreign key (user_id) references users(id) ); create table blocks ( broadcaster_id text not null, user_id text not null, created_at text not null, primary key (broadcaster_id, user_id), foreign key (broadcaster_id) references users(id), foreign key (user_id) references users(id) ); create table bans ( broadcaster_id text not null, user_id text not null, created_at text not null, expires_at text, primary key (broadcaster_id, user_id), foreign key (broadcaster_id) references users(id), foreign key (user_id) references users(id) ); create table ban_events ( id text not null primary key, event_timestamp text not null, event_type text not null, event_version text not null default '1.0', broadcaster_id text not null, user_id text not null, expires_at text, foreign key (broadcaster_id) references users(id), foreign key (user_id) references users(id) ); create table moderators ( broadcaster_id text not null, user_id text not null, created_at text not null, primary key (broadcaster_id, user_id), foreign key (broadcaster_id) references users(id), foreign key (user_id) references users(id) ); create table moderator_actions ( id text not null primary key, event_timestamp text not null, event_type text not null, event_version text not null default '1.0', broadcaster_id text not null, user_id text not null, foreign key (broadcaster_id) references users(id), foreign key (user_id) references users(id) ); create table editors ( broadcaster_id text not null, user_id text not null, created_at text not null, primary key (broadcaster_id, user_id), foreign key (broadcaster_id) references users(id), foreign key (user_id) references users(id) ); create table channel_points_rewards( id text not null primary key, broadcaster_id text not null, reward_image text, background_color text, is_enabled boolean not null default false, cost int not null default 0, title text not null, reward_prompt text, is_user_input_required boolean default false, stream_max_enabled boolean default false, stream_max_count int default 0, stream_user_max_enabled boolean default false, stream_user_max_count int default 0, global_cooldown_enabled boolean default false, global_cooldown_seconds int default 0, is_paused boolean default false, is_in_stock boolean default true, should_redemptions_skip_queue boolean default false, redemptions_redeemed_current_stream int, cooldown_expires_at text, foreign key (broadcaster_id) references users(id) ); create table channel_points_redemptions( id text not null primary key, reward_id text not null, broadcaster_id text not null, user_id text not null, user_input text, redemption_status text not null, redeemed_at text, foreign key (reward_id) references channel_points_rewards(id), foreign key (broadcaster_id) references users(id), foreign key (user_id) references users(id) ); create table streams( id text not null primary key, broadcaster_id id text not null, stream_type text not null default 'live', viewer_count int not null, started_at text not null, is_mature boolean not null default false, foreign key (broadcaster_id) references users(id) ); create table tags( id text not null primary key, is_auto boolean not null default false, tag_name text not null ); create table stream_tags( user_id text not null, tag_id text not null, primary key(user_id, tag_id), foreign key(user_id) references users(id), foreign key(tag_id) references tags(id) ); create table teams( id text not null primary key, background_image_url text, banner text, created_at text not null, updated_at text, info text, thumbnail_url text, team_name text, team_display_name text ); create table team_members( team_id text not null, user_id text not null, primary key (team_id, user_id) foreign key (team_id) references teams(id), foreign key (user_id) references users(id) ); create table videos( id text not null primary key, stream_id text, broadcaster_id text not null, title text not null, video_description text not null, created_at text not null, published_at text, viewable text not null, view_count int not null default 0, duration text not null, video_language text not null default 'en', category_id text, type text default 'archive', foreign key (stream_id) references streams(id), foreign key (broadcaster_id) references users(id), foreign key (category_id) references categories(id) ); create table stream_markers( id text not null primary key, video_id text not null, position_seconds int not null, created_at text not null, description text not null, broadcaster_id text not null, foreign key (broadcaster_id) references users(id), foreign key (video_id) references videos(id) ); create table video_muted_segments ( video_id text not null, video_offset int not null, duration int not null, primary key (video_id, video_offset), foreign key (video_id) references videos(id) ); create table subscriptions ( broadcaster_id text not null, user_id text not null, is_gift boolean not null default false, gifter_id text, tier text not null default '1000', created_at text not null, primary key (broadcaster_id, user_id), foreign key (broadcaster_id) references users(id), foreign key (user_id) references users(id), foreign key (gifter_id) references users(id) ); create table drops_entitlements( id text not null primary key, benefit_id text not null, timestamp text not null, user_id text not null, game_id text not null, foreign key (user_id) references users(id), foreign key (game_id) references categories(id) ); create table clients ( id text not null primary key, secret text not null, is_extension boolean default false, name text not null ); create table authorizations ( id integer not null primary key AUTOINCREMENT, client_id text not null, user_id text, token text not null unique, expires_at text not null, scopes text, foreign key (client_id) references clients(id) ); create table polls ( id text not null primary key, broadcaster_id text not null, title text not null, bits_voting_enabled boolean default false, bits_per_vote int default 10, channel_points_voting_enabled boolean default false, channel_points_per_vote int default 10, status text not null, duration int not null, started_at text not null, ended_at text, foreign key (broadcaster_id) references users(id) ); create table poll_choices ( id text not null primary key, title text not null, votes int not null default 0, channel_points_votes int not null default 0, bits_votes int not null default 0, poll_id text not null, foreign key (poll_id) references polls(id) ); create table predictions ( id text not null primary key, broadcaster_id text not null, title text not null, winning_outcome_id text, prediction_window int, status text not null, created_at text not null, ended_at text, locked_at text, foreign key (broadcaster_id) references users(id) ); create table prediction_outcomes ( id text not null primary key, title text not null, users int not null default 0, channel_points int not null default 0, color text not null, prediction_id text not null, foreign key (prediction_id) references predictions(id) ); create table prediction_predictions ( prediction_id text not null, user_id text not null, amount int not null, outcome_id text not null, primary key(prediction_id, user_id), foreign key(user_id) references users(id), foreign key(prediction_id) references predictions(id), foreign key(outcome_id) references prediction_outcomes(id) ); create table clips ( id text not null primary key, broadcaster_id text not null, creator_id text not null, video_id text not null, game_id text not null, title text not null, view_count int default 0, created_at text not null, duration real not null, foreign key (broadcaster_id) references users(id), foreign key (creator_id) references users(id) ); ` + createSQL := `create table events( id text not null primary key, event text not null, json text not null, from_user text not null, to_user text not null, transport text not null, timestamp text not null); create table categories( id text not null primary key, category_name text not null ); create table users( id text not null primary key, user_login text not null, display_name text not null, email text not null, user_type text, broadcaster_type text, user_description text, created_at text not null, category_id text, modified_at text, stream_language text not null default 'en', title text not null default '', delay int not null default 0, foreign key (category_id) references categories(id) ); create table follows ( broadcaster_id text not null, user_id text not null, created_at text not null, primary key (broadcaster_id, user_id), foreign key (broadcaster_id) references users(id), foreign key (user_id) references users(id) ); create table blocks ( broadcaster_id text not null, user_id text not null, created_at text not null, primary key (broadcaster_id, user_id), foreign key (broadcaster_id) references users(id), foreign key (user_id) references users(id) ); create table bans ( broadcaster_id text not null, user_id text not null, created_at text not null, expires_at text, primary key (broadcaster_id, user_id), foreign key (broadcaster_id) references users(id), foreign key (user_id) references users(id) ); create table ban_events ( id text not null primary key, event_timestamp text not null, event_type text not null, event_version text not null default '1.0', broadcaster_id text not null, user_id text not null, expires_at text, foreign key (broadcaster_id) references users(id), foreign key (user_id) references users(id) ); create table moderators ( broadcaster_id text not null, user_id text not null, created_at text not null, primary key (broadcaster_id, user_id), foreign key (broadcaster_id) references users(id), foreign key (user_id) references users(id) ); create table moderator_actions ( id text not null primary key, event_timestamp text not null, event_type text not null, event_version text not null default '1.0', broadcaster_id text not null, user_id text not null, foreign key (broadcaster_id) references users(id), foreign key (user_id) references users(id) ); create table editors ( broadcaster_id text not null, user_id text not null, created_at text not null, primary key (broadcaster_id, user_id), foreign key (broadcaster_id) references users(id), foreign key (user_id) references users(id) ); create table channel_points_rewards( id text not null primary key, broadcaster_id text not null, reward_image text, background_color text, is_enabled boolean not null default false, cost int not null default 0, title text not null, reward_prompt text, is_user_input_required boolean default false, stream_max_enabled boolean default false, stream_max_count int default 0, stream_user_max_enabled boolean default false, stream_user_max_count int default 0, global_cooldown_enabled boolean default false, global_cooldown_seconds int default 0, is_paused boolean default false, is_in_stock boolean default true, should_redemptions_skip_queue boolean default false, redemptions_redeemed_current_stream int, cooldown_expires_at text, foreign key (broadcaster_id) references users(id) ); create table channel_points_redemptions( id text not null primary key, reward_id text not null, broadcaster_id text not null, user_id text not null, user_input text, redemption_status text not null, redeemed_at text, foreign key (reward_id) references channel_points_rewards(id), foreign key (broadcaster_id) references users(id), foreign key (user_id) references users(id) ); create table streams( id text not null primary key, broadcaster_id id text not null, stream_type text not null default 'live', viewer_count int not null, started_at text not null, is_mature boolean not null default false, foreign key (broadcaster_id) references users(id) ); create table tags( id text not null primary key, is_auto boolean not null default false, tag_name text not null ); create table stream_tags( user_id text not null, tag_id text not null, primary key(user_id, tag_id), foreign key(user_id) references users(id), foreign key(tag_id) references tags(id) ); create table teams( id text not null primary key, background_image_url text, banner text, created_at text not null, updated_at text, info text, thumbnail_url text, team_name text, team_display_name text ); create table team_members( team_id text not null, user_id text not null, primary key (team_id, user_id) foreign key (team_id) references teams(id), foreign key (user_id) references users(id) ); create table videos( id text not null primary key, stream_id text, broadcaster_id text not null, title text not null, video_description text not null, created_at text not null, published_at text, viewable text not null, view_count int not null default 0, duration text not null, video_language text not null default 'en', category_id text, type text default 'archive', foreign key (stream_id) references streams(id), foreign key (broadcaster_id) references users(id), foreign key (category_id) references categories(id) ); create table stream_markers( id text not null primary key, video_id text not null, position_seconds int not null, created_at text not null, description text not null, broadcaster_id text not null, foreign key (broadcaster_id) references users(id), foreign key (video_id) references videos(id) ); create table video_muted_segments ( video_id text not null, video_offset int not null, duration int not null, primary key (video_id, video_offset), foreign key (video_id) references videos(id) ); create table subscriptions ( broadcaster_id text not null, user_id text not null, is_gift boolean not null default false, gifter_id text, tier text not null default '1000', created_at text not null, primary key (broadcaster_id, user_id), foreign key (broadcaster_id) references users(id), foreign key (user_id) references users(id), foreign key (gifter_id) references users(id) ); create table drops_entitlements( id text not null primary key, benefit_id text not null, timestamp text not null, user_id text not null, game_id text not null, status text not null default 'CLAIMED', foreign key (user_id) references users(id), foreign key (game_id) references categories(id) ); create table clients ( id text not null primary key, secret text not null, is_extension boolean default false, name text not null ); create table authorizations ( id integer not null primary key AUTOINCREMENT, client_id text not null, user_id text, token text not null unique, expires_at text not null, scopes text, foreign key (client_id) references clients(id) ); create table polls ( id text not null primary key, broadcaster_id text not null, title text not null, bits_voting_enabled boolean default false, bits_per_vote int default 10, channel_points_voting_enabled boolean default false, channel_points_per_vote int default 10, status text not null, duration int not null, started_at text not null, ended_at text, foreign key (broadcaster_id) references users(id) ); create table poll_choices ( id text not null primary key, title text not null, votes int not null default 0, channel_points_votes int not null default 0, bits_votes int not null default 0, poll_id text not null, foreign key (poll_id) references polls(id) ); create table predictions ( id text not null primary key, broadcaster_id text not null, title text not null, winning_outcome_id text, prediction_window int, status text not null, created_at text not null, ended_at text, locked_at text, foreign key (broadcaster_id) references users(id) ); create table prediction_outcomes ( id text not null primary key, title text not null, users int not null default 0, channel_points int not null default 0, color text not null, prediction_id text not null, foreign key (prediction_id) references predictions(id) ); create table prediction_predictions ( prediction_id text not null, user_id text not null, amount int not null, outcome_id text not null, primary key(prediction_id, user_id), foreign key(user_id) references users(id), foreign key(prediction_id) references predictions(id), foreign key(outcome_id) references prediction_outcomes(id) ); create table clips ( id text not null primary key, broadcaster_id text not null, creator_id text not null, video_id text not null, game_id text not null, title text not null, view_count int default 0, created_at text not null, duration real not null, foreign key (broadcaster_id) references users(id), foreign key (creator_id) references users(id) ); create table stream_schedule( id text not null primary key, broadcaster_id text not null, starttime text not null, endtime text not null, timezone text not null, is_vacation boolean not null default false, is_recurring boolean not null default false, is_canceled boolean not null default false, title text, category_id text, foreign key(broadcaster_id) references users(id), foreign key (category_id) references categories(id));` for i := 1; i <= 5; i++ { tx := db.MustBegin() tx.Exec(createSQL) diff --git a/internal/database/schedule.go b/internal/database/schedule.go new file mode 100644 index 00000000..10667838 --- /dev/null +++ b/internal/database/schedule.go @@ -0,0 +1,131 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package database + +import ( + "database/sql" + "errors" + "time" +) + +type Schedule struct { + Segments []ScheduleSegment `json:"segments"` + UserID string `db:"broadcaster_id" json:"broadcaster_id"` + UserLogin string `db:"broadcaster_login" json:"broadcaster_login" dbi:"false"` + UserName string `db:"broadcaster_name" json:"broadcaster_name" dbi:"false"` + Vacation *ScheduleVacation `json:"vacation"` +} + +type ScheduleSegment struct { + ID string `db:"id" json:"id" dbs:"s.id"` + Title string `db:"title" json:"title"` + StartTime string `db:"starttime" json:"start_time"` + EndTime string `db:"endtime" json:"end_time"` + IsRecurring bool `db:"is_recurring" json:"is_recurring"` + IsVacation bool `db:"is_vacation" json:"-"` + Category *SegmentCategory `json:"category"` + UserID string `db:"broadcaster_id" json:"-"` + Timezone string `db:"timezone" json:"timezone"` + CategoryID *string `db:"category_id" json:"-"` + CategoryName *string `db:"category_name" dbi:"false" json:"-"` + IsCanceled *bool `db:"is_canceled" json:"-"` + CanceledUntil *string `json:"canceled_until"` +} +type ScheduleVacation struct { + ID string `db:"id" json:"-"` + StartTime string `db:"starttime" json:"start_time"` + EndTime string `db:"endtime" json:"end_time"` +} + +type SegmentCategory struct { + ID *string `db:"category_id" json:"id" dbs:"category_id"` + CategoryName *string `db:"category_name" json:"name" dbi:"false"` +} + +func (q *Query) GetSchedule(p ScheduleSegment, startTime time.Time) (*DBResponse, error) { + r := Schedule{} + + u, err := q.GetUser(User{ID: p.UserID}) + if err != nil { + return nil, err + } + r.UserID = u.ID + r.UserLogin = u.UserLogin + r.UserName = u.DisplayName + + sql := generateSQL("select s.*, c.category_name from stream_schedule s left join categories c on s.category_id = c.id", p, SEP_AND) + p.StartTime = startTime.Format(time.RFC3339) + sql += " and datetime(starttime) >= datetime(:starttime) " + q.SQL + rows, err := q.DB.NamedQuery(sql, p) + if err != nil { + return nil, err + } + + for rows.Next() { + var s ScheduleSegment + err := rows.StructScan(&s) + if err != nil { + return nil, err + } + if s.CategoryID != nil { + s.Category = &SegmentCategory{ + ID: s.CategoryID, + CategoryName: s.CategoryName, + } + } + if s.IsVacation { + r.Vacation = &ScheduleVacation{ + StartTime: s.StartTime, + EndTime: s.EndTime, + } + } else { + r.Segments = append(r.Segments, s) + } + } + v, err := q.GetVacations(ScheduleSegment{UserID: p.UserID}) + if err != nil { + return nil, err + } + r.Vacation = &v + dbr := DBResponse{ + Data: r, + Limit: q.Limit, + Total: len(r.Segments), + } + + if len(r.Segments) != q.Limit { + q.PaginationCursor = "" + } + + dbr.Cursor = q.PaginationCursor + + return &dbr, err +} + +func (q *Query) InsertSchedule(p ScheduleSegment) error { + tx := q.DB.MustBegin() + _, err := tx.NamedExec(generateInsertSQL("stream_schedule", "id", p, false), p) + if err != nil { + return err + } + return tx.Commit() +} + +func (q *Query) DeleteSegment(id string, broadcasterID string) error { + _, err := q.DB.Exec("delete from stream_schedule where id=$1 and broadcaster_id=$2", id, broadcasterID) + return err +} + +func (q *Query) UpdateSegment(p ScheduleSegment) error { + _, err := q.DB.NamedExec(generateUpdateSQL("stream_schedule", []string{"id"}, p), p) + return err +} + +func (q *Query) GetVacations(p ScheduleSegment) (ScheduleVacation, error) { + v := ScheduleVacation{} + err := q.DB.Get(&v, "select id,starttime,endtime from stream_schedule where is_vacation=true and datetime(endtime) > datetime('now') and broadcaster_id= $1 limit 1", p.UserID) + if errors.As(err, &sql.ErrNoRows) { + return v, nil + } + return v, err +} diff --git a/internal/events/trigger/trigger_event.go b/internal/events/trigger/trigger_event.go index 20271cf7..93ce195c 100644 --- a/internal/events/trigger/trigger_event.go +++ b/internal/events/trigger/trigger_event.go @@ -4,6 +4,7 @@ package trigger import ( "fmt" + "log" "time" "github.com/twitchdev/twitch-cli/internal/database" @@ -111,7 +112,7 @@ func Fire(p TriggerParameters) (string, error) { } defer resp.Body.Close() - println(fmt.Sprintf(`[%v] Request Sent`, resp.StatusCode)) + log.Println(fmt.Sprintf(`[%v] Request Sent`, resp.StatusCode)) } return string(resp.JSON), nil diff --git a/internal/mock_api/authentication/authentication.go b/internal/mock_api/authentication/authentication.go index 6fb7ae3c..0c183203 100644 --- a/internal/mock_api/authentication/authentication.go +++ b/internal/mock_api/authentication/authentication.go @@ -24,6 +24,7 @@ type UserAuthentication struct { func AuthenticationMiddleware(next mock_api.MockEndpoint) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") db := r.Context().Value("db").(database.CLIDatabase) // skip auth check for unsupported methods diff --git a/internal/mock_api/endpoints/chat/channel_emotes.go b/internal/mock_api/endpoints/chat/channel_emotes.go new file mode 100644 index 00000000..1decc21b --- /dev/null +++ b/internal/mock_api/endpoints/chat/channel_emotes.go @@ -0,0 +1,96 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package chat + +import ( + "encoding/json" + "fmt" + "net/http" + + "github.com/twitchdev/twitch-cli/internal/database" + "github.com/twitchdev/twitch-cli/internal/mock_api/mock_errors" + "github.com/twitchdev/twitch-cli/internal/models" + "github.com/twitchdev/twitch-cli/internal/util" +) + +var channelEmotesMethodsSupported = map[string]bool{ + http.MethodGet: true, + http.MethodPost: false, + http.MethodDelete: false, + http.MethodPatch: false, + http.MethodPut: false, +} + +var channelEmotesScopesByMethod = map[string][]string{ + http.MethodGet: {}, + http.MethodPost: {}, + http.MethodDelete: {}, + http.MethodPatch: {}, + http.MethodPut: {}, +} + +type ChannelEmotes struct{} + +func (e ChannelEmotes) Path() string { return "/chat/emotes/channel" } + +func (e ChannelEmotes) GetRequiredScopes(method string) []string { + return channelEmotesScopesByMethod[method] +} + +func (e ChannelEmotes) ValidMethod(method string) bool { + return channelEmotesMethodsSupported[method] +} + +func (e ChannelEmotes) ServeHTTP(w http.ResponseWriter, r *http.Request) { + db = r.Context().Value("db").(database.CLIDatabase) + + switch r.Method { + case http.MethodGet: + getChannelEmotes(w, r) + break + default: + w.WriteHeader(http.StatusMethodNotAllowed) + } +} +func getChannelEmotes(w http.ResponseWriter, r *http.Request) { + emotes := []EmotesResponse{} + broadcaster := r.URL.Query().Get("broadcaster_id") + if broadcaster == "" { + mock_errors.WriteBadRequest(w, "Missing required parameter broadcaster_id") + return + } + + setID := fmt.Sprint(util.RandomInt(10 * 1000)) + ownerID := util.RandomUserID() + for _, v := range defaultEmoteTypes { + emoteType := v + for i := 0; i < 5; i++ { + id := util.RandomInt(10 * 1000) + name := util.RandomGUID() + er := EmotesResponse{ + ID: fmt.Sprint(id), + Name: name, + Images: EmotesImages{ + ImageURL1X: fmt.Sprintf("https://static-cdn.jtvnw.net/emoticons/v1/%v/1.0", id), + ImageURL2X: fmt.Sprintf("https://static-cdn.jtvnw.net/emoticons/v1/%v/2.0", id), + ImageURL4X: fmt.Sprintf("https://static-cdn.jtvnw.net/emoticons/v1/%v/4.0", id), + }, + EmoteType: &emoteType, + EmoteSetID: &setID, + OwnerID: &ownerID, + } + if emoteType == "subscription" { + thousand := "1000" + er.Tier = &thousand + } else { + es := "" + er.Tier = &es + } + + emotes = append(emotes, er) + } + } + + bytes, _ := json.Marshal(models.APIResponse{Data: emotes}) + w.Write(bytes) +} diff --git a/internal/mock_api/endpoints/chat/chat_test.go b/internal/mock_api/endpoints/chat/chat_test.go index 5aa23728..d5284d50 100644 --- a/internal/mock_api/endpoints/chat/chat_test.go +++ b/internal/mock_api/endpoints/chat/chat_test.go @@ -41,3 +41,54 @@ func TestChannelBadges(t *testing.T) { a.Nil(err) a.Equal(200, resp.StatusCode) } + +func TestGlobalEmotes(t *testing.T) { + a := test_setup.SetupTestEnv(t) + ts := test_server.SetupTestServer(GlobalEmotes{}) + + // get + req, _ := http.NewRequest(http.MethodGet, ts.URL+GlobalEmotes{}.Path(), nil) + q := req.URL.Query() + req.URL.RawQuery = q.Encode() + resp, err := http.DefaultClient.Do(req) + a.Nil(err) + a.Equal(200, resp.StatusCode) +} + +func TestChannelEmotes(t *testing.T) { + a := test_setup.SetupTestEnv(t) + ts := test_server.SetupTestServer(ChannelEmotes{}) + + // get + req, _ := http.NewRequest(http.MethodGet, ts.URL+ChannelEmotes{}.Path(), nil) + q := req.URL.Query() + req.URL.RawQuery = q.Encode() + resp, err := http.DefaultClient.Do(req) + a.Nil(err) + a.Equal(400, resp.StatusCode) + + q.Set("broadcaster_id", "1") + req.URL.RawQuery = q.Encode() + resp, err = http.DefaultClient.Do(req) + a.Nil(err) + a.Equal(200, resp.StatusCode) +} + +func TestEmoteSets(t *testing.T) { + a := test_setup.SetupTestEnv(t) + ts := test_server.SetupTestServer(EmoteSets{}) + + // get + req, _ := http.NewRequest(http.MethodGet, ts.URL+EmoteSets{}.Path(), nil) + q := req.URL.Query() + req.URL.RawQuery = q.Encode() + resp, err := http.DefaultClient.Do(req) + a.Nil(err) + a.Equal(400, resp.StatusCode) + + q.Set("emote_set_id", "1") + req.URL.RawQuery = q.Encode() + resp, err = http.DefaultClient.Do(req) + a.Nil(err) + a.Equal(200, resp.StatusCode) +} diff --git a/internal/mock_api/endpoints/chat/emote_set.go b/internal/mock_api/endpoints/chat/emote_set.go new file mode 100644 index 00000000..9d0a32dd --- /dev/null +++ b/internal/mock_api/endpoints/chat/emote_set.go @@ -0,0 +1,93 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package chat + +import ( + "encoding/json" + "fmt" + "net/http" + + "github.com/twitchdev/twitch-cli/internal/database" + "github.com/twitchdev/twitch-cli/internal/mock_api/mock_errors" + "github.com/twitchdev/twitch-cli/internal/models" + "github.com/twitchdev/twitch-cli/internal/util" +) + +var emoteSetsMethodsSupported = map[string]bool{ + http.MethodGet: true, + http.MethodPost: false, + http.MethodDelete: false, + http.MethodPatch: false, + http.MethodPut: false, +} + +var emoteSetsScopesByMethod = map[string][]string{ + http.MethodGet: {}, + http.MethodPost: {}, + http.MethodDelete: {}, + http.MethodPatch: {}, + http.MethodPut: {}, +} + +type EmoteSets struct{} + +func (e EmoteSets) Path() string { return "/chat/emotes/set" } + +func (e EmoteSets) GetRequiredScopes(method string) []string { + return emoteSetsScopesByMethod[method] +} + +func (e EmoteSets) ValidMethod(method string) bool { + return emoteSetsMethodsSupported[method] +} + +func (e EmoteSets) ServeHTTP(w http.ResponseWriter, r *http.Request) { + db = r.Context().Value("db").(database.CLIDatabase) + + switch r.Method { + case http.MethodGet: + getEmoteSets(w, r) + break + default: + w.WriteHeader(http.StatusMethodNotAllowed) + } +} +func getEmoteSets(w http.ResponseWriter, r *http.Request) { + emotes := []EmotesResponse{} + setID := r.URL.Query().Get("emote_set_id") + if setID == "" { + mock_errors.WriteBadRequest(w, "Missing required parameter emote_set_id") + return + } + + for _, v := range defaultEmoteTypes { + emoteType := v + for i := 0; i < 5; i++ { + id := util.RandomInt(10 * 1000) + name := util.RandomGUID() + er := EmotesResponse{ + ID: fmt.Sprint(id), + Name: name, + Images: EmotesImages{ + ImageURL1X: fmt.Sprintf("https://static-cdn.jtvnw.net/emoticons/v1/%v/1.0", id), + ImageURL2X: fmt.Sprintf("https://static-cdn.jtvnw.net/emoticons/v1/%v/2.0", id), + ImageURL4X: fmt.Sprintf("https://static-cdn.jtvnw.net/emoticons/v1/%v/4.0", id), + }, + EmoteType: &emoteType, + EmoteSetID: &setID, + } + if emoteType == "subscription" { + thousand := "1000" + er.Tier = &thousand + } else { + es := "" + er.Tier = &es + } + + emotes = append(emotes, er) + } + } + + bytes, _ := json.Marshal(models.APIResponse{Data: emotes}) + w.Write(bytes) +} diff --git a/internal/mock_api/endpoints/chat/global_emotes.go b/internal/mock_api/endpoints/chat/global_emotes.go new file mode 100644 index 00000000..2dceedad --- /dev/null +++ b/internal/mock_api/endpoints/chat/global_emotes.go @@ -0,0 +1,74 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package chat + +import ( + "encoding/json" + "fmt" + "net/http" + + "github.com/twitchdev/twitch-cli/internal/database" + "github.com/twitchdev/twitch-cli/internal/models" + "github.com/twitchdev/twitch-cli/internal/util" +) + +var globalEmotesMethodsSupported = map[string]bool{ + http.MethodGet: true, + http.MethodPost: false, + http.MethodDelete: false, + http.MethodPatch: false, + http.MethodPut: false, +} + +var globalEmotesScopesByMethod = map[string][]string{ + http.MethodGet: {}, + http.MethodPost: {}, + http.MethodDelete: {}, + http.MethodPatch: {}, + http.MethodPut: {}, +} + +type GlobalEmotes struct{} + +func (e GlobalEmotes) Path() string { return "/chat/emotes/global" } + +func (e GlobalEmotes) GetRequiredScopes(method string) []string { + return globalEmotesScopesByMethod[method] +} + +func (e GlobalEmotes) ValidMethod(method string) bool { + return globalEmotesMethodsSupported[method] +} + +func (e GlobalEmotes) ServeHTTP(w http.ResponseWriter, r *http.Request) { + db = r.Context().Value("db").(database.CLIDatabase) + + switch r.Method { + case http.MethodGet: + getGlobalEmotes(w, r) + break + default: + w.WriteHeader(http.StatusMethodNotAllowed) + } +} + +func getGlobalEmotes(w http.ResponseWriter, r *http.Request) { + emotes := []EmotesResponse{} + + for i := 0; i < 100; i++ { + id := util.RandomInt(10 * 1000) + name := util.RandomGUID() + emotes = append(emotes, EmotesResponse{ + ID: fmt.Sprintf("%v", id), + Name: name, + Images: EmotesImages{ + ImageURL1X: fmt.Sprintf("https://static-cdn.jtvnw.net/emoticons/v1/%v/1.0", id), + ImageURL2X: fmt.Sprintf("https://static-cdn.jtvnw.net/emoticons/v1/%v/2.0", id), + ImageURL4X: fmt.Sprintf("https://static-cdn.jtvnw.net/emoticons/v1/%v/4.0", id), + }, + }) + } + + bytes, _ := json.Marshal(models.APIResponse{Data: emotes}) + w.Write(bytes) +} diff --git a/internal/mock_api/endpoints/chat/shared.go b/internal/mock_api/endpoints/chat/shared.go index 0552b074..c81b6cbd 100644 --- a/internal/mock_api/endpoints/chat/shared.go +++ b/internal/mock_api/endpoints/chat/shared.go @@ -5,6 +5,7 @@ package chat import "github.com/twitchdev/twitch-cli/internal/database" var db database.CLIDatabase +var defaultEmoteTypes = []string{"subscription", "bitstier", "follower"} type BadgesResponse struct { SetID string `json:"set_id"` @@ -17,3 +18,19 @@ type BadgesVersion struct { ImageURL2X string `json:"image_url_2x"` ImageURL4X string `json:"image_url_4x"` } + +type EmotesResponse struct { + ID string `json:"id"` + Name string `json:"name"` + Images EmotesImages `json:"images"` + Tier *string `json:"tier,omitempty"` + EmoteType *string `json:"emote_type,omitempty"` + EmoteSetID *string `json:"emote_set_id,omitempty"` + OwnerID *string `json:"owner_id,omitempty"` +} + +type EmotesImages struct { + ImageURL1X string `json:"url_1x"` + ImageURL2X string `json:"url_2x"` + ImageURL4X string `json:"url_4x"` +} diff --git a/internal/mock_api/endpoints/clips/clips.go b/internal/mock_api/endpoints/clips/clips.go index e5c51071..9af8d60e 100644 --- a/internal/mock_api/endpoints/clips/clips.go +++ b/internal/mock_api/endpoints/clips/clips.go @@ -87,8 +87,7 @@ func getClips(w http.ResponseWriter, r *http.Request) { dbr, err := db.NewQuery(r, 100).GetClips(database.Clip{ID: id, BroadcasterID: broadcasterID, GameID: gameID}, startedAt, endedAt) if err != nil { - println(err.Error()) - mock_errors.WriteServerError(w, "Error fetching clips") + mock_errors.WriteServerError(w, err.Error()) return } @@ -150,8 +149,7 @@ func postClips(w http.ResponseWriter, r *http.Request) { err = db.NewQuery(r, 100).InsertClip(clip) if err != nil { - println(err.Error()) - mock_errors.WriteServerError(w, "Error creating clip for user") + mock_errors.WriteServerError(w, err.Error()) return } diff --git a/internal/mock_api/endpoints/drops/drops_test.go b/internal/mock_api/endpoints/drops/drops_test.go index aff4805c..69d3edfe 100644 --- a/internal/mock_api/endpoints/drops/drops_test.go +++ b/internal/mock_api/endpoints/drops/drops_test.go @@ -3,13 +3,48 @@ package drops import ( + "bytes" + "encoding/json" + "log" "net/http" + "os" "testing" + "time" + "github.com/twitchdev/twitch-cli/internal/database" + "github.com/twitchdev/twitch-cli/internal/util" "github.com/twitchdev/twitch-cli/test_setup" "github.com/twitchdev/twitch-cli/test_setup/test_server" ) +var entitlement database.DropsEntitlement + +func TestMain(m *testing.M) { + test_setup.SetupTestEnv(&testing.T{}) + + db, err := database.NewConnection() + if err != nil { + log.Fatal(err) + } + e := database.DropsEntitlement{ + ID: util.RandomGUID(), + UserID: "1", + BenefitID: "1234", + GameID: "1", + Timestamp: util.GetTimestamp().Format(time.RFC3339), + Status: "CLAIMED", + } + + err = db.NewQuery(nil, 100).InsertDropsEntitlement(e) + if err != nil { + log.Fatal(err) + } + entitlement = e + + db.DB.Close() + + os.Exit(m.Run()) +} func TestDropsEntitlements(t *testing.T) { a := test_setup.SetupTestEnv(t) ts := test_server.SetupTestServer(DropsEntitlements{}) @@ -21,4 +56,33 @@ func TestDropsEntitlements(t *testing.T) { resp, err := http.DefaultClient.Do(req) a.Nil(err) a.Equal(200, resp.StatusCode) + + // patch + // patch tests + body := PatchEntitlementsBody{ + FulfillmentStatus: "FULFILLED", + EntitlementIDs: []string{ + entitlement.ID, + "potato", + }, + } + + b, _ := json.Marshal(body) + req, _ = http.NewRequest(http.MethodPatch, ts.URL+DropsEntitlements{}.Path(), bytes.NewBuffer(b)) + q = req.URL.Query() + req.URL.RawQuery = q.Encode() + resp, err = http.DefaultClient.Do(req) + a.Nil(err) + a.NotNil(resp) + a.Equal(200, resp.StatusCode) + + body.FulfillmentStatus = "potato" + b, _ = json.Marshal(body) + req, _ = http.NewRequest(http.MethodPatch, ts.URL+DropsEntitlements{}.Path(), bytes.NewBuffer(b)) + q = req.URL.Query() + req.URL.RawQuery = q.Encode() + resp, err = http.DefaultClient.Do(req) + a.Nil(err) + a.NotNil(resp) + a.Equal(400, resp.StatusCode) } diff --git a/internal/mock_api/endpoints/drops/entitlements.go b/internal/mock_api/endpoints/drops/entitlements.go index cdb9b619..6fdcf79a 100644 --- a/internal/mock_api/endpoints/drops/entitlements.go +++ b/internal/mock_api/endpoints/drops/entitlements.go @@ -16,7 +16,7 @@ var dropsEntitlementsMethodsSupported = map[string]bool{ http.MethodGet: true, http.MethodPost: false, http.MethodDelete: false, - http.MethodPatch: false, + http.MethodPatch: true, http.MethodPut: false, } @@ -30,6 +30,16 @@ var dropsEntitlementsScopesByMethod = map[string][]string{ type DropsEntitlements struct{} +type PatchEntitlementsBody struct { + FulfillmentStatus string `json:"fulfillment_status"` + EntitlementIDs []string `json:"entitlement_ids"` +} + +type PatchEntitlementsResponse struct { + Status string `json:"status"` + IDs []string `json:"ids"` +} + func (e DropsEntitlements) Path() string { return "/entitlements/drops" } func (e DropsEntitlements) GetRequiredScopes(method string) []string { @@ -46,7 +56,8 @@ func (e DropsEntitlements) ServeHTTP(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: getEntitlements(w, r) - break + case http.MethodPatch: + patchEntitlements(w, r) default: w.WriteHeader(http.StatusMethodNotAllowed) } @@ -86,3 +97,69 @@ func getEntitlements(w http.ResponseWriter, r *http.Request) { bytes, err := json.Marshal(apiResponse) w.Write(bytes) } + +func patchEntitlements(w http.ResponseWriter, r *http.Request) { + userCtx := r.Context().Value("auth").(authentication.UserAuthentication) + + var body PatchEntitlementsBody + err := json.NewDecoder(r.Body).Decode(&body) + if err != nil { + mock_errors.WriteBadRequest(w, "Invalid body") + return + } + + if body.FulfillmentStatus != "CLAIMED" && body.FulfillmentStatus != "FULFILLED" { + mock_errors.WriteBadRequest(w, "fulfillment_status must be one of CLAIMED or FULFILLED") + return + } + + if len(body.EntitlementIDs) == 0 || len(body.EntitlementIDs) > 100 { + mock_errors.WriteBadRequest(w, "entitlement_ids must be at least 1 and at most 100") + return + } + s := PatchEntitlementsResponse{ + Status: "SUCCESS", + } + ua := PatchEntitlementsResponse{Status: "UNAUTHORIZED"} + fail := PatchEntitlementsResponse{Status: "UPDATE_FAILED"} + notFound := PatchEntitlementsResponse{Status: "NOT_FOUND"} + for _, e := range body.EntitlementIDs { + dbr, err := db.NewQuery(nil, 100).GetDropsEntitlements(database.DropsEntitlement{ID: e}) + if err != nil { + fail.IDs = append(fail.IDs, e) + continue + } + entitlement := dbr.Data.([]database.DropsEntitlement) + if len(entitlement) == 0 { + notFound.IDs = append(notFound.IDs, e) + continue + } + + if userCtx.UserID != "" && userCtx.UserID != entitlement[0].UserID { + ua.IDs = append(ua.IDs, e) + continue + } + + err = db.NewQuery(nil, 100).UpdateDropsEntitlement(database.DropsEntitlement{ID: e, UserID: entitlement[0].UserID, Status: body.FulfillmentStatus}) + if err != nil { + fail.IDs = append(fail.IDs, e) + continue + } + s.IDs = append(s.IDs, e) + } + all := []PatchEntitlementsResponse{ + s, + ua, + fail, + notFound, + } + resp := []PatchEntitlementsResponse{} + for _, r := range all { + if len(r.IDs) != 0 { + resp = append(resp, r) + } + } + + bytes, _ := json.Marshal(resp) + w.Write(bytes) +} diff --git a/internal/mock_api/endpoints/endpoints.go b/internal/mock_api/endpoints/endpoints.go index 548ccd9b..257f2536 100644 --- a/internal/mock_api/endpoints/endpoints.go +++ b/internal/mock_api/endpoints/endpoints.go @@ -15,6 +15,7 @@ import ( "github.com/twitchdev/twitch-cli/internal/mock_api/endpoints/moderation" "github.com/twitchdev/twitch-cli/internal/mock_api/endpoints/polls" "github.com/twitchdev/twitch-cli/internal/mock_api/endpoints/predictions" + "github.com/twitchdev/twitch-cli/internal/mock_api/endpoints/schedule" "github.com/twitchdev/twitch-cli/internal/mock_api/endpoints/search" "github.com/twitchdev/twitch-cli/internal/mock_api/endpoints/streams" "github.com/twitchdev/twitch-cli/internal/mock_api/endpoints/subscriptions" @@ -35,7 +36,10 @@ func All() []mock_api.MockEndpoint { channels.Editors{}, channels.InformationEndpoint{}, chat.ChannelBadges{}, + chat.ChannelEmotes{}, + chat.EmoteSets{}, chat.GlobalBadges{}, + chat.GlobalEmotes{}, clips.Clips{}, drops.DropsEntitlements{}, hype_train.HypeTrainEvents{}, @@ -47,6 +51,10 @@ func All() []mock_api.MockEndpoint { moderation.Moderators{}, polls.Polls{}, predictions.Predictions{}, + schedule.Schedule{}, + schedule.ScheduleICal{}, + schedule.ScheduleSegment{}, + schedule.ScheduleSettings{}, search.SearchCategories{}, search.SearchChannels{}, streams.AllTags{}, diff --git a/internal/mock_api/endpoints/polls/polls.go b/internal/mock_api/endpoints/polls/polls.go index 8b3fd41d..bacb8000 100644 --- a/internal/mock_api/endpoints/polls/polls.go +++ b/internal/mock_api/endpoints/polls/polls.go @@ -227,8 +227,7 @@ func patchPolls(w http.ResponseWriter, r *http.Request) { dbr, err := db.NewQuery(r, 100).GetPolls(database.Poll{BroadcasterID: userCtx.UserID, ID: body.ID}) if err != nil { - println(err.Error()) - mock_errors.WriteServerError(w, "error fetching polls") + mock_errors.WriteServerError(w, err.Error()) return } diff --git a/internal/mock_api/endpoints/schedule/ical.go b/internal/mock_api/endpoints/schedule/ical.go new file mode 100644 index 00000000..ff087441 --- /dev/null +++ b/internal/mock_api/endpoints/schedule/ical.go @@ -0,0 +1,79 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package schedule + +import ( + "net/http" + + "github.com/twitchdev/twitch-cli/internal/database" + "github.com/twitchdev/twitch-cli/internal/mock_api/mock_errors" +) + +var scheduleICalMethodsSupported = map[string]bool{ + http.MethodGet: true, + http.MethodPost: false, + http.MethodDelete: false, + http.MethodPatch: false, + http.MethodPut: false, +} + +var scheduleICalScopesByMethod = map[string][]string{ + http.MethodGet: {}, + http.MethodPost: {}, + http.MethodDelete: {}, + http.MethodPatch: {}, + http.MethodPut: {}, +} + +type ScheduleICal struct{} + +func (e ScheduleICal) Path() string { return "/schedule/icalendar" } + +func (e ScheduleICal) GetRequiredScopes(method string) []string { + return scheduleICalScopesByMethod[method] +} + +func (e ScheduleICal) ValidMethod(method string) bool { + return scheduleICalMethodsSupported[method] +} + +func (e ScheduleICal) ServeHTTP(w http.ResponseWriter, r *http.Request) { + db = r.Context().Value("db").(database.CLIDatabase) + + switch r.Method { + case http.MethodGet: + e.getIcal(w, r) + default: + w.WriteHeader(http.StatusMethodNotAllowed) + } +} + +// stubbed with fake data for now, since .ics generation libraries are far and few between for golang +// and it's just useful for mock data +func (e ScheduleICal) getIcal(w http.ResponseWriter, r *http.Request) { + broadcaster := r.URL.Query().Get("broadcaster_id") + if broadcaster == "" { + mock_errors.WriteBadRequest(w, "Missing required paramater broadaster_id") + return + } + + body := + `BEGIN:VCALENDAR +PRODID:-//twitch.tv//StreamSchedule//1.0 +VERSION:2.0 +CALSCALE:GREGORIAN +REFRESH-INTERVAL;VALUE=DURATION:PT1H +NAME:TwitchDev +BEGIN:VEVENT +UID:e4acc724-371f-402c-81ca-23ada79759d4 +DTSTAMP:20210323T040131Z +DTSTART;TZID=/America/New_York:20210701T140000 +DTEND;TZID=/America/New_York:20210701T150000 +SUMMARY:TwitchDev Monthly Update // July 1, 2021 +DESCRIPTION:Science & Technology. +CATEGORIES:Science & Technology +END:VEVENT +END:VCALENDAR%` + w.Header().Add("Content-Type", "text/calendar") + w.Write([]byte(body)) +} diff --git a/internal/mock_api/endpoints/schedule/scehdule_test.go b/internal/mock_api/endpoints/schedule/scehdule_test.go new file mode 100644 index 00000000..5f9b54cf --- /dev/null +++ b/internal/mock_api/endpoints/schedule/scehdule_test.go @@ -0,0 +1,332 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package schedule + +import ( + "bytes" + "encoding/json" + "log" + "net/http" + "os" + "testing" + "time" + + "github.com/twitchdev/twitch-cli/internal/database" + "github.com/twitchdev/twitch-cli/internal/util" + "github.com/twitchdev/twitch-cli/test_setup" + "github.com/twitchdev/twitch-cli/test_setup/test_server" +) + +type RewardResponse struct { + Data []database.ChannelPointsReward `json:"data"` +} + +var ( + segment database.ScheduleSegment +) + +func TestMain(m *testing.M) { + test_setup.SetupTestEnv(&testing.T{}) + + db, err := database.NewConnection() + if err != nil { + log.Fatal(err) + } + f := false + s := database.ScheduleSegment{ + ID: util.RandomGUID(), + UserID: "1", + Title: "from_unit_tests", + IsRecurring: true, + IsVacation: false, + StartTime: time.Now().UTC().Format(time.RFC3339), + EndTime: time.Now().UTC().Add(24 * time.Hour).Format(time.RFC3339), + IsCanceled: &f, + } + err = db.NewQuery(nil, 100).InsertSchedule(s) + if err != nil { + log.Fatal(err) + } + + _, err = db.NewQuery(nil, 100).GetSchedule(database.ScheduleSegment{UserID: "1"}, time.Date(0, 0, 0, 0, 0, 0, 0, time.UTC)) + + segment = s + db.DB.Close() + + os.Exit(m.Run()) +} +func TestSchedule(t *testing.T) { + a := test_setup.SetupTestEnv(t) + + ts := test_server.SetupTestServer(Schedule{}) + + // get + req, _ := http.NewRequest(http.MethodGet, ts.URL+Schedule{}.Path(), nil) + q := req.URL.Query() + req.URL.RawQuery = q.Encode() + resp, err := http.DefaultClient.Do(req) + a.Nil(err) + a.Equal(400, resp.StatusCode) + + q.Set("broadcaster_id", "1") + req.URL.RawQuery = q.Encode() + resp, err = http.DefaultClient.Do(req) + a.Nil(err) + a.Equal(200, resp.StatusCode) + + q.Set("broadcaster_id", "2") + req.URL.RawQuery = q.Encode() + resp, err = http.DefaultClient.Do(req) + a.Nil(err) + a.Equal(200, resp.StatusCode) + + q.Set("broadcaster_id", "1") + q.Set("id", segment.ID) + req.URL.RawQuery = q.Encode() + resp, err = http.DefaultClient.Do(req) + a.Nil(err) + a.Equal(200, resp.StatusCode) + + q.Set("utc_offset", "60") + req.URL.RawQuery = q.Encode() + resp, err = http.DefaultClient.Do(req) + a.Nil(err) + a.Equal(200, resp.StatusCode) + + q.Set("start_time", "test") + req.URL.RawQuery = q.Encode() + resp, err = http.DefaultClient.Do(req) + a.Nil(err) + a.Equal(400, resp.StatusCode) + + q.Set("start_time", segment.StartTime) + req.URL.RawQuery = q.Encode() + resp, err = http.DefaultClient.Do(req) + a.Nil(err) + a.Equal(200, resp.StatusCode) +} + +func TestICal(t *testing.T) { + a := test_setup.SetupTestEnv(t) + + ts := test_server.SetupTestServer(ScheduleICal{}) + req, _ := http.NewRequest(http.MethodGet, ts.URL+Schedule{}.Path(), nil) + q := req.URL.Query() + req.URL.RawQuery = q.Encode() + resp, err := http.DefaultClient.Do(req) + a.Nil(err) + a.Equal(400, resp.StatusCode) + + q.Set("broadcaster_id", "1") + req.URL.RawQuery = q.Encode() + resp, err = http.DefaultClient.Do(req) + a.Nil(err) + a.Equal(200, resp.StatusCode) +} + +func TestSegment(t *testing.T) { + a := test_setup.SetupTestEnv(t) + + ts := test_server.SetupTestServer(ScheduleSegment{}) + tr := true + + // post tests + body := SegmentPatchAndPostBody{ + Title: "hello", + Timezone: "America/Los_Angeles", + StartTime: time.Now().Format(time.RFC3339), + IsRecurring: &tr, + Duration: "60", + } + + b, _ := json.Marshal(body) + req, _ := http.NewRequest(http.MethodPost, ts.URL+ScheduleSegment{}.Path(), bytes.NewBuffer(b)) + q := req.URL.Query() + q.Set("broadcaster_id", "1") + req.URL.RawQuery = q.Encode() + resp, err := http.DefaultClient.Do(req) + a.Nil(err) + a.NotNil(resp) + a.Equal(200, resp.StatusCode) + + body.Title = "" + b, _ = json.Marshal(body) + req, _ = http.NewRequest(http.MethodPost, ts.URL+ScheduleSegment{}.Path(), bytes.NewBuffer(b)) + q.Set("broadcaster_id", "1") + req.URL.RawQuery = q.Encode() + resp, err = http.DefaultClient.Do(req) + a.Nil(err) + a.Equal(200, resp.StatusCode) + + b, _ = json.Marshal(body) + req, _ = http.NewRequest(http.MethodPost, ts.URL+ScheduleSegment{}.Path(), bytes.NewBuffer(b)) + q.Set("broadcaster_id", "2") + req.URL.RawQuery = q.Encode() + resp, err = http.DefaultClient.Do(req) + a.Nil(err) + a.Equal(401, resp.StatusCode) + + body.Title = "testing" + body.Timezone = "" + b, _ = json.Marshal(body) + req, _ = http.NewRequest(http.MethodPost, ts.URL+ScheduleSegment{}.Path(), bytes.NewBuffer(b)) + q.Set("broadcaster_id", "1") + req.URL.RawQuery = q.Encode() + resp, err = http.DefaultClient.Do(req) + a.Nil(err) + a.Equal(400, resp.StatusCode) + + body.Timezone = "test" + b, _ = json.Marshal(body) + req, _ = http.NewRequest(http.MethodPost, ts.URL+ScheduleSegment{}.Path(), bytes.NewBuffer(b)) + q.Set("broadcaster_id", "1") + req.URL.RawQuery = q.Encode() + resp, err = http.DefaultClient.Do(req) + a.Nil(err) + a.Equal(400, resp.StatusCode) + + body.Timezone = segment.Timezone + body.IsRecurring = nil + b, _ = json.Marshal(body) + req, _ = http.NewRequest(http.MethodPost, ts.URL+ScheduleSegment{}.Path(), bytes.NewBuffer(b)) + q.Set("broadcaster_id", "1") + req.URL.RawQuery = q.Encode() + resp, err = http.DefaultClient.Do(req) + a.Nil(err) + a.Equal(400, resp.StatusCode) + + // patch + // no id + b, _ = json.Marshal(body) + req, _ = http.NewRequest(http.MethodPatch, ts.URL+ScheduleSegment{}.Path(), bytes.NewBuffer(b)) + q.Set("broadcaster_id", "1") + q.Del("id") + req.URL.RawQuery = q.Encode() + resp, err = http.DefaultClient.Do(req) + a.Nil(err) + a.Equal(400, resp.StatusCode) + + //mismatch bid and token + b, _ = json.Marshal(body) + req, _ = http.NewRequest(http.MethodPatch, ts.URL+ScheduleSegment{}.Path(), bytes.NewBuffer(b)) + q.Set("broadcaster_id", "2") + req.URL.RawQuery = q.Encode() + resp, err = http.DefaultClient.Do(req) + a.Nil(err) + a.Equal(401, resp.StatusCode) + + // good request + body.Title = "patched_title" + b, _ = json.Marshal(body) + req, _ = http.NewRequest(http.MethodPatch, ts.URL+ScheduleSegment{}.Path(), bytes.NewBuffer(b)) + q.Set("broadcaster_id", "1") + q.Set("id", segment.ID) + req.URL.RawQuery = q.Encode() + resp, err = http.DefaultClient.Do(req) + a.Nil(err) + a.Equal(200, resp.StatusCode) + + // delete + req, _ = http.NewRequest(http.MethodDelete, ts.URL+ScheduleSegment{}.Path(), nil) + q.Set("broadcaster_id", "1") + req.URL.RawQuery = q.Encode() + resp, err = http.DefaultClient.Do(req) + a.Nil(err) + a.Equal(204, resp.StatusCode) + + q.Set("id", segment.ID) + req.URL.RawQuery = q.Encode() + resp, err = http.DefaultClient.Do(req) + a.Nil(err) + a.Equal(204, resp.StatusCode) + + q.Set("broadcaster_id", "2") + req.URL.RawQuery = q.Encode() + resp, err = http.DefaultClient.Do(req) + a.Nil(err) + a.Equal(401, resp.StatusCode) +} + +func TestSettings(t *testing.T) { + a := test_setup.SetupTestEnv(t) + + ts := test_server.SetupTestServer(ScheduleSettings{}) + tr := true + f := false + + // patch tests + body := PatchSettingsBody{ + Timezone: "America/Los_Angeles", + VacationStartTime: time.Now().Format(time.RFC3339), + VacationEndTime: segment.EndTime, + IsVacationEnabled: &f, + } + + b, _ := json.Marshal(body) + req, _ := http.NewRequest(http.MethodPatch, ts.URL+ScheduleSettings{}.Path(), bytes.NewBuffer(b)) + q := req.URL.Query() + req.URL.RawQuery = q.Encode() + resp, err := http.DefaultClient.Do(req) + a.Nil(err) + a.NotNil(resp) + a.Equal(401, resp.StatusCode) + + b, _ = json.Marshal(body) + req, _ = http.NewRequest(http.MethodPatch, ts.URL+ScheduleSettings{}.Path(), bytes.NewBuffer(b)) + q.Set("broadcaster_id", "1") + req.URL.RawQuery = q.Encode() + resp, err = http.DefaultClient.Do(req) + a.Nil(err) + a.Equal(204, resp.StatusCode) + + body.IsVacationEnabled = &tr + b, _ = json.Marshal(body) + req, _ = http.NewRequest(http.MethodPatch, ts.URL+ScheduleSettings{}.Path(), bytes.NewBuffer(b)) + req.URL.RawQuery = q.Encode() + resp, err = http.DefaultClient.Do(req) + a.Nil(err) + a.Equal(204, resp.StatusCode) + + b, _ = json.Marshal(body) + req, _ = http.NewRequest(http.MethodPatch, ts.URL+ScheduleSettings{}.Path(), bytes.NewBuffer(b)) + q.Set("broadcaster_id", "1") + req.URL.RawQuery = q.Encode() + resp, err = http.DefaultClient.Do(req) + a.Nil(err) + a.Equal(400, resp.StatusCode) + + body.IsVacationEnabled = &f + b, _ = json.Marshal(body) + req, _ = http.NewRequest(http.MethodPatch, ts.URL+ScheduleSettings{}.Path(), bytes.NewBuffer(b)) + req.URL.RawQuery = q.Encode() + resp, err = http.DefaultClient.Do(req) + a.Nil(err) + a.Equal(204, resp.StatusCode) + + body.IsVacationEnabled = &tr + body.VacationStartTime = "123" + b, _ = json.Marshal(body) + req, _ = http.NewRequest(http.MethodPatch, ts.URL+ScheduleSettings{}.Path(), bytes.NewBuffer(b)) + req.URL.RawQuery = q.Encode() + resp, err = http.DefaultClient.Do(req) + a.Nil(err) + a.Equal(400, resp.StatusCode) + + body.VacationStartTime = segment.StartTime + body.VacationEndTime = "123" + b, _ = json.Marshal(body) + req, _ = http.NewRequest(http.MethodPatch, ts.URL+ScheduleSettings{}.Path(), bytes.NewBuffer(b)) + req.URL.RawQuery = q.Encode() + resp, err = http.DefaultClient.Do(req) + a.Nil(err) + a.Equal(400, resp.StatusCode) + + body.VacationEndTime = segment.EndTime + body.Timezone = "1" + b, _ = json.Marshal(body) + req, _ = http.NewRequest(http.MethodPatch, ts.URL+ScheduleSettings{}.Path(), bytes.NewBuffer(b)) + req.URL.RawQuery = q.Encode() + resp, err = http.DefaultClient.Do(req) + a.Nil(err) + a.Equal(400, resp.StatusCode) +} diff --git a/internal/mock_api/endpoints/schedule/schedule.go b/internal/mock_api/endpoints/schedule/schedule.go new file mode 100644 index 00000000..63254015 --- /dev/null +++ b/internal/mock_api/endpoints/schedule/schedule.go @@ -0,0 +1,133 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package schedule + +import ( + "encoding/json" + "log" + "net/http" + "strconv" + "time" + + "github.com/twitchdev/twitch-cli/internal/database" + "github.com/twitchdev/twitch-cli/internal/mock_api/mock_errors" + "github.com/twitchdev/twitch-cli/internal/models" +) + +var scheduleMethodsSupported = map[string]bool{ + http.MethodGet: true, + http.MethodPost: false, + http.MethodDelete: false, + http.MethodPatch: false, + http.MethodPut: false, +} + +var scheduleScopesByMethod = map[string][]string{ + http.MethodGet: {}, + http.MethodPost: {}, + http.MethodDelete: {}, + http.MethodPatch: {}, + http.MethodPut: {}, +} + +type Schedule struct{} + +func (e Schedule) Path() string { return "/schedule" } + +func (e Schedule) GetRequiredScopes(method string) []string { + return scheduleScopesByMethod[method] +} + +func (e Schedule) ValidMethod(method string) bool { + return scheduleMethodsSupported[method] +} + +func (e Schedule) ServeHTTP(w http.ResponseWriter, r *http.Request) { + db = r.Context().Value("db").(database.CLIDatabase) + + switch r.Method { + case http.MethodGet: + e.getSchedule(w, r) + default: + w.WriteHeader(http.StatusMethodNotAllowed) + } +} + +func (e Schedule) getSchedule(w http.ResponseWriter, r *http.Request) { + broadcasterID := r.URL.Query().Get("broadcaster_id") + queryTime := r.URL.Query().Get("start_time") + offset := r.URL.Query().Get("utc_offset") + ids := r.URL.Query()["id"] + schedule := database.Schedule{} + startTime := time.Now().UTC() + apiResponse := models.APIResponse{} + + if broadcasterID == "" { + mock_errors.WriteBadRequest(w, "Required parameter broadcaster_id is missing") + return + } + + if queryTime != "" { + st, err := time.Parse(time.RFC3339, queryTime) + if err != nil { + mock_errors.WriteBadRequest(w, "Parameter start_time is in an invalid format") + return + } + startTime = st.UTC() + } + + if offset != "" { + o, err := strconv.Atoi(offset) + if err != nil { + mock_errors.WriteBadRequest(w, "Error decoding parameter offset") + return + } + tz := time.FixedZone("", o*60) + startTime = startTime.In(tz) + } + + segments := []database.ScheduleSegment{} + if len(ids) > 0 { + if len(ids) > 100 { + mock_errors.WriteBadRequest(w, "Parameter id may only have a maximum of 100 values") + return + } + for _, id := range ids { + dbr, err := db.NewQuery(r, 25).GetSchedule(database.ScheduleSegment{ID: id, UserID: broadcasterID}, startTime) + if err != nil { + log.Print(err) + mock_errors.WriteServerError(w, err.Error()) + return + } + response := dbr.Data.(database.Schedule) + schedule = response + segments = append(segments, response.Segments...) + } + schedule.Segments = segments + apiResponse = models.APIResponse{ + Data: schedule, + } + } else { + dbr, err := db.NewQuery(r, 25).GetSchedule(database.ScheduleSegment{UserID: broadcasterID}, startTime) + if err != nil { + mock_errors.WriteServerError(w, err.Error()) + return + } + response := dbr.Data.(database.Schedule) + segments = append(segments, response.Segments...) + schedule = response + schedule.Segments = segments + apiResponse = models.APIResponse{ + Data: schedule, + } + + if len(schedule.Segments) == dbr.Limit { + apiResponse.Pagination = &models.APIPagination{ + Cursor: dbr.Cursor, + } + } + } + + bytes, _ := json.Marshal(apiResponse) + w.Write(bytes) +} diff --git a/internal/mock_api/endpoints/schedule/segment.go b/internal/mock_api/endpoints/schedule/segment.go new file mode 100644 index 00000000..0b423c00 --- /dev/null +++ b/internal/mock_api/endpoints/schedule/segment.go @@ -0,0 +1,319 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package schedule + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "strconv" + "time" + + "github.com/twitchdev/twitch-cli/internal/database" + "github.com/twitchdev/twitch-cli/internal/mock_api/authentication" + "github.com/twitchdev/twitch-cli/internal/mock_api/mock_errors" + "github.com/twitchdev/twitch-cli/internal/util" +) + +var scheduleSegmentMethodsSupported = map[string]bool{ + http.MethodGet: false, + http.MethodPost: true, + http.MethodDelete: true, + http.MethodPatch: true, + http.MethodPut: false, +} + +var scheduleSegmentScopesByMethod = map[string][]string{ + http.MethodGet: {}, + http.MethodPost: {"channel:manage:schedule"}, + http.MethodDelete: {"channel:manage:schedule"}, + http.MethodPatch: {"channel:manage:schedule"}, + http.MethodPut: {}, +} + +var f = false + +type ScheduleSegment struct{} + +type SegmentPatchAndPostBody struct { + StartTime string `json:"start_time"` + Timezone string `json:"timezone"` + IsRecurring *bool `json:"is_recurring"` + Duration string `json:"duration"` + CategoryID *string `json:"category_id"` + Title string `json:"title"` + IsCanceled *bool `json:"is_canceled"` +} + +func (e ScheduleSegment) Path() string { return "/schedule/segment" } + +func (e ScheduleSegment) GetRequiredScopes(method string) []string { + return scheduleSegmentScopesByMethod[method] +} + +func (e ScheduleSegment) ValidMethod(method string) bool { + return scheduleSegmentMethodsSupported[method] +} + +func (e ScheduleSegment) ServeHTTP(w http.ResponseWriter, r *http.Request) { + db = r.Context().Value("db").(database.CLIDatabase) + + switch r.Method { + case http.MethodPost: + e.postSegment(w, r) + case http.MethodDelete: + e.deleteSegment(w, r) + case http.MethodPatch: + e.patchSegment(w, r) + default: + w.WriteHeader(http.StatusMethodNotAllowed) + } +} + +func (e ScheduleSegment) postSegment(w http.ResponseWriter, r *http.Request) { + userCtx := r.Context().Value("auth").(authentication.UserAuthentication) + duration := 240 + + if !userCtx.MatchesBroadcasterIDParam(r) { + mock_errors.WriteUnauthorized(w, "User token does not match broadcaster_id parameter") + return + } + var body SegmentPatchAndPostBody + err := json.NewDecoder(r.Body).Decode(&body) + if err != nil { + mock_errors.WriteBadRequest(w, "Error parsing body") + return + } + + if body.StartTime == "" { + mock_errors.WriteBadRequest(w, "Missing start_time") + return + } + st, err := time.Parse(time.RFC3339, body.StartTime) + if err != nil { + mock_errors.WriteBadRequest(w, "Invalid timezone provided") + return + } + if body.Timezone == "" { + mock_errors.WriteBadRequest(w, "Missing timezone") + return + } + _, err = time.LoadLocation(body.Timezone) + if err != nil { + mock_errors.WriteBadRequest(w, "Invalid timezone provided") + return + } + if body.IsRecurring == nil { + mock_errors.WriteBadRequest(w, "Missing is_recurring") + return + } + + if len(body.Title) > 140 { + mock_errors.WriteBadRequest(w, "Title must be less than 140 characters") + return + } + + if body.Duration != "" { + duration, err = strconv.Atoi(body.Duration) + if err != nil { + mock_errors.WriteBadRequest(w, "Invalid duration provided") + return + } + } + et := st.Add(time.Duration(duration) * time.Minute) + + segmentID := util.RandomGUID() + eventID := base64.RawStdEncoding.EncodeToString([]byte(fmt.Sprintf("%v\\%v", segmentID, st))) + segment := database.ScheduleSegment{ + ID: eventID, + StartTime: st.UTC().Format(time.RFC3339), + EndTime: et.UTC().Format(time.RFC3339), + IsRecurring: *body.IsRecurring, + IsVacation: false, + CategoryID: body.CategoryID, + Title: body.Title, + UserID: userCtx.UserID, + Timezone: "America/Los_Angeles", + IsCanceled: &f, + } + err = db.NewQuery(nil, 100).InsertSchedule(segment) + if err != nil { + mock_errors.WriteServerError(w, err.Error()) + return + } + if *body.IsRecurring { + // just a years worth of recurring events; mock data + for i := 0; i < 52; i++ { + weekAdd := (i + 1) * 7 * 24 + startTime := time.Now().Add(time.Duration(weekAdd) * time.Hour).UTC() + endTime := time.Now().Add(time.Duration(weekAdd) * time.Hour).UTC() + eventID := base64.RawStdEncoding.EncodeToString([]byte(fmt.Sprintf("%v\\%v", segmentID, startTime))) + + s := database.ScheduleSegment{ + ID: eventID, + StartTime: startTime.Format(time.RFC3339), + EndTime: endTime.Format(time.RFC3339), + IsRecurring: *body.IsRecurring, + IsVacation: false, + CategoryID: body.CategoryID, + Title: body.Title, + UserID: userCtx.UserID, + Timezone: body.Timezone, + IsCanceled: &f, + } + + err := db.NewQuery(nil, 100).InsertSchedule(s) + if err != nil { + mock_errors.WriteServerError(w, err.Error()) + return + } + } + } + dbr, err := db.NewQuery(nil, 100).GetSchedule(database.ScheduleSegment{ID: eventID}, time.Date(0, 0, 0, 0, 0, 0, 0, time.UTC)) + if err != nil { + mock_errors.WriteServerError(w, err.Error()) + return + } + b := dbr.Data.(database.Schedule) + bytes, _ := json.Marshal(b) + w.Write(bytes) +} + +func (e ScheduleSegment) deleteSegment(w http.ResponseWriter, r *http.Request) { + userCtx := r.Context().Value("auth").(authentication.UserAuthentication) + id := r.URL.Query().Get("id") + if !userCtx.MatchesBroadcasterIDParam(r) { + mock_errors.WriteUnauthorized(w, "User token does not match broadcaster_id parameter") + return + } + + if id == "" { + mock_errors.WriteBadRequest(w, "Missing required parameter id") + return + } + + err := db.NewQuery(nil, 100).DeleteSegment(id, userCtx.UserID) + if err != nil { + mock_errors.WriteServerError(w, err.Error()) + return + } + w.WriteHeader(http.StatusNoContent) +} + +func (e ScheduleSegment) patchSegment(w http.ResponseWriter, r *http.Request) { + userCtx := r.Context().Value("auth").(authentication.UserAuthentication) + id := r.URL.Query().Get("id") + if !userCtx.MatchesBroadcasterIDParam(r) { + mock_errors.WriteUnauthorized(w, "User token does not match broadcaster_id parameter") + return + } + if id == "" { + mock_errors.WriteBadRequest(w, "Missing required parameter id") + return + } + + dbr, err := db.NewQuery(nil, 100).GetSchedule(database.ScheduleSegment{ID: id, UserID: userCtx.UserID}, time.Date(0, 0, 0, 0, 0, 0, 0, time.UTC)) + if err != nil { + mock_errors.WriteServerError(w, err.Error()) + return + } + b := dbr.Data.(database.Schedule) + + if len(b.Segments) == 0 { + mock_errors.WriteBadRequest(w, "Invalid ID requested") + return + } + segment := b.Segments[0] + + var body SegmentPatchAndPostBody + err = json.NewDecoder(r.Body).Decode(&body) + if err != nil { + mock_errors.WriteBadRequest(w, "Error parsing body") + return + } + + // start_time + st, err := time.Parse(time.RFC3339, segment.StartTime) + if err != nil { + mock_errors.WriteServerError(w, err.Error()) + return + } + if body.StartTime != "" { + st, err = time.Parse(time.RFC3339, body.StartTime) + if err != nil { + mock_errors.WriteBadRequest(w, "Error parsing start_time") + return + } + } + + // timezone + tz, err := time.LoadLocation(segment.Timezone) + if err != nil { + mock_errors.WriteServerError(w, err.Error()) + return + } + if body.Timezone != "" { + tz, err = time.LoadLocation(body.Timezone) + if err != nil { + mock_errors.WriteBadRequest(w, "Error parsing timezone") + return + } + } + + // is_canceled + isCanceled := false + if body.IsCanceled != nil { + isCanceled = *body.IsCanceled + } + + // title + title := segment.Title + if body.Title != "" { + if len(body.Title) > 140 { + mock_errors.WriteBadRequest(w, "Title must be less than 140 characters") + return + } + title = body.Title + } + + // duration + et, err := time.Parse(time.RFC3339, segment.EndTime) + if err != nil { + mock_errors.WriteServerError(w, err.Error()) + return + } + if body.Duration != "" { + duration, err := strconv.Atoi(body.Duration) + if err != nil { + mock_errors.WriteBadRequest(w, "Invalid duration provided") + return + } + + et = st.Add(time.Duration(duration) * time.Minute) + } + + s := database.ScheduleSegment{ + ID: segment.ID, + StartTime: st.UTC().Format(time.RFC3339), + EndTime: et.UTC().Format(time.RFC3339), + IsCanceled: &isCanceled, + Timezone: tz.String(), + Title: title, + } + + err = db.NewQuery(r, 20).UpdateSegment(s) + if err != nil { + mock_errors.WriteServerError(w, err.Error()) + return + } + + dbr, err = db.NewQuery(nil, 100).GetSchedule(database.ScheduleSegment{ID: segment.ID}, time.Date(0, 0, 0, 0, 0, 0, 0, time.UTC)) + if err != nil { + mock_errors.WriteServerError(w, err.Error()) + return + } + b = dbr.Data.(database.Schedule) + bytes, _ := json.Marshal(b) + w.Write(bytes) +} diff --git a/internal/mock_api/endpoints/schedule/settings.go b/internal/mock_api/endpoints/schedule/settings.go new file mode 100644 index 00000000..a41a82c2 --- /dev/null +++ b/internal/mock_api/endpoints/schedule/settings.go @@ -0,0 +1,140 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package schedule + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/twitchdev/twitch-cli/internal/database" + "github.com/twitchdev/twitch-cli/internal/mock_api/authentication" + "github.com/twitchdev/twitch-cli/internal/mock_api/mock_errors" + "github.com/twitchdev/twitch-cli/internal/util" +) + +var scheduleSettingsMethodsSupported = map[string]bool{ + http.MethodGet: false, + http.MethodPost: false, + http.MethodDelete: false, + http.MethodPatch: true, + http.MethodPut: false, +} + +var scheduleSettingsScopesByMethod = map[string][]string{ + http.MethodGet: {}, + http.MethodPost: {}, + http.MethodDelete: {}, + http.MethodPatch: {"channel:manage:schedule"}, + http.MethodPut: {}, +} + +type ScheduleSettings struct{} + +type PatchSettingsBody struct { + IsVacationEnabled *bool `json:"is_vacation_enabled"` + VacationStartTime string `json:"vacation_start_time"` + VacationEndTime string `json:"vacation_end_time"` + Timezone string `json:"timezone"` +} + +func (e ScheduleSettings) Path() string { return "/schedule/settings" } + +func (e ScheduleSettings) GetRequiredScopes(method string) []string { + return scheduleSettingsScopesByMethod[method] +} + +func (e ScheduleSettings) ValidMethod(method string) bool { + return scheduleSettingsMethodsSupported[method] +} + +func (e ScheduleSettings) ServeHTTP(w http.ResponseWriter, r *http.Request) { + db = r.Context().Value("db").(database.CLIDatabase) + + switch r.Method { + case http.MethodPatch: + e.patchSchedule(w, r) + default: + w.WriteHeader(http.StatusMethodNotAllowed) + } +} + +func (e ScheduleSettings) patchSchedule(w http.ResponseWriter, r *http.Request) { + userCtx := r.Context().Value("auth").(authentication.UserAuthentication) + if !userCtx.MatchesBroadcasterIDParam(r) { + mock_errors.WriteUnauthorized(w, "User token does not match broadcaster_id parameter") + return + } + + vacation, err := db.NewQuery(r, 100).GetVacations(database.ScheduleSegment{UserID: userCtx.UserID}) + if err != nil { + mock_errors.WriteServerError(w, err.Error()) + return + } + + var body PatchSettingsBody + err = json.NewDecoder(r.Body).Decode(&body) + + if body.IsVacationEnabled == nil { + w.WriteHeader(http.StatusNoContent) + return + } + + if *body.IsVacationEnabled == false { + if vacation.ID != "" { + err := db.NewQuery(r, 100).DeleteSegment(vacation.ID, userCtx.UserID) + if err != nil { + mock_errors.WriteServerError(w, err.Error()) + return + } + } + w.WriteHeader(http.StatusNoContent) + return + } + + if vacation.ID != "" && *body.IsVacationEnabled == true { + mock_errors.WriteBadRequest(w, "Existing vacation already exists") + return + } + + if body.Timezone == "" || body.VacationStartTime == "" || body.VacationEndTime == "" { + mock_errors.WriteBadRequest(w, "Missing required parameter") + return + } + + _, err = time.LoadLocation(body.Timezone) + if err != nil { + mock_errors.WriteBadRequest(w, "Invalid timezone requested") + return + } + + st, err := time.Parse(time.RFC3339, body.VacationStartTime) + if err != nil { + mock_errors.WriteBadRequest(w, "Invalid vacation_start_time requested") + return + } + + et, err := time.Parse(time.RFC3339, body.VacationEndTime) + if err != nil { + mock_errors.WriteBadRequest(w, "Invalid vacation_end_time requested") + return + } + f := false + err = db.NewQuery(r, 100).InsertSchedule(database.ScheduleSegment{ + ID: base64.RawStdEncoding.EncodeToString([]byte(fmt.Sprintf("%v\\%v", util.RandomGUID(), st))), + StartTime: st.UTC().Format(time.RFC3339), + EndTime: et.UTC().Format(time.RFC3339), + IsVacation: true, + IsRecurring: false, + IsCanceled: &f, + UserID: userCtx.UserID, + }) + if err != nil { + mock_errors.WriteServerError(w, err.Error()) + return + } + + w.WriteHeader(http.StatusNoContent) +} diff --git a/internal/mock_api/endpoints/schedule/shared.go b/internal/mock_api/endpoints/schedule/shared.go new file mode 100644 index 00000000..4ba602dd --- /dev/null +++ b/internal/mock_api/endpoints/schedule/shared.go @@ -0,0 +1,7 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package schedule + +import "github.com/twitchdev/twitch-cli/internal/database" + +var db database.CLIDatabase diff --git a/internal/mock_api/endpoints/search/channels.go b/internal/mock_api/endpoints/search/channels.go index 827606d8..809bf2d2 100644 --- a/internal/mock_api/endpoints/search/channels.go +++ b/internal/mock_api/endpoints/search/channels.go @@ -64,7 +64,6 @@ func searchChannels(w http.ResponseWriter, r *http.Request) { if r.URL.Query().Get("live_only") != "" { live_only, _ = strconv.ParseBool(r.URL.Query().Get("live_only")) } - println(live_only) dbr, err := db.NewQuery(r, 100).SearchChannels(query, live_only) if err != nil { log.Print(err) diff --git a/internal/mock_api/endpoints/streams/all_tags.go b/internal/mock_api/endpoints/streams/all_tags.go index 635361ff..bbd962de 100644 --- a/internal/mock_api/endpoints/streams/all_tags.go +++ b/internal/mock_api/endpoints/streams/all_tags.go @@ -63,7 +63,6 @@ func getAllTags(w http.ResponseWriter, r *http.Request) { if len(tagIDs) > 0 { for _, id := range tagIDs { - println(id) t := database.Tag{ID: id} dbr, err := db.NewQuery(r, 100).GetTags(t) if err != nil { diff --git a/internal/mock_api/endpoints/streams/markers.go b/internal/mock_api/endpoints/streams/markers.go index 7fad7065..988b65be 100644 --- a/internal/mock_api/endpoints/streams/markers.go +++ b/internal/mock_api/endpoints/streams/markers.go @@ -89,8 +89,7 @@ func getMarkers(w http.ResponseWriter, r *http.Request) { dbr, err := db.NewQuery(r, 100).GetStreamMarkers(database.StreamMarker{BroadcasterID: userID, VideoID: videoID}) if err != nil { - println(err.Error()) - mock_errors.WriteServerError(w, "error fetching markers") + mock_errors.WriteServerError(w, err.Error()) return } markerResponse := dbr.Data.([]database.StreamMarkerUser) @@ -145,8 +144,7 @@ func postMarkers(w http.ResponseWriter, r *http.Request) { err = db.NewQuery(r, 100).InsertStreamMarker(sm) if err != nil { - println(err.Error()) - mock_errors.WriteServerError(w, "error inserting marker") + mock_errors.WriteServerError(w, err.Error()) return } diff --git a/internal/mock_api/endpoints/streams/stream_tags.go b/internal/mock_api/endpoints/streams/stream_tags.go index 7a24e15b..d677937e 100644 --- a/internal/mock_api/endpoints/streams/stream_tags.go +++ b/internal/mock_api/endpoints/streams/stream_tags.go @@ -7,6 +7,7 @@ import ( "log" "net/http" + "github.com/mattn/go-sqlite3" "github.com/twitchdev/twitch-cli/internal/database" "github.com/twitchdev/twitch-cli/internal/mock_api/authentication" "github.com/twitchdev/twitch-cli/internal/mock_api/mock_errors" @@ -106,15 +107,18 @@ func putStreamTags(w http.ResponseWriter, r *http.Request) { err = db.NewQuery(r, 100).DeleteAllStreamTags(userCtx.UserID) if err != nil { - println(err.Error()) - mock_errors.WriteBadRequest(w, "error removing stream tags") + log.Print(err) + mock_errors.WriteServerError(w, err.Error()) return } for _, tag := range body.TagIDs { err = db.NewQuery(r, 100).InsertStreamTag(database.StreamTag{UserID: userCtx.UserID, TagID: tag}) if err != nil { - println(err.Error()) - mock_errors.WriteBadRequest(w, "error adding stream tag") + if database.DatabaseErrorIs(err, sqlite3.ErrConstraintForeignKey) { + mock_errors.WriteBadRequest(w, "invalid tag provided") + return + } + mock_errors.WriteServerError(w, err.Error()) return } } diff --git a/internal/mock_api/generate/generate.go b/internal/mock_api/generate/generate.go index cffa8dec..79ed7e44 100644 --- a/internal/mock_api/generate/generate.go +++ b/internal/mock_api/generate/generate.go @@ -5,6 +5,7 @@ package generate import ( "context" "database/sql" + "encoding/base64" "fmt" "log" "strings" @@ -24,6 +25,8 @@ type UserInfo struct { Type string } +var f = false + func Generate(userCount int) error { db, err := database.NewConnection() if err != nil { @@ -183,6 +186,7 @@ func generateUsers(ctx context.Context, count int) error { GameID: dropsGameID, UserID: broadcaster.ID, Timestamp: util.GetTimestamp().Format(time.RFC3339Nano), + Status: "CLAIMED", } err = db.NewQuery(nil, 1000).InsertDropsEntitlement(entitlement) if err != nil { @@ -232,7 +236,7 @@ func generateUsers(ctx context.Context, count int) error { }, { ID: util.RandomGUID(), - Title: "Choice1", + Title: "Choice2", Color: "PINK", Users: 0, ChannelPoints: 0, @@ -245,6 +249,35 @@ func generateUsers(ctx context.Context, count int) error { log.Print(err.Error()) } + // create fake schedule event + segmentID := util.RandomGUID() + + // just a years worth of recurring events; mock data + for i := 0; i < 52; i++ { + weekAdd := (i + 1) * 7 * 24 + startTime := time.Now().Add(time.Duration(weekAdd) * time.Hour).UTC() + endTime := time.Now().Add(time.Duration(weekAdd) * time.Hour).UTC() + eventID := base64.RawStdEncoding.EncodeToString([]byte(fmt.Sprintf("%v\\%v", segmentID, startTime))) + + segment := database.ScheduleSegment{ + ID: eventID, + StartTime: startTime.Format(time.RFC3339), + EndTime: endTime.Format(time.RFC3339), + IsRecurring: true, + IsVacation: false, + CategoryID: &dropsGameID, + Title: "Test Title", + UserID: broadcaster.ID, + Timezone: "America/Los_Angeles", + IsCanceled: &f, + } + + err := db.NewQuery(nil, 100).InsertSchedule(segment) + if err != nil { + log.Print(err.Error()) + } + } + for j, user := range users { // create a seed used for the below determination on if a user should follow one another- this simply simulates a social mesh userSeed := util.RandomInt(100 * 100) @@ -509,6 +542,10 @@ func generateAuthorization(ctx context.Context, c database.AuthenticationClient, if err != nil { return err } - log.Printf("Created authorization for user %v with token %v", userID, auth.Token) + if userID != "" { + log.Printf("Created authorization for user %v with token %v", userID, auth.Token) + } else { + log.Printf("Created authorization with token %v", auth.Token) + } return nil } diff --git a/internal/mock_api/generate/generate_tools.go b/internal/mock_api/generate/generate_tools.go index 3e6fd62e..7eefd8a9 100644 --- a/internal/mock_api/generate/generate_tools.go +++ b/internal/mock_api/generate/generate_tools.go @@ -30,6 +30,8 @@ var usernamePossibilities = []string{ "Skateboard", "Egg", "Lion", + "Isaac", + "Jill", } func generateUsername() string { diff --git a/internal/mock_api/mock_server/server.go b/internal/mock_api/mock_server/server.go index 819cc745..c69ff781 100644 --- a/internal/mock_api/mock_server/server.go +++ b/internal/mock_api/mock_server/server.go @@ -81,6 +81,11 @@ func StartServer(port int) { func RegisterHandlers(m *http.ServeMux) { // all mock endpoints live in the /mock/ namespace for _, e := range endpoints.All() { + // no auth requirements on this endpoint, so just add it manually + if e.Path() == "/schedule/icalendar" { + m.Handle(MOCK_NAMESPACE+e.Path(), loggerMiddleware(e)) + continue + } m.Handle(MOCK_NAMESPACE+e.Path(), loggerMiddleware(authentication.AuthenticationMiddleware(e))) } for _, e := range mock_units.All() { @@ -95,7 +100,6 @@ func RegisterHandlers(m *http.ServeMux) { func loggerMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { log.Printf("%v %v", r.Method, r.URL.Path) - w.Header().Set("Content-Type", "application/json") next.ServeHTTP(w, r) }) } diff --git a/internal/mock_auth/mock_auth.go b/internal/mock_auth/mock_auth.go index d63f0b62..f3c91289 100644 --- a/internal/mock_auth/mock_auth.go +++ b/internal/mock_auth/mock_auth.go @@ -39,6 +39,7 @@ var validScopesByTokenType = map[string]map[string]bool{ "channel:read:polls": true, "channel:read:predictions": true, "channel:read:redemptions": true, + "channel:manage:schedule": true, "channel:read:stream_key": true, "channel:read:subscriptions": true, "clips:edit": true,