diff --git a/main b/main index ed3f34d..5676a5f 100755 Binary files a/main and b/main differ diff --git a/main.go b/main.go index 4552dcd..6d379e8 100644 --- a/main.go +++ b/main.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "fmt" - "io" "log" "net/http" "os" @@ -21,6 +20,20 @@ import ( "github.com/gorilla/websocket" ) +type PeerSet struct { + sync.Mutex + m map[string]struct{} +} + +var userPeersMutex sync.Mutex = sync.Mutex{} +var userPeers = make(map[string]*PeerSet) + +var peerConnectionsMutex sync.Mutex = sync.Mutex{} +var peerConnections = make(map[string]*Peer) + +var connectionPeersMutex sync.Mutex = sync.Mutex{} +var connectionPeers = make(map[*websocket.Conn]string) + var upgrader = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { origin := r.Header.Get("Origin") @@ -65,23 +78,27 @@ const ( ) type Peer struct { - conn *websocket.Conn - send chan []byte - lastActive time.Time - m sync.Mutex - closeOnce sync.Once + conn *websocket.Conn + send chan []byte + lastActive time.Time + lastActiveMutex sync.Mutex + closeOnce sync.Once } func removePeer(peerID string, peer *Peer) { delete(peerConnections, peerID) + userPeersMutex.Lock() + defer userPeersMutex.Unlock() + for userID, peers := range userPeers { - delete(peers, peerID) - if len(peers) == 0 { + delete(peers.m, peerID) + if len(peers.m) == 0 { delete(userPeers, userID) // not safe need mutex } } - + connectionPeersMutex.Lock() + defer connectionPeersMutex.Unlock() delete(connectionPeers, peer.conn) // Close the peer's send channel safely @@ -104,9 +121,10 @@ func handleWebSocket(w http.ResponseWriter, r *http.Request) { // Create a Peer object with a buffered channel for sending messages peer := &Peer{ - conn: conn, - send: make(chan []byte, 256), - lastActive: time.Now(), + conn: conn, + send: make(chan []byte, 4096), + lastActiveMutex: sync.Mutex{}, + lastActive: time.Now(), } // Start the write loop in a separate goroutine @@ -119,9 +137,9 @@ func handleWebSocket(w http.ResponseWriter, r *http.Request) { break } - peer.m.Lock() + peer.lastActiveMutex.Lock() peer.lastActive = time.Now() - peer.m.Unlock() + peer.lastActiveMutex.Unlock() // fmt.Println("ws<-", connectionPeers[conn], ":", string(message[:min(80, len(message))])) @@ -187,12 +205,6 @@ func handlePing(message []byte, peer *Peer) ([]byte, error) { return []byte(`{"type":"pong"}`), nil } -type PeerSet map[string]struct{} - -var userPeers = make(map[string]PeerSet) -var peerConnections = make(map[string]*Peer) -var connectionPeers = make(map[*websocket.Conn]string) - func handleHello(message []byte, peer *Peer) ([]byte, error) { var m struct { @@ -209,21 +221,32 @@ func handleHello(message []byte, peer *Peer) ([]byte, error) { } log.Printf("Received hello from peer %s:%s, user %s:%s", m.PeerID[0:5], m.PeerName, m.UserID[0:5], m.UserName) + userPeersMutex.Lock() + defer userPeersMutex.Unlock() + if userPeers[m.UserID] == nil { - userPeers[m.UserID] = make(PeerSet) + userPeers[m.UserID] = &PeerSet{ + sync.Mutex{}, + make(map[string]struct{}), + } } for _, knownUserID := range m.KnownUsers { fmt.Printf("Adding user %s for peer %s\n", knownUserID, m.PeerID) if userPeers[knownUserID] == nil { - userPeers[knownUserID] = make(PeerSet) + userPeers[knownUserID] = &PeerSet{ + sync.Mutex{}, + make(map[string]struct{}), + } } - userPeers[knownUserID][m.PeerID] = struct{}{} + userPeers[knownUserID].Mutex.Lock() + defer userPeers[knownUserID].Mutex.Unlock() + userPeers[knownUserID].m[m.PeerID] = struct{}{} } - userPeers[m.UserID][m.PeerID] = struct{}{} + userPeers[m.UserID].m[m.PeerID] = struct{}{} peerConnections[m.PeerID] = peer connectionPeers[peer.conn] = m.PeerID @@ -278,23 +301,29 @@ func handlePeerMessage(message []byte, peer *Peer) ([]byte, error) { } // BrotliResponseWriter wraps http.ResponseWriter to support Brotli compression -type brotliResponseWriter struct { - http.ResponseWriter - Writer io.Writer -} +// type brotliResponseWriter struct { +// http.ResponseWriter +// Writer io.Writer +// } -func (w *brotliResponseWriter) Write(b []byte) (int, error) { - return w.Writer.Write(b) -} +// func (w *brotliResponseWriter) Write(b []byte) (int, error) { +// return w.Writer.Write(b) +// } // noDirListing wraps an http.FileServer handler to prevent directory listings func noDirListing(h http.Handler, root string) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/" { + // Early out for root. + http.ServeFile(w, r, filepath.Join(root, "/static/index.html")) + return + } + // For anything under /static/, we serve it, unless it's a directory // Otherwise we serve index.html and the app does the routing nad rendering. - log.Printf("%s %s %s", r.URL.Path, r.RemoteAddr, r.UserAgent()) + // log.Printf("%s %s %s", r.URL.Path, r.RemoteAddr, r.UserAgent()) if r.URL.Path == "/sw.js" { http.ServeFile(w, r, filepath.Join(root, "static/sw.js")) @@ -330,7 +359,7 @@ func noDirListing(h http.Handler, root string) http.HandlerFunc { // // Serve index.html when root is requested // if r.URL.Path == "/" { - log.Printf("Serving index %s", r.URL.Path) + // log.Printf("Serving index %s", r.URL.Path) http.ServeFile(w, r, filepath.Join(root, "/static/index.html")) // return // } @@ -415,12 +444,14 @@ func main() { // Collect inactive peers var inactivePeers []string + peerConnectionsMutex.Lock() + defer peerConnectionsMutex.Unlock() for peerID, peer := range peerConnections { - peer.m.Lock() + peer.lastActiveMutex.Lock() + defer peer.lastActiveMutex.Unlock() if now.Sub(peer.lastActive) > 60*time.Second { inactivePeers = append(inactivePeers, peerID) } - peer.m.Unlock() } // Remove inactive peers