Skip to content

Commit

Permalink
Merge pull request juanfont#542 from mpldr/issue-342-send-on-closed-c…
Browse files Browse the repository at this point in the history
…hannel
  • Loading branch information
kradalby authored Apr 11, 2022
2 parents a92f6ab + 9f03a01 commit 367f848
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 23 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### Changes

- Fix labels cardinality error when registering unknown pre-auth key [#519](https://github.com/juanfont/headscale/pull/519)
- Fix send on closed channel crash in polling [#542](https://github.com/juanfont/headscale/pull/542)

## 0.15.0 (2022-03-20)

Expand Down
68 changes: 45 additions & 23 deletions poll.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,32 +175,13 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
Str("machine", machine.Name).
Msg("Loading or creating update channel")

// TODO: could probably remove all that duplication once generics land.
closeChanWithLog := func(channel interface{}, name string) {
log.Trace().
Str("handler", "PollNetMap").
Str("machine", machine.Name).
Str("channel", "Done").
Msg(fmt.Sprintf("Closing %s channel", name))

switch c := channel.(type) {
case (chan struct{}):
close(c)

case (chan []byte):
close(c)
}
}

const chanSize = 8
updateChan := make(chan struct{}, chanSize)
defer closeChanWithLog(updateChan, "updateChan")

pollDataChan := make(chan []byte, chanSize)
defer closeChanWithLog(pollDataChan, "pollDataChan")
defer closeChanWithLog(pollDataChan, machine.Name, "pollDataChan")

keepAliveChan := make(chan []byte)
defer closeChanWithLog(keepAliveChan, "keepAliveChan")

if req.OmitPeers && !req.Stream {
log.Info().
Expand Down Expand Up @@ -273,7 +254,27 @@ func (h *Headscale) PollNetMapStream(
updateChan chan struct{},
) {
{
ctx, cancel := context.WithCancel(ctx.Request.Context())
machine, err := h.GetMachineByMachineKey(machineKey)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Warn().
Str("handler", "PollNetMap").
Msgf("Ignoring request, cannot find machine with key %s", machineKey.String())
ctx.String(http.StatusUnauthorized, "")

return
}
log.Error().
Str("handler", "PollNetMap").
Msgf("Failed to fetch machine from the database with Machine key: %s", machineKey.String())
ctx.String(http.StatusInternalServerError, "")

return
}

ctx := context.WithValue(ctx.Request.Context(), "machineName", machine.Name)

ctx, cancel := context.WithCancel(ctx)
defer cancel()

go h.scheduledPollWorker(
Expand Down Expand Up @@ -564,15 +565,26 @@ func (h *Headscale) PollNetMapStream(

func (h *Headscale) scheduledPollWorker(
ctx context.Context,
updateChan chan<- struct{},
keepAliveChan chan<- []byte,
updateChan chan struct{},
keepAliveChan chan []byte,
machineKey key.MachinePublic,
mapRequest tailcfg.MapRequest,
machine *Machine,
) {
keepAliveTicker := time.NewTicker(keepAliveInterval)
updateCheckerTicker := time.NewTicker(updateCheckInterval)

defer closeChanWithLog(
updateChan,
fmt.Sprint(ctx.Value("machineName")),
"updateChan",
)
defer closeChanWithLog(
keepAliveChan,
fmt.Sprint(ctx.Value("machineName")),
"updateChan",
)

for {
select {
case <-ctx.Done():
Expand Down Expand Up @@ -606,3 +618,13 @@ func (h *Headscale) scheduledPollWorker(
}
}
}

func closeChanWithLog[C chan []byte | chan struct{}](channel C, machine, name string) {
log.Trace().
Str("handler", "PollNetMap").
Str("machine", machine).
Str("channel", "Done").
Msg(fmt.Sprintf("Closing %s channel", name))

close(channel)
}

0 comments on commit 367f848

Please sign in to comment.