Try and fix data races
This commit is contained in:
89
main.go
89
main.go
@@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
@@ -21,6 +20,20 @@ import (
|
|||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type PeerSet struct {
|
||||||
|
sync.Mutex
|
||||||
|
m map[string]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
var userPeersMutex sync.Mutex = sync.Mutex{}
|
||||||
|
var userPeers = make(map[string]*PeerSet)
|
||||||
|
|
||||||
|
var peerConnectionsMutex sync.Mutex = sync.Mutex{}
|
||||||
|
var peerConnections = make(map[string]*Peer)
|
||||||
|
|
||||||
|
var connectionPeersMutex sync.Mutex = sync.Mutex{}
|
||||||
|
var connectionPeers = make(map[*websocket.Conn]string)
|
||||||
|
|
||||||
var upgrader = websocket.Upgrader{
|
var upgrader = websocket.Upgrader{
|
||||||
CheckOrigin: func(r *http.Request) bool {
|
CheckOrigin: func(r *http.Request) bool {
|
||||||
origin := r.Header.Get("Origin")
|
origin := r.Header.Get("Origin")
|
||||||
@@ -68,20 +81,24 @@ type Peer struct {
|
|||||||
conn *websocket.Conn
|
conn *websocket.Conn
|
||||||
send chan []byte
|
send chan []byte
|
||||||
lastActive time.Time
|
lastActive time.Time
|
||||||
m sync.Mutex
|
lastActiveMutex sync.Mutex
|
||||||
closeOnce sync.Once
|
closeOnce sync.Once
|
||||||
}
|
}
|
||||||
|
|
||||||
func removePeer(peerID string, peer *Peer) {
|
func removePeer(peerID string, peer *Peer) {
|
||||||
delete(peerConnections, peerID)
|
delete(peerConnections, peerID)
|
||||||
|
|
||||||
|
userPeersMutex.Lock()
|
||||||
|
defer userPeersMutex.Unlock()
|
||||||
|
|
||||||
for userID, peers := range userPeers {
|
for userID, peers := range userPeers {
|
||||||
delete(peers, peerID)
|
delete(peers.m, peerID)
|
||||||
if len(peers) == 0 {
|
if len(peers.m) == 0 {
|
||||||
delete(userPeers, userID) // not safe need mutex
|
delete(userPeers, userID) // not safe need mutex
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
connectionPeersMutex.Lock()
|
||||||
|
defer connectionPeersMutex.Unlock()
|
||||||
delete(connectionPeers, peer.conn)
|
delete(connectionPeers, peer.conn)
|
||||||
|
|
||||||
// Close the peer's send channel safely
|
// Close the peer's send channel safely
|
||||||
@@ -105,7 +122,8 @@ func handleWebSocket(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Create a Peer object with a buffered channel for sending messages
|
// Create a Peer object with a buffered channel for sending messages
|
||||||
peer := &Peer{
|
peer := &Peer{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
send: make(chan []byte, 256),
|
send: make(chan []byte, 4096),
|
||||||
|
lastActiveMutex: sync.Mutex{},
|
||||||
lastActive: time.Now(),
|
lastActive: time.Now(),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -119,9 +137,9 @@ func handleWebSocket(w http.ResponseWriter, r *http.Request) {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
peer.m.Lock()
|
peer.lastActiveMutex.Lock()
|
||||||
peer.lastActive = time.Now()
|
peer.lastActive = time.Now()
|
||||||
peer.m.Unlock()
|
peer.lastActiveMutex.Unlock()
|
||||||
|
|
||||||
// fmt.Println("ws<-", connectionPeers[conn], ":", string(message[:min(80, len(message))]))
|
// fmt.Println("ws<-", connectionPeers[conn], ":", string(message[:min(80, len(message))]))
|
||||||
|
|
||||||
@@ -187,12 +205,6 @@ func handlePing(message []byte, peer *Peer) ([]byte, error) {
|
|||||||
return []byte(`{"type":"pong"}`), nil
|
return []byte(`{"type":"pong"}`), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type PeerSet map[string]struct{}
|
|
||||||
|
|
||||||
var userPeers = make(map[string]PeerSet)
|
|
||||||
var peerConnections = make(map[string]*Peer)
|
|
||||||
var connectionPeers = make(map[*websocket.Conn]string)
|
|
||||||
|
|
||||||
func handleHello(message []byte, peer *Peer) ([]byte, error) {
|
func handleHello(message []byte, peer *Peer) ([]byte, error) {
|
||||||
|
|
||||||
var m struct {
|
var m struct {
|
||||||
@@ -209,21 +221,32 @@ func handleHello(message []byte, peer *Peer) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("Received hello from peer %s:%s, user %s:%s", m.PeerID[0:5], m.PeerName, m.UserID[0:5], m.UserName)
|
log.Printf("Received hello from peer %s:%s, user %s:%s", m.PeerID[0:5], m.PeerName, m.UserID[0:5], m.UserName)
|
||||||
|
userPeersMutex.Lock()
|
||||||
|
defer userPeersMutex.Unlock()
|
||||||
|
|
||||||
if userPeers[m.UserID] == nil {
|
if userPeers[m.UserID] == nil {
|
||||||
userPeers[m.UserID] = make(PeerSet)
|
userPeers[m.UserID] = &PeerSet{
|
||||||
|
sync.Mutex{},
|
||||||
|
make(map[string]struct{}),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, knownUserID := range m.KnownUsers {
|
for _, knownUserID := range m.KnownUsers {
|
||||||
fmt.Printf("Adding user %s for peer %s\n", knownUserID, m.PeerID)
|
fmt.Printf("Adding user %s for peer %s\n", knownUserID, m.PeerID)
|
||||||
if userPeers[knownUserID] == nil {
|
if userPeers[knownUserID] == nil {
|
||||||
userPeers[knownUserID] = make(PeerSet)
|
userPeers[knownUserID] = &PeerSet{
|
||||||
|
sync.Mutex{},
|
||||||
|
make(map[string]struct{}),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
userPeers[knownUserID][m.PeerID] = struct{}{}
|
userPeers[knownUserID].Mutex.Lock()
|
||||||
|
defer userPeers[knownUserID].Mutex.Unlock()
|
||||||
|
userPeers[knownUserID].m[m.PeerID] = struct{}{}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
userPeers[m.UserID][m.PeerID] = struct{}{}
|
userPeers[m.UserID].m[m.PeerID] = struct{}{}
|
||||||
peerConnections[m.PeerID] = peer
|
peerConnections[m.PeerID] = peer
|
||||||
connectionPeers[peer.conn] = m.PeerID
|
connectionPeers[peer.conn] = m.PeerID
|
||||||
|
|
||||||
@@ -278,23 +301,29 @@ func handlePeerMessage(message []byte, peer *Peer) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// BrotliResponseWriter wraps http.ResponseWriter to support Brotli compression
|
// BrotliResponseWriter wraps http.ResponseWriter to support Brotli compression
|
||||||
type brotliResponseWriter struct {
|
// type brotliResponseWriter struct {
|
||||||
http.ResponseWriter
|
// http.ResponseWriter
|
||||||
Writer io.Writer
|
// Writer io.Writer
|
||||||
}
|
// }
|
||||||
|
|
||||||
func (w *brotliResponseWriter) Write(b []byte) (int, error) {
|
// func (w *brotliResponseWriter) Write(b []byte) (int, error) {
|
||||||
return w.Writer.Write(b)
|
// return w.Writer.Write(b)
|
||||||
}
|
// }
|
||||||
|
|
||||||
// noDirListing wraps an http.FileServer handler to prevent directory listings
|
// noDirListing wraps an http.FileServer handler to prevent directory listings
|
||||||
func noDirListing(h http.Handler, root string) http.HandlerFunc {
|
func noDirListing(h http.Handler, root string) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
|
if r.URL.Path == "/" {
|
||||||
|
// Early out for root.
|
||||||
|
http.ServeFile(w, r, filepath.Join(root, "/static/index.html"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// For anything under /static/, we serve it, unless it's a directory
|
// For anything under /static/, we serve it, unless it's a directory
|
||||||
// Otherwise we serve index.html and the app does the routing nad rendering.
|
// Otherwise we serve index.html and the app does the routing nad rendering.
|
||||||
|
|
||||||
log.Printf("%s %s %s", r.URL.Path, r.RemoteAddr, r.UserAgent())
|
// log.Printf("%s %s %s", r.URL.Path, r.RemoteAddr, r.UserAgent())
|
||||||
|
|
||||||
if r.URL.Path == "/sw.js" {
|
if r.URL.Path == "/sw.js" {
|
||||||
http.ServeFile(w, r, filepath.Join(root, "static/sw.js"))
|
http.ServeFile(w, r, filepath.Join(root, "static/sw.js"))
|
||||||
@@ -330,7 +359,7 @@ func noDirListing(h http.Handler, root string) http.HandlerFunc {
|
|||||||
|
|
||||||
// // Serve index.html when root is requested
|
// // Serve index.html when root is requested
|
||||||
// if r.URL.Path == "/" {
|
// if r.URL.Path == "/" {
|
||||||
log.Printf("Serving index %s", r.URL.Path)
|
// log.Printf("Serving index %s", r.URL.Path)
|
||||||
http.ServeFile(w, r, filepath.Join(root, "/static/index.html"))
|
http.ServeFile(w, r, filepath.Join(root, "/static/index.html"))
|
||||||
// return
|
// return
|
||||||
// }
|
// }
|
||||||
@@ -415,12 +444,14 @@ func main() {
|
|||||||
|
|
||||||
// Collect inactive peers
|
// Collect inactive peers
|
||||||
var inactivePeers []string
|
var inactivePeers []string
|
||||||
|
peerConnectionsMutex.Lock()
|
||||||
|
defer peerConnectionsMutex.Unlock()
|
||||||
for peerID, peer := range peerConnections {
|
for peerID, peer := range peerConnections {
|
||||||
peer.m.Lock()
|
peer.lastActiveMutex.Lock()
|
||||||
|
defer peer.lastActiveMutex.Unlock()
|
||||||
if now.Sub(peer.lastActive) > 60*time.Second {
|
if now.Sub(peer.lastActive) > 60*time.Second {
|
||||||
inactivePeers = append(inactivePeers, peerID)
|
inactivePeers = append(inactivePeers, peerID)
|
||||||
}
|
}
|
||||||
peer.m.Unlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove inactive peers
|
// Remove inactive peers
|
||||||
|
|||||||
Reference in New Issue
Block a user