package live

import (
	"bytes"
	"context"
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"log"
	"mime/multipart"
	"net/http"
	"os"
	"path/filepath"
	"strings"
	"time"

	"github.com/gorilla/sessions"
	"golang.org/x/net/html"
	"nhooyr.io/websocket"
)

var _ Engine = &HttpEngine{}
var _ Socket = &HttpSocket{}
var _ HttpSessionStore = &CookieStore{}

// sessionCookie the name of the session cookie.
const sessionCookie string = "_ls"

// HttpSessionStore handles storing and retrieving sessions.
type HttpSessionStore interface {
	Get(*http.Request) (Session, error)
	Save(http.ResponseWriter, *http.Request, Session) error
	Clear(http.ResponseWriter, *http.Request) error
}

// HttpEngine serves live for net/http.
type HttpEngine struct {
	acceptOptions *websocket.AcceptOptions
	sessionStore  HttpSessionStore
	*BaseEngine
}

// WithWebsocketAcceptOptions apply websocket accept options to the HTTP engine.
func WithWebsocketAcceptOptions(options *websocket.AcceptOptions) EngineConfig {
	return func(e Engine) error {
		if httpEngine, ok := e.(*HttpEngine); ok {
			httpEngine.acceptOptions = options
		}
		return nil
	}
}

// NewHttpHandler returns the net/http handler for live.
func NewHttpHandler(store HttpSessionStore, handler Handler, configs ...EngineConfig) *HttpEngine {
	e := &HttpEngine{
		sessionStore: store,
		BaseEngine:   NewBaseEngine(handler),
	}
	for _, conf := range configs {
		if err := conf(e); err != nil {
			log.Println("warning:", fmt.Errorf("could not apply config to engine: %w", err))
		}
	}
	return e
}

// ServeHTTP serves this handler.
func (h *HttpEngine) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	if r.URL.Path == "/favicon.ico" {
		if h.IgnoreFaviconRequest {
			w.WriteHeader(404)
			return
		}
	}

	// Check if we are going to upgrade to a websocket.
	upgrade := false
	for _, header := range r.Header["Upgrade"] {
		if header == "websocket" {
			upgrade = true
			break
		}
	}

	ctx := httpContext(w, r)

	if !upgrade {
		switch r.Method {
		case http.MethodPost:
			h.post(ctx, w, r)
		default:
			h.get(ctx, w, r)
		}
		return
	}

	// Upgrade to the websocket version.
	h.serveWS(ctx, w, r)
}

// post handler.
func (h *HttpEngine) post(ctx context.Context, w http.ResponseWriter, r *http.Request) {
	// Get session.
	session, err := h.sessionStore.Get(r)
	if err != nil {
		h.Error()(ctx, fmt.Errorf("no session found: %w", err))
		return
	}

	// Get socket.
	sock, err := h.GetSocket(session)
	if err != nil {
		h.Error()(ctx, err)
		return
	}

	r.Body = http.MaxBytesReader(w, r.Body, h.MaxUploadSize)
	if err := r.ParseMultipartForm(h.MaxUploadSize); err != nil {
		h.Error()(ctx, fmt.Errorf("could not parse form for uploads: %w", err))
		return
	}

	uploadDir := filepath.Join(h.UploadStagingLocation, string(sock.ID()))
	if h.UploadStagingLocation == "" {
		uploadDir, err = os.MkdirTemp("", string(sock.ID()))
		if err != nil {
			h.Error()(ctx, fmt.Errorf("%s upload dir creation failed: %w", sock.ID(), err))
			return
		}
	}

	for _, config := range sock.UploadConfigs() {
		for _, fileHeader := range r.MultipartForm.File[config.Name] {
			u := uploadFromFileHeader(fileHeader)
			sock.AssignUpload(config.Name, u)
			handleFileUpload(h, sock, config, u, uploadDir, fileHeader)

			render, err := RenderSocket(ctx, h, sock)
			if err != nil {
				h.Error()(ctx, err)
				return
			}
			sock.UpdateRender(render)
		}
	}
}

