1
+ /*
2
+ * Copyright © 2019-2020 A Bunch Tell LLC.
3
+ *
4
+ * This file is part of WriteFreely.
5
+ *
6
+ * WriteFreely is free software: you can redistribute it and/or modify
7
+ * it under the terms of the GNU Affero General Public License, included
8
+ * in the LICENSE file in this source code package.
9
+ */
10
+
1
11
package writefreely
2
12
3
13
import (
@@ -15,10 +25,27 @@ import (
15
25
"github.com/gorilla/sessions"
16
26
"github.com/writeas/impart"
17
27
"github.com/writeas/web-core/log"
18
-
19
28
"github.com/writeas/writefreely/config"
20
29
)
21
30
31
+ // OAuthButtons holds display information for different OAuth providers we support.
32
+ type OAuthButtons struct {
33
+ SlackEnabled bool
34
+ WriteAsEnabled bool
35
+ GitLabEnabled bool
36
+ GitLabDisplayName string
37
+ }
38
+
39
+ // NewOAuthButtons creates a new OAuthButtons struct based on our app configuration.
40
+ func NewOAuthButtons (cfg * config.Config ) * OAuthButtons {
41
+ return & OAuthButtons {
42
+ SlackEnabled : cfg .SlackOauth .ClientID != "" ,
43
+ WriteAsEnabled : cfg .WriteAsOauth .ClientID != "" ,
44
+ GitLabEnabled : cfg .GitlabOauth .ClientID != "" ,
45
+ GitLabDisplayName : config .OrDefaultString (cfg .GitlabOauth .DisplayName , gitlabDisplayName ),
46
+ }
47
+ }
48
+
22
49
// TokenResponse contains data returned when a token is created either
23
50
// through a code exchange or using a refresh token.
24
51
type TokenResponse struct {
@@ -61,8 +88,8 @@ type OAuthDatastoreProvider interface {
61
88
type OAuthDatastore interface {
62
89
GetIDForRemoteUser (context.Context , string , string , string ) (int64 , error )
63
90
RecordRemoteUserID (context.Context , int64 , string , string , string , string ) error
64
- ValidateOAuthState (context.Context , string ) (string , string , int64 , error )
65
- GenerateOAuthState (context.Context , string , string , int64 ) (string , error )
91
+ ValidateOAuthState (context.Context , string ) (string , string , int64 , string , error )
92
+ GenerateOAuthState (context.Context , string , string , int64 , string ) (string , error )
66
93
67
94
CreateUser (* config.Config , * User , string ) error
68
95
GetUserByID (int64 ) (* User , error )
@@ -108,7 +135,7 @@ func (h oauthHandler) viewOauthInit(app *App, w http.ResponseWriter, r *http.Req
108
135
attachUser = user .ID
109
136
}
110
137
111
- state , err := h .DB .GenerateOAuthState (ctx , h .oauthClient .GetProvider (), h .oauthClient .GetClientID (), attachUser )
138
+ state , err := h .DB .GenerateOAuthState (ctx , h .oauthClient .GetProvider (), h .oauthClient .GetClientID (), attachUser , r . FormValue ( "invite_code" ) )
112
139
if err != nil {
113
140
log .Error ("viewOauthInit error: %s" , err )
114
141
return impart.HTTPError {http .StatusInternalServerError , "could not prepare oauth redirect url" }
@@ -228,7 +255,7 @@ func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http
228
255
code := r .FormValue ("code" )
229
256
state := r .FormValue ("state" )
230
257
231
- provider , clientID , attachUserID , err := h .DB .ValidateOAuthState (ctx , state )
258
+ provider , clientID , attachUserID , inviteCode , err := h .DB .ValidateOAuthState (ctx , state )
232
259
if err != nil {
233
260
log .Error ("Unable to ValidateOAuthState: %s" , err )
234
261
return impart.HTTPError {http .StatusInternalServerError , err .Error ()}
@@ -285,7 +312,16 @@ func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http
285
312
286
313
// New user registration below.
287
314
// First, verify that user is allowed to register
288
- if ! app .cfg .App .OpenRegistration {
315
+ if inviteCode != "" {
316
+ // Verify invite code is valid
317
+ i , err := app .db .GetUserInvite (inviteCode )
318
+ if err != nil {
319
+ return impart.HTTPError {http .StatusInternalServerError , err .Error ()}
320
+ }
321
+ if ! i .Active (app .db ) {
322
+ return impart.HTTPError {http .StatusNotFound , "Invite link has expired." }
323
+ }
324
+ } else if ! app .cfg .App .OpenRegistration {
289
325
addSessionFlash (app , w , r , ErrUserNotFound .Error (), nil )
290
326
return impart.HTTPError {http .StatusFound , "/login" }
291
327
}
@@ -303,6 +339,7 @@ func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http
303
339
TokenRemoteUser : tokenInfo .UserID ,
304
340
Provider : provider ,
305
341
ClientID : clientID ,
342
+ InviteCode : inviteCode ,
306
343
}
307
344
tp .TokenHash = tp .HashTokenParams (h .Config .Server .HashSeed )
308
345
0 commit comments