Skip to content

Commit

Permalink
Merge branch 'main' into ipv4
Browse files Browse the repository at this point in the history
  • Loading branch information
SputNikPlop authored Sep 12, 2023
2 parents 97405a7 + 98ddbfb commit a63a147
Show file tree
Hide file tree
Showing 21 changed files with 78 additions and 42 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ jobs:
- name: Checkout code
uses: actions/checkout@v3
- name: Test
run: go test ./...
run: go test ./... -timeout=120s
16 changes: 11 additions & 5 deletions internal/database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,17 @@ type CLIDatabase struct {
DB *sqlx.DB
}

func NewConnection() (CLIDatabase, error) {
db, err := getDatabase()
func NewConnection(extendedBusyTimeout bool) (CLIDatabase, error) {
db, err := getDatabase(extendedBusyTimeout)
if err != nil {
return CLIDatabase{}, err
}

return CLIDatabase{DB: &db}, nil
}

func getDatabase() (sqlx.DB, error) {
// extendedBusyTimeout sets an extended timeout for waiting on a busy database. This is mainly an issue in tests on WSL, so this flag shouldn't be used in production.
func getDatabase(extendedBusyTimeout bool) (sqlx.DB, error) {
home, err := util.GetApplicationDir()
if err != nil {
return sqlx.DB{}, err
Expand All @@ -46,9 +47,14 @@ func getDatabase() (sqlx.DB, error) {
needToInit = true
}

// force Foreign Key support ("fk=true")
dbFlags := "?_fk=true&cache=shared"
if extendedBusyTimeout {
// https://www.sqlite.org/c3ref/busy_timeout.html
dbFlags += "&_busy_timeout=60000"
}
for i := 0; i <= 5; i++ {
// open and force Foreign Key support ("fk=true")
db, err := sqlx.Open("sqlite3", path+"?_fk=true&cache=shared")
db, err := sqlx.Open("sqlite3", path+dbFlags)
if err != nil {
log.Print(i)
if i == 5 {
Expand Down
2 changes: 1 addition & 1 deletion internal/database/database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func TestMain(m *testing.M) {
log.Fatal(err)
}

db, err = NewConnection()
db, err = NewConnection(true)
if err != nil {
log.Print(err)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/events/trigger/retrigger_event.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
)

func RefireEvent(id string, p TriggerParameters) (string, error) {
db, err := database.NewConnection()
db, err := database.NewConnection(false)
if err != nil {
return "", err
}
Expand Down
2 changes: 1 addition & 1 deletion internal/events/trigger/trigger_event.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ https://dev.twitch.tv/docs/eventsub/handling-webhook-events#processing-an-event`
return "", err
}

db, err := database.NewConnection()
db, err := database.NewConnection(false)
if err != nil {
return "", err
}
Expand Down
2 changes: 1 addition & 1 deletion internal/events/websocket/mock_server/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ type Client struct {
clientName string // Unique name for the client. Not the Client ID.
conn *websocket.Conn
mutex sync.Mutex
ConnectedAtTimestamp string
ConnectedAtTimestamp string // RFC3339Nano timestamp indicating when the client connected to the server
connectionUrl string

mustSubscribeTimer *time.Timer
Expand Down
9 changes: 7 additions & 2 deletions internal/events/websocket/mock_server/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,9 +263,14 @@ func subscriptionPageHandlerGet(w http.ResponseWriter, r *http.Request) {

for clientName, clientSubscriptions := range server.Subscriptions {
for _, subscription := range clientSubscriptions {
if clientID == "debug" || subscription.ClientID == clientID {
disabledAndExpired := false // Production EventSub only shows disabled WebSocket subscriptions that were disabled under 1 hour ago
if subscription.DisabledAt != nil && subscription.DisabledAt.Add(time.Hour).Before(util.GetTimestamp()) {
disabledAndExpired = true
}

if clientID == "debug" || (subscription.ClientID == clientID && !disabledAndExpired) {
allSubscriptions = append(allSubscriptions, SubscriptionPostSuccessResponseBody{
ID: subscription.ClientID,
ID: subscription.SubscriptionID,
Status: subscription.Status,
Type: subscription.Type,
Version: subscription.Version,
Expand Down
11 changes: 7 additions & 4 deletions internal/events/websocket/mock_server/rpc_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,7 @@ func RPCFireEventSubHandler(args rpc.RPCArgs) rpc.RPCResponse {
}
}

clientName, exists := args.Variables["ClientName"]
if !exists {

}
clientName := args.Variables["ClientName"]
if sessionRegex.MatchString(clientName) {
// Users can include the full session_id given in the response. If they do, subtract it to just the client name
clientName = sessionRegex.FindAllStringSubmatch(clientName, -1)[0][2]
Expand Down Expand Up @@ -253,6 +250,12 @@ func RPCSubscriptionHandler(args rpc.RPCArgs) rpc.RPCResponse {
found = true

server.Subscriptions[client][i].Status = args.Variables["SubscriptionStatus"]
if args.Variables["SubscriptionStatus"] == STATUS_ENABLED {
server.Subscriptions[client][i].DisabledAt = nil
} else {
tNow := util.GetTimestamp()
server.Subscriptions[client][i].DisabledAt = &tNow
}
break
}
}
Expand Down
29 changes: 23 additions & 6 deletions internal/events/websocket/mock_server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ import (
const KEEPALIVE_TIMEOUT_SECONDS = 10

type WebSocketServer struct {
ServerId string // Int representing the ID of the server
//ConnectionUrl string // Server's url for people to connect to. Used for messaging in reconnect testing
DebugEnabled bool // Display debug messages; --debug
StrictMode bool // Force stricter production-like qualities; --strict
Upgrader websocket.Upgrader
ServerId string // Int representing the ID of the server
DebugEnabled bool // Display debug messages; --debug
StrictMode bool // Force stricter production-like qualities; --strict

Upgrader websocket.Upgrader

Clients *util.List[Client] // All connected clients
muClients sync.Mutex // Mutex for WebSocketServer.Clients
Expand Down Expand Up @@ -410,6 +410,8 @@ func (ws *WebSocketServer) HandleRPCEventSubForwarding(eventsubBody string, clie
foundClientId = sub.ClientID

ws.Subscriptions[client][i].Status = STATUS_AUTHORIZATION_REVOKED
tNow := util.GetTimestamp()
ws.Subscriptions[client][i].DisabledAt = &tNow
break
}
}
Expand All @@ -426,6 +428,7 @@ func (ws *WebSocketServer) HandleRPCEventSubForwarding(eventsubBody string, clie
}

// Check for subscriptions when running with --require-subscription
subscriptionCreatedAtTimestamp := "" // Used below if in strict mode
if ws.StrictMode {
found := false
for _, clientSubscriptions := range ws.Subscriptions {
Expand All @@ -436,6 +439,7 @@ func (ws *WebSocketServer) HandleRPCEventSubForwarding(eventsubBody string, clie
for _, sub := range clientSubscriptions {
if sub.SessionClientName == client.clientName && sub.Type == eventObj.Subscription.Type && sub.Version == eventObj.Subscription.Version {
found = true
subscriptionCreatedAtTimestamp = sub.CreatedAt
}
}
}
Expand All @@ -448,6 +452,16 @@ func (ws *WebSocketServer) HandleRPCEventSubForwarding(eventsubBody string, clie
// Change payload's subscription.transport.session_id to contain the correct Session ID
eventObj.Subscription.Transport.SessionID = fmt.Sprintf("%v_%v", ws.ServerId, client.clientName)

// Change payload's subscription.created_at to contain the correct timestamp -- https://github.com/twitchdev/twitch-cli/issues/264
if ws.StrictMode {
// When running WITH --require-subscription, created_at will be set to the time the subscription was created using the mock EventSub REST endpoint
eventObj.Subscription.CreatedAt = subscriptionCreatedAtTimestamp
} else {
// When running WITHOUT --require-subscription, created_at will be set to the time the client connected
// This is because without --require-subscription the server "grants" access to all event subscriptions at the moment the client is connected
eventObj.Subscription.CreatedAt = client.ConnectedAtTimestamp
}

// Build notification message
notificationMsg, err := json.Marshal(
NotificationMessage{
Expand Down Expand Up @@ -503,9 +517,12 @@ func (ws *WebSocketServer) handleClientConnectionClose(client *Client, closeReas
subscriptions := ws.Subscriptions[client.clientName]
for i := range subscriptions {
if subscriptions[i].Status == STATUS_ENABLED {
tNow := util.GetTimestamp()

subscriptions[i].Status = getStatusFromCloseMessage(closeReason)
subscriptions[i].ClientConnectedAt = ""
subscriptions[i].ClientDisconnectedAt = time.Now().UTC().Format(time.RFC3339Nano)
subscriptions[i].ClientDisconnectedAt = tNow.Format(time.RFC3339Nano)
subscriptions[i].DisabledAt = &tNow
}
}
ws.Subscriptions[client.clientName] = subscriptions
Expand Down
21 changes: 13 additions & 8 deletions internal/events/websocket/mock_server/subscription.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
package mock_server

import "github.com/twitchdev/twitch-cli/internal/models"
import (
"time"

"github.com/twitchdev/twitch-cli/internal/models"
)

type Subscription struct {
SubscriptionID string // Random GUID for the subscription
ClientID string // Client ID included in headers
Type string // EventSub topic
Version string // EventSub topic version
CreatedAt string // Timestamp of when the subscription was created
Status string // Status of the subscription
SessionClientName string // Client name of the session this is associated with.
SubscriptionID string // Random GUID for the subscription
ClientID string // Client ID included in headers
Type string // EventSub topic
Version string // EventSub topic version
CreatedAt string // Timestamp of when the subscription was created
DisabledAt *time.Time // Not public; Timestamp of when the subscription was disabled
Status string // Status of the subscription
SessionClientName string // Client name of the session this is associated with.

ClientConnectedAt string // Time client connected
ClientDisconnectedAt string // Time client disconnected
Expand Down
2 changes: 1 addition & 1 deletion internal/login/login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ func TestUserAuthServer(t *testing.T) {
userResponse <- *res
}()

time.Sleep(25)
time.Sleep(1 * time.Second)
_, err = loginRequest(http.MethodGet, fmt.Sprintf("http://localhost:3000?code=%s&state=%s", code, state), nil)
a.Nil(err, err)

Expand Down
2 changes: 1 addition & 1 deletion internal/mock_api/authentication/authentication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func baseMiddleware(next http.Handler) http.Handler {
ctx := context.Background()

// just stub it all
db, err := database.NewConnection()
db, err := database.NewConnection(false)
if err != nil {
log.Fatalf("Error connecting to database: %v", err.Error())
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ var (
func TestMain(m *testing.M) {
test_setup.SetupTestEnv(&testing.T{})

db, err := database.NewConnection()
db, err := database.NewConnection(true)
if err != nil {
log.Fatal(err)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/mock_api/endpoints/channels/channels_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func TestMain(m *testing.M) {
test_setup.SetupTestEnv(&testing.T{})

// adding mock data
db, _ := database.NewConnection()
db, _ := database.NewConnection(true)
q := db.NewQuery(nil, 100)
q.InsertStream(database.Stream{ID: util.RandomGUID(), UserID: "1", StreamType: "live", ViewerCount: 0}, false)
db.DB.Close()
Expand Down
2 changes: 1 addition & 1 deletion internal/mock_api/endpoints/clips/clips_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func TestMain(m *testing.M) {
test_setup.SetupTestEnv(&testing.T{})

// adding mock data
db, _ := database.NewConnection()
db, _ := database.NewConnection(true)
q := db.NewQuery(nil, 100)
q.InsertStream(database.Stream{ID: util.RandomGUID(), UserID: "1", StreamType: "live", ViewerCount: 0}, false)
db.DB.Close()
Expand Down
2 changes: 1 addition & 1 deletion internal/mock_api/endpoints/drops/drops_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ var entitlement database.DropsEntitlement
func TestMain(m *testing.M) {
test_setup.SetupTestEnv(&testing.T{})

db, err := database.NewConnection()
db, err := database.NewConnection(true)
if err != nil {
log.Fatal(err)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/mock_api/endpoints/schedule/scehdule_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ var (
func TestMain(m *testing.M) {
test_setup.SetupTestEnv(&testing.T{})

db, err := database.NewConnection()
db, err := database.NewConnection(true)
if err != nil {
log.Fatal(err)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/mock_api/generate/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ type UserInfo struct {
var f = false

func Generate(userCount int) error {
db, err := database.NewConnection()
db, err := database.NewConnection(false)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion internal/mock_api/mock_server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func StartServer(port int) error {

ctx := context.Background()

db, err := database.NewConnection()
db, err := database.NewConnection(false)
if err != nil {
return fmt.Errorf("Error connecting to database: %v", err.Error())
}
Expand Down
4 changes: 2 additions & 2 deletions internal/mock_auth/mock_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func TestValidateToken(t *testing.T) {
a.Nil(err, err)
a.Equal(401, resp.StatusCode)

db, err := database.NewConnection()
db, err := database.NewConnection(true)
a.Nil(err, err)
defer db.DB.Close()

Expand Down Expand Up @@ -152,7 +152,7 @@ func baseMiddleware(next http.Handler) http.Handler {
ctx := context.Background()

// just stub it all
db, err := database.NewConnection()
db, err := database.NewConnection(true)
if err != nil {
log.Fatalf("Error connecting to database: %v", err.Error())
return
Expand Down
2 changes: 1 addition & 1 deletion test_setup/test_server/test_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func SetupTestServer(next mock_api.MockEndpoint) *httptest.Server {
ctx := context.Background()

// just stub it all
db, err := database.NewConnection()
db, err := database.NewConnection(true)
if err != nil {
log.Fatalf("Error connecting to database: %v", err.Error())
return
Expand Down

0 comments on commit a63a147

Please sign in to comment.