func uploadFromFileHeader(fh *multipart.FileHeader) *Upload {
	return &Upload{
		Name: fh.Filename,
		Size: fh.Size,
	}
}

func handleFileUpload(h *HttpEngine, sock Socket, config *UploadConfig, u *Upload, uploadDir string, fileHeader *multipart.FileHeader) {
	// Check file claims to be within the max size.
	if fileHeader.Size > config.MaxSize {
		u.Errors = append(u.Errors, fmt.Errorf("%s greater than max allowed size of %d", fileHeader.Filename, config.MaxSize))
		return
	}

	// Open the incoming file.
	file, err := fileHeader.Open()
	if err != nil {
		u.Errors = append(u.Errors, fmt.Errorf("could not open %s for upload: %w", fileHeader.Filename, err))
		return
	}
	defer file.Close()

	// Check the actual filetype.
	buff := make([]byte, 512)
	_, err = file.Read(buff)
	if err != nil {
		u.Errors = append(u.Errors, fmt.Errorf("could not check %s for type: %w", fileHeader.Filename, err))
		return
	}
	filetype := http.DetectContentType(buff)
	allowed := false
	for _, a := range config.Accept {
		if filetype == a {
			allowed = true
			break
		}
	}
	if !allowed {
		u.Errors = append(u.Errors, fmt.Errorf("%s filetype is not allowed", fileHeader.Filename))
		return
	}
	u.Type = filetype

	// Rewind to start of the
	_, err = file.Seek(0, io.SeekStart)
	if err != nil {
		u.Errors = append(u.Errors, fmt.Errorf("%s rewind error: %w", fileHeader.Filename, err))
		return
	}

	f, err := os.Create(filepath.Join(uploadDir, fmt.Sprintf("%d%s", time.Now().UnixNano(), filepath.Ext(fileHeader.Filename))))
	if err != nil {
		u.Errors = append(u.Errors, fmt.Errorf("%s upload file creation failed: %w", fileHeader.Filename, err))
		return
	}
	defer f.Close()
	u.internalLocation = f.Name()
	u.Name = fileHeader.Filename

	written, err := io.Copy(f, io.TeeReader(file, &UploadProgress{Upload: u, Engine: h, Socket: sock}))
	if err != nil {
		u.Errors = append(u.Errors, fmt.Errorf("%s upload failed: %w", fileHeader.Filename, err))
		return
	}
	u.Size = written

	return
}

// get renderer.
func (h *HttpEngine) get(ctx context.Context, w http.ResponseWriter, r *http.Request) {

	// Get session.
	session, err := h.sessionStore.Get(r)
	if err != nil {
		if r.URL.Query().Get("live-repair") != "" {
			h.Error()(ctx, fmt.Errorf("session corrupted: %w", err))
			return
		} else {
			log.Println(fmt.Errorf("session corrupted trying to repair: %w", err))
			h.sessionStore.Clear(w, r)
			q := r.URL.Query()
			q.Set("live-repair", "1")
			r.URL.RawQuery = q.Encode()
			http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect)
		}
		return
	}

	// Get socket.
	sock := NewHttpSocket(session, h, false)

	// Run mount, this generates the state for the page we are on.
	data, err := h.Mount()(ctx, sock)
	if err != nil {
		h.Error()(ctx, err)
		return
	}
	sock.Assign(data)

	// Handle any query parameters that are on the page.
	for _, ph := range h.Params() {
		data, err := ph(ctx, sock, NewParamsFromRequest(r))
		if err != nil {
			h.Error()(ctx, err)
			return
		}
		sock.Assign(data)
	}

	// Render the HTML to display the page.
	render, err := RenderSocket(ctx, h, sock)
	if err != nil {
		h.Error()(ctx, err)
		return
	}
	sock.UpdateRender(render)

	var rendered bytes.Buffer
	html.Render(&rendered, render)

	if err := h.sessionStore.Save(w, r, session); err != nil {
		h.Error()(ctx, err)
		return
	}

	w.WriteHeader(200)
	io.Copy(w, &rendered)
}

