diff --git a/backend/cmd/headlamp.go b/backend/cmd/headlamp.go index fd4bd3dd81..87cb557c5e 100644 --- a/backend/cmd/headlamp.go +++ b/backend/cmd/headlamp.go @@ -1609,6 +1609,9 @@ func (c *HeadlampConfig) addClusterSetupRoute(r *mux.Router) { // Rename a cluster r.HandleFunc("/cluster/{name}", c.renameCluster).Methods("PUT") + + // Websocket connections + r.HandleFunc("/wsMultiplexer", c.multiplexer.HandleClientWebSocket) } /* diff --git a/backend/cmd/multiplexer.go b/backend/cmd/multiplexer.go index ad37f91359..4afcf76371 100644 --- a/backend/cmd/multiplexer.go +++ b/backend/cmd/multiplexer.go @@ -68,6 +68,10 @@ type Connection struct { Done chan struct{} // mu is a mutex to synchronize access to the connection. mu sync.RWMutex + // writeMu is a mutex to synchronize access to the write operations. + writeMu sync.Mutex + // closed is a flag to indicate if the connection is closed. + closed bool } // Message represents a WebSocket message structure. @@ -81,7 +85,11 @@ type Message struct { // UserID is the ID of the user. UserID string `json:"userId"` // Data contains the message payload. - Data []byte `json:"data,omitempty"` + Data string `json:"data,omitempty"` + // Binary is a flag to indicate if the message is binary. + Binary bool `json:"binary,omitempty"` + // Type is the type of the message. + Type string `json:"type"` } // Multiplexer manages multiple WebSocket connections. @@ -114,41 +122,58 @@ func (c *Connection) updateStatus(state ConnectionState, err error) { c.mu.Lock() defer c.mu.Unlock() + if c.closed { + return + } + c.Status.State = state c.Status.LastMsg = time.Now() + c.Status.Error = "" if err != nil { c.Status.Error = err.Error() - } else { - c.Status.Error = "" } - if c.Client != nil { - statusData := struct { - State string `json:"state"` - Error string `json:"error"` - }{ - State: string(state), - Error: c.Status.Error, - } + if c.Client == nil { + return + } - jsonData, jsonErr := json.Marshal(statusData) - if jsonErr != nil { - logger.Log(logger.LevelError, map[string]string{"clusterID": c.ClusterID}, jsonErr, "marshaling status message") + c.writeMu.Lock() + defer c.writeMu.Unlock() - return - } + // Check if connection is closed before writing + if c.closed { + return + } - statusMsg := Message{ - ClusterID: c.ClusterID, - Path: c.Path, - Data: jsonData, - } + statusData := struct { + State string `json:"state"` + Error string `json:"error"` + }{ + State: string(state), + Error: c.Status.Error, + } - err := c.Client.WriteJSON(statusMsg) - if err != nil { + jsonData, jsonErr := json.Marshal(statusData) + if jsonErr != nil { + logger.Log(logger.LevelError, map[string]string{"clusterID": c.ClusterID}, jsonErr, "marshaling status message") + + return + } + + statusMsg := Message{ + ClusterID: c.ClusterID, + Path: c.Path, + Data: string(jsonData), + Type: "STATUS", + } + + if err := c.Client.WriteJSON(statusMsg); err != nil { + if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { logger.Log(logger.LevelError, map[string]string{"clusterID": c.ClusterID}, err, "writing status message to client") } + + c.closed = true } } @@ -188,7 +213,8 @@ func (m *Multiplexer) establishClusterConnection( connection.updateStatus(StateConnected, nil) m.mutex.Lock() - m.connections[clusterID+path] = connection + connKey := m.createConnectionKey(clusterID, path, userID) + m.connections[connKey] = connection m.mutex.Unlock() go m.monitorConnection(connection) @@ -291,6 +317,10 @@ func (m *Multiplexer) monitorConnection(conn *Connection) { // reconnect attempts to reestablish a connection. func (m *Multiplexer) reconnect(conn *Connection) (*Connection, error) { + if conn.closed { + return nil, fmt.Errorf("cannot reconnect closed connection") + } + if conn.WSConn != nil { conn.WSConn.Close() } @@ -309,13 +339,12 @@ func (m *Multiplexer) reconnect(conn *Connection) (*Connection, error) { } m.mutex.Lock() - m.connections[conn.ClusterID+conn.Path] = newConn + m.connections[m.createConnectionKey(conn.ClusterID, conn.Path, conn.UserID)] = newConn m.mutex.Unlock() return newConn, nil } -// HandleClientWebSocket handles incoming WebSocket connections from clients. // HandleClientWebSocket handles incoming WebSocket connections from clients. func (m *Multiplexer) HandleClientWebSocket(w http.ResponseWriter, r *http.Request) { clientConn, err := m.upgrader.Upgrade(w, r, nil) @@ -333,16 +362,8 @@ func (m *Multiplexer) HandleClientWebSocket(w http.ResponseWriter, r *http.Reque } // Check if it's a close message - if msg.Data != nil && len(msg.Data) > 0 && string(msg.Data) == "close" { - err := m.CloseConnection(msg.ClusterID, msg.Path, msg.UserID) - if err != nil { - logger.Log( - logger.LevelError, - map[string]string{"clusterID": msg.ClusterID, "UserID": msg.UserID}, - err, - "closing connection", - ) - } + if msg.Type == "CLOSE" { + m.CloseConnection(msg.ClusterID, msg.Path, msg.UserID) continue } @@ -354,8 +375,8 @@ func (m *Multiplexer) HandleClientWebSocket(w http.ResponseWriter, r *http.Reque continue } - if len(msg.Data) > 0 && conn.Status.State == StateConnected { - err = m.writeMessageToCluster(conn, msg.Data) + if msg.Type == "REQUEST" && conn.Status.State == StateConnected { + err = m.writeMessageToCluster(conn, []byte(msg.Data)) if err != nil { continue } @@ -388,7 +409,7 @@ func (m *Multiplexer) readClientMessage(clientConn *websocket.Conn) (Message, er // getOrCreateConnection gets an existing connection or creates a new one if it doesn't exist. func (m *Multiplexer) getOrCreateConnection(msg Message, clientConn *websocket.Conn) (*Connection, error) { - connKey := fmt.Sprintf("%s:%s:%s", msg.ClusterID, msg.Path, msg.UserID) + connKey := m.createConnectionKey(msg.ClusterID, msg.Path, msg.UserID) m.mutex.RLock() conn, exists := m.connections[connKey] @@ -457,100 +478,182 @@ func (m *Multiplexer) writeMessageToCluster(conn *Connection, data []byte) error // handleClusterMessages handles messages from a cluster connection. func (m *Multiplexer) handleClusterMessages(conn *Connection, clientConn *websocket.Conn) { - defer func() { - conn.updateStatus(StateClosed, nil) - conn.WSConn.Close() - }() + defer m.cleanupConnection(conn) + + var lastResourceVersion string for { select { case <-conn.Done: return default: - if err := m.processClusterMessage(conn, clientConn); err != nil { + if err := m.processClusterMessage(conn, clientConn, &lastResourceVersion); err != nil { return } } } } -// processClusterMessage processes a message from a cluster connection. -func (m *Multiplexer) processClusterMessage(conn *Connection, clientConn *websocket.Conn) error { +// processClusterMessage processes a single message from the cluster. +func (m *Multiplexer) processClusterMessage( + conn *Connection, + clientConn *websocket.Conn, + lastResourceVersion *string, +) error { messageType, message, err := conn.WSConn.ReadMessage() if err != nil { - m.handleReadError(conn, err) + if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { + logger.Log(logger.LevelError, + map[string]string{ + "clusterID": conn.ClusterID, + "userID": conn.UserID, + }, + err, + "reading cluster message", + ) + } return err } - wrapperMsg := m.createWrapperMessage(conn, messageType, message) + if err := m.sendIfNewResourceVersion(message, conn, clientConn, lastResourceVersion); err != nil { + return err + } - if err := clientConn.WriteJSON(wrapperMsg); err != nil { - m.handleWriteError(conn, err) + return m.sendDataMessage(conn, clientConn, messageType, message) +} - return err +// sendIfNewResourceVersion checks the version of a resource from an incoming message +// and sends a complete message to the client if the resource version has changed. +// +// This function is used to ensure that the client is always aware of the latest version +// of a resource. When a new message is received, it extracts the resource version from +// the message metadata. If the resource version has changed since the last known version, +// it sends a complete message to the client to update them with the latest resource state. +// Parameters: +// - message: The JSON-encoded message containing resource information. +// - conn: The connection object representing the current connection. +// - clientConn: The WebSocket connection to the client. +// - lastResourceVersion: A pointer to the last known resource version string. +// +// Returns: +// - An error if any issues occur while processing the message, or nil if successful. +func (m *Multiplexer) sendIfNewResourceVersion( + message []byte, + conn *Connection, + clientConn *websocket.Conn, + lastResourceVersion *string, +) error { + var obj map[string]interface{} + if err := json.Unmarshal(message, &obj); err != nil { + return fmt.Errorf("error unmarshaling message: %v", err) } - conn.mu.Lock() - conn.Status.LastMsg = time.Now() - conn.mu.Unlock() + // Try to find metadata directly + metadata, ok := obj["metadata"].(map[string]interface{}) + if !ok { + // Try to find metadata in object field + if objField, ok := obj["object"].(map[string]interface{}); ok { + if metadata, ok = objField["metadata"].(map[string]interface{}); !ok { + // No metadata field found, nothing to do + return nil + } + } else { + // No metadata field found, nothing to do + return nil + } + } + + rv, ok := metadata["resourceVersion"].(string) + if !ok { + // No resourceVersion field, nothing to do + return nil + } + + // Update version and send complete message if version is different + if rv != *lastResourceVersion { + *lastResourceVersion = rv + + return m.sendCompleteMessage(conn, clientConn) + } return nil } -// createWrapperMessage creates a wrapper message for a cluster connection. -func (m *Multiplexer) createWrapperMessage(conn *Connection, messageType int, message []byte) struct { - ClusterID string `json:"clusterId"` - Path string `json:"path"` - Query string `json:"query"` - UserID string `json:"userId"` - Data string `json:"data"` - Binary bool `json:"binary"` -} { - wrapperMsg := struct { - ClusterID string `json:"clusterId"` - Path string `json:"path"` - Query string `json:"query"` - UserID string `json:"userId"` - Data string `json:"data"` - Binary bool `json:"binary"` - }{ +// sendCompleteMessage sends a COMPLETE message to the client. +func (m *Multiplexer) sendCompleteMessage(conn *Connection, clientConn *websocket.Conn) error { + completeMsg := Message{ ClusterID: conn.ClusterID, Path: conn.Path, Query: conn.Query, UserID: conn.UserID, - Binary: messageType == websocket.BinaryMessage, + Type: "COMPLETE", } - if messageType == websocket.BinaryMessage { - wrapperMsg.Data = base64.StdEncoding.EncodeToString(message) - } else { - wrapperMsg.Data = string(message) + conn.writeMu.Lock() + defer conn.writeMu.Unlock() + + return clientConn.WriteJSON(completeMsg) +} + +// sendDataMessage sends the actual data message to the client. +func (m *Multiplexer) sendDataMessage( + conn *Connection, + clientConn *websocket.Conn, + messageType int, + message []byte, +) error { + dataMsg := m.createWrapperMessage(conn, messageType, message) + + conn.writeMu.Lock() + defer conn.writeMu.Unlock() + + if err := clientConn.WriteJSON(dataMsg); err != nil { + return err } - return wrapperMsg + conn.mu.Lock() + conn.Status.LastMsg = time.Now() + conn.mu.Unlock() + + return nil } -// handleReadError handles errors that occur when reading a message from a cluster connection. -func (m *Multiplexer) handleReadError(conn *Connection, err error) { - conn.updateStatus(StateError, err) - logger.Log( - logger.LevelError, - map[string]string{"clusterID": conn.ClusterID, "UserID": conn.UserID}, - err, - "reading message from cluster", - ) +// cleanupConnection performs cleanup for a connection. +func (m *Multiplexer) cleanupConnection(conn *Connection) { + conn.mu.Lock() + defer conn.mu.Unlock() // Ensure the mutex is unlocked even if an error occurs + + conn.closed = true + + if conn.WSConn != nil { + conn.WSConn.Close() + } + + m.mutex.Lock() + connKey := m.createConnectionKey(conn.ClusterID, conn.Path, conn.UserID) + delete(m.connections, connKey) + m.mutex.Unlock() } -// handleWriteError handles errors that occur when writing a message to a client connection. -func (m *Multiplexer) handleWriteError(conn *Connection, err error) { - conn.updateStatus(StateError, err) - logger.Log( - logger.LevelError, - map[string]string{"clusterID": conn.ClusterID, "UserID": conn.UserID}, - err, - "writing message to client", - ) +// createWrapperMessage creates a wrapper message for a cluster connection. +func (m *Multiplexer) createWrapperMessage(conn *Connection, messageType int, message []byte) Message { + var data string + if messageType == websocket.BinaryMessage { + data = base64.StdEncoding.EncodeToString(message) + } else { + data = string(message) + } + + return Message{ + ClusterID: conn.ClusterID, + Path: conn.Path, + Query: conn.Query, + UserID: conn.UserID, + Data: data, + Binary: messageType == websocket.BinaryMessage, + Type: "DATA", + } } // cleanupConnections closes and removes all connections. @@ -586,39 +689,49 @@ func (m *Multiplexer) getClusterConfig(clusterID string) (*rest.Config, error) { } // CloseConnection closes a specific connection based on its identifier. -func (m *Multiplexer) CloseConnection(clusterID, path, userID string) error { - connKey := fmt.Sprintf("%s:%s:%s", clusterID, path, userID) +func (m *Multiplexer) CloseConnection(clusterID, path, userID string) { + connKey := m.createConnectionKey(clusterID, path, userID) m.mutex.Lock() - defer m.mutex.Unlock() conn, exists := m.connections[connKey] if !exists { - return fmt.Errorf("connection not found for key: %s", connKey) + m.mutex.Unlock() + // Don't log error for non-existent connections during cleanup + return } - // Signal the connection to close - close(conn.Done) + // Mark as closed before releasing the lock + conn.mu.Lock() + if conn.closed { + conn.mu.Unlock() + m.mutex.Unlock() + logger.Log(logger.LevelError, map[string]string{"clusterID": conn.ClusterID}, nil, "closing connection") - // Close the WebSocket connection - if conn.WSConn != nil { - if err := conn.WSConn.Close(); err != nil { - logger.Log( - logger.LevelError, - map[string]string{"clusterID": clusterID, "userID": userID}, - err, - "closing WebSocket connection", - ) - } + return } - // Update the connection status - conn.updateStatus(StateClosed, nil) + conn.closed = true + conn.mu.Unlock() - // Remove the connection from the map delete(m.connections, connKey) + m.mutex.Unlock() - return nil + // Lock the connection mutex before accessing shared resources + conn.mu.Lock() + defer conn.mu.Unlock() // Ensure the mutex is unlocked after the operations + + // Close the Done channel and connections after removing from map + close(conn.Done) + + if conn.WSConn != nil { + conn.WSConn.Close() + } +} + +// createConnectionKey creates a unique key for a connection based on cluster ID, path, and user ID. +func (m *Multiplexer) createConnectionKey(clusterID, path, userID string) string { + return fmt.Sprintf("%s:%s:%s", clusterID, path, userID) } // createWebSocketURL creates a WebSocket URL from the given parameters. diff --git a/backend/cmd/multiplexer_test.go b/backend/cmd/multiplexer_test.go index 058e01377b..0df729ef4a 100644 --- a/backend/cmd/multiplexer_test.go +++ b/backend/cmd/multiplexer_test.go @@ -8,6 +8,7 @@ import ( "net/http" "net/http/httptest" "strings" + "sync" "testing" "time" @@ -22,6 +23,7 @@ func newTestDialer() *websocket.Dialer { return &websocket.Dialer{ NetDial: net.Dial, HandshakeTimeout: 45 * time.Second, + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, //nolint:gosec } } @@ -36,28 +38,47 @@ func TestNewMultiplexer(t *testing.T) { } func TestHandleClientWebSocket(t *testing.T) { - store := kubeconfig.NewContextStore() - m := NewMultiplexer(store) + contextStore := kubeconfig.NewContextStore() + m := NewMultiplexer(contextStore) + // Create test server server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { m.HandleClientWebSocket(w, r) })) defer server.Close() - url := "ws" + strings.TrimPrefix(server.URL, "http") - + // Connect to test server dialer := newTestDialer() - conn, resp, err := dialer.Dial(url, nil) - if err == nil { - defer conn.Close() - } + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + ws, resp, err := dialer.Dial(wsURL, nil) + require.NoError(t, err) if resp != nil && resp.Body != nil { defer resp.Body.Close() } - assert.NoError(t, err, "Should successfully establish WebSocket connection") + defer ws.Close() + + // Test WATCH message + watchMsg := Message{ + Type: "WATCH", + ClusterID: "test-cluster", + Path: "/api/v1/pods", + UserID: "test-user", + } + err = ws.WriteJSON(watchMsg) + require.NoError(t, err) + + // Test CLOSE message + closeMsg := Message{ + Type: "CLOSE", + ClusterID: "test-cluster", + Path: "/api/v1/pods", + UserID: "test-user", + } + err = ws.WriteJSON(closeMsg) + require.NoError(t, err) } func TestGetClusterConfigWithFallback(t *testing.T) { @@ -104,21 +125,20 @@ func TestDialWebSocket(t *testing.T) { return true // Allow all connections for testing }, } - - c, err := upgrader.Upgrade(w, r, nil) + ws, err := upgrader.Upgrade(w, r, nil) if err != nil { t.Logf("Upgrade error: %v", err) return } - defer c.Close() + defer ws.Close() // Echo incoming messages back to the client for { - mt, message, err := c.ReadMessage() + mt, message, err := ws.ReadMessage() if err != nil { break } - err = c.WriteMessage(mt, message) + err = ws.WriteMessage(mt, message) if err != nil { break } @@ -129,6 +149,7 @@ func TestDialWebSocket(t *testing.T) { wsURL := "ws" + strings.TrimPrefix(server.URL, "http") conn, err := m.dialWebSocket(wsURL, &tls.Config{InsecureSkipVerify: true}, server.URL) //nolint:gosec + assert.NoError(t, err) assert.NotNil(t, conn) @@ -137,6 +158,23 @@ func TestDialWebSocket(t *testing.T) { } } +func TestDialWebSocket_Errors(t *testing.T) { + contextStore := kubeconfig.NewContextStore() + m := NewMultiplexer(contextStore) + + // Test invalid URL + tlsConfig := &tls.Config{InsecureSkipVerify: true} //nolint:gosec + + ws, err := m.dialWebSocket("invalid-url", tlsConfig, "") + assert.Error(t, err) + assert.Nil(t, ws) + + // Test unreachable URL + ws, err = m.dialWebSocket("ws://localhost:12345", tlsConfig, "") + assert.Error(t, err) + assert.Nil(t, ws) +} + func TestMonitorConnection(t *testing.T) { m := NewMultiplexer(kubeconfig.NewContextStore()) clientConn, _ := createTestWebSocketConnection() @@ -158,6 +196,94 @@ func TestMonitorConnection(t *testing.T) { assert.Equal(t, StateClosed, conn.Status.State) } +func TestUpdateStatus(t *testing.T) { + conn := &Connection{ + Status: ConnectionStatus{}, + Done: make(chan struct{}), + } + + // Test different state transitions + states := []ConnectionState{ + StateConnecting, + StateConnected, + StateClosed, + StateError, + } + + for _, state := range states { + conn.Status.State = state + assert.Equal(t, state, conn.Status.State) + } + + // Test concurrent updates + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + + go func(i int) { + defer wg.Done() + + state := states[i%len(states)] + conn.Status.State = state + }(i) + } + wg.Wait() + + // Verify final state is valid + assert.Contains(t, states, conn.Status.State) +} + +func TestMonitorConnection_Reconnect(t *testing.T) { + contextStore := kubeconfig.NewContextStore() + m := NewMultiplexer(contextStore) + + // Create a server that will accept the connection and then close it + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upgrader := websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true + }, + } + ws, err := upgrader.Upgrade(w, r, nil) + require.NoError(t, err) + + defer ws.Close() + + // Keep connection alive briefly + time.Sleep(100 * time.Millisecond) + ws.Close() + })) + + defer server.Close() + + conn := &Connection{ + Status: ConnectionStatus{ + State: StateConnecting, + }, + Done: make(chan struct{}), + } + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + tlsConfig := &tls.Config{InsecureSkipVerify: true} //nolint:gosec + + ws, err := m.dialWebSocket(wsURL, tlsConfig, "") + require.NoError(t, err) + + conn.WSConn = ws + + // Start monitoring in a goroutine + go m.monitorConnection(conn) + + // Wait for state transitions + time.Sleep(300 * time.Millisecond) + + // Verify connection status, it should reconnect + assert.Equal(t, StateConnecting, conn.Status.State) + + // Clean up + close(conn.Done) +} + //nolint:funlen func TestHandleClusterMessages(t *testing.T) { m := NewMultiplexer(kubeconfig.NewContextStore()) @@ -225,7 +351,7 @@ func TestHandleClusterMessages(t *testing.T) { t.Fatal("Test timed out") } - assert.Equal(t, StateClosed, conn.Status.State) + assert.Equal(t, StateConnecting, conn.Status.State) } func TestCleanupConnections(t *testing.T) { @@ -245,42 +371,198 @@ func TestCleanupConnections(t *testing.T) { assert.Equal(t, StateClosed, conn.Status.State) } -func createTestWebSocketConnection() (*websocket.Conn, *httptest.Server) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - upgrader := websocket.Upgrader{} - c, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return - } +func TestCreateWebSocketURL(t *testing.T) { + tests := []struct { + name string + host string + path string + query string + expected string + }{ + { + name: "basic URL without query", + host: "http://localhost:8080", + path: "/api/v1/pods", + query: "", + expected: "wss://localhost:8080/api/v1/pods", + }, + { + name: "URL with query parameters", + host: "https://example.com", + path: "/api/v1/pods", + query: "watch=true", + expected: "wss://example.com/api/v1/pods?watch=true", + }, + { + name: "URL with path and multiple query parameters", + host: "https://k8s.example.com", + path: "/api/v1/namespaces/default/pods", + query: "watch=true&labelSelector=app%3Dnginx", + expected: "wss://k8s.example.com/api/v1/namespaces/default/pods?watch=true&labelSelector=app%3Dnginx", + }, + } - defer c.Close() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := createWebSocketURL(tt.host, tt.path, tt.query) + assert.Equal(t, tt.expected, result) + }) + } +} - for { - mt, message, err := c.ReadMessage() - if err != nil { - break - } +func TestGetOrCreateConnection(t *testing.T) { + store := kubeconfig.NewContextStore() + m := NewMultiplexer(store) - err = c.WriteMessage(mt, message) - if err != nil { - break - } - } - })) + // Create a mock Kubernetes API server + mockServer := createMockKubeAPIServer() + defer mockServer.Close() - wsURL := "ws" + strings.TrimPrefix(server.URL, "http") - dialer := newTestDialer() + // Add a mock cluster config with our test server URL + err := store.AddContext(&kubeconfig.Context{ + Name: "test-cluster", + Cluster: &api.Cluster{ + Server: mockServer.URL, + InsecureSkipTLSVerify: true, + CertificateAuthorityData: nil, + }, + }) + require.NoError(t, err) - ws, resp, err := dialer.Dial(wsURL, nil) - if err != nil { - panic(err) + clientConn, clientServer := createTestWebSocketConnection() + defer clientServer.Close() + + // Test getting a non-existent connection (should create new) + msg := Message{ + ClusterID: "test-cluster", + Path: "/api/v1/pods", + Query: "watch=true", + UserID: "test-user", } - if resp != nil && resp.Body != nil { - defer resp.Body.Close() + conn, err := m.getOrCreateConnection(msg, clientConn) + assert.NoError(t, err) + assert.NotNil(t, conn) + assert.Equal(t, "test-cluster", conn.ClusterID) + assert.Equal(t, "test-user", conn.UserID) + assert.Equal(t, "/api/v1/pods", conn.Path) + assert.Equal(t, "watch=true", conn.Query) + + // Test getting an existing connection + conn2, err := m.getOrCreateConnection(msg, clientConn) + assert.NoError(t, err) + assert.Equal(t, conn, conn2, "Should return the same connection instance") + + // Test with invalid cluster + msg.ClusterID = "non-existent-cluster" + conn3, err := m.getOrCreateConnection(msg, clientConn) + assert.Error(t, err) + assert.Nil(t, conn3) +} + +func TestEstablishClusterConnection(t *testing.T) { + store := kubeconfig.NewContextStore() + m := NewMultiplexer(store) + + // Create a mock Kubernetes API server + mockServer := createMockKubeAPIServer() + defer mockServer.Close() + + // Add a mock cluster config with our test server URL + err := store.AddContext(&kubeconfig.Context{ + Name: "test-cluster", + Cluster: &api.Cluster{ + Server: mockServer.URL, + InsecureSkipTLSVerify: true, + CertificateAuthorityData: nil, + }, + }) + require.NoError(t, err) + + clientConn, clientServer := createTestWebSocketConnection() + defer clientServer.Close() + + // Test successful connection establishment + conn, err := m.establishClusterConnection("test-cluster", "test-user", "/api/v1/pods", "watch=true", clientConn) + assert.NoError(t, err) + assert.NotNil(t, conn) + assert.Equal(t, "test-cluster", conn.ClusterID) + assert.Equal(t, "test-user", conn.UserID) + assert.Equal(t, "/api/v1/pods", conn.Path) + assert.Equal(t, "watch=true", conn.Query) + + // Test with invalid cluster + conn, err = m.establishClusterConnection("non-existent", "test-user", "/api/v1/pods", "watch=true", clientConn) + assert.Error(t, err) + assert.Nil(t, conn) +} + +//nolint:funlen +func TestReconnect(t *testing.T) { + store := kubeconfig.NewContextStore() + m := NewMultiplexer(store) + + // Create a mock Kubernetes API server + mockServer := createMockKubeAPIServer() + defer mockServer.Close() + + // Add a mock cluster config with our test server URL + err := store.AddContext(&kubeconfig.Context{ + Name: "test-cluster", + Cluster: &api.Cluster{ + Server: mockServer.URL, + InsecureSkipTLSVerify: true, + CertificateAuthorityData: nil, + }, + }) + require.NoError(t, err) + + clientConn, clientServer := createTestWebSocketConnection() + defer clientServer.Close() + + // Create initial connection + conn := m.createConnection("test-cluster", "test-user", "/api/v1/pods", "watch=true", clientConn) + conn.Status.State = StateError // Simulate an error state + + // Test successful reconnection + newConn, err := m.reconnect(conn) + assert.NoError(t, err) + assert.NotNil(t, newConn) + assert.Equal(t, StateConnected, newConn.Status.State) + assert.Equal(t, conn.ClusterID, newConn.ClusterID) + assert.Equal(t, conn.UserID, newConn.UserID) + assert.Equal(t, conn.Path, newConn.Path) + assert.Equal(t, conn.Query, newConn.Query) + + // Test reconnection with invalid cluster + conn.ClusterID = "non-existent" + newConn, err = m.reconnect(conn) + assert.Error(t, err) + assert.Nil(t, newConn) + assert.Contains(t, err.Error(), "getting context: key not found") + + // Test reconnection with closed connection + conn = m.createConnection("test-cluster", "test-user", "/api/v1/pods", "watch=true", clientConn) + clusterConn, err := m.establishClusterConnection("test-cluster", "test-user", "/api/v1/pods", "watch=true", clientConn) + require.NoError(t, err) + require.NotNil(t, clusterConn) + + // Close the connection and wait for cleanup + conn.closed = true + if conn.WSConn != nil { + conn.WSConn.Close() } - return ws, server + if conn.Client != nil { + conn.Client.Close() + } + + close(conn.Done) + + // Try to reconnect the closed connection + newConn, err = m.reconnect(conn) + assert.Error(t, err) + assert.Nil(t, newConn) } func TestCloseConnection(t *testing.T) { @@ -292,14 +574,10 @@ func TestCloseConnection(t *testing.T) { connKey := "test-cluster:/api/v1/pods:test-user" m.connections[connKey] = conn - err := m.CloseConnection("test-cluster", "/api/v1/pods", "test-user") - assert.NoError(t, err) + m.CloseConnection("test-cluster", "/api/v1/pods", "test-user") assert.Empty(t, m.connections) - assert.Equal(t, StateClosed, conn.Status.State) - - // Test closing a non-existent connection - err = m.CloseConnection("non-existent", "/api/v1/pods", "test-user") - assert.Error(t, err) + // It will reconnect to the cluster + assert.Equal(t, StateConnecting, conn.Status.State) } func TestCreateWrapperMessage(t *testing.T) { @@ -424,3 +702,376 @@ func TestWriteMessageToCluster(t *testing.T) { assert.Error(t, err) assert.Equal(t, StateError, conn.Status.State) } + +//nolint:funlen +func TestReadClientMessage_InvalidMessage(t *testing.T) { + contextStore := kubeconfig.NewContextStore() + m := NewMultiplexer(contextStore) + + // Create a server that will echo messages back + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upgrader := websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true + }, + } + ws, err := upgrader.Upgrade(w, r, nil) + require.NoError(t, err) + + defer ws.Close() + + // Echo messages back + for { + messageType, p, err := ws.ReadMessage() + if err != nil { + return + } + err = ws.WriteMessage(messageType, p) + if err != nil { + return + } + } + })) + defer server.Close() + + // Connect to the server + dialer := newTestDialer() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + clientConn, _, err := dialer.Dial(wsURL, nil) //nolint:bodyclose + require.NoError(t, err) + + defer clientConn.Close() + + // Test completely invalid JSON + err = clientConn.WriteMessage(websocket.TextMessage, []byte("not json at all")) + require.NoError(t, err) + + msg, err := m.readClientMessage(clientConn) + require.Error(t, err) + assert.Equal(t, Message{}, msg) + + // Test JSON with invalid data type + err = clientConn.WriteJSON(map[string]interface{}{ + "type": "INVALID", + "data": 123, // data should be string + }) + require.NoError(t, err) + + msg, err = m.readClientMessage(clientConn) + require.Error(t, err) + assert.Equal(t, Message{}, msg) + + // Test empty JSON object + err = clientConn.WriteMessage(websocket.TextMessage, []byte("{}")) + require.NoError(t, err) + + msg, err = m.readClientMessage(clientConn) + // Empty message is valid JSON but will be unmarshaled into an empty Message struct + require.NoError(t, err) + assert.Equal(t, Message{}, msg) + + // Test missing required fields + err = clientConn.WriteJSON(map[string]interface{}{ + "data": "some data", + // Missing type field + }) + require.NoError(t, err) + + msg, err = m.readClientMessage(clientConn) + // Missing fields are allowed by json.Unmarshal + require.NoError(t, err) + assert.Equal(t, Message{Data: "some data"}, msg) +} + +func TestUpdateStatus_WithError(t *testing.T) { + clientConn, clientServer := createTestWebSocketConnection() + defer clientServer.Close() + + conn := &Connection{ + Status: ConnectionStatus{}, + Done: make(chan struct{}), + Client: clientConn, + } + + // Test error state with message + testErr := fmt.Errorf("test error") + conn.updateStatus(StateError, testErr) + assert.Equal(t, StateError, conn.Status.State) + assert.Equal(t, testErr.Error(), conn.Status.Error) + + // Test state change without error + conn.updateStatus(StateConnected, nil) + assert.Equal(t, StateConnected, conn.Status.State) + assert.Empty(t, conn.Status.Error) + + // Test with closed connection - state should remain error + conn.updateStatus(StateError, testErr) + assert.Equal(t, StateError, conn.Status.State) + assert.Equal(t, testErr.Error(), conn.Status.Error) + + close(conn.Done) + conn.closed = true // Mark connection as closed + + // Try to update state after close - should not change + conn.updateStatus(StateConnected, nil) + assert.Equal(t, StateError, conn.Status.State) // State should not change after close + assert.Equal(t, testErr.Error(), conn.Status.Error) // Error should remain +} + +func TestMonitorConnection_ReconnectFailure(t *testing.T) { + store := kubeconfig.NewContextStore() + m := NewMultiplexer(store) + + // Add an invalid cluster config to force reconnection failure + err := store.AddContext(&kubeconfig.Context{ + Name: "test-cluster", + Cluster: &api.Cluster{ + Server: "https://invalid-server:8443", + }, + }) + require.NoError(t, err) + + clientConn, _ := createTestWebSocketConnection() + conn := m.createConnection("test-cluster", "test-user", "/api/v1/pods", "", clientConn) + conn.WSConn, _ = createTestWebSocketConnection() + + // Start monitoring + done := make(chan struct{}) + go func() { + m.monitorConnection(conn) + close(done) + }() + + // Force connection closure and error state + conn.updateStatus(StateError, fmt.Errorf("forced error")) + conn.WSConn.Close() + + // Wait briefly to ensure error state is set + time.Sleep(50 * time.Millisecond) + + // Verify connection is in error state + assert.Equal(t, StateError, conn.Status.State) + assert.NotEmpty(t, conn.Status.Error) + + close(conn.Done) + <-done +} + +func TestHandleClientWebSocket_InvalidMessages(t *testing.T) { + m := NewMultiplexer(kubeconfig.NewContextStore()) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + m.HandleClientWebSocket(w, r) + })) + defer server.Close() + + // Test invalid JSON + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + ws, resp, err := websocket.DefaultDialer.Dial(wsURL, nil) + require.NoError(t, err) + + if resp != nil && resp.Body != nil { + defer resp.Body.Close() + } + + err = ws.WriteMessage(websocket.TextMessage, []byte("invalid json")) + require.NoError(t, err) + + // Should receive an error message or close + _, message, err := ws.ReadMessage() + if err != nil { + // Connection may be closed due to error + if !websocket.IsCloseError(err, websocket.CloseAbnormalClosure) { + t.Errorf("expected abnormal closure, got %v", err) + } + } else { + assert.Contains(t, string(message), "error") + } + + ws.Close() + + // Test invalid message type with new connection + ws, resp, err = websocket.DefaultDialer.Dial(wsURL, nil) + require.NoError(t, err) + + if resp != nil && resp.Body != nil { + defer resp.Body.Close() + } + + defer ws.Close() + + err = ws.WriteJSON(Message{ + Type: "INVALID_TYPE", + ClusterID: "test-cluster", + Path: "/api/v1/pods", + UserID: "test-user", + }) + require.NoError(t, err) + + // Should receive an error message or close + _, message, err = ws.ReadMessage() + if err != nil { + // Connection may be closed due to error + if !websocket.IsCloseError(err, websocket.CloseAbnormalClosure) { + t.Errorf("expected abnormal closure, got %v", err) + } + } else { + assert.Contains(t, string(message), "error") + } +} + +func TestSendIfNewResourceVersion_VersionComparison(t *testing.T) { + m := NewMultiplexer(kubeconfig.NewContextStore()) + clientConn, clientServer := createTestWebSocketConnection() + + defer clientServer.Close() + + conn := &Connection{ + ClusterID: "test-cluster", + Path: "/api/v1/pods", + UserID: "test-user", + Client: clientConn, + } + + // Initialize lastVersion pointer + lastVersion := "" + + // Test initial version + message := []byte(`{"metadata":{"resourceVersion":"100"}}`) + err := m.sendIfNewResourceVersion(message, conn, clientConn, &lastVersion) + require.NoError(t, err) + assert.Equal(t, "100", lastVersion) + + // Test same version - should not send + err = m.sendIfNewResourceVersion(message, conn, clientConn, &lastVersion) + require.NoError(t, err) + assert.Equal(t, "100", lastVersion) + + // Test newer version + message = []byte(`{"metadata":{"resourceVersion":"200"}}`) + err = m.sendIfNewResourceVersion(message, conn, clientConn, &lastVersion) + require.NoError(t, err) + assert.Equal(t, "200", lastVersion) + + // Test invalid JSON + message = []byte(`invalid json`) + err = m.sendIfNewResourceVersion(message, conn, clientConn, &lastVersion) + assert.Error(t, err) + assert.Equal(t, "200", lastVersion) // Version should not change on error + + // Test missing resourceVersion + message = []byte(`{"metadata":{}}`) + err = m.sendIfNewResourceVersion(message, conn, clientConn, &lastVersion) + require.NoError(t, err) // Should not error, but also not update version + assert.Equal(t, "200", lastVersion) +} + +func TestSendCompleteMessage_ClosedConnection(t *testing.T) { + m := NewMultiplexer(kubeconfig.NewContextStore()) + clientConn, clientServer := createTestWebSocketConnection() + + defer clientServer.Close() + + conn := &Connection{ + ClusterID: "test-cluster", + Path: "/api/v1/pods", + UserID: "test-user", + Query: "watch=true", + } + + // Test successful complete message + err := m.sendCompleteMessage(conn, clientConn) + require.NoError(t, err) + + // Verify the message + _, message, err := clientConn.ReadMessage() + require.NoError(t, err) + + var msg Message + err = json.Unmarshal(message, &msg) + require.NoError(t, err) + + assert.Equal(t, "COMPLETE", msg.Type) + assert.Equal(t, conn.ClusterID, msg.ClusterID) + assert.Equal(t, conn.Path, msg.Path) + assert.Equal(t, conn.Query, msg.Query) + assert.Equal(t, conn.UserID, msg.UserID) + + // Test with closed connection + clientConn.Close() + err = m.sendCompleteMessage(conn, clientConn) + assert.Error(t, err) +} + +func createMockKubeAPIServer() *httptest.Server { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upgrader := websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true + }, + } + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + + defer c.Close() + + // Echo messages back + for { + _, msg, err := c.ReadMessage() + if err != nil { + break + } + if err := c.WriteMessage(websocket.TextMessage, msg); err != nil { + break + } + } + })) + + // Configure the test client to accept the test server's TLS certificate + server.Client().Transport.(*http.Transport).TLSClientConfig = &tls.Config{ + InsecureSkipVerify: true, //nolint:gosec + } + + return server +} + +func createTestWebSocketConnection() (*websocket.Conn, *httptest.Server) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upgrader := websocket.Upgrader{} + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + + defer c.Close() + + for { + mt, message, err := c.ReadMessage() + if err != nil { + break + } + + err = c.WriteMessage(mt, message) + if err != nil { + break + } + } + })) + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + dialer := newTestDialer() + + ws, resp, err := dialer.Dial(wsURL, nil) + if err != nil { + panic(err) + } + + if resp != nil && resp.Body != nil { + defer resp.Body.Close() + } + + return ws, server +} diff --git a/frontend/package.json b/frontend/package.json index 55efa85ff4..50c28eb622 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -120,6 +120,7 @@ "build": "cross-env PUBLIC_URL=./ NODE_OPTIONS=--max-old-space-size=8096 vite build && npx shx rm -f build/frontend/index.baseUrl.html", "pretest": "npm run make-version", "test": "vitest", + "start-without-multiplexer": "cross-env REACT_APP_ENABLE_WEBSOCKET_MULTIPLEXER=false npm run start", "lint": "eslint --cache -c package.json --ext .js,.ts,.tsx src/ ../app/electron ../plugins/headlamp-plugin --ignore-pattern ../plugins/headlamp-plugin/template --ignore-pattern ../plugins/headlamp-plugin/lib/", "format": "prettier --config package.json --write --cache src ../app/electron ../app/tsconfig.json ../app/scripts ../plugins/headlamp-plugin/bin ../plugins/headlamp-plugin/config ../plugins/headlamp-plugin/template ../plugins/headlamp-plugin/test*.js ../plugins/headlamp-plugin/*.json ../plugins/headlamp-plugin/*.js", "format-check": "prettier --config package.json --check --cache src ../app/electron ../app/tsconfig.json ../app/scripts ../plugins/headlamp-plugin/bin ../plugins/headlamp-plugin/config ../plugins/headlamp-plugin/template ../plugins/headlamp-plugin/test*.js ../plugins/headlamp-plugin/*.json ../plugins/headlamp-plugin/*.js", diff --git a/frontend/src/components/common/Resource/index.test.ts b/frontend/src/components/common/Resource/index.test.ts index ae4c85ad8c..c2786f7eba 100644 --- a/frontend/src/components/common/Resource/index.test.ts +++ b/frontend/src/components/common/Resource/index.test.ts @@ -41,7 +41,7 @@ function getFilesToVerify() { const filesToVerify: string[] = []; fs.readdirSync(__dirname).forEach(file => { const fileNoSuffix = file.replace(/\.[^/.]+$/, ''); - if (!avoidCheck.find(suffix => fileNoSuffix.endsWith(suffix))) { + if (fileNoSuffix && !avoidCheck.find(suffix => fileNoSuffix.endsWith(suffix))) { filesToVerify.push(fileNoSuffix); } }); diff --git a/frontend/src/components/common/index.test.ts b/frontend/src/components/common/index.test.ts index a1d343491c..3cf0ac59b0 100644 --- a/frontend/src/components/common/index.test.ts +++ b/frontend/src/components/common/index.test.ts @@ -50,7 +50,7 @@ function getFilesToVerify() { const filesToVerify: string[] = []; fs.readdirSync(__dirname).forEach(file => { const fileNoSuffix = file.replace(/\.[^/.]+$/, ''); - if (!avoidCheck.find(suffix => fileNoSuffix.endsWith(suffix))) { + if (fileNoSuffix && !avoidCheck.find(suffix => fileNoSuffix.endsWith(suffix))) { filesToVerify.push(fileNoSuffix); } }); diff --git a/frontend/src/helpers/index.ts b/frontend/src/helpers/index.ts index 40780c37e9..e3500745c7 100644 --- a/frontend/src/helpers/index.ts +++ b/frontend/src/helpers/index.ts @@ -352,6 +352,14 @@ function loadTableSettings(tableId: string): { id: string; show: boolean }[] { return settings; } +/** + * @returns true if the websocket multiplexer is enabled. + * defaults to true. This is a feature flag to enable the websocket multiplexer. + */ +export function getWebsocketMultiplexerEnabled(): boolean { + return import.meta.env.REACT_APP_ENABLE_WEBSOCKET_MULTIPLEXER !== 'false'; +} + /** * The backend token to use when making API calls from Headlamp when running as an app. * The app opens the index.html?backendToken=... and passes the token to the frontend @@ -393,6 +401,7 @@ const exportFunctions = { storeClusterSettings, loadClusterSettings, getHeadlampAPIHeaders, + getWebsocketMultiplexerEnabled, storeTableSettings, loadTableSettings, }; diff --git a/frontend/src/lib/k8s/api/v2/hooks.ts b/frontend/src/lib/k8s/api/v2/hooks.ts index bde11f2425..a88b0b30c9 100644 --- a/frontend/src/lib/k8s/api/v2/hooks.ts +++ b/frontend/src/lib/k8s/api/v2/hooks.ts @@ -133,7 +133,7 @@ export function useKubeObject({ const data: Instance | null = query.error ? null : query.data ?? null; - useWebSocket>({ + useWebSocket>({ url: () => makeUrl([KubeObjectEndpoint.toUrl(endpoint!)], { ...cleanedUpQueryParams, @@ -142,7 +142,7 @@ export function useKubeObject({ }), enabled: !!endpoint && !!data, cluster, - onMessage(update) { + onMessage(update: KubeListUpdateEvent) { if (update.type !== 'ADDED' && update.object) { client.setQueryData(queryKey, new kubeObjectClass(update.object)); } diff --git a/frontend/src/lib/k8s/api/v2/useKubeObjectList.test.tsx b/frontend/src/lib/k8s/api/v2/useKubeObjectList.test.tsx index 4018f1a533..68150027f0 100644 --- a/frontend/src/lib/k8s/api/v2/useKubeObjectList.test.tsx +++ b/frontend/src/lib/k8s/api/v2/useKubeObjectList.test.tsx @@ -1,5 +1,6 @@ import { QueryClient, QueryClientProvider } from '@tanstack/react-query'; import { renderHook } from '@testing-library/react'; +import { describe, expect, it, vi } from 'vitest'; import { kubeObjectListQuery, ListResponse, @@ -9,6 +10,18 @@ import { } from './useKubeObjectList'; import * as websocket from './webSocket'; +// Mock WebSocket functionality +const mockUseWebSockets = vi.fn(); +const mockSubscribe = vi.fn().mockImplementation(() => Promise.resolve(() => {})); + +vi.mock('./webSocket', () => ({ + useWebSockets: (...args: any[]) => mockUseWebSockets(...args), + WebSocketManager: { + subscribe: (...args: any[]) => mockSubscribe(...args), + }, + BASE_WS_URL: 'http://localhost:3000', +})); + describe('makeListRequests', () => { describe('for non namespaced resource', () => { it('should not include namespace in requests', () => { @@ -96,6 +109,11 @@ const mockClass = class { } as any; describe('useWatchKubeObjectLists', () => { + beforeEach(() => { + vi.stubEnv('REACT_APP_ENABLE_WEBSOCKET_MULTIPLEXER', 'false'); + vi.clearAllMocks(); + }); + it('should not be enabled when no endpoint is provided', () => { const spy = vi.spyOn(websocket, 'useWebSockets'); const queryClient = new QueryClient(); @@ -271,3 +289,97 @@ describe('useKubeObjectList', () => { expect(spy.mock.calls[3][0].connections.length).toBe(1); // updated connections after we removed namespace 'b' }); }); + +describe('useWatchKubeObjectLists (Multiplexer)', () => { + beforeEach(() => { + vi.stubEnv('REACT_APP_ENABLE_WEBSOCKET_MULTIPLEXER', 'true'); + vi.clearAllMocks(); + }); + + it('should subscribe using WebSocketManager when multiplexer is enabled', () => { + const lists = [{ cluster: 'cluster-a', namespace: 'namespace-a', resourceVersion: '1' }]; + + renderHook( + () => + useWatchKubeObjectLists({ + kubeObjectClass: mockClass, + endpoint: { version: 'v1', resource: 'pods' }, + lists, + }), + { + wrapper: ({ children }) => ( + {children} + ), + } + ); + + expect(mockSubscribe).toHaveBeenCalledWith( + 'cluster-a', + expect.stringContaining('/api/v1/namespaces/namespace-a/pods'), + 'watch=1&resourceVersion=1', + expect.any(Function) + ); + }); + + it('should subscribe to multiple clusters', () => { + const lists = [ + { cluster: 'cluster-a', namespace: 'namespace-a', resourceVersion: '1' }, + { cluster: 'cluster-b', namespace: 'namespace-b', resourceVersion: '2' }, + ]; + + renderHook( + () => + useWatchKubeObjectLists({ + kubeObjectClass: mockClass, + endpoint: { version: 'v1', resource: 'pods' }, + lists, + }), + { + wrapper: ({ children }) => ( + {children} + ), + } + ); + + expect(mockSubscribe).toHaveBeenCalledTimes(2); + expect(mockSubscribe).toHaveBeenNthCalledWith( + 1, + 'cluster-a', + expect.stringContaining('/api/v1/namespaces/namespace-a/pods'), + 'watch=1&resourceVersion=1', + expect.any(Function) + ); + expect(mockSubscribe).toHaveBeenNthCalledWith( + 2, + 'cluster-b', + expect.stringContaining('/api/v1/namespaces/namespace-b/pods'), + 'watch=1&resourceVersion=2', + expect.any(Function) + ); + }); + + it('should handle non-namespaced resources', () => { + const lists = [{ cluster: 'cluster-a', resourceVersion: '1' }]; + + renderHook( + () => + useWatchKubeObjectLists({ + kubeObjectClass: mockClass, + endpoint: { version: 'v1', resource: 'pods' }, + lists, + }), + { + wrapper: ({ children }) => ( + {children} + ), + } + ); + + expect(mockSubscribe).toHaveBeenCalledWith( + 'cluster-a', + expect.stringContaining('/api/v1/pods'), + 'watch=1&resourceVersion=1', + expect.any(Function) + ); + }); +}); diff --git a/frontend/src/lib/k8s/api/v2/useKubeObjectList.ts b/frontend/src/lib/k8s/api/v2/useKubeObjectList.ts index 9d8e203340..31d340bc9a 100644 --- a/frontend/src/lib/k8s/api/v2/useKubeObjectList.ts +++ b/frontend/src/lib/k8s/api/v2/useKubeObjectList.ts @@ -1,5 +1,6 @@ import { QueryObserverOptions, useQueries, useQueryClient } from '@tanstack/react-query'; -import { useMemo, useState } from 'react'; +import { useCallback, useEffect, useMemo, useRef, useState } from 'react'; +import { getWebsocketMultiplexerEnabled } from '../../../../helpers'; import { KubeObject, KubeObjectClass } from '../../KubeObject'; import { ApiError } from '../v1/clusterRequests'; import { QueryParameters } from '../v1/queryParameters'; @@ -8,7 +9,7 @@ import { QueryListResponse, useEndpoints } from './hooks'; import { KubeList, KubeListUpdateEvent } from './KubeList'; import { KubeObjectEndpoint } from './KubeObjectEndpoint'; import { makeUrl } from './makeUrl'; -import { useWebSockets } from './webSocket'; +import { BASE_WS_URL, useWebSockets, WebSocketManager } from './webSocket'; /** * Object representing a List of Kube object @@ -115,6 +116,186 @@ export function useWatchKubeObjectLists({ endpoint?: KubeObjectEndpoint | null; /** Which clusters and namespaces to watch */ lists: Array<{ cluster: string; namespace?: string; resourceVersion: string }>; +}) { + if (getWebsocketMultiplexerEnabled()) { + return useWatchKubeObjectListsMultiplexed({ + kubeObjectClass, + endpoint, + lists, + queryParams, + }); + } else { + return useWatchKubeObjectListsLegacy({ + kubeObjectClass, + endpoint, + lists, + queryParams, + }); + } +} + +/** + * Watches Kubernetes resource lists using multiplexed WebSocket connections. + * Efficiently manages subscriptions and updates to prevent unnecessary re-renders + * and WebSocket reconnections. + * + * @template K - Type extending KubeObject for the resources being watched + * @param kubeObjectClass - Class constructor for the Kubernetes resource type + * @param endpoint - API endpoint information for the resource + * @param lists - Array of cluster, namespace, and resourceVersion combinations to watch + * @param queryParams - Optional query parameters for the WebSocket URL + */ +function useWatchKubeObjectListsMultiplexed({ + kubeObjectClass, + endpoint, + lists, + queryParams, +}: { + kubeObjectClass: (new (...args: any) => K) & typeof KubeObject; + endpoint?: KubeObjectEndpoint | null; + lists: Array<{ cluster: string; namespace?: string; resourceVersion: string }>; + queryParams?: QueryParameters; +}): void { + const client = useQueryClient(); + + // Track the latest resource versions to prevent duplicate updates + const latestResourceVersions = useRef>({}); + + // Stabilize queryParams to prevent unnecessary effect triggers + // Only update when the stringified params change + const stableQueryParams = useMemo(() => queryParams, [JSON.stringify(queryParams)]); + + // Create stable connection URLs for each list + // Updates only when endpoint, lists, or stableQueryParams change + const connections = useMemo(() => { + if (!endpoint) { + return []; + } + + return lists.map(list => { + const key = `${list.cluster}:${list.namespace || ''}`; + + // Update resource version if newer one is available + const currentVersion = latestResourceVersions.current[key]; + const newVersion = list.resourceVersion; + if (!currentVersion || parseInt(newVersion) > parseInt(currentVersion)) { + latestResourceVersions.current[key] = newVersion; + } + + // Construct WebSocket URL with current parameters + return { + url: makeUrl([KubeObjectEndpoint.toUrl(endpoint, list.namespace)], { + ...stableQueryParams, + watch: 1, + resourceVersion: latestResourceVersions.current[key], + }), + cluster: list.cluster, + namespace: list.namespace, + }; + }); + }, [endpoint, lists, stableQueryParams]); + + // Create stable update handler to process WebSocket messages + // Re-create only when dependencies change + const handleUpdate = useCallback( + (update: any, cluster: string, namespace: string | undefined) => { + if (!update || typeof update !== 'object' || !endpoint) { + return; + } + + const key = `${cluster}:${namespace || ''}`; + + // Update resource version from incoming message + if (update.object?.metadata?.resourceVersion) { + latestResourceVersions.current[key] = update.object.metadata.resourceVersion; + } + + // Create query key for React Query cache + const queryKey = kubeObjectListQuery( + kubeObjectClass, + endpoint, + namespace, + cluster, + stableQueryParams ?? {} + ).queryKey; + + // Update React Query cache with new data + client.setQueryData(queryKey, (oldResponse: ListResponse | undefined | null) => { + if (!oldResponse) { + return oldResponse; + } + + const newList = KubeList.applyUpdate(oldResponse.list, update, kubeObjectClass); + + // Only update if the list actually changed + if (newList === oldResponse.list) { + return oldResponse; + } + + return { ...oldResponse, list: newList }; + }); + }, + [client, kubeObjectClass, endpoint, stableQueryParams] + ); + + // Set up WebSocket subscriptions + useEffect(() => { + if (!endpoint || connections.length === 0) { + return; + } + + const cleanups: (() => void)[] = []; + + // Create subscriptions for each connection + connections.forEach(({ url, cluster, namespace }) => { + const parsedUrl = new URL(url, BASE_WS_URL); + + // Subscribe to WebSocket updates + WebSocketManager.subscribe(cluster, parsedUrl.pathname, parsedUrl.search.slice(1), update => + handleUpdate(update, cluster, namespace) + ).then( + cleanup => cleanups.push(cleanup), + error => { + // Track retry count in the URL's searchParams + const retryCount = parseInt(parsedUrl.searchParams.get('retryCount') || '0'); + if (retryCount < 3) { + // Only log and allow retry if under threshold + console.error('WebSocket subscription failed:', error); + parsedUrl.searchParams.set('retryCount', (retryCount + 1).toString()); + } + } + ); + }); + + // Cleanup subscriptions when effect re-runs or unmounts + return () => { + cleanups.forEach(cleanup => cleanup()); + }; + }, [connections, endpoint, handleUpdate]); +} + +/** + * Accepts a list of lists to watch. + * Upon receiving update it will modify query data for list query + * @param kubeObjectClass - KubeObject class of the watched resource list + * @param endpoint - Kube resource API endpoint information + * @param lists - Which clusters and namespaces to watch + * @param queryParams - Query parameters for the WebSocket connection URL + */ +function useWatchKubeObjectListsLegacy({ + kubeObjectClass, + endpoint, + lists, + queryParams, +}: { + /** KubeObject class of the watched resource list */ + kubeObjectClass: (new (...args: any) => K) & typeof KubeObject; + /** Query parameters for the WebSocket connection URL */ + queryParams?: QueryParameters; + /** Kube resource API endpoint information */ + endpoint?: KubeObjectEndpoint | null; + /** Which clusters and namespaces to watch */ + lists: Array<{ cluster: string; namespace?: string; resourceVersion: string }>; }) { const client = useQueryClient(); diff --git a/frontend/src/lib/k8s/api/v2/webSocket.test.ts b/frontend/src/lib/k8s/api/v2/webSocket.test.ts new file mode 100644 index 0000000000..d5f78242b1 --- /dev/null +++ b/frontend/src/lib/k8s/api/v2/webSocket.test.ts @@ -0,0 +1,590 @@ +import { renderHook } from '@testing-library/react'; +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; +import WS from 'vitest-websocket-mock'; +import { findKubeconfigByClusterName, getUserIdFromLocalStorage } from '../../../../stateless'; +import { getToken } from '../../../auth'; +import { getCluster } from '../../../cluster'; +import { BASE_WS_URL, useWebSocket, WebSocketManager } from './webSocket'; + +// Mock dependencies +vi.mock('../../../cluster', () => ({ + getCluster: vi.fn(), +})); + +vi.mock('../../../../stateless', () => ({ + getUserIdFromLocalStorage: vi.fn(), + findKubeconfigByClusterName: vi.fn(), +})); + +vi.mock('../../../auth', () => ({ + getToken: vi.fn(), +})); + +vi.mock('./makeUrl', () => ({ + makeUrl: vi.fn((paths: string[] | string, query = {}) => { + const url = Array.isArray(paths) ? paths.filter(Boolean).join('/') : paths; + const queryString = new URLSearchParams(query).toString(); + const fullUrl = queryString ? `${url}?${queryString}` : url; + return fullUrl.replace(/([^:]\/)\/+/g, '$1'); + }), +})); + +const clusterName = 'test-cluster'; +const userId = 'test-user'; +const token = 'test-token'; + +describe('WebSocket Tests', () => { + let mockServer: WS; + let onMessage: ReturnType; + let onError: ReturnType; + + beforeEach(() => { + vi.stubEnv('REACT_APP_ENABLE_WEBSOCKET_MULTIPLEXER', 'true'); + vi.clearAllMocks(); + onMessage = vi.fn(); + onError = vi.fn(); + (getCluster as ReturnType).mockReturnValue(clusterName); + (getUserIdFromLocalStorage as ReturnType).mockReturnValue(userId); + (getToken as ReturnType).mockReturnValue(token); + (findKubeconfigByClusterName as ReturnType).mockResolvedValue({}); + + mockServer = new WS(`${BASE_WS_URL}wsMultiplexer`); + }); + + afterEach(async () => { + WS.clean(); + vi.restoreAllMocks(); + vi.unstubAllEnvs(); + WebSocketManager.socketMultiplexer = null; + WebSocketManager.connecting = false; + WebSocketManager.isReconnecting = false; + WebSocketManager.listeners.clear(); + WebSocketManager.completedPaths.clear(); + WebSocketManager.activeSubscriptions.clear(); + WebSocketManager.pendingUnsubscribes.clear(); + }); + + describe('WebSocketManager', () => { + it('should establish connection and handle messages', async () => { + const path = '/api/v1/pods'; + const query = 'watch=true'; + + // Subscribe to pod updates + await WebSocketManager.subscribe(clusterName, path, query, onMessage); + await mockServer.connected; + + // Get the subscription message + const subscribeMsg = JSON.parse((await mockServer.nextMessage) as string); + expect(subscribeMsg).toEqual({ + clusterId: clusterName, + path, + query, + userId, + type: 'REQUEST', + }); + + // Send a message from server + const podData = { kind: 'Pod', metadata: { name: 'test-pod' } }; + const serverMessage = { + clusterId: clusterName, + path, + query, + data: JSON.stringify(podData), // Important: data needs to be stringified + type: 'DATA', + }; + + await mockServer.send(JSON.stringify(serverMessage)); + + // Wait for message processing + await vi.waitFor(() => { + expect(onMessage).toHaveBeenCalledWith(podData); + }); + }); + + it('should handle multiple subscriptions', async () => { + const subs = [ + { path: '/api/v1/pods', query: 'watch=true' }, + { path: '/api/v1/services', query: 'watch=true' }, + ]; + + // Subscribe to multiple resources + await Promise.all( + subs.map(sub => WebSocketManager.subscribe(clusterName, sub.path, sub.query, onMessage)) + ); + + await mockServer.connected; + + // Verify subscription messages + for (const sub of subs) { + const msg = JSON.parse((await mockServer.nextMessage) as string); + expect(msg).toEqual({ + clusterId: clusterName, + path: sub.path, + query: sub.query, + userId, + type: 'REQUEST', + }); + + // Send data for this subscription + const resourceData = { + kind: sub.path.includes('pods') ? 'Pod' : 'Service', + metadata: { name: `test-${sub.path}` }, + }; + + const serverMessage = { + clusterId: clusterName, + path: sub.path, + query: sub.query, + data: JSON.stringify(resourceData), + type: 'DATA', + }; + + await mockServer.send(JSON.stringify(serverMessage)); + } + + // Verify all messages were received + await vi.waitFor(() => { + expect(onMessage).toHaveBeenCalledTimes(2); + }); + }); + + it('should handle COMPLETE messages', async () => { + const path = '/api/v1/pods'; + const query = 'watch=true'; + + await WebSocketManager.subscribe(clusterName, path, query, onMessage); + await mockServer.connected; + + // Skip subscription message + await mockServer.nextMessage; + + // Send COMPLETE message + const completeMessage = { + clusterId: clusterName, + path, + query, + type: 'COMPLETE', + }; + + await mockServer.send(JSON.stringify(completeMessage)); + + // Verify the path is marked as completed + const key = WebSocketManager.createKey(clusterName, path, query); + expect(WebSocketManager.completedPaths.has(key)).toBe(true); + }); + + it('should handle unsubscribe', async () => { + const path = '/api/v1/pods'; + const query = 'watch=true'; + + const cleanup = await WebSocketManager.subscribe(clusterName, path, query, onMessage); + await mockServer.connected; + + // Skip subscription message + await mockServer.nextMessage; + + // Unsubscribe + cleanup(); + + // Wait for unsubscribe message (after debounce) + await vi.waitFor(async () => { + const msg = JSON.parse((await mockServer.nextMessage) as string); + expect(msg).toEqual({ + clusterId: clusterName, + path, + query, + userId, + type: 'CLOSE', + }); + }); + + // Verify subscription is removed + const key = WebSocketManager.createKey(clusterName, path, query); + expect(WebSocketManager.activeSubscriptions.has(key)).toBe(false); + }); + + it('should handle connection errors', async () => { + // Close the server to simulate connection failure + await mockServer.close(); + + // Attempt to subscribe should fail + await expect( + WebSocketManager.subscribe(clusterName, '/api/v1/pods', 'watch=true', onMessage) + ).rejects.toThrow('WebSocket connection failed'); + }); + + it('should handle duplicate subscriptions', async () => { + const path = '/api/v1/pods'; + const query = 'watch=true'; + + // Create two subscriptions with the same parameters + const onMessage2 = vi.fn(); + await WebSocketManager.subscribe(clusterName, path, query, onMessage); + await WebSocketManager.subscribe(clusterName, path, query, onMessage2); + + await mockServer.connected; + + // Should only receive one subscription message + const subMsg = JSON.parse((await mockServer.nextMessage) as string); + expect(subMsg.type).toBe('REQUEST'); + + // Send a message + const podData = { kind: 'Pod', metadata: { name: 'test-pod' } }; + await mockServer.send( + JSON.stringify({ + clusterId: clusterName, + path, + query, + data: JSON.stringify(podData), + type: 'DATA', + }) + ); + + // Both handlers should receive the message + await vi.waitFor(() => { + expect(onMessage).toHaveBeenCalledWith(podData); + expect(onMessage2).toHaveBeenCalledWith(podData); + }); + }); + + it('should debounce unsubscribe operations', async () => { + const path = '/api/v1/pods'; + const query = 'watch=true'; + + const cleanup = await WebSocketManager.subscribe(clusterName, path, query, onMessage); + await mockServer.connected; + + // Skip subscription message + await mockServer.nextMessage; + + // Unsubscribe + cleanup(); + + // Subscribe again immediately + await WebSocketManager.subscribe(clusterName, path, query, onMessage); + + // Wait for potential unsubscribe message + await vi.waitFor(() => { + const key = WebSocketManager.createKey(clusterName, path, query); + expect(WebSocketManager.activeSubscriptions.has(key)).toBe(true); + }); + + // Verify no CLOSE message was sent + try { + const msg = JSON.parse((await mockServer.nextMessage) as string); + expect(msg.type).not.toBe('CLOSE'); + } catch (e) { + // No message is also acceptable + } + }); + }); + + describe('useWebSocket hook', () => { + it('should not connect when disabled', () => { + renderHook(() => + useWebSocket({ + url: () => '/api/v1/pods', + enabled: false, + cluster: clusterName, + onMessage, + onError, + }) + ); + + expect(WebSocketManager.socketMultiplexer).toBeNull(); + }); + + it('should handle successful connection and messages', async () => { + const fullUrl = `${BASE_WS_URL}api/v1/pods`; + + renderHook(() => + useWebSocket({ + url: () => fullUrl, + enabled: true, + cluster: clusterName, + onMessage, + onError, + }) + ); + + await mockServer.connected; + + // Skip subscription message + await mockServer.nextMessage; + + // Send test message + const podData = { kind: 'Pod', metadata: { name: 'test-pod' } }; + await mockServer.send( + JSON.stringify({ + clusterId: clusterName, + path: '/api/v1/pods', + data: JSON.stringify(podData), + type: 'DATA', + }) + ); + + await vi.waitFor(() => { + expect(onMessage).toHaveBeenCalledWith(podData); + }); + }, 10000); + + it('should handle connection errors', async () => { + const fullUrl = `${BASE_WS_URL}api/v1/pods`; + + // Close the server to simulate connection failure + await mockServer.close(); + + renderHook(() => + useWebSocket({ + url: () => fullUrl, + enabled: true, + cluster: clusterName, + onMessage, + onError, + }) + ); + + await vi.waitFor(() => { + expect(onError).toHaveBeenCalled(); + }); + }); + + it('should cleanup on unmount', async () => { + const fullUrl = `${BASE_WS_URL}api/v1/pods`; + + const { unmount } = renderHook(() => + useWebSocket({ + url: () => fullUrl, + enabled: true, + cluster: clusterName, + onMessage, + onError, + }) + ); + + await mockServer.connected; + await mockServer.nextMessage; // Skip subscription + + // Unmount and wait for cleanup + unmount(); + + await vi.waitFor( + async () => { + const msg = JSON.parse((await mockServer.nextMessage) as string); + expect(msg.type).toBe('CLOSE'); + }, + { timeout: 10000 } + ); + }); + }); + + describe('WebSocket error handling', () => { + it('should handle polling timeout', async () => { + // Mock WebSocket to never open + const mockWS = vi.spyOn(window, 'WebSocket').mockImplementation(() => { + const ws = new EventTarget() as WebSocket; + Object.defineProperty(ws, 'readyState', { value: WebSocket.CONNECTING }); + Object.defineProperty(ws, 'send', { value: null }); + return ws; + }); + + const path = '/api/v1/pods'; + const query = 'watch=true'; + + let error: Error | null = null; + try { + await WebSocketManager.subscribe(clusterName, path, query, onMessage); + } catch (e) { + error = e as Error; + } + + expect(error).toBeTruthy(); + expect(error?.message).toBe("Cannot read properties of null (reading 'send')"); + + mockWS.mockRestore(); + }); + + it('should handle reconnection and resubscribe', async () => { + const path = '/api/v1/pods'; + const query = 'watch=true'; + + // First connection + await WebSocketManager.subscribe(clusterName, path, query, onMessage); + await mockServer.connected; + await mockServer.nextMessage; // Skip initial subscription + + // Close the connection to trigger reconnect + mockServer.close(); + + // Verify WebSocketManager state after close + expect(WebSocketManager.socketMultiplexer).toBeNull(); + expect(WebSocketManager.isReconnecting).toBe(true); + expect(WebSocketManager.connecting).toBe(false); + + // Try to use connection again to trigger reconnect + const newServer = new WS(`${BASE_WS_URL}wsMultiplexer`); + await WebSocketManager.connect(); + await newServer.connected; + + // Should get resubscription message + const resubMsg = JSON.parse((await newServer.nextMessage) as string); + expect(resubMsg).toEqual({ + clusterId: clusterName, + path, + query, + userId, + type: 'REQUEST', + }); + + // Verify reconnection state is reset + expect(WebSocketManager.isReconnecting).toBe(false); + + newServer.close(); + }); + + it('should handle WebSocket close event', async () => { + const path = '/api/v1/pods'; + const query = 'watch=true'; + + await WebSocketManager.subscribe(clusterName, path, query, onMessage); + await mockServer.connected; + + // Close the connection + mockServer.close(); + + // Verify WebSocket state after close + expect(WebSocketManager.socketMultiplexer).toBeNull(); + expect(WebSocketManager.connecting).toBe(false); + expect(WebSocketManager.completedPaths.size).toBe(0); + expect(WebSocketManager.isReconnecting).toBe(true); // Should be true since we have active subscriptions + }); + + it('should handle error in message callback', async () => { + const path = '/api/v1/pods'; + const query = 'watch=true'; + const error = new Error('Message processing failed'); + const errorCallback = vi.fn().mockImplementation(() => { + throw error; + }); + + await WebSocketManager.subscribe(clusterName, path, query, errorCallback); + await mockServer.connected; + await mockServer.nextMessage; // Skip subscription message + + const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {}); + + // Send message that will trigger error in callback + await mockServer.send( + JSON.stringify({ + clusterId: clusterName, + path, + query, + data: JSON.stringify({ kind: 'Pod' }), + type: 'DATA', + }) + ); + + expect(consoleSpy).toHaveBeenCalledWith('Failed to process WebSocket message:', error); + consoleSpy.mockRestore(); + }); + + it('should handle invalid message format', async () => { + const path = '/api/v1/pods'; + const query = 'watch=true'; + + await WebSocketManager.subscribe(clusterName, path, query, onMessage); + await mockServer.connected; + await mockServer.nextMessage; // Skip subscription + + // Send invalid message + await mockServer.send('invalid json'); + + expect(onMessage).not.toHaveBeenCalled(); + }); + + it('should handle parse errors in message data', async () => { + const path = '/api/v1/pods'; + const query = 'watch=true'; + const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {}); + + await WebSocketManager.subscribe(clusterName, path, query, onMessage); + await mockServer.connected; + + // Skip subscription message + await mockServer.nextMessage; + + // Send malformed data + await mockServer.send( + JSON.stringify({ + clusterId: clusterName, + path, + query, + data: 'invalid{json', + type: 'DATA', + }) + ); + + expect(onMessage).not.toHaveBeenCalled(); + expect(consoleSpy).toHaveBeenCalledWith('Failed to parse update data:', expect.any(Error)); + + consoleSpy.mockRestore(); + }); + + it('should handle message callback errors in useWebSocket', async () => { + const errorMessage = 'Message processing failed'; + const errorFn = vi.fn().mockImplementation(() => { + throw new Error(errorMessage); + }); + + renderHook(() => + useWebSocket({ + url: () => `${BASE_WS_URL}api/v1/pods`, + enabled: true, + cluster: clusterName, + onMessage: errorFn, + onError, + }) + ); + + await mockServer.connected; + await mockServer.nextMessage; // Skip subscription + + // Send message that will cause error in callback + await mockServer.send( + JSON.stringify({ + clusterId: clusterName, + path: '/api/v1/pods', + data: JSON.stringify({ kind: 'Pod' }), + type: 'DATA', + }) + ); + + expect(onError).toHaveBeenCalledWith(expect.any(Error)); + expect(onError).toHaveBeenCalledWith( + expect.objectContaining({ + message: errorMessage, + }) + ); + }); + + it('should handle missing fields in messages', async () => { + const path = '/api/v1/pods'; + const query = 'watch=true'; + + await WebSocketManager.subscribe(clusterName, path, query, onMessage); + await mockServer.connected; + + // Skip subscription message + await mockServer.nextMessage; + + // Send message without required fields + await mockServer.send( + JSON.stringify({ + data: JSON.stringify({ kind: 'Pod' }), + }) + ); + + expect(onMessage).not.toHaveBeenCalled(); + }); + }); +}); diff --git a/frontend/src/lib/k8s/api/v2/webSocket.ts b/frontend/src/lib/k8s/api/v2/webSocket.ts index 0ff934a00d..e8c12162a6 100644 --- a/frontend/src/lib/k8s/api/v2/webSocket.ts +++ b/frontend/src/lib/k8s/api/v2/webSocket.ts @@ -1,11 +1,468 @@ -import { useEffect, useMemo } from 'react'; +import { useCallback, useEffect, useMemo } from 'react'; import { findKubeconfigByClusterName, getUserIdFromLocalStorage } from '../../../../stateless'; import { getToken } from '../../../auth'; import { getCluster } from '../../../cluster'; import { BASE_HTTP_URL } from './fetch'; import { makeUrl } from './makeUrl'; -const BASE_WS_URL = BASE_HTTP_URL.replace('http', 'ws'); +// Constants for WebSocket connection +export const BASE_WS_URL = BASE_HTTP_URL.replace('http', 'ws'); + +/** + * Multiplexer endpoint for WebSocket connections + * This endpoint allows multiple subscriptions over a single connection + */ +const MULTIPLEXER_ENDPOINT = 'wsMultiplexer'; + +/** + * Message format for WebSocket communication between client and server. + * Used to manage subscriptions to Kubernetes resource updates. + */ +interface WebSocketMessage { + /** + * Cluster identifier used to route messages to the correct Kubernetes cluster. + * This is particularly important in multi-cluster environments. + */ + clusterId: string; + + /** + * API resource path that identifies the Kubernetes resource being watched. + * Example: '/api/v1/pods' or '/apis/apps/v1/deployments' + */ + path: string; + + /** + * Query parameters for filtering or modifying the watch request. + * Example: 'labelSelector=app%3Dnginx&fieldSelector=status.phase%3DRunning' + */ + query: string; + + /** + * User identifier for authentication and authorization. + * Used to ensure users only receive updates for resources they have access to. + */ + userId: string; + + /** + * Message type that indicates the purpose of the message: + * - REQUEST: Client is requesting to start watching a resource + * - CLOSE: Client wants to stop watching a resource + * - COMPLETE: Server indicates the watch request has completed (e.g., due to timeout or error) + */ + type: 'REQUEST' | 'CLOSE' | 'COMPLETE'; +} + +/** + * WebSocket manager to handle connections across the application. + * Provides a singleton-like interface for managing WebSocket connections, + * subscriptions, and message handling. Implements connection multiplexing + * to optimize network usage. + */ +export const WebSocketManager = { + /** Current WebSocket connection instance */ + socketMultiplexer: null as WebSocket | null, + + /** Flag to track if a connection attempt is in progress */ + connecting: false, + + /** Flag to track if we're reconnecting after a disconnect */ + isReconnecting: false, + + /** Map of message handlers for each subscription path */ + listeners: new Map void>>(), + + /** Set of paths that have received a COMPLETE message */ + completedPaths: new Set(), + + /** Map of active WebSocket subscriptions with their details */ + activeSubscriptions: new Map(), + + /** Map to track pending unsubscribe operations for debouncing */ + pendingUnsubscribes: new Map(), + + /** + * Creates a unique key for identifying WebSocket subscriptions + * @param clusterId - Cluster identifier + * @param path - API resource path + * @param query - Query parameters + * @returns Unique subscription key + */ + createKey(clusterId: string, path: string, query: string): string { + return `${clusterId}:${path}:${query}`; + }, + + /** + * Establishes or returns an existing WebSocket connection. + * + * This implementation uses a polling approach to handle concurrent connection attempts. + * While not ideal, it's a simple solution that works for most cases. + * + * Known limitations: + * 1. Polls every 100ms which may not be optimal for performance + * 2. No timeout - could theoretically run forever if connection never opens + * 3. May miss state changes that happen between polls + * + * A more robust solution would use event listeners and Promise caching, + * but that adds complexity and potential race conditions to handle. + * The current polling approach, while not perfect, is simple and mostly reliable. + * + * @returns Promise resolving to WebSocket connection + */ + async connect(): Promise { + // Return existing connection if available + if (this.socketMultiplexer?.readyState === WebSocket.OPEN) { + return this.socketMultiplexer; + } + + // Wait for existing connection attempt if in progress + if (this.connecting) { + return new Promise(resolve => { + const checkConnection = setInterval(() => { + if (this.socketMultiplexer?.readyState === WebSocket.OPEN) { + clearInterval(checkConnection); + resolve(this.socketMultiplexer); + } + }, 100); + }); + } + + this.connecting = true; + const wsUrl = `${BASE_WS_URL}${MULTIPLEXER_ENDPOINT}`; + + return new Promise((resolve, reject) => { + const socket = new WebSocket(wsUrl); + + socket.onopen = () => { + this.socketMultiplexer = socket; + this.connecting = false; + + // Only resubscribe if we're reconnecting after a disconnect + if (this.isReconnecting) { + this.resubscribeAll(socket); + } + this.isReconnecting = false; + + resolve(socket); + }; + + socket.onmessage = this.handleWebSocketMessage.bind(this); + + socket.onerror = event => { + console.error('WebSocket error:', event); + this.connecting = false; + reject(new Error('WebSocket connection failed')); + }; + + socket.onclose = () => { + this.handleWebSocketClose(); + }; + }); + }, + + /** + * Resubscribes all active subscriptions to a new socket + * @param socket - WebSocket connection to subscribe to + */ + resubscribeAll(socket: WebSocket): void { + this.activeSubscriptions.forEach(({ clusterId, path, query }) => { + const userId = getUserIdFromLocalStorage(); + const requestMsg: WebSocketMessage = { + clusterId, + path, + query, + userId: userId || '', + type: 'REQUEST', + }; + socket.send(JSON.stringify(requestMsg)); + }); + }, + + /** + * Subscribe to WebSocket updates for a specific resource + * @param clusterId - Cluster identifier + * @param path - API resource path + * @param query - Query parameters + * @param onMessage - Callback for handling incoming messages + * @returns Promise resolving to cleanup function + */ + async subscribe( + clusterId: string, + path: string, + query: string, + onMessage: (data: any) => void + ): Promise<() => void> { + const key = this.createKey(clusterId, path, query); + + // Add to active subscriptions + this.activeSubscriptions.set(key, { clusterId, path, query }); + + // Add message listener + const listeners = this.listeners.get(key) || new Set(); + listeners.add(onMessage); + this.listeners.set(key, listeners); + + // Establish connection and send REQUEST + const socket = await this.connect(); + const userId = getUserIdFromLocalStorage(); + const requestMsg: WebSocketMessage = { + clusterId, + path, + query, + userId: userId || '', + type: 'REQUEST', + }; + socket.send(JSON.stringify(requestMsg)); + + // Return cleanup function + return () => this.unsubscribe(key, clusterId, path, query, onMessage); + }, + + /** + * Unsubscribes from WebSocket updates with debouncing to prevent rapid subscribe/unsubscribe cycles. + * + * State Management: + * - Manages pendingUnsubscribes: Map of timeouts for delayed unsubscription + * - Manages listeners: Map of message handlers for each subscription + * - Manages activeSubscriptions: Set of currently active WebSocket subscriptions + * - Manages completedPaths: Set of paths that have completed their initial data fetch + * + * Debouncing Logic: + * 1. Clears any pending unsubscribe timeout for the subscription + * 2. Removes the message handler from listeners + * 3. If no listeners remain, sets a timeout before actually unsubscribing + * 4. Only sends CLOSE message if no new listeners are added during timeout + * + * @param key - Subscription key that uniquely identifies this subscription + * @param clusterId - Cluster identifier for routing to correct cluster + * @param path - API resource path being watched + * @param query - Query parameters for filtering + * @param onMessage - Message handler to remove from subscription + */ + unsubscribe( + key: string, + clusterId: string, + path: string, + query: string, + onMessage: (data: any) => void + ): void { + // Clear any pending unsubscribe for this key + const pendingTimeout = this.pendingUnsubscribes.get(key); + if (pendingTimeout) { + clearTimeout(pendingTimeout); + this.pendingUnsubscribes.delete(key); + } + + // Remove the listener + const listeners = this.listeners.get(key); + if (listeners) { + listeners.delete(onMessage); + if (listeners.size === 0) { + this.listeners.delete(key); + + // Delay unsubscription to handle rapid re-subscriptions + // This prevents unnecessary WebSocket churn when a component quickly unmounts and remounts + // For example: during route changes or component updates in React's strict mode + const timeout = setTimeout(() => { + // Only unsubscribe if there are still no listeners + if (!this.listeners.has(key)) { + this.activeSubscriptions.delete(key); + this.completedPaths.delete(key); + + if (this.socketMultiplexer?.readyState === WebSocket.OPEN) { + const userId = getUserIdFromLocalStorage(); + const closeMsg: WebSocketMessage = { + clusterId, + path, + query, + userId: userId || '', + type: 'CLOSE', + }; + this.socketMultiplexer.send(JSON.stringify(closeMsg)); + } + } + this.pendingUnsubscribes.delete(key); + }, 100); // 100ms debounce + + this.pendingUnsubscribes.set(key, timeout); + } + } + }, + + /** + * Handles WebSocket connection close event + * Sets up state for potential reconnection + */ + handleWebSocketClose(): void { + this.socketMultiplexer = null; + this.connecting = false; + this.completedPaths.clear(); + + // Set reconnecting flag if we have active subscriptions + this.isReconnecting = this.activeSubscriptions.size > 0; + }, + + /** + * Handles incoming WebSocket messages + * Processes different message types and notifies appropriate listeners + * @param event - WebSocket message event + */ + handleWebSocketMessage(event: MessageEvent): void { + try { + const data = JSON.parse(event.data); + if (!data.clusterId || !data.path) { + return; + } + + const key = this.createKey(data.clusterId, data.path, data.query || ''); + + // Handle COMPLETE messages + if (data.type === 'COMPLETE') { + this.completedPaths.add(key); + return; + } + + // Skip if path is already completed + if (this.completedPaths.has(key)) { + return; + } + + // Parse and validate update data + let update; + try { + update = data.data ? JSON.parse(data.data) : data; + } catch (err) { + console.error('Failed to parse update data:', err); + return; + } + + // Notify listeners if update is valid + if (update && typeof update === 'object') { + const listeners = this.listeners.get(key); + if (listeners) { + listeners.forEach(listener => listener(update)); + } + } + } catch (err) { + console.error('Failed to process WebSocket message:', err); + } + }, +}; + +/** + * Configuration for establishing a WebSocket connection to watch Kubernetes resources. + * Used by the multiplexer to manage multiple WebSocket connections efficiently. + * + * @template T The expected type of data that will be received over the WebSocket + */ +export type WebSocketConnectionRequest = { + /** + * The Kubernetes cluster identifier to connect to. + * Used for routing WebSocket messages in multi-cluster environments. + */ + cluster: string; + + /** + * The WebSocket endpoint URL to connect to. + * Should be a full URL including protocol and any query parameters. + * Example: 'https://cluster.example.com/api/v1/pods/watch' + */ + url: string; + + /** + * Callback function that handles incoming messages from the WebSocket. + * @param data The message payload, typed as T (e.g., K8s Pod, Service, etc.) + */ + onMessage: (data: T) => void; +}; + +/** + * React hook for WebSocket subscription to Kubernetes resources + * @template T - Type of data expected from the WebSocket + * @param options - Configuration options for the WebSocket connection + * @param options.url - Function that returns the WebSocket URL to connect to + * @param options.enabled - Whether the WebSocket connection should be active + * @param options.cluster - The Kubernetes cluster ID to watch + * @param options.onMessage - Callback function to handle incoming messages + * @param options.onError - Callback function to handle connection errors + */ +export function useWebSocket({ + url: createUrl, + enabled = true, + cluster = '', + onMessage, + onError, +}: { + /** Function that returns the WebSocket URL to connect to */ + url: () => string; + /** Whether the WebSocket connection should be active */ + enabled?: boolean; + /** The Kubernetes cluster ID to watch */ + cluster?: string; + /** Callback function to handle incoming messages */ + onMessage: (data: T) => void; + /** Callback function to handle connection errors */ + onError?: (error: Error) => void; +}) { + const url = useMemo(() => (enabled ? createUrl() : ''), [enabled, createUrl]); + + const stableOnMessage = useCallback( + (rawData: any) => { + try { + let parsedData: T; + try { + parsedData = typeof rawData === 'string' ? JSON.parse(rawData) : rawData; + } catch (parseError) { + console.error('Failed to parse WebSocket message:', parseError); + onError?.(parseError as Error); + return; + } + + onMessage(parsedData); + } catch (err) { + console.error('Failed to process WebSocket message:', err); + onError?.(err as Error); + } + }, + [onMessage, onError] + ); + + useEffect(() => { + if (!enabled || !url) { + return; + } + + let cleanup: (() => void) | undefined; + + const connectWebSocket = async () => { + try { + const parsedUrl = new URL(url); + cleanup = await WebSocketManager.subscribe( + cluster, + parsedUrl.pathname, + parsedUrl.search.slice(1), + stableOnMessage + ); + } catch (err) { + console.error('WebSocket connection failed:', err); + onError?.(err as Error); + } + }; + + connectWebSocket(); + + return () => { + if (cleanup) { + cleanup(); + } + }; + }, [url, enabled, cluster, stableOnMessage, onError]); +} + +/** + * Keeps track of open WebSocket connections and active listeners + */ +const sockets = new Map(); +const listeners = new Map void>>(); /** * Create new WebSocket connection to the backend @@ -78,60 +535,6 @@ export async function openWebSocket( return socket; } -// Global state for useWebSocket hook -// Keeps track of open WebSocket connections and active listeners -const sockets = new Map(); -const listeners = new Map void>>(); - -/** - * Creates or joins existing WebSocket connection - * - * @param url - endpoint URL - * @param options - WebSocket options - */ -export function useWebSocket({ - url: createUrl, - enabled = true, - protocols, - type = 'json', - cluster, - onMessage, -}: { - url: () => string; - enabled?: boolean; - /** - * Any additional protocols to include in WebSocket connection - */ - protocols?: string | string[]; - /** - * Type of websocket data - */ - type?: 'json' | 'binary'; - /** - * Cluster name - */ - cluster?: string; - /** - * Message callback - */ - onMessage: (data: T) => void; -}) { - const url = useMemo(() => (enabled ? createUrl() : ''), [enabled]); - const connections = useMemo(() => [{ cluster: cluster ?? '', url, onMessage }], [cluster, url]); - - return useWebSockets({ - connections, - protocols, - type, - }); -} - -export type WebSocketConnectionRequest = { - cluster: string; - url: string; - onMessage: (data: T) => void; -}; - /** * Creates or joins mutiple existing WebSocket connections * diff --git a/frontend/src/plugin/__snapshots__/pluginLib.snapshot b/frontend/src/plugin/__snapshots__/pluginLib.snapshot index dafdc1d599..d5206ac49c 100644 --- a/frontend/src/plugin/__snapshots__/pluginLib.snapshot +++ b/frontend/src/plugin/__snapshots__/pluginLib.snapshot @@ -16114,4 +16114,4 @@ "registerSidebarEntry": [Function], "registerSidebarEntryFilter": [Function], "runCommand": [Function], -} \ No newline at end of file +}