510 lines
12 KiB
Go
510 lines
12 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"path/filepath"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"syscall"
|
|
"time"
|
|
|
|
// "strings"
|
|
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
type PeerSet struct {
|
|
sync.Mutex
|
|
M map[string]struct{}
|
|
}
|
|
|
|
var userPeersMutex sync.Mutex = sync.Mutex{}
|
|
var userPeers = make(map[string]*PeerSet)
|
|
|
|
var peerConnectionsMutex sync.Mutex = sync.Mutex{}
|
|
var peerConnections = make(map[string]*Peer)
|
|
|
|
var connectionPeersMutex sync.Mutex = sync.Mutex{}
|
|
var connectionPeers = make(map[*websocket.Conn]string)
|
|
|
|
var upgrader = websocket.Upgrader{
|
|
CheckOrigin: func(r *http.Request) bool {
|
|
origin := r.Header.Get("Origin")
|
|
return origin == "https://ddlion.net" || origin == "https://ddln.app" || origin == "https://localhost:6789"
|
|
},
|
|
}
|
|
|
|
func websocketCloseHandler(code int, text string) error {
|
|
log.Print("Client closed websocket.")
|
|
return nil
|
|
}
|
|
|
|
type Message struct {
|
|
Type string `json:"type"`
|
|
}
|
|
|
|
type MessageHandler func([]byte, *Peer) ([]byte, error)
|
|
|
|
var messageHandlers = make(map[string]MessageHandler)
|
|
|
|
func registerHandler(messageType string, handler MessageHandler) {
|
|
messageHandlers[messageType] = handler
|
|
}
|
|
|
|
func dispatchMessage(message []byte, peer *Peer) ([]byte, error) {
|
|
var msg Message
|
|
if err := json.Unmarshal(message, &msg); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
handler, ok := messageHandlers[msg.Type]
|
|
if !ok {
|
|
err := fmt.Errorf("no handler registered for message type: %s", msg.Type)
|
|
return []byte(fmt.Sprintf(`{"type":"error", "message": "%s"}`, err.Error())), nil
|
|
}
|
|
|
|
return handler(message, peer)
|
|
}
|
|
|
|
const (
|
|
writeWait = 120 * time.Second
|
|
)
|
|
|
|
type Peer struct {
|
|
conn *websocket.Conn
|
|
send chan []byte
|
|
lastActive time.Time
|
|
lastActiveMutex sync.Mutex
|
|
closeOnce sync.Once
|
|
}
|
|
|
|
func removePeer(peerID string, peer *Peer) {
|
|
delete(peerConnections, peerID)
|
|
|
|
userPeersMutex.Lock()
|
|
defer userPeersMutex.Unlock()
|
|
|
|
for userID, peers := range userPeers {
|
|
delete(peers.m, peerID)
|
|
if len(peers.m) == 0 {
|
|
delete(userPeers, userID) // not safe need mutex
|
|
}
|
|
}
|
|
connectionPeersMutex.Lock()
|
|
defer connectionPeersMutex.Unlock()
|
|
delete(connectionPeers, peer.conn)
|
|
|
|
// Close the peer's send channel safely
|
|
peer.closeOnce.Do(func() {
|
|
close(peer.send)
|
|
})
|
|
}
|
|
|
|
func handleWebSocket(w http.ResponseWriter, r *http.Request) {
|
|
log.Println("Websocket connection!", r.RemoteAddr)
|
|
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
log.Println("Upgrade error:", err)
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
conn.SetCloseHandler(websocketCloseHandler)
|
|
|
|
// Create a Peer object with a buffered channel for sending messages
|
|
peer := &Peer{
|
|
conn: conn,
|
|
send: make(chan []byte, 4096),
|
|
lastActiveMutex: sync.Mutex{},
|
|
lastActive: time.Now(),
|
|
}
|
|
|
|
// Start the write loop in a separate goroutine
|
|
go writePump(peer)
|
|
|
|
for {
|
|
_, message, err := conn.ReadMessage()
|
|
if err != nil {
|
|
log.Println("ReadMessage error:", err, connectionPeers[conn])
|
|
break
|
|
}
|
|
|
|
peer.lastActiveMutex.Lock()
|
|
peer.lastActive = time.Now()
|
|
peer.lastActiveMutex.Unlock()
|
|
|
|
// fmt.Println("ws<-", connectionPeers[conn], ":", string(message[:min(80, len(message))]))
|
|
|
|
response, err := dispatchMessage(message, peer)
|
|
|
|
if err != nil {
|
|
log.Printf("Error dispatching message: %v", err)
|
|
}
|
|
|
|
if response != nil {
|
|
// Send the response to the write loop
|
|
peer.send <- response
|
|
}
|
|
}
|
|
|
|
// Clean up when the connection is closed
|
|
peer.closeOnce.Do(func() {
|
|
close(peer.send)
|
|
})
|
|
|
|
peerID := connectionPeers[peer.conn]
|
|
if peerID != "" {
|
|
delete(peerConnections, peerID)
|
|
}
|
|
}
|
|
|
|
func writePump(peer *Peer) {
|
|
defer func() {
|
|
peer.conn.Close()
|
|
}()
|
|
for {
|
|
select {
|
|
case message, ok := <-peer.send:
|
|
if !ok {
|
|
// Channel closed, close the connection
|
|
peer.conn.WriteMessage(websocket.CloseMessage, []byte{})
|
|
return
|
|
}
|
|
peer.conn.SetWriteDeadline(time.Now().Add(writeWait))
|
|
// fmt.Println("ws->", connectionPeers[peer.conn], ":", string(message[:min(80, len(message))]))
|
|
|
|
err := peer.conn.WriteMessage(websocket.TextMessage, message)
|
|
if err != nil {
|
|
log.Println("WriteMessage error:", err)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func handlePing(message []byte, peer *Peer) ([]byte, error) {
|
|
var pingMsg struct {
|
|
Type string `json:"type"`
|
|
PeerID string `json:"peer_id"`
|
|
}
|
|
|
|
if err := json.Unmarshal(message, &pingMsg); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// log.Printf("Received ping from peer: %s", pingMsg.PeerID)
|
|
|
|
return []byte(`{"type":"pong"}`), nil
|
|
}
|
|
|
|
func handleHello(message []byte, peer *Peer) ([]byte, error) {
|
|
|
|
var m struct {
|
|
Type string `json:"type"`
|
|
UserID string `json:"user_id"`
|
|
UserName string `json:"user_name"`
|
|
PeerID string `json:"peer_id"`
|
|
PeerName string `json:"peer_name"`
|
|
KnownUsers []string `json:"known_users"`
|
|
}
|
|
|
|
if err := json.Unmarshal(message, &m); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
log.Printf("Received hello from peer %s:%s, user %s:%s", m.PeerID[0:5], m.PeerName, m.UserID[0:5], m.UserName)
|
|
userPeersMutex.Lock()
|
|
defer userPeersMutex.Unlock()
|
|
|
|
if userPeers[m.UserID] == nil {
|
|
userPeers[m.UserID] = &PeerSet{
|
|
sync.Mutex{},
|
|
make(map[string]struct{}),
|
|
}
|
|
}
|
|
|
|
for _, knownUserID := range m.KnownUsers {
|
|
fmt.Printf("Adding user %s for peer %s\n", knownUserID, m.PeerID)
|
|
if userPeers[knownUserID] == nil {
|
|
userPeers[knownUserID] = &PeerSet{
|
|
sync.Mutex{},
|
|
make(map[string]struct{}),
|
|
}
|
|
}
|
|
|
|
userPeers[knownUserID].Mutex.Lock()
|
|
defer userPeers[knownUserID].Mutex.Unlock()
|
|
userPeers[knownUserID].m[m.PeerID] = struct{}{}
|
|
|
|
}
|
|
|
|
userPeers[m.UserID].m[m.PeerID] = struct{}{}
|
|
peerConnections[m.PeerID] = peer
|
|
connectionPeers[peer.conn] = m.PeerID
|
|
|
|
jsonData, _ := json.MarshalIndent(userPeers, "", " ")
|
|
fmt.Println(string(jsonData), peerConnections)
|
|
|
|
// return all the peers we know about, with their user_id and peer_id
|
|
|
|
return []byte(fmt.Sprintf(`{"type":"hello", "userPeers": %s}`, string(jsonData))), nil
|
|
}
|
|
|
|
func handlePeerMessage(message []byte, peer *Peer) ([]byte, error) {
|
|
|
|
type InnerMessage struct {
|
|
Type string `json:"type"`
|
|
UserID string `json:"user_id"`
|
|
}
|
|
|
|
type PeerMessage struct {
|
|
Type string `json:"type"`
|
|
From string `json:"from"`
|
|
FromUserName string `json:"from_username"`
|
|
FromPeerName string `json:"from_peername"`
|
|
To string `json:"to"`
|
|
Message InnerMessage `json:"message"`
|
|
}
|
|
var m PeerMessage
|
|
|
|
if err := json.Unmarshal(message, &m); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
fmt.Printf("peer message type %s from %s:%s:%s to %s with message length %d\n", m.Message.Type, m.From[0:5], m.FromPeerName, m.FromUserName, m.To[0:5], len(message))
|
|
|
|
toPeer := peerConnections[m.To]
|
|
|
|
if toPeer == nil {
|
|
fmt.Printf("Couldn't find peer %s\n", m.To)
|
|
fmt.Println(peerConnections)
|
|
return nil, nil
|
|
}
|
|
|
|
// Send the message to the recipient's send channel
|
|
select {
|
|
case toPeer.send <- message:
|
|
default:
|
|
fmt.Println("Could not send message to peer; channel full or closed")
|
|
}
|
|
|
|
// No response for this type of message
|
|
return nil, nil
|
|
}
|
|
|
|
// BrotliResponseWriter wraps http.ResponseWriter to support Brotli compression
|
|
// type brotliResponseWriter struct {
|
|
// http.ResponseWriter
|
|
// Writer io.Writer
|
|
// }
|
|
|
|
// func (w *brotliResponseWriter) Write(b []byte) (int, error) {
|
|
// return w.Writer.Write(b)
|
|
// }
|
|
|
|
// noDirListing wraps an http.FileServer handler to prevent directory listings
|
|
func noDirListing(h http.Handler, root string) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
if r.URL.Path == "/" {
|
|
// Early out for root.
|
|
http.ServeFile(w, r, filepath.Join(root, "/static/index.html"))
|
|
return
|
|
}
|
|
|
|
// For anything under /static/, we serve it, unless it's a directory
|
|
// Otherwise we serve index.html and the app does the routing nad rendering.
|
|
|
|
// log.Printf("%s %s %s", r.URL.Path, r.RemoteAddr, r.UserAgent())
|
|
|
|
if r.URL.Path == "/sw.js" {
|
|
http.ServeFile(w, r, filepath.Join(root, "static/sw.js"))
|
|
return
|
|
}
|
|
|
|
if r.URL.Path == "/robots.txt" {
|
|
http.ServeFile(w, r, filepath.Join(root, "static/robots.txt"))
|
|
return
|
|
}
|
|
|
|
if r.URL.Path == "/favicon.ico" {
|
|
http.ServeFile(w, r, filepath.Join(root, "static/favicon.ico"))
|
|
return
|
|
}
|
|
|
|
if strings.Contains(r.URL.Path, "/static/") {
|
|
log.Print("Serving static")
|
|
|
|
path := filepath.Join(root, r.URL.Path)
|
|
info, err := os.Stat(path)
|
|
if err != nil || info.IsDir() {
|
|
log.Printf("404 File not found/dir")
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
|
|
log.Printf("Serving")
|
|
h.ServeHTTP(w, r)
|
|
return
|
|
|
|
}
|
|
|
|
// // Serve index.html when root is requested
|
|
// if r.URL.Path == "/" {
|
|
// log.Printf("Serving index %s", r.URL.Path)
|
|
http.ServeFile(w, r, filepath.Join(root, "/static/index.html"))
|
|
// return
|
|
// }
|
|
|
|
// w.Header().Set("Cache-Control", "no-cache")
|
|
|
|
// Check if client supports Brotli encoding
|
|
// if strings.Contains(r.Header.Get("Accept-Encoding"), "br") {
|
|
// if false {
|
|
// w.Header().Set("Content-Encoding", "br")
|
|
// w.Header().Del("Content-Length") // Cannot know content length with compressed data
|
|
|
|
// // Wrap the ResponseWriter with Brotli writer
|
|
// brWriter := brotli.NewWriter(w)
|
|
// defer brWriter.Close()
|
|
|
|
// // Create a ResponseWriter that writes to brWriter
|
|
// bw := &brotliResponseWriter{
|
|
// ResponseWriter: w,
|
|
// Writer: brWriter,
|
|
// }
|
|
|
|
// // Serve the file using http.ServeFile
|
|
// http.ServeFile(bw, r, path)
|
|
// return
|
|
// }
|
|
|
|
// h.ServeHTTP(w, r)
|
|
}
|
|
}
|
|
|
|
func min(a, b int) int {
|
|
if a < b {
|
|
return a
|
|
}
|
|
return b
|
|
}
|
|
|
|
func main() {
|
|
// Create a channel to receive OS signals for graceful shutdown
|
|
sigChan := make(chan os.Signal, 1)
|
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
|
|
|
// Create a channel to signal when the program should shut down
|
|
done := make(chan bool)
|
|
|
|
// Define the directory to serve and the port to listen on
|
|
dir := "./"
|
|
port := 6789
|
|
|
|
addr := ":" + strconv.Itoa(port)
|
|
log.Printf("Starting server on %s", addr)
|
|
|
|
// Register message handlers
|
|
registerHandler("hello", handleHello)
|
|
registerHandler("ping", handlePing)
|
|
registerHandler("peer_message", handlePeerMessage)
|
|
|
|
// Set up file server and WebSocket endpoint
|
|
fs := http.FileServer(http.Dir(dir))
|
|
http.Handle("/", noDirListing(fs, dir))
|
|
|
|
http.HandleFunc("/ws", handleWebSocket)
|
|
|
|
// Configure the HTTP server
|
|
server := &http.Server{
|
|
Addr: addr,
|
|
Handler: nil, // Use the default ServeMux
|
|
}
|
|
|
|
// Start the inactivity monitor goroutine
|
|
go func() {
|
|
ticker := time.NewTicker(10 * time.Second)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-done:
|
|
return
|
|
case <-ticker.C:
|
|
now := time.Now()
|
|
|
|
// Collect inactive peers
|
|
var inactivePeers []string
|
|
peerConnectionsMutex.Lock()
|
|
defer peerConnectionsMutex.Unlock()
|
|
for peerID, peer := range peerConnections {
|
|
peer.lastActiveMutex.Lock()
|
|
defer peer.lastActiveMutex.Unlock()
|
|
if now.Sub(peer.lastActive) > 60*time.Second {
|
|
inactivePeers = append(inactivePeers, peerID)
|
|
}
|
|
}
|
|
|
|
// Remove inactive peers
|
|
for _, peerID := range inactivePeers {
|
|
peer := peerConnections[peerID]
|
|
|
|
if peer != nil {
|
|
log.Printf("Peer %s inactive for more than 60 seconds. Closing connection.", peerID)
|
|
peer.conn.Close()
|
|
removePeer(peerID, peer)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Run a goroutine to handle graceful shutdown
|
|
go func() {
|
|
sig := <-sigChan
|
|
fmt.Println()
|
|
fmt.Println("Received signal:", sig)
|
|
|
|
// Perform cleanup here
|
|
fmt.Println("Shutting down gracefully...")
|
|
|
|
// Create a context with timeout for the shutdown
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
// Attempt to gracefully shut down the server
|
|
if err := server.Shutdown(ctx); err != nil {
|
|
log.Fatalf("Server Shutdown Failed:%+v", err)
|
|
}
|
|
|
|
// Signal that shutdown is complete
|
|
close(done)
|
|
}()
|
|
|
|
// Start the HTTP server in a separate goroutine
|
|
go func() {
|
|
log.Printf("Server is configured and serving on port %d...", port)
|
|
if err := server.ListenAndServeTLS(
|
|
"/etc/letsencrypt/live/ddlion.net/fullchain.pem",
|
|
"/etc/letsencrypt/live/ddlion.net/privkey.pem",
|
|
); err != nil && err != http.ErrServerClosed {
|
|
log.Fatalf("Could not listen on %s: %v\n", addr, err)
|
|
}
|
|
}()
|
|
|
|
fmt.Println("Program is running. Press Ctrl+C to exit.")
|
|
|
|
// Wait for the shutdown signal
|
|
<-done
|
|
fmt.Println("Program has exited.")
|
|
}
|