// serveWS serve a websocket request to the handler.
func (h *HttpEngine) serveWS(ctx context.Context, w http.ResponseWriter, r *http.Request) {
	// Get the session from the http request.
	session, err := h.sessionStore.Get(r)
	if err != nil {
		h.Error()(ctx, err)
		return
	}

	// https://github.com/nhooyr/websocket/issues/218
	// https://github.com/gorilla/websocket/issues/731
	if strings.Contains(r.UserAgent(), "Safari") {
		if h.acceptOptions == nil {
			h.acceptOptions = &websocket.AcceptOptions{}
		}
		h.acceptOptions.CompressionMode = websocket.CompressionDisabled
	}

	c, err := websocket.Accept(w, r, h.acceptOptions)
	if err != nil {
		h.Error()(ctx, err)
		return
	}
	defer c.Close(websocket.StatusInternalError, "")
	writeTimeout(ctx, time.Second*5, c, Event{T: EventConnect})
	{
		err := h._serveWS(ctx, r, session, c)
		if errors.Is(err, context.Canceled) {
			return
		}
		switch websocket.CloseStatus(err) {
		case websocket.StatusNormalClosure:
			return
		case websocket.StatusGoingAway:
			return
		default:
			log.Println(fmt.Errorf("ws closed with status (%d): %w", websocket.CloseStatus(err), err))
			return
		}
	}
}

// _serveWS implement the logic for a web socket connection.
func (h *HttpEngine) _serveWS(ctx context.Context, r *http.Request, session Session, c *websocket.Conn) error {
	// Get the sessions socket and register it with the server.
	sock := NewHttpSocket(session, h, true)
	sock.assignWS(c)
	h.AddSocket(sock)
	defer h.DeleteSocket(sock)

	// Internal errors.
	internalErrors := make(chan error)

	// Event errors.
	eventErrors := make(chan ErrorEvent)

	// Handle events coming from the websocket connection.
	go func() {
		for {
			t, d, err := c.Read(ctx)
			if err != nil {
				internalErrors <- err
				break
			}
			switch t {
			case websocket.MessageText:
				var m Event
				if err := json.Unmarshal(d, &m); err != nil {
					internalErrors <- err
					break
				}
				switch m.T {
				case EventParams:
					if err := h.CallParams(ctx, sock, m); err != nil {
						switch {
						case errors.Is(err, ErrNoEventHandler):
							log.Println("event error", m, err)
						default:
							eventErrors <- ErrorEvent{Source: m, Err: err.Error()}
						}
					}
				default:
					if err := h.CallEvent(ctx, m.T, sock, m); err != nil {
						switch {
						case errors.Is(err, ErrNoEventHandler):
							log.Println("event error", m, err)
						default:
							eventErrors <- ErrorEvent{Source: m, Err: err.Error()}
						}
					}
				}
				render, err := RenderSocket(ctx, h, sock)
				if err != nil {
					internalErrors <- fmt.Errorf("socket handle error: %w", err)
				} else {
					sock.UpdateRender(render)
				}
				if err := sock.Send(EventAck, nil, WithID(m.ID)); err != nil {
					internalErrors <- fmt.Errorf("socket send error: %w", err)
				}
			case websocket.MessageBinary:
				log.Println("binary messages unhandled")
			}
		}
		close(internalErrors)
		close(eventErrors)
	}()

	// Run mount again now that eh socket is connected, passing true indicating
	// a connection has been made.
	data, err := h.Mount()(ctx, sock)
	if err != nil {
		return fmt.Errorf("socket mount error: %w", err)
	}
	sock.Assign(data)

	// Run params again now that the socket is connected.
	for _, ph := range h.Params() {
		data, err := ph(ctx, sock, NewParamsFromRequest(r))
		if err != nil {
			return fmt.Errorf("socket params error: %w", err)
		}
		sock.Assign(data)
	}

	// Run render now that we are connected for the first time and we have just
	// mounted again. This will generate and send any patches if there have
	// been changes.
	render, err := RenderSocket(ctx, h, sock)
	if err != nil {
		return fmt.Errorf("socket render error: %w", err)
	}
	sock.UpdateRender(render)

	// Send events to the websocket connection.
	for {
		select {
		case msg := <-sock.msgs:
			if err := writeTimeout(ctx, time.Second*5, c, msg); err != nil {
				return fmt.Errorf("writing to socket error: %w", err)
			}
		case ee := <-eventErrors:
			d, err := json.Marshal(ee)
			if err != nil {
				return fmt.Errorf("writing to socket error: %w", err)
			}
			if err := writeTimeout(ctx, time.Second*5, c, Event{T: EventError, Data: d}); err != nil {
				return fmt.Errorf("writing to socket error: %w", err)
			}
		case err := <-internalErrors:
			if err != nil {
				d, err := json.Marshal(err.Error())
				if err != nil {
					return fmt.Errorf("writing to socket error: %w", err)
				}
				if err := writeTimeout(ctx, time.Second*5, c, Event{T: EventError, Data: d}); err != nil {
					return fmt.Errorf("writing to socket error: %w", err)
				}
				// Something catastrophic has happened.
				return fmt.Errorf("internal error: %w", err)
			}
		case <-ctx.Done():
			return nil
		}
	}
}

