Try and fix data races

This commit is contained in:
“bobbydigitales”
2024-10-08 21:10:55 -07:00
parent 0183537c7e
commit a3a9790eed
2 changed files with 66 additions and 35 deletions

101
main.go
View File

@@ -4,7 +4,6 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
@@ -21,6 +20,20 @@ import (
"github.com/gorilla/websocket"
)
type PeerSet struct {
sync.Mutex
m map[string]struct{}
}
var userPeersMutex sync.Mutex = sync.Mutex{}
var userPeers = make(map[string]*PeerSet)
var peerConnectionsMutex sync.Mutex = sync.Mutex{}
var peerConnections = make(map[string]*Peer)
var connectionPeersMutex sync.Mutex = sync.Mutex{}
var connectionPeers = make(map[*websocket.Conn]string)
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
origin := r.Header.Get("Origin")
@@ -65,23 +78,27 @@ const (
)
type Peer struct {
conn *websocket.Conn
send chan []byte
lastActive time.Time
m sync.Mutex
closeOnce sync.Once
conn *websocket.Conn
send chan []byte
lastActive time.Time
lastActiveMutex sync.Mutex
closeOnce sync.Once
}
func removePeer(peerID string, peer *Peer) {
delete(peerConnections, peerID)
userPeersMutex.Lock()
defer userPeersMutex.Unlock()
for userID, peers := range userPeers {
delete(peers, peerID)
if len(peers) == 0 {
delete(peers.m, peerID)
if len(peers.m) == 0 {
delete(userPeers, userID) // not safe need mutex
}
}
connectionPeersMutex.Lock()
defer connectionPeersMutex.Unlock()
delete(connectionPeers, peer.conn)
// Close the peer's send channel safely
@@ -104,9 +121,10 @@ func handleWebSocket(w http.ResponseWriter, r *http.Request) {
// Create a Peer object with a buffered channel for sending messages
peer := &Peer{
conn: conn,
send: make(chan []byte, 256),
lastActive: time.Now(),
conn: conn,
send: make(chan []byte, 4096),
lastActiveMutex: sync.Mutex{},
lastActive: time.Now(),
}
// Start the write loop in a separate goroutine
@@ -119,9 +137,9 @@ func handleWebSocket(w http.ResponseWriter, r *http.Request) {
break
}
peer.m.Lock()
peer.lastActiveMutex.Lock()
peer.lastActive = time.Now()
peer.m.Unlock()
peer.lastActiveMutex.Unlock()
// fmt.Println("ws<-", connectionPeers[conn], ":", string(message[:min(80, len(message))]))
@@ -187,12 +205,6 @@ func handlePing(message []byte, peer *Peer) ([]byte, error) {
return []byte(`{"type":"pong"}`), nil
}
type PeerSet map[string]struct{}
var userPeers = make(map[string]PeerSet)
var peerConnections = make(map[string]*Peer)
var connectionPeers = make(map[*websocket.Conn]string)
func handleHello(message []byte, peer *Peer) ([]byte, error) {
var m struct {
@@ -209,21 +221,32 @@ func handleHello(message []byte, peer *Peer) ([]byte, error) {
}
log.Printf("Received hello from peer %s:%s, user %s:%s", m.PeerID[0:5], m.PeerName, m.UserID[0:5], m.UserName)
userPeersMutex.Lock()
defer userPeersMutex.Unlock()
if userPeers[m.UserID] == nil {
userPeers[m.UserID] = make(PeerSet)
userPeers[m.UserID] = &PeerSet{
sync.Mutex{},
make(map[string]struct{}),
}
}
for _, knownUserID := range m.KnownUsers {
fmt.Printf("Adding user %s for peer %s\n", knownUserID, m.PeerID)
if userPeers[knownUserID] == nil {
userPeers[knownUserID] = make(PeerSet)
userPeers[knownUserID] = &PeerSet{
sync.Mutex{},
make(map[string]struct{}),
}
}
userPeers[knownUserID][m.PeerID] = struct{}{}
userPeers[knownUserID].Mutex.Lock()
defer userPeers[knownUserID].Mutex.Unlock()
userPeers[knownUserID].m[m.PeerID] = struct{}{}
}
userPeers[m.UserID][m.PeerID] = struct{}{}
userPeers[m.UserID].m[m.PeerID] = struct{}{}
peerConnections[m.PeerID] = peer
connectionPeers[peer.conn] = m.PeerID
@@ -278,23 +301,29 @@ func handlePeerMessage(message []byte, peer *Peer) ([]byte, error) {
}
// BrotliResponseWriter wraps http.ResponseWriter to support Brotli compression
type brotliResponseWriter struct {
http.ResponseWriter
Writer io.Writer
}
// type brotliResponseWriter struct {
// http.ResponseWriter
// Writer io.Writer
// }
func (w *brotliResponseWriter) Write(b []byte) (int, error) {
return w.Writer.Write(b)
}
// func (w *brotliResponseWriter) Write(b []byte) (int, error) {
// return w.Writer.Write(b)
// }
// noDirListing wraps an http.FileServer handler to prevent directory listings
func noDirListing(h http.Handler, root string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/" {
// Early out for root.
http.ServeFile(w, r, filepath.Join(root, "/static/index.html"))
return
}
// For anything under /static/, we serve it, unless it's a directory
// Otherwise we serve index.html and the app does the routing nad rendering.
log.Printf("%s %s %s", r.URL.Path, r.RemoteAddr, r.UserAgent())
// log.Printf("%s %s %s", r.URL.Path, r.RemoteAddr, r.UserAgent())
if r.URL.Path == "/sw.js" {
http.ServeFile(w, r, filepath.Join(root, "static/sw.js"))
@@ -330,7 +359,7 @@ func noDirListing(h http.Handler, root string) http.HandlerFunc {
// // Serve index.html when root is requested
// if r.URL.Path == "/" {
log.Printf("Serving index %s", r.URL.Path)
// log.Printf("Serving index %s", r.URL.Path)
http.ServeFile(w, r, filepath.Join(root, "/static/index.html"))
// return
// }
@@ -415,12 +444,14 @@ func main() {
// Collect inactive peers
var inactivePeers []string
peerConnectionsMutex.Lock()
defer peerConnectionsMutex.Unlock()
for peerID, peer := range peerConnections {
peer.m.Lock()
peer.lastActiveMutex.Lock()
defer peer.lastActiveMutex.Unlock()
if now.Sub(peer.lastActive) > 60*time.Second {
inactivePeers = append(inactivePeers, peerID)
}
peer.m.Unlock()
}
// Remove inactive peers