package main import ( "context" "encoding/json" "fmt" "io" "log" "net/http" "os" "os/signal" "path/filepath" "strconv" "strings" "sync" "syscall" "time" // "strings" "github.com/gorilla/websocket" ) var upgrader = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { origin := r.Header.Get("Origin") return origin == "https://ddlion.net" || origin == "https://ddln.app" }, } func websocketCloseHandler(code int, text string) error { log.Print("Client closed websocket.") return nil } type Message struct { Type string `json:"type"` } type MessageHandler func([]byte, *Peer) ([]byte, error) var messageHandlers = make(map[string]MessageHandler) func registerHandler(messageType string, handler MessageHandler) { messageHandlers[messageType] = handler } func dispatchMessage(message []byte, peer *Peer) ([]byte, error) { var msg Message if err := json.Unmarshal(message, &msg); err != nil { return nil, err } handler, ok := messageHandlers[msg.Type] if !ok { err := fmt.Errorf("no handler registered for message type: %s", msg.Type) return []byte(fmt.Sprintf(`{"type":"error", "message": "%s"}`, err.Error())), nil } return handler(message, peer) } const ( writeWait = 120 * time.Second ) type Peer struct { conn *websocket.Conn send chan []byte lastActive time.Time closeOnce sync.Once } func removePeer(peerID string, peer *Peer) { delete(peerConnections, peerID) for userID, peers := range userPeers { delete(peers, peerID) if len(peers) == 0 { delete(userPeers, userID) // not safe need mutex } } delete(connectionPeers, peer.conn) // Close the peer's send channel safely peer.closeOnce.Do(func() { close(peer.send) }) } func handleWebSocket(w http.ResponseWriter, r *http.Request) { log.Println("Websocket connection!", r.RemoteAddr) conn, err := upgrader.Upgrade(w, r, nil) if err != nil { log.Println("Upgrade error:", err) return } defer conn.Close() conn.SetCloseHandler(websocketCloseHandler) // Create a Peer object with a buffered channel for sending messages peer := &Peer{ conn: conn, send: make(chan []byte, 256), lastActive: time.Now(), } // Start the write loop in a separate goroutine go writePump(peer) for { _, message, err := conn.ReadMessage() if err != nil { log.Println("ReadMessage error:", err, connectionPeers[conn]) break } peer.lastActive = time.Now() // fmt.Println("ws<-", connectionPeers[conn], ":", string(message[:min(80, len(message))])) response, err := dispatchMessage(message, peer) if err != nil { log.Printf("Error dispatching message: %v", err) } if response != nil { // Send the response to the write loop peer.send <- response } } // Clean up when the connection is closed peer.closeOnce.Do(func() { close(peer.send) }) peerID := connectionPeers[peer.conn] if peerID != "" { delete(peerConnections, peerID) } } func writePump(peer *Peer) { defer func() { peer.conn.Close() }() for { select { case message, ok := <-peer.send: if !ok { // Channel closed, close the connection peer.conn.WriteMessage(websocket.CloseMessage, []byte{}) return } peer.conn.SetWriteDeadline(time.Now().Add(writeWait)) // fmt.Println("ws->", connectionPeers[peer.conn], ":", string(message[:min(80, len(message))])) err := peer.conn.WriteMessage(websocket.TextMessage, message) if err != nil { log.Println("WriteMessage error:", err) return } } } } func handlePing(message []byte, peer *Peer) ([]byte, error) { var pingMsg struct { Type string `json:"type"` PeerID string `json:"peer_id"` } if err := json.Unmarshal(message, &pingMsg); err != nil { return nil, err } // log.Printf("Received ping from peer: %s", pingMsg.PeerID) return []byte(`{"type":"pong"}`), nil } type PeerSet map[string]struct{} var userPeers = make(map[string]PeerSet) var peerConnections = make(map[string]*Peer) var connectionPeers = make(map[*websocket.Conn]string) func handleHello(message []byte, peer *Peer) ([]byte, error) { var m struct { Type string `json:"type"` UserID string `json:"user_id"` UserName string `json:"user_name"` PeerID string `json:"peer_id"` PeerName string `json:"peer_name"` KnownUsers []string `json:"known_users"` } if err := json.Unmarshal(message, &m); err != nil { return nil, err } log.Printf("Received hello from peer %s:%s, user %s:%s", m.PeerID[0:5], m.PeerName, m.UserID[0:5], m.UserName) if userPeers[m.UserID] == nil { userPeers[m.UserID] = make(PeerSet) } for _, knownUserID := range m.KnownUsers { fmt.Printf("Adding user %s for peer %s\n", knownUserID, m.PeerID) if userPeers[knownUserID] == nil { userPeers[knownUserID] = make(PeerSet) } userPeers[knownUserID][m.PeerID] = struct{}{} } userPeers[m.UserID][m.PeerID] = struct{}{} peerConnections[m.PeerID] = peer connectionPeers[peer.conn] = m.PeerID jsonData, _ := json.MarshalIndent(userPeers, "", " ") fmt.Println(string(jsonData), peerConnections) // return all the peers we know about, with their user_id and peer_id return []byte(fmt.Sprintf(`{"type":"hello", "userPeers": %s}`, string(jsonData))), nil } func handlePeerMessage(message []byte, peer *Peer) ([]byte, error) { type InnerMessage struct { Type string `json:"type"` UserID string `json:"user_id"` } type PeerMessage struct { Type string `json:"type"` From string `json:"from"` FromUserName string `json:"from_username"` FromPeerName string `json:"from_peername"` To string `json:"to"` Message InnerMessage `json:"message"` } var m PeerMessage if err := json.Unmarshal(message, &m); err != nil { return nil, err } fmt.Printf("peer message type %s from %s:%s:%s to %s with message length %d\n", m.Message.Type, m.From[0:5], m.FromPeerName, m.FromUserName, m.To[0:5], len(message)) toPeer := peerConnections[m.To] if toPeer == nil { fmt.Printf("Couldn't find peer %s\n", m.To) fmt.Println(peerConnections) return nil, nil } // Send the message to the recipient's send channel select { case toPeer.send <- message: default: fmt.Println("Could not send message to peer; channel full or closed") } // No response for this type of message return nil, nil } // BrotliResponseWriter wraps http.ResponseWriter to support Brotli compression type brotliResponseWriter struct { http.ResponseWriter Writer io.Writer } func (w *brotliResponseWriter) Write(b []byte) (int, error) { return w.Writer.Write(b) } // noDirListing wraps an http.FileServer handler to prevent directory listings func noDirListing(h http.Handler, root string) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // For anything under /static/, we serve it, unless it's a directory // Otherwise we serve index.html and the app does the routing nad rendering. log.Printf("%s %s %s", r.URL.Path, r.RemoteAddr, r.UserAgent()) if r.URL.Path == "/sw.js" { http.ServeFile(w, r, filepath.Join(root, "static/sw.js")) return } if r.URL.Path == "/robots.txt" { http.ServeFile(w, r, filepath.Join(root, "static/robots.txt")) return } if r.URL.Path == "/favicon.ico" { http.ServeFile(w, r, filepath.Join(root, "static/favicon.ico")) return } if strings.Contains(r.URL.Path, "/static/") { log.Print("Serving static") path := filepath.Join(root, r.URL.Path) info, err := os.Stat(path) if err != nil || info.IsDir() { log.Printf("404 File not found/dir") http.NotFound(w, r) return } log.Printf("Serving") h.ServeHTTP(w, r) return } // // Serve index.html when root is requested // if r.URL.Path == "/" { log.Printf("Serving index %s", r.URL.Path) http.ServeFile(w, r, filepath.Join(root, "/static/index.html")) // return // } // w.Header().Set("Cache-Control", "no-cache") // Check if client supports Brotli encoding // if strings.Contains(r.Header.Get("Accept-Encoding"), "br") { // if false { // w.Header().Set("Content-Encoding", "br") // w.Header().Del("Content-Length") // Cannot know content length with compressed data // // Wrap the ResponseWriter with Brotli writer // brWriter := brotli.NewWriter(w) // defer brWriter.Close() // // Create a ResponseWriter that writes to brWriter // bw := &brotliResponseWriter{ // ResponseWriter: w, // Writer: brWriter, // } // // Serve the file using http.ServeFile // http.ServeFile(bw, r, path) // return // } // h.ServeHTTP(w, r) } } func min(a, b int) int { if a < b { return a } return b } func main() { // Create a channel to receive OS signals for graceful shutdown sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) // Create a channel to signal when the program should shut down done := make(chan bool) // Define the directory to serve and the port to listen on dir := "./" port := 6789 addr := ":" + strconv.Itoa(port) log.Printf("Starting server on %s", addr) // Register message handlers registerHandler("hello", handleHello) registerHandler("ping", handlePing) registerHandler("peer_message", handlePeerMessage) // Set up file server and WebSocket endpoint fs := http.FileServer(http.Dir(dir)) http.Handle("/", noDirListing(fs, dir)) http.HandleFunc("/ws", handleWebSocket) // Configure the HTTP server server := &http.Server{ Addr: addr, Handler: nil, // Use the default ServeMux } // Start the inactivity monitor goroutine go func() { ticker := time.NewTicker(10 * time.Second) defer ticker.Stop() for { select { case <-done: return case <-ticker.C: now := time.Now() // Collect inactive peers var inactivePeers []string for peerID, peer := range peerConnections { if now.Sub(peer.lastActive) > 60*time.Second { inactivePeers = append(inactivePeers, peerID) } } // Remove inactive peers for _, peerID := range inactivePeers { peer := peerConnections[peerID] if peer != nil { log.Printf("Peer %s inactive for more than 60 seconds. Closing connection.", peerID) peer.conn.Close() removePeer(peerID, peer) } } } } }() // Run a goroutine to handle graceful shutdown go func() { sig := <-sigChan fmt.Println() fmt.Println("Received signal:", sig) // Perform cleanup here fmt.Println("Shutting down gracefully...") // Create a context with timeout for the shutdown ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() // Attempt to gracefully shut down the server if err := server.Shutdown(ctx); err != nil { log.Fatalf("Server Shutdown Failed:%+v", err) } // Signal that shutdown is complete close(done) }() // Start the HTTP server in a separate goroutine go func() { log.Printf("Server is configured and serving on port %d...", port) if err := server.ListenAndServeTLS( "/etc/letsencrypt/live/ddlion.net/fullchain.pem", "/etc/letsencrypt/live/ddlion.net/privkey.pem", ); err != nil && err != http.ErrServerClosed { log.Fatalf("Could not listen on %s: %v\n", addr, err) } }() fmt.Println("Program is running. Press Ctrl+C to exit.") // Wait for the shutdown signal <-done fmt.Println("Program has exited.") }