type HttpSocket struct {
	*BaseSocket
}

// NewHttpSocket creates a new http socket.
func NewHttpSocket(s Session, e Engine, connected bool) *HttpSocket {
	return &HttpSocket{
		BaseSocket: NewBaseSocket(s, e, connected),
	}
}

// assignWS connect a web socket to a socket.
func (s *HttpSocket) assignWS(ws *websocket.Conn) {
	s.closeSlow = func() {
		ws.Close(websocket.StatusPolicyViolation, "socket too slow to keep up with messages")
	}
}

func httpContext(w http.ResponseWriter, r *http.Request) context.Context {
	ctx := r.Context()
	ctx = contextWithRequest(ctx, r)
	ctx = contextWithWriter(ctx, w)
	return ctx
}

func writeTimeout(ctx context.Context, timeout time.Duration, c *websocket.Conn, msg Event) error {
	ctx, cancel := context.WithTimeout(ctx, timeout)
	defer cancel()

	data, err := json.Marshal(&msg)
	if err != nil {
		return fmt.Errorf("failed writeTimeout: %w", err)
	}

	return c.Write(ctx, websocket.MessageText, data)
}

// CookieStore a `gorilla/sessions` based cookie store.
type CookieStore struct {
	Store       *sessions.CookieStore
	sessionName string // session name.
}

// NewCookieStore create a new `gorilla/sessions` based cookie store.
func NewCookieStore(sessionName string, keyPairs ...[]byte) *CookieStore {
	s := sessions.NewCookieStore(keyPairs...)
	s.Options.HttpOnly = true
	s.Options.Secure = false
	s.Options.SameSite = http.SameSiteStrictMode

	return &CookieStore{
		Store:       s,
		sessionName: sessionName,
	}
}

// Get get a session.
func (c CookieStore) Get(r *http.Request) (Session, error) {
	var sess Session
	session, err := c.Store.Get(r, c.sessionName)
	if err != nil {
		return NewSession(), err
	}
	vals, ok := session.Values[sessionCookie]
	if !ok {
		// Create new connection.
		ns := NewSession()
		sess = ns
	} else {
		sess, ok = vals.(Session)
		if !ok {
			// Create new session and set.
			ns := NewSession()
			sess = ns
		}
	}
	return sess, nil
}

// Save a session.
func (c CookieStore) Save(w http.ResponseWriter, r *http.Request, session Session) error {
	s, err := c.Store.Get(r, c.sessionName)
	if err != nil {
		return err
	}
	s.Values[sessionCookie] = session
	return s.Save(r, w)
}

// Clear a session.
func (c CookieStore) Clear(w http.ResponseWriter, r *http.Request) error {
	http.SetCookie(w, &http.Cookie{
		Name:     c.sessionName,
		Value:    "",
		Path:     "/",
		Expires:  time.Unix(0, 0),
		HttpOnly: true,
	})
	return nil
}