Skip to content

Commit ca4a576

Browse files
committed
Support OAuth registration with invite code
This adds any OAuth login buttons to the invite signup page, stores the invite code for the flow duration, and associates the new user with it once successfully registered. It enables invite-only instances with OAuth-based registration.
1 parent 93c2773 commit ca4a576

13 files changed

+204
-59
lines changed

database.go

+11-9
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,8 @@ type writestore interface {
132132

133133
GetIDForRemoteUser(context.Context, string, string, string) (int64, error)
134134
RecordRemoteUserID(context.Context, int64, string, string, string, string) error
135-
ValidateOAuthState(context.Context, string) (string, string, int64, error)
136-
GenerateOAuthState(context.Context, string, string, int64) (string, error)
135+
ValidateOAuthState(context.Context, string) (string, string, int64, string, error)
136+
GenerateOAuthState(context.Context, string, string, int64, string) (string, error)
137137
GetOauthAccounts(ctx context.Context, userID int64) ([]oauthAccountInfo, error)
138138
RemoveOauth(ctx context.Context, userID int64, provider string, clientID string, remoteUserID string) error
139139

@@ -2516,24 +2516,26 @@ func (db *datastore) GetCollectionLastPostTime(id int64) (*time.Time, error) {
25162516
return &t, nil
25172517
}
25182518

2519-
func (db *datastore) GenerateOAuthState(ctx context.Context, provider string, clientID string, attachUser int64) (string, error) {
2519+
func (db *datastore) GenerateOAuthState(ctx context.Context, provider string, clientID string, attachUser int64, inviteCode string) (string, error) {
25202520
state := store.Generate62RandomString(24)
25212521
attachUserVal := sql.NullInt64{Valid: attachUser > 0, Int64: attachUser}
2522-
_, err := db.ExecContext(ctx, "INSERT INTO oauth_client_states (state, provider, client_id, used, created_at, attach_user_id) VALUES (?, ?, ?, FALSE, "+db.now()+", ?)", state, provider, clientID, attachUserVal)
2522+
inviteCodeVal := sql.NullString{Valid: inviteCode != "", String: inviteCode}
2523+
_, err := db.ExecContext(ctx, "INSERT INTO oauth_client_states (state, provider, client_id, used, created_at, attach_user_id, invite_code) VALUES (?, ?, ?, FALSE, "+db.now()+", ?, ?)", state, provider, clientID, attachUserVal, inviteCodeVal)
25232524
if err != nil {
25242525
return "", fmt.Errorf("unable to record oauth client state: %w", err)
25252526
}
25262527
return state, nil
25272528
}
25282529

2529-
func (db *datastore) ValidateOAuthState(ctx context.Context, state string) (string, string, int64, error) {
2530+
func (db *datastore) ValidateOAuthState(ctx context.Context, state string) (string, string, int64, string, error) {
25302531
var provider string
25312532
var clientID string
25322533
var attachUserID sql.NullInt64
2534+
var inviteCode sql.NullString
25332535
err := wf_db.RunTransactionWithOptions(ctx, db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error {
25342536
err := tx.
2535-
QueryRowContext(ctx, "SELECT provider, client_id, attach_user_id FROM oauth_client_states WHERE state = ? AND used = FALSE", state).
2536-
Scan(&provider, &clientID, &attachUserID)
2537+
QueryRowContext(ctx, "SELECT provider, client_id, attach_user_id, invite_code FROM oauth_client_states WHERE state = ? AND used = FALSE", state).
2538+
Scan(&provider, &clientID, &attachUserID, &inviteCode)
25372539
if err != nil {
25382540
return err
25392541
}
@@ -2552,9 +2554,9 @@ func (db *datastore) ValidateOAuthState(ctx context.Context, state string) (stri
25522554
return nil
25532555
})
25542556
if err != nil {
2555-
return "", "", 0, nil
2557+
return "", "", 0, "", nil
25562558
}
2557-
return provider, clientID, attachUserID.Int64, nil
2559+
return provider, clientID, attachUserID.Int64, inviteCode.String, nil
25582560
}
25592561

25602562
func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID, provider, clientID, accessToken string) error {

database_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@ func TestOAuthDatastore(t *testing.T) {
1818
driverName: "",
1919
}
2020

21-
state, err := ds.GenerateOAuthState(ctx, "test", "development", 0)
21+
state, err := ds.GenerateOAuthState(ctx, "test", "development", 0, "")
2222
assert.NoError(t, err)
2323
assert.Len(t, state, 24)
2424

2525
countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = false", state)
2626

27-
_, _, _, err = ds.ValidateOAuthState(ctx, state)
27+
_, _, _, _, err = ds.ValidateOAuthState(ctx, state)
2828
assert.NoError(t, err)
2929

3030
countRows(t, ctx, db, 1, "SELECT COUNT(*) FROM `oauth_client_states` WHERE `state` = ? AND `used` = true", state)

invites.go

+15-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright © 2019 A Bunch Tell LLC.
2+
* Copyright © 2019-2020 A Bunch Tell LLC.
33
*
44
* This file is part of WriteFreely.
55
*
@@ -42,6 +42,18 @@ func (i Invite) Expired() bool {
4242
return i.Expires != nil && i.Expires.Before(time.Now())
4343
}
4444

45+
func (i Invite) Active(db *datastore) bool {
46+
if i.Expired() {
47+
return false
48+
}
49+
if i.MaxUses.Valid && i.MaxUses.Int64 > 0 {
50+
if c := db.GetUsersInvitedCount(i.ID); c >= i.MaxUses.Int64 {
51+
return false
52+
}
53+
}
54+
return true
55+
}
56+
4557
func (i Invite) ExpiresFriendly() string {
4658
return i.Expires.Format("January 2, 2006, 3:04 PM")
4759
}
@@ -161,9 +173,11 @@ func handleViewInvite(app *App, w http.ResponseWriter, r *http.Request) error {
161173
Error string
162174
Flashes []template.HTML
163175
Invite string
176+
OAuth *OAuthButtons
164177
}{
165178
StaticPage: pageForReq(app, r),
166179
Invite: inviteCode,
180+
OAuth: NewOAuthButtons(app.cfg),
167181
}
168182

169183
if expired {

less/app.less

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
@import "post-temp";
66
@import "effects";
77
@import "admin";
8+
@import "login";
89
@import "pages/error";
910
@import "lib/elements";
1011
@import "lib/material";

less/login.less

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Copyright © 2020 A Bunch Tell LLC.
3+
*
4+
* This file is part of WriteFreely.
5+
*
6+
* WriteFreely is free software: you can redistribute it and/or modify
7+
* it under the terms of the GNU Affero General Public License, included
8+
* in the LICENSE file in this source code package.
9+
*/
10+
11+
.row.signinbtns {
12+
justify-content: space-evenly;
13+
font-size: 1em;
14+
margin-top: 3em;
15+
margin-bottom: 2em;
16+
17+
.loginbtn {
18+
height: 40px;
19+
}
20+
21+
#writeas-login, #gitlab-login {
22+
box-sizing: border-box;
23+
font-size: 17px;
24+
}
25+
}
26+
27+
.or {
28+
text-align: center;
29+
margin-bottom: 3.5em;
30+
31+
p {
32+
display: inline-block;
33+
background-color: white;
34+
padding: 0 1em;
35+
}
36+
37+
hr {
38+
margin-top: -1.6em;
39+
margin-bottom: 0;
40+
}
41+
42+
hr.short {
43+
max-width: 30rem;
44+
}
45+
}

migrations/migrations.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ var migrations = []Migration{
6262
New("support oauth", oauth), // V3 -> V4
6363
New("support slack oauth", oauthSlack), // V4 -> v5
6464
New("support ActivityPub mentions", supportActivityPubMentions), // V5 -> V6
65-
New("support oauth attach", oauthAttach), // V6 -> V7 (v0.12.0)
65+
New("support oauth attach", oauthAttach), // V6 -> V7
66+
New("support oauth via invite", oauthInvites), // V7 -> V8 (v0.12.0)
6667
}
6768

6869
// CurrentVer returns the current migration version the application is on

migrations/v8.go

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Copyright © 2020 A Bunch Tell LLC.
3+
*
4+
* This file is part of WriteFreely.
5+
*
6+
* WriteFreely is free software: you can redistribute it and/or modify
7+
* it under the terms of the GNU Affero General Public License, included
8+
* in the LICENSE file in this source code package.
9+
*/
10+
11+
package migrations
12+
13+
import (
14+
"context"
15+
"database/sql"
16+
17+
wf_db "github.com/writeas/writefreely/db"
18+
)
19+
20+
func oauthInvites(db *datastore) error {
21+
dialect := wf_db.DialectMySQL
22+
if db.driverName == driverSQLite {
23+
dialect = wf_db.DialectSQLite
24+
}
25+
return wf_db.RunTransactionWithOptions(context.Background(), db.DB, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error {
26+
builders := []wf_db.SQLBuilder{
27+
dialect.
28+
AlterTable("oauth_client_states").
29+
AddColumn(dialect.Column("invite_code", wf_db.ColumnTypeChar, wf_db.OptionalInt{
30+
Set: true,
31+
Value: 6,
32+
}).SetNullable(true)),
33+
}
34+
for _, builder := range builders {
35+
query, err := builder.ToSQL()
36+
if err != nil {
37+
return err
38+
}
39+
if _, err := tx.ExecContext(ctx, query); err != nil {
40+
return err
41+
}
42+
}
43+
return nil
44+
})
45+
}

oauth.go

+43-6
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
1+
/*
2+
* Copyright © 2019-2020 A Bunch Tell LLC.
3+
*
4+
* This file is part of WriteFreely.
5+
*
6+
* WriteFreely is free software: you can redistribute it and/or modify
7+
* it under the terms of the GNU Affero General Public License, included
8+
* in the LICENSE file in this source code package.
9+
*/
10+
111
package writefreely
212

313
import (
@@ -15,10 +25,27 @@ import (
1525
"github.com/gorilla/sessions"
1626
"github.com/writeas/impart"
1727
"github.com/writeas/web-core/log"
18-
1928
"github.com/writeas/writefreely/config"
2029
)
2130

31+
// OAuthButtons holds display information for different OAuth providers we support.
32+
type OAuthButtons struct {
33+
SlackEnabled bool
34+
WriteAsEnabled bool
35+
GitLabEnabled bool
36+
GitLabDisplayName string
37+
}
38+
39+
// NewOAuthButtons creates a new OAuthButtons struct based on our app configuration.
40+
func NewOAuthButtons(cfg *config.Config) *OAuthButtons {
41+
return &OAuthButtons{
42+
SlackEnabled: cfg.SlackOauth.ClientID != "",
43+
WriteAsEnabled: cfg.WriteAsOauth.ClientID != "",
44+
GitLabEnabled: cfg.GitlabOauth.ClientID != "",
45+
GitLabDisplayName: config.OrDefaultString(cfg.GitlabOauth.DisplayName, gitlabDisplayName),
46+
}
47+
}
48+
2249
// TokenResponse contains data returned when a token is created either
2350
// through a code exchange or using a refresh token.
2451
type TokenResponse struct {
@@ -61,8 +88,8 @@ type OAuthDatastoreProvider interface {
6188
type OAuthDatastore interface {
6289
GetIDForRemoteUser(context.Context, string, string, string) (int64, error)
6390
RecordRemoteUserID(context.Context, int64, string, string, string, string) error
64-
ValidateOAuthState(context.Context, string) (string, string, int64, error)
65-
GenerateOAuthState(context.Context, string, string, int64) (string, error)
91+
ValidateOAuthState(context.Context, string) (string, string, int64, string, error)
92+
GenerateOAuthState(context.Context, string, string, int64, string) (string, error)
6693

6794
CreateUser(*config.Config, *User, string) error
6895
GetUserByID(int64) (*User, error)
@@ -108,7 +135,7 @@ func (h oauthHandler) viewOauthInit(app *App, w http.ResponseWriter, r *http.Req
108135
attachUser = user.ID
109136
}
110137

111-
state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID(), attachUser)
138+
state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID(), attachUser, r.FormValue("invite_code"))
112139
if err != nil {
113140
log.Error("viewOauthInit error: %s", err)
114141
return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"}
@@ -228,7 +255,7 @@ func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http
228255
code := r.FormValue("code")
229256
state := r.FormValue("state")
230257

231-
provider, clientID, attachUserID, err := h.DB.ValidateOAuthState(ctx, state)
258+
provider, clientID, attachUserID, inviteCode, err := h.DB.ValidateOAuthState(ctx, state)
232259
if err != nil {
233260
log.Error("Unable to ValidateOAuthState: %s", err)
234261
return impart.HTTPError{http.StatusInternalServerError, err.Error()}
@@ -285,7 +312,16 @@ func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http
285312

286313
// New user registration below.
287314
// First, verify that user is allowed to register
288-
if !app.cfg.App.OpenRegistration {
315+
if inviteCode != "" {
316+
// Verify invite code is valid
317+
i, err := app.db.GetUserInvite(inviteCode)
318+
if err != nil {
319+
return impart.HTTPError{http.StatusInternalServerError, err.Error()}
320+
}
321+
if !i.Active(app.db) {
322+
return impart.HTTPError{http.StatusNotFound, "Invite link has expired."}
323+
}
324+
} else if !app.cfg.App.OpenRegistration {
289325
addSessionFlash(app, w, r, ErrUserNotFound.Error(), nil)
290326
return impart.HTTPError{http.StatusFound, "/login"}
291327
}
@@ -303,6 +339,7 @@ func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http
303339
TokenRemoteUser: tokenInfo.UserID,
304340
Provider: provider,
305341
ClientID: clientID,
342+
InviteCode: inviteCode,
306343
}
307344
tp.TokenHash = tp.HashTokenParams(h.Config.Server.HashSeed)
308345

oauth_signup.go

+13
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ type viewOauthSignupVars struct {
3838
Provider string
3939
ClientID string
4040
TokenHash string
41+
InviteCode string
4142

4243
LoginUsername string
4344
Alias string // TODO: rename this to match the data it represents: the collection title
@@ -57,6 +58,7 @@ const (
5758
oauthParamAlias = "alias"
5859
oauthParamEmail = "email"
5960
oauthParamPassword = "password"
61+
oauthParamInviteCode = "invite_code"
6062
)
6163

6264
type oauthSignupPageParams struct {
@@ -68,6 +70,7 @@ type oauthSignupPageParams struct {
6870
ClientID string
6971
Provider string
7072
TokenHash string
73+
InviteCode string
7174
}
7275

7376
func (p oauthSignupPageParams) HashTokenParams(key string) string {
@@ -92,6 +95,7 @@ func (h oauthHandler) viewOauthSignup(app *App, w http.ResponseWriter, r *http.R
9295
TokenRemoteUser: r.FormValue(oauthParamTokenRemoteUserID),
9396
ClientID: r.FormValue(oauthParamClientID),
9497
Provider: r.FormValue(oauthParamProvider),
98+
InviteCode: r.FormValue(oauthParamInviteCode),
9599
}
96100
if tp.HashTokenParams(h.Config.Server.HashSeed) != r.FormValue(oauthParamHash) {
97101
return impart.HTTPError{Status: http.StatusBadRequest, Message: "Request has been tampered with."}
@@ -128,6 +132,14 @@ func (h oauthHandler) viewOauthSignup(app *App, w http.ResponseWriter, r *http.R
128132
return h.showOauthSignupPage(app, w, r, tp, err)
129133
}
130134

135+
// Log invite if needed
136+
if tp.InviteCode != "" {
137+
err = app.db.CreateInvitedUser(tp.InviteCode, newUser.ID)
138+
if err != nil {
139+
return err
140+
}
141+
}
142+
131143
err = h.DB.RecordRemoteUserID(r.Context(), newUser.ID, r.FormValue(oauthParamTokenRemoteUserID), r.FormValue(oauthParamProvider), r.FormValue(oauthParamClientID), r.FormValue(oauthParamAccessToken))
132144
if err != nil {
133145
return h.showOauthSignupPage(app, w, r, tp, err)
@@ -195,6 +207,7 @@ func (h oauthHandler) showOauthSignupPage(app *App, w http.ResponseWriter, r *ht
195207
Provider: tp.Provider,
196208
ClientID: tp.ClientID,
197209
TokenHash: tp.TokenHash,
210+
InviteCode: tp.InviteCode,
198211

199212
LoginUsername: username,
200213
Alias: collTitle,

0 commit comments

Comments
 (0)