Files
dandelion/main.go
2024-09-22 21:12:51 -07:00

424 lines
9.9 KiB
Go

package main
import (
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"os/signal"
"path/filepath"
"strconv"
"syscall"
"time"
// "strings"
"github.com/andybalholm/brotli"
"github.com/gorilla/websocket"
)
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
origin := r.Header.Get("Origin")
return origin == "https://ddlion.net" || origin == "https://ddln.app"
},
}
func websocketCloseHandler(code int, text string) error {
log.Print("Client closed websocket.")
return nil
}
type Message struct {
Type string `json:"type"`
}
type MessageHandler func([]byte, *Peer) ([]byte, error)
var messageHandlers = make(map[string]MessageHandler)
func registerHandler(messageType string, handler MessageHandler) {
messageHandlers[messageType] = handler
}
func dispatchMessage(message []byte, peer *Peer) ([]byte, error) {
var msg Message
if err := json.Unmarshal(message, &msg); err != nil {
return nil, err
}
handler, ok := messageHandlers[msg.Type]
if !ok {
err := fmt.Errorf("no handler registered for message type: %s", msg.Type)
return []byte(fmt.Sprintf(`{"type":"error", "message": "%s"}`, err.Error())), nil
}
return handler(message, peer)
}
const (
writeWait = 10 * time.Second
)
type Peer struct {
conn *websocket.Conn
send chan []byte
lastActive time.Time
}
// func removePeer(peerID string, peer *Peer) {
// delete(peerConnections, peerID)
// for userID, peers := range userPeers {
// delete(peers, peerID)
// if len(peers) == 0 {
// delete(userPeers, userID)
// }
// }
// delete(connectionPeers, peer.conn)
// // Close the peer's send channel
// close(peer.send)
// }
func handleWebSocket(w http.ResponseWriter, r *http.Request) {
log.Println("Websocket connection!", r.RemoteAddr)
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Println("Upgrade error:", err)
return
}
defer conn.Close()
conn.SetCloseHandler(websocketCloseHandler)
// Create a Peer object with a buffered channel for sending messages
peer := &Peer{
conn: conn,
send: make(chan []byte, 256),
lastActive: time.Now(),
}
// Start the write loop in a separate goroutine
go writePump(peer)
for {
_, message, err := conn.ReadMessage()
if err != nil {
log.Println("ReadMessage error:", err)
break
}
peer.lastActive = time.Now()
fmt.Println("ws<-", connectionPeers[conn], ":", string(message[:min(80, len(message))]))
response, err := dispatchMessage(message, peer)
if err != nil {
log.Printf("Error dispatching message: %v", err)
}
if response != nil {
// Send the response to the write loop
peer.send <- response
}
}
// Clean up when the connection is closed
close(peer.send)
peerID := connectionPeers[peer.conn]
if peerID != "" {
delete(peerConnections, peerID)
}
}
func writePump(peer *Peer) {
defer func() {
peer.conn.Close()
}()
for {
select {
case message, ok := <-peer.send:
if !ok {
// Channel closed, close the connection
peer.conn.WriteMessage(websocket.CloseMessage, []byte{})
return
}
peer.conn.SetWriteDeadline(time.Now().Add(writeWait))
fmt.Println("ws->", connectionPeers[peer.conn], ":", string(message[:min(80, len(message))]))
err := peer.conn.WriteMessage(websocket.TextMessage, message)
if err != nil {
log.Println("WriteMessage error:", err)
return
}
}
}
}
func handlePing(message []byte, peer *Peer) ([]byte, error) {
var pingMsg struct {
Type string `json:"type"`
PeerID string `json:"peer_id"`
}
if err := json.Unmarshal(message, &pingMsg); err != nil {
return nil, err
}
// log.Printf("Received ping from peer: %s", pingMsg.PeerID)
return []byte(`{"type":"pong"}`), nil
}
type PeerSet map[string]struct{}
var userPeers = make(map[string]PeerSet)
var peerConnections = make(map[string]*Peer)
var connectionPeers = make(map[*websocket.Conn]string)
func handleHello(message []byte, peer *Peer) ([]byte, error) {
var m struct {
Type string `json:"type"`
UserID string `json:"user_id"`
PeerID string `json:"peer_id"`
}
if err := json.Unmarshal(message, &m); err != nil {
return nil, err
}
// log.Printf("Received hello from peer: %s, user:%s", m.PeerID, m.UserID)
if userPeers[m.UserID] == nil {
userPeers[m.UserID] = make(PeerSet)
}
userPeers[m.UserID][m.PeerID] = struct{}{}
peerConnections[m.PeerID] = peer
connectionPeers[peer.conn] = m.PeerID
jsonData, _ := json.MarshalIndent(userPeers, "", " ")
fmt.Println(string(jsonData), peerConnections)
// return all the peers we know about, with their user_id and peer_id
return []byte(fmt.Sprintf(`{"type":"hello", "userPeers": %s}`, string(jsonData))), nil
}
func handlePeerMessage(message []byte, peer *Peer) ([]byte, error) {
type InnerMessage struct {
Type string `json:"type"`
UserID string `json:"user_id"`
}
type PeerMessage struct {
Type string `json:"type"`
From string `json:"from"`
To string `json:"to"`
Message InnerMessage `json:"message"`
}
var m PeerMessage
if err := json.Unmarshal(message, &m); err != nil {
return nil, err
}
fmt.Printf("peer message type %s from %s to %s with message length %d\n", m.Message.Type, m.From, m.To, len(message))
toPeer := peerConnections[m.To]
if toPeer == nil {
fmt.Printf("Couldn't find peer %s\n", m.To)
fmt.Println(peerConnections)
return nil, nil
}
// Send the message to the recipient's send channel
select {
case toPeer.send <- message:
default:
fmt.Println("Could not send message to peer; channel full or closed")
}
// No response for this type of message
return nil, nil
}
// BrotliResponseWriter wraps http.ResponseWriter to support Brotli compression
type brotliResponseWriter struct {
http.ResponseWriter
Writer io.Writer
}
func (w *brotliResponseWriter) Write(b []byte) (int, error) {
return w.Writer.Write(b)
}
// noDirListing wraps an http.FileServer handler to prevent directory listings
func noDirListing(h http.Handler, root string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Serve index.html when root is requested
if r.URL.Path == "/" {
http.ServeFile(w, r, filepath.Join(root, "index.html"))
return
}
path := filepath.Join(root, r.URL.Path)
info, err := os.Stat(path)
if err != nil || info.IsDir() {
log.Printf("404 File not found/dir serving: %s to ip %s, useragent %s", r.URL.Path, r.RemoteAddr, r.UserAgent())
http.NotFound(w, r)
return
}
log.Printf("Serving: %s to ip %s, useragent %s", r.URL.Path, r.RemoteAddr, r.UserAgent())
// w.Header().Set("Cache-Control", "no-cache")
// Check if client supports Brotli encoding
// if strings.Contains(r.Header.Get("Accept-Encoding"), "br") {
if false {
w.Header().Set("Content-Encoding", "br")
w.Header().Del("Content-Length") // Cannot know content length with compressed data
// Wrap the ResponseWriter with Brotli writer
brWriter := brotli.NewWriter(w)
defer brWriter.Close()
// Create a ResponseWriter that writes to brWriter
bw := &brotliResponseWriter{
ResponseWriter: w,
Writer: brWriter,
}
// Serve the file using http.ServeFile
http.ServeFile(bw, r, path)
return
}
h.ServeHTTP(w, r)
}
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
func main() {
// Create a channel to receive OS signals for graceful shutdown
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
// Create a channel to signal when the program should shut down
done := make(chan bool)
// Define the directory to serve and the port to listen on
dir := "./"
port := 6789
addr := ":" + strconv.Itoa(port)
log.Printf("Starting server on %s", addr)
// Register message handlers
registerHandler("hello", handleHello)
registerHandler("ping", handlePing)
registerHandler("peer_message", handlePeerMessage)
// Set up file server and WebSocket endpoint
fs := http.FileServer(http.Dir(dir))
http.Handle("/", noDirListing(fs, dir))
http.HandleFunc("/ws", handleWebSocket)
// Configure the HTTP server
server := &http.Server{
Addr: addr,
Handler: nil, // Use the default ServeMux
}
// // Start the inactivity monitor goroutine
// go func() {
// ticker := time.NewTicker(10 * time.Second)
// defer ticker.Stop()
// for {
// select {
// case <-done:
// return
// case <-ticker.C:
// now := time.Now()
// // Collect inactive peers
// var inactivePeers []string
// for peerID, peer := range peerConnections {
// if now.Sub(peer.lastActive) > 60*time.Second {
// inactivePeers = append(inactivePeers, peerID)
// }
// }
// // Remove inactive peers
// for _, peerID := range inactivePeers {
// peer := peerConnections[peerID]
// if peer != nil {
// log.Printf("Peer %s inactive for more than 60 seconds. Closing connection.", peerID)
// peer.conn.Close()
// removePeer(peerID, peer)
// }
// }
// }
// }
// }()
// Run a goroutine to handle graceful shutdown
go func() {
sig := <-sigChan
fmt.Println()
fmt.Println("Received signal:", sig)
// Perform cleanup here
fmt.Println("Shutting down gracefully...")
// Create a context with timeout for the shutdown
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// Attempt to gracefully shut down the server
if err := server.Shutdown(ctx); err != nil {
log.Fatalf("Server Shutdown Failed:%+v", err)
}
// Signal that shutdown is complete
close(done)
}()
// Start the HTTP server in a separate goroutine
go func() {
log.Printf("Server is configured and serving on port %d...", port)
if err := server.ListenAndServeTLS(
"/etc/letsencrypt/live/ddlion.net/fullchain.pem",
"/etc/letsencrypt/live/ddlion.net/privkey.pem",
); err != nil && err != http.ErrServerClosed {
log.Fatalf("Could not listen on %s: %v\n", addr, err)
}
}()
fmt.Println("Program is running. Press Ctrl+C to exit.")
// Wait for the shutdown signal
<-done
fmt.Println("Program has exited.")
}