Try and fix data races

This commit is contained in:
“bobbydigitales”
2024-10-08 21:10:55 -07:00
parent 0183537c7e
commit a3a9790eed
2 changed files with 66 additions and 35 deletions

BIN
main

Binary file not shown.

101
main.go
View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"log" "log"
"net/http" "net/http"
"os" "os"
@@ -21,6 +20,20 @@ import (
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
) )
type PeerSet struct {
sync.Mutex
m map[string]struct{}
}
var userPeersMutex sync.Mutex = sync.Mutex{}
var userPeers = make(map[string]*PeerSet)
var peerConnectionsMutex sync.Mutex = sync.Mutex{}
var peerConnections = make(map[string]*Peer)
var connectionPeersMutex sync.Mutex = sync.Mutex{}
var connectionPeers = make(map[*websocket.Conn]string)
var upgrader = websocket.Upgrader{ var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { CheckOrigin: func(r *http.Request) bool {
origin := r.Header.Get("Origin") origin := r.Header.Get("Origin")
@@ -65,23 +78,27 @@ const (
) )
type Peer struct { type Peer struct {
conn *websocket.Conn conn *websocket.Conn
send chan []byte send chan []byte
lastActive time.Time lastActive time.Time
m sync.Mutex lastActiveMutex sync.Mutex
closeOnce sync.Once closeOnce sync.Once
} }
func removePeer(peerID string, peer *Peer) { func removePeer(peerID string, peer *Peer) {
delete(peerConnections, peerID) delete(peerConnections, peerID)
userPeersMutex.Lock()
defer userPeersMutex.Unlock()
for userID, peers := range userPeers { for userID, peers := range userPeers {
delete(peers, peerID) delete(peers.m, peerID)
if len(peers) == 0 { if len(peers.m) == 0 {
delete(userPeers, userID) // not safe need mutex delete(userPeers, userID) // not safe need mutex
} }
} }
connectionPeersMutex.Lock()
defer connectionPeersMutex.Unlock()
delete(connectionPeers, peer.conn) delete(connectionPeers, peer.conn)
// Close the peer's send channel safely // Close the peer's send channel safely
@@ -104,9 +121,10 @@ func handleWebSocket(w http.ResponseWriter, r *http.Request) {
// Create a Peer object with a buffered channel for sending messages // Create a Peer object with a buffered channel for sending messages
peer := &Peer{ peer := &Peer{
conn: conn, conn: conn,
send: make(chan []byte, 256), send: make(chan []byte, 4096),
lastActive: time.Now(), lastActiveMutex: sync.Mutex{},
lastActive: time.Now(),
} }
// Start the write loop in a separate goroutine // Start the write loop in a separate goroutine
@@ -119,9 +137,9 @@ func handleWebSocket(w http.ResponseWriter, r *http.Request) {
break break
} }
peer.m.Lock() peer.lastActiveMutex.Lock()
peer.lastActive = time.Now() peer.lastActive = time.Now()
peer.m.Unlock() peer.lastActiveMutex.Unlock()
// fmt.Println("ws<-", connectionPeers[conn], ":", string(message[:min(80, len(message))])) // fmt.Println("ws<-", connectionPeers[conn], ":", string(message[:min(80, len(message))]))
@@ -187,12 +205,6 @@ func handlePing(message []byte, peer *Peer) ([]byte, error) {
return []byte(`{"type":"pong"}`), nil return []byte(`{"type":"pong"}`), nil
} }
type PeerSet map[string]struct{}
var userPeers = make(map[string]PeerSet)
var peerConnections = make(map[string]*Peer)
var connectionPeers = make(map[*websocket.Conn]string)
func handleHello(message []byte, peer *Peer) ([]byte, error) { func handleHello(message []byte, peer *Peer) ([]byte, error) {
var m struct { var m struct {
@@ -209,21 +221,32 @@ func handleHello(message []byte, peer *Peer) ([]byte, error) {
} }
log.Printf("Received hello from peer %s:%s, user %s:%s", m.PeerID[0:5], m.PeerName, m.UserID[0:5], m.UserName) log.Printf("Received hello from peer %s:%s, user %s:%s", m.PeerID[0:5], m.PeerName, m.UserID[0:5], m.UserName)
userPeersMutex.Lock()
defer userPeersMutex.Unlock()
if userPeers[m.UserID] == nil { if userPeers[m.UserID] == nil {
userPeers[m.UserID] = make(PeerSet) userPeers[m.UserID] = &PeerSet{
sync.Mutex{},
make(map[string]struct{}),
}
} }
for _, knownUserID := range m.KnownUsers { for _, knownUserID := range m.KnownUsers {
fmt.Printf("Adding user %s for peer %s\n", knownUserID, m.PeerID) fmt.Printf("Adding user %s for peer %s\n", knownUserID, m.PeerID)
if userPeers[knownUserID] == nil { if userPeers[knownUserID] == nil {
userPeers[knownUserID] = make(PeerSet) userPeers[knownUserID] = &PeerSet{
sync.Mutex{},
make(map[string]struct{}),
}
} }
userPeers[knownUserID][m.PeerID] = struct{}{} userPeers[knownUserID].Mutex.Lock()
defer userPeers[knownUserID].Mutex.Unlock()
userPeers[knownUserID].m[m.PeerID] = struct{}{}
} }
userPeers[m.UserID][m.PeerID] = struct{}{} userPeers[m.UserID].m[m.PeerID] = struct{}{}
peerConnections[m.PeerID] = peer peerConnections[m.PeerID] = peer
connectionPeers[peer.conn] = m.PeerID connectionPeers[peer.conn] = m.PeerID
@@ -278,23 +301,29 @@ func handlePeerMessage(message []byte, peer *Peer) ([]byte, error) {
} }
// BrotliResponseWriter wraps http.ResponseWriter to support Brotli compression // BrotliResponseWriter wraps http.ResponseWriter to support Brotli compression
type brotliResponseWriter struct { // type brotliResponseWriter struct {
http.ResponseWriter // http.ResponseWriter
Writer io.Writer // Writer io.Writer
} // }
func (w *brotliResponseWriter) Write(b []byte) (int, error) { // func (w *brotliResponseWriter) Write(b []byte) (int, error) {
return w.Writer.Write(b) // return w.Writer.Write(b)
} // }
// noDirListing wraps an http.FileServer handler to prevent directory listings // noDirListing wraps an http.FileServer handler to prevent directory listings
func noDirListing(h http.Handler, root string) http.HandlerFunc { func noDirListing(h http.Handler, root string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/" {
// Early out for root.
http.ServeFile(w, r, filepath.Join(root, "/static/index.html"))
return
}
// For anything under /static/, we serve it, unless it's a directory // For anything under /static/, we serve it, unless it's a directory
// Otherwise we serve index.html and the app does the routing nad rendering. // Otherwise we serve index.html and the app does the routing nad rendering.
log.Printf("%s %s %s", r.URL.Path, r.RemoteAddr, r.UserAgent()) // log.Printf("%s %s %s", r.URL.Path, r.RemoteAddr, r.UserAgent())
if r.URL.Path == "/sw.js" { if r.URL.Path == "/sw.js" {
http.ServeFile(w, r, filepath.Join(root, "static/sw.js")) http.ServeFile(w, r, filepath.Join(root, "static/sw.js"))
@@ -330,7 +359,7 @@ func noDirListing(h http.Handler, root string) http.HandlerFunc {
// // Serve index.html when root is requested // // Serve index.html when root is requested
// if r.URL.Path == "/" { // if r.URL.Path == "/" {
log.Printf("Serving index %s", r.URL.Path) // log.Printf("Serving index %s", r.URL.Path)
http.ServeFile(w, r, filepath.Join(root, "/static/index.html")) http.ServeFile(w, r, filepath.Join(root, "/static/index.html"))
// return // return
// } // }
@@ -415,12 +444,14 @@ func main() {
// Collect inactive peers // Collect inactive peers
var inactivePeers []string var inactivePeers []string
peerConnectionsMutex.Lock()
defer peerConnectionsMutex.Unlock()
for peerID, peer := range peerConnections { for peerID, peer := range peerConnections {
peer.m.Lock() peer.lastActiveMutex.Lock()
defer peer.lastActiveMutex.Unlock()
if now.Sub(peer.lastActive) > 60*time.Second { if now.Sub(peer.lastActive) > 60*time.Second {
inactivePeers = append(inactivePeers, peerID) inactivePeers = append(inactivePeers, peerID)
} }
peer.m.Unlock()
} }
// Remove inactive peers // Remove inactive peers