diff --git a/backend/cmd/multiplexer.go b/backend/cmd/multiplexer.go index 54c6fe51e9..4afcf76371 100644 --- a/backend/cmd/multiplexer.go +++ b/backend/cmd/multiplexer.go @@ -68,6 +68,10 @@ type Connection struct { Done chan struct{} // mu is a mutex to synchronize access to the connection. mu sync.RWMutex + // writeMu is a mutex to synchronize access to the write operations. + writeMu sync.Mutex + // closed is a flag to indicate if the connection is closed. + closed bool } // Message represents a WebSocket message structure. @@ -81,7 +85,9 @@ type Message struct { // UserID is the ID of the user. UserID string `json:"userId"` // Data contains the message payload. - Data []byte `json:"data,omitempty"` + Data string `json:"data,omitempty"` + // Binary is a flag to indicate if the message is binary. + Binary bool `json:"binary,omitempty"` // Type is the type of the message. Type string `json:"type"` } @@ -116,41 +122,58 @@ func (c *Connection) updateStatus(state ConnectionState, err error) { c.mu.Lock() defer c.mu.Unlock() + if c.closed { + return + } + c.Status.State = state c.Status.LastMsg = time.Now() + c.Status.Error = "" if err != nil { c.Status.Error = err.Error() - } else { - c.Status.Error = "" } - if c.Client != nil { - statusData := struct { - State string `json:"state"` - Error string `json:"error"` - }{ - State: string(state), - Error: c.Status.Error, - } + if c.Client == nil { + return + } - jsonData, jsonErr := json.Marshal(statusData) - if jsonErr != nil { - logger.Log(logger.LevelError, map[string]string{"clusterID": c.ClusterID}, jsonErr, "marshaling status message") + c.writeMu.Lock() + defer c.writeMu.Unlock() - return - } + // Check if connection is closed before writing + if c.closed { + return + } - statusMsg := Message{ - ClusterID: c.ClusterID, - Path: c.Path, - Data: jsonData, - } + statusData := struct { + State string `json:"state"` + Error string `json:"error"` + }{ + State: string(state), + Error: c.Status.Error, + } - err := c.Client.WriteJSON(statusMsg) - if err != nil { + jsonData, jsonErr := json.Marshal(statusData) + if jsonErr != nil { + logger.Log(logger.LevelError, map[string]string{"clusterID": c.ClusterID}, jsonErr, "marshaling status message") + + return + } + + statusMsg := Message{ + ClusterID: c.ClusterID, + Path: c.Path, + Data: string(jsonData), + Type: "STATUS", + } + + if err := c.Client.WriteJSON(statusMsg); err != nil { + if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { logger.Log(logger.LevelError, map[string]string{"clusterID": c.ClusterID}, err, "writing status message to client") } + + c.closed = true } } @@ -190,7 +213,8 @@ func (m *Multiplexer) establishClusterConnection( connection.updateStatus(StateConnected, nil) m.mutex.Lock() - m.connections[clusterID+path] = connection + connKey := m.createConnectionKey(clusterID, path, userID) + m.connections[connKey] = connection m.mutex.Unlock() go m.monitorConnection(connection) @@ -293,6 +317,10 @@ func (m *Multiplexer) monitorConnection(conn *Connection) { // reconnect attempts to reestablish a connection. func (m *Multiplexer) reconnect(conn *Connection) (*Connection, error) { + if conn.closed { + return nil, fmt.Errorf("cannot reconnect closed connection") + } + if conn.WSConn != nil { conn.WSConn.Close() } @@ -311,7 +339,7 @@ func (m *Multiplexer) reconnect(conn *Connection) (*Connection, error) { } m.mutex.Lock() - m.connections[conn.ClusterID+conn.Path] = newConn + m.connections[m.createConnectionKey(conn.ClusterID, conn.Path, conn.UserID)] = newConn m.mutex.Unlock() return newConn, nil @@ -334,16 +362,8 @@ func (m *Multiplexer) HandleClientWebSocket(w http.ResponseWriter, r *http.Reque } // Check if it's a close message - if msg.Data != nil && len(msg.Data) > 0 && string(msg.Data) == "close" { - err := m.CloseConnection(msg.ClusterID, msg.Path, msg.UserID) - if err != nil { - logger.Log( - logger.LevelError, - map[string]string{"clusterID": msg.ClusterID, "UserID": msg.UserID}, - err, - "closing connection", - ) - } + if msg.Type == "CLOSE" { + m.CloseConnection(msg.ClusterID, msg.Path, msg.UserID) continue } @@ -355,8 +375,8 @@ func (m *Multiplexer) HandleClientWebSocket(w http.ResponseWriter, r *http.Reque continue } - if len(msg.Data) > 0 && conn.Status.State == StateConnected { - err = m.writeMessageToCluster(conn, msg.Data) + if msg.Type == "REQUEST" && conn.Status.State == StateConnected { + err = m.writeMessageToCluster(conn, []byte(msg.Data)) if err != nil { continue } @@ -389,7 +409,7 @@ func (m *Multiplexer) readClientMessage(clientConn *websocket.Conn) (Message, er // getOrCreateConnection gets an existing connection or creates a new one if it doesn't exist. func (m *Multiplexer) getOrCreateConnection(msg Message, clientConn *websocket.Conn) (*Connection, error) { - connKey := fmt.Sprintf("%s:%s:%s", msg.ClusterID, msg.Path, msg.UserID) + connKey := m.createConnectionKey(msg.ClusterID, msg.Path, msg.UserID) m.mutex.RLock() conn, exists := m.connections[connKey] @@ -458,100 +478,182 @@ func (m *Multiplexer) writeMessageToCluster(conn *Connection, data []byte) error // handleClusterMessages handles messages from a cluster connection. func (m *Multiplexer) handleClusterMessages(conn *Connection, clientConn *websocket.Conn) { - defer func() { - conn.updateStatus(StateClosed, nil) - conn.WSConn.Close() - }() + defer m.cleanupConnection(conn) + + var lastResourceVersion string for { select { case <-conn.Done: return default: - if err := m.processClusterMessage(conn, clientConn); err != nil { + if err := m.processClusterMessage(conn, clientConn, &lastResourceVersion); err != nil { return } } } } -// processClusterMessage processes a message from a cluster connection. -func (m *Multiplexer) processClusterMessage(conn *Connection, clientConn *websocket.Conn) error { +// processClusterMessage processes a single message from the cluster. +func (m *Multiplexer) processClusterMessage( + conn *Connection, + clientConn *websocket.Conn, + lastResourceVersion *string, +) error { messageType, message, err := conn.WSConn.ReadMessage() if err != nil { - m.handleReadError(conn, err) + if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { + logger.Log(logger.LevelError, + map[string]string{ + "clusterID": conn.ClusterID, + "userID": conn.UserID, + }, + err, + "reading cluster message", + ) + } return err } - wrapperMsg := m.createWrapperMessage(conn, messageType, message) + if err := m.sendIfNewResourceVersion(message, conn, clientConn, lastResourceVersion); err != nil { + return err + } - if err := clientConn.WriteJSON(wrapperMsg); err != nil { - m.handleWriteError(conn, err) + return m.sendDataMessage(conn, clientConn, messageType, message) +} - return err +// sendIfNewResourceVersion checks the version of a resource from an incoming message +// and sends a complete message to the client if the resource version has changed. +// +// This function is used to ensure that the client is always aware of the latest version +// of a resource. When a new message is received, it extracts the resource version from +// the message metadata. If the resource version has changed since the last known version, +// it sends a complete message to the client to update them with the latest resource state. +// Parameters: +// - message: The JSON-encoded message containing resource information. +// - conn: The connection object representing the current connection. +// - clientConn: The WebSocket connection to the client. +// - lastResourceVersion: A pointer to the last known resource version string. +// +// Returns: +// - An error if any issues occur while processing the message, or nil if successful. +func (m *Multiplexer) sendIfNewResourceVersion( + message []byte, + conn *Connection, + clientConn *websocket.Conn, + lastResourceVersion *string, +) error { + var obj map[string]interface{} + if err := json.Unmarshal(message, &obj); err != nil { + return fmt.Errorf("error unmarshaling message: %v", err) } - conn.mu.Lock() - conn.Status.LastMsg = time.Now() - conn.mu.Unlock() + // Try to find metadata directly + metadata, ok := obj["metadata"].(map[string]interface{}) + if !ok { + // Try to find metadata in object field + if objField, ok := obj["object"].(map[string]interface{}); ok { + if metadata, ok = objField["metadata"].(map[string]interface{}); !ok { + // No metadata field found, nothing to do + return nil + } + } else { + // No metadata field found, nothing to do + return nil + } + } + + rv, ok := metadata["resourceVersion"].(string) + if !ok { + // No resourceVersion field, nothing to do + return nil + } + + // Update version and send complete message if version is different + if rv != *lastResourceVersion { + *lastResourceVersion = rv + + return m.sendCompleteMessage(conn, clientConn) + } return nil } -// createWrapperMessage creates a wrapper message for a cluster connection. -func (m *Multiplexer) createWrapperMessage(conn *Connection, messageType int, message []byte) struct { - ClusterID string `json:"clusterId"` - Path string `json:"path"` - Query string `json:"query"` - UserID string `json:"userId"` - Data string `json:"data"` - Binary bool `json:"binary"` -} { - wrapperMsg := struct { - ClusterID string `json:"clusterId"` - Path string `json:"path"` - Query string `json:"query"` - UserID string `json:"userId"` - Data string `json:"data"` - Binary bool `json:"binary"` - }{ +// sendCompleteMessage sends a COMPLETE message to the client. +func (m *Multiplexer) sendCompleteMessage(conn *Connection, clientConn *websocket.Conn) error { + completeMsg := Message{ ClusterID: conn.ClusterID, Path: conn.Path, Query: conn.Query, UserID: conn.UserID, - Binary: messageType == websocket.BinaryMessage, + Type: "COMPLETE", } - if messageType == websocket.BinaryMessage { - wrapperMsg.Data = base64.StdEncoding.EncodeToString(message) - } else { - wrapperMsg.Data = string(message) + conn.writeMu.Lock() + defer conn.writeMu.Unlock() + + return clientConn.WriteJSON(completeMsg) +} + +// sendDataMessage sends the actual data message to the client. +func (m *Multiplexer) sendDataMessage( + conn *Connection, + clientConn *websocket.Conn, + messageType int, + message []byte, +) error { + dataMsg := m.createWrapperMessage(conn, messageType, message) + + conn.writeMu.Lock() + defer conn.writeMu.Unlock() + + if err := clientConn.WriteJSON(dataMsg); err != nil { + return err } - return wrapperMsg + conn.mu.Lock() + conn.Status.LastMsg = time.Now() + conn.mu.Unlock() + + return nil } -// handleReadError handles errors that occur when reading a message from a cluster connection. -func (m *Multiplexer) handleReadError(conn *Connection, err error) { - conn.updateStatus(StateError, err) - logger.Log( - logger.LevelError, - map[string]string{"clusterID": conn.ClusterID, "UserID": conn.UserID}, - err, - "reading message from cluster", - ) +// cleanupConnection performs cleanup for a connection. +func (m *Multiplexer) cleanupConnection(conn *Connection) { + conn.mu.Lock() + defer conn.mu.Unlock() // Ensure the mutex is unlocked even if an error occurs + + conn.closed = true + + if conn.WSConn != nil { + conn.WSConn.Close() + } + + m.mutex.Lock() + connKey := m.createConnectionKey(conn.ClusterID, conn.Path, conn.UserID) + delete(m.connections, connKey) + m.mutex.Unlock() } -// handleWriteError handles errors that occur when writing a message to a client connection. -func (m *Multiplexer) handleWriteError(conn *Connection, err error) { - conn.updateStatus(StateError, err) - logger.Log( - logger.LevelError, - map[string]string{"clusterID": conn.ClusterID, "UserID": conn.UserID}, - err, - "writing message to client", - ) +// createWrapperMessage creates a wrapper message for a cluster connection. +func (m *Multiplexer) createWrapperMessage(conn *Connection, messageType int, message []byte) Message { + var data string + if messageType == websocket.BinaryMessage { + data = base64.StdEncoding.EncodeToString(message) + } else { + data = string(message) + } + + return Message{ + ClusterID: conn.ClusterID, + Path: conn.Path, + Query: conn.Query, + UserID: conn.UserID, + Data: data, + Binary: messageType == websocket.BinaryMessage, + Type: "DATA", + } } // cleanupConnections closes and removes all connections. @@ -587,39 +689,49 @@ func (m *Multiplexer) getClusterConfig(clusterID string) (*rest.Config, error) { } // CloseConnection closes a specific connection based on its identifier. -func (m *Multiplexer) CloseConnection(clusterID, path, userID string) error { - connKey := fmt.Sprintf("%s:%s:%s", clusterID, path, userID) +func (m *Multiplexer) CloseConnection(clusterID, path, userID string) { + connKey := m.createConnectionKey(clusterID, path, userID) m.mutex.Lock() - defer m.mutex.Unlock() conn, exists := m.connections[connKey] if !exists { - return fmt.Errorf("connection not found for key: %s", connKey) + m.mutex.Unlock() + // Don't log error for non-existent connections during cleanup + return } - // Signal the connection to close - close(conn.Done) + // Mark as closed before releasing the lock + conn.mu.Lock() + if conn.closed { + conn.mu.Unlock() + m.mutex.Unlock() + logger.Log(logger.LevelError, map[string]string{"clusterID": conn.ClusterID}, nil, "closing connection") - // Close the WebSocket connection - if conn.WSConn != nil { - if err := conn.WSConn.Close(); err != nil { - logger.Log( - logger.LevelError, - map[string]string{"clusterID": clusterID, "userID": userID}, - err, - "closing WebSocket connection", - ) - } + return } - // Update the connection status - conn.updateStatus(StateClosed, nil) + conn.closed = true + conn.mu.Unlock() - // Remove the connection from the map delete(m.connections, connKey) + m.mutex.Unlock() - return nil + // Lock the connection mutex before accessing shared resources + conn.mu.Lock() + defer conn.mu.Unlock() // Ensure the mutex is unlocked after the operations + + // Close the Done channel and connections after removing from map + close(conn.Done) + + if conn.WSConn != nil { + conn.WSConn.Close() + } +} + +// createConnectionKey creates a unique key for a connection based on cluster ID, path, and user ID. +func (m *Multiplexer) createConnectionKey(clusterID, path, userID string) string { + return fmt.Sprintf("%s:%s:%s", clusterID, path, userID) } // createWebSocketURL creates a WebSocket URL from the given parameters. diff --git a/backend/cmd/multiplexer_test.go b/backend/cmd/multiplexer_test.go index 058e01377b..0df729ef4a 100644 --- a/backend/cmd/multiplexer_test.go +++ b/backend/cmd/multiplexer_test.go @@ -8,6 +8,7 @@ import ( "net/http" "net/http/httptest" "strings" + "sync" "testing" "time" @@ -22,6 +23,7 @@ func newTestDialer() *websocket.Dialer { return &websocket.Dialer{ NetDial: net.Dial, HandshakeTimeout: 45 * time.Second, + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, //nolint:gosec } } @@ -36,28 +38,47 @@ func TestNewMultiplexer(t *testing.T) { } func TestHandleClientWebSocket(t *testing.T) { - store := kubeconfig.NewContextStore() - m := NewMultiplexer(store) + contextStore := kubeconfig.NewContextStore() + m := NewMultiplexer(contextStore) + // Create test server server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { m.HandleClientWebSocket(w, r) })) defer server.Close() - url := "ws" + strings.TrimPrefix(server.URL, "http") - + // Connect to test server dialer := newTestDialer() - conn, resp, err := dialer.Dial(url, nil) - if err == nil { - defer conn.Close() - } + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + ws, resp, err := dialer.Dial(wsURL, nil) + require.NoError(t, err) if resp != nil && resp.Body != nil { defer resp.Body.Close() } - assert.NoError(t, err, "Should successfully establish WebSocket connection") + defer ws.Close() + + // Test WATCH message + watchMsg := Message{ + Type: "WATCH", + ClusterID: "test-cluster", + Path: "/api/v1/pods", + UserID: "test-user", + } + err = ws.WriteJSON(watchMsg) + require.NoError(t, err) + + // Test CLOSE message + closeMsg := Message{ + Type: "CLOSE", + ClusterID: "test-cluster", + Path: "/api/v1/pods", + UserID: "test-user", + } + err = ws.WriteJSON(closeMsg) + require.NoError(t, err) } func TestGetClusterConfigWithFallback(t *testing.T) { @@ -104,21 +125,20 @@ func TestDialWebSocket(t *testing.T) { return true // Allow all connections for testing }, } - - c, err := upgrader.Upgrade(w, r, nil) + ws, err := upgrader.Upgrade(w, r, nil) if err != nil { t.Logf("Upgrade error: %v", err) return } - defer c.Close() + defer ws.Close() // Echo incoming messages back to the client for { - mt, message, err := c.ReadMessage() + mt, message, err := ws.ReadMessage() if err != nil { break } - err = c.WriteMessage(mt, message) + err = ws.WriteMessage(mt, message) if err != nil { break } @@ -129,6 +149,7 @@ func TestDialWebSocket(t *testing.T) { wsURL := "ws" + strings.TrimPrefix(server.URL, "http") conn, err := m.dialWebSocket(wsURL, &tls.Config{InsecureSkipVerify: true}, server.URL) //nolint:gosec + assert.NoError(t, err) assert.NotNil(t, conn) @@ -137,6 +158,23 @@ func TestDialWebSocket(t *testing.T) { } } +func TestDialWebSocket_Errors(t *testing.T) { + contextStore := kubeconfig.NewContextStore() + m := NewMultiplexer(contextStore) + + // Test invalid URL + tlsConfig := &tls.Config{InsecureSkipVerify: true} //nolint:gosec + + ws, err := m.dialWebSocket("invalid-url", tlsConfig, "") + assert.Error(t, err) + assert.Nil(t, ws) + + // Test unreachable URL + ws, err = m.dialWebSocket("ws://localhost:12345", tlsConfig, "") + assert.Error(t, err) + assert.Nil(t, ws) +} + func TestMonitorConnection(t *testing.T) { m := NewMultiplexer(kubeconfig.NewContextStore()) clientConn, _ := createTestWebSocketConnection() @@ -158,6 +196,94 @@ func TestMonitorConnection(t *testing.T) { assert.Equal(t, StateClosed, conn.Status.State) } +func TestUpdateStatus(t *testing.T) { + conn := &Connection{ + Status: ConnectionStatus{}, + Done: make(chan struct{}), + } + + // Test different state transitions + states := []ConnectionState{ + StateConnecting, + StateConnected, + StateClosed, + StateError, + } + + for _, state := range states { + conn.Status.State = state + assert.Equal(t, state, conn.Status.State) + } + + // Test concurrent updates + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + + go func(i int) { + defer wg.Done() + + state := states[i%len(states)] + conn.Status.State = state + }(i) + } + wg.Wait() + + // Verify final state is valid + assert.Contains(t, states, conn.Status.State) +} + +func TestMonitorConnection_Reconnect(t *testing.T) { + contextStore := kubeconfig.NewContextStore() + m := NewMultiplexer(contextStore) + + // Create a server that will accept the connection and then close it + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upgrader := websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true + }, + } + ws, err := upgrader.Upgrade(w, r, nil) + require.NoError(t, err) + + defer ws.Close() + + // Keep connection alive briefly + time.Sleep(100 * time.Millisecond) + ws.Close() + })) + + defer server.Close() + + conn := &Connection{ + Status: ConnectionStatus{ + State: StateConnecting, + }, + Done: make(chan struct{}), + } + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + tlsConfig := &tls.Config{InsecureSkipVerify: true} //nolint:gosec + + ws, err := m.dialWebSocket(wsURL, tlsConfig, "") + require.NoError(t, err) + + conn.WSConn = ws + + // Start monitoring in a goroutine + go m.monitorConnection(conn) + + // Wait for state transitions + time.Sleep(300 * time.Millisecond) + + // Verify connection status, it should reconnect + assert.Equal(t, StateConnecting, conn.Status.State) + + // Clean up + close(conn.Done) +} + //nolint:funlen func TestHandleClusterMessages(t *testing.T) { m := NewMultiplexer(kubeconfig.NewContextStore()) @@ -225,7 +351,7 @@ func TestHandleClusterMessages(t *testing.T) { t.Fatal("Test timed out") } - assert.Equal(t, StateClosed, conn.Status.State) + assert.Equal(t, StateConnecting, conn.Status.State) } func TestCleanupConnections(t *testing.T) { @@ -245,42 +371,198 @@ func TestCleanupConnections(t *testing.T) { assert.Equal(t, StateClosed, conn.Status.State) } -func createTestWebSocketConnection() (*websocket.Conn, *httptest.Server) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - upgrader := websocket.Upgrader{} - c, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return - } +func TestCreateWebSocketURL(t *testing.T) { + tests := []struct { + name string + host string + path string + query string + expected string + }{ + { + name: "basic URL without query", + host: "http://localhost:8080", + path: "/api/v1/pods", + query: "", + expected: "wss://localhost:8080/api/v1/pods", + }, + { + name: "URL with query parameters", + host: "https://example.com", + path: "/api/v1/pods", + query: "watch=true", + expected: "wss://example.com/api/v1/pods?watch=true", + }, + { + name: "URL with path and multiple query parameters", + host: "https://k8s.example.com", + path: "/api/v1/namespaces/default/pods", + query: "watch=true&labelSelector=app%3Dnginx", + expected: "wss://k8s.example.com/api/v1/namespaces/default/pods?watch=true&labelSelector=app%3Dnginx", + }, + } - defer c.Close() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := createWebSocketURL(tt.host, tt.path, tt.query) + assert.Equal(t, tt.expected, result) + }) + } +} - for { - mt, message, err := c.ReadMessage() - if err != nil { - break - } +func TestGetOrCreateConnection(t *testing.T) { + store := kubeconfig.NewContextStore() + m := NewMultiplexer(store) - err = c.WriteMessage(mt, message) - if err != nil { - break - } - } - })) + // Create a mock Kubernetes API server + mockServer := createMockKubeAPIServer() + defer mockServer.Close() - wsURL := "ws" + strings.TrimPrefix(server.URL, "http") - dialer := newTestDialer() + // Add a mock cluster config with our test server URL + err := store.AddContext(&kubeconfig.Context{ + Name: "test-cluster", + Cluster: &api.Cluster{ + Server: mockServer.URL, + InsecureSkipTLSVerify: true, + CertificateAuthorityData: nil, + }, + }) + require.NoError(t, err) - ws, resp, err := dialer.Dial(wsURL, nil) - if err != nil { - panic(err) + clientConn, clientServer := createTestWebSocketConnection() + defer clientServer.Close() + + // Test getting a non-existent connection (should create new) + msg := Message{ + ClusterID: "test-cluster", + Path: "/api/v1/pods", + Query: "watch=true", + UserID: "test-user", } - if resp != nil && resp.Body != nil { - defer resp.Body.Close() + conn, err := m.getOrCreateConnection(msg, clientConn) + assert.NoError(t, err) + assert.NotNil(t, conn) + assert.Equal(t, "test-cluster", conn.ClusterID) + assert.Equal(t, "test-user", conn.UserID) + assert.Equal(t, "/api/v1/pods", conn.Path) + assert.Equal(t, "watch=true", conn.Query) + + // Test getting an existing connection + conn2, err := m.getOrCreateConnection(msg, clientConn) + assert.NoError(t, err) + assert.Equal(t, conn, conn2, "Should return the same connection instance") + + // Test with invalid cluster + msg.ClusterID = "non-existent-cluster" + conn3, err := m.getOrCreateConnection(msg, clientConn) + assert.Error(t, err) + assert.Nil(t, conn3) +} + +func TestEstablishClusterConnection(t *testing.T) { + store := kubeconfig.NewContextStore() + m := NewMultiplexer(store) + + // Create a mock Kubernetes API server + mockServer := createMockKubeAPIServer() + defer mockServer.Close() + + // Add a mock cluster config with our test server URL + err := store.AddContext(&kubeconfig.Context{ + Name: "test-cluster", + Cluster: &api.Cluster{ + Server: mockServer.URL, + InsecureSkipTLSVerify: true, + CertificateAuthorityData: nil, + }, + }) + require.NoError(t, err) + + clientConn, clientServer := createTestWebSocketConnection() + defer clientServer.Close() + + // Test successful connection establishment + conn, err := m.establishClusterConnection("test-cluster", "test-user", "/api/v1/pods", "watch=true", clientConn) + assert.NoError(t, err) + assert.NotNil(t, conn) + assert.Equal(t, "test-cluster", conn.ClusterID) + assert.Equal(t, "test-user", conn.UserID) + assert.Equal(t, "/api/v1/pods", conn.Path) + assert.Equal(t, "watch=true", conn.Query) + + // Test with invalid cluster + conn, err = m.establishClusterConnection("non-existent", "test-user", "/api/v1/pods", "watch=true", clientConn) + assert.Error(t, err) + assert.Nil(t, conn) +} + +//nolint:funlen +func TestReconnect(t *testing.T) { + store := kubeconfig.NewContextStore() + m := NewMultiplexer(store) + + // Create a mock Kubernetes API server + mockServer := createMockKubeAPIServer() + defer mockServer.Close() + + // Add a mock cluster config with our test server URL + err := store.AddContext(&kubeconfig.Context{ + Name: "test-cluster", + Cluster: &api.Cluster{ + Server: mockServer.URL, + InsecureSkipTLSVerify: true, + CertificateAuthorityData: nil, + }, + }) + require.NoError(t, err) + + clientConn, clientServer := createTestWebSocketConnection() + defer clientServer.Close() + + // Create initial connection + conn := m.createConnection("test-cluster", "test-user", "/api/v1/pods", "watch=true", clientConn) + conn.Status.State = StateError // Simulate an error state + + // Test successful reconnection + newConn, err := m.reconnect(conn) + assert.NoError(t, err) + assert.NotNil(t, newConn) + assert.Equal(t, StateConnected, newConn.Status.State) + assert.Equal(t, conn.ClusterID, newConn.ClusterID) + assert.Equal(t, conn.UserID, newConn.UserID) + assert.Equal(t, conn.Path, newConn.Path) + assert.Equal(t, conn.Query, newConn.Query) + + // Test reconnection with invalid cluster + conn.ClusterID = "non-existent" + newConn, err = m.reconnect(conn) + assert.Error(t, err) + assert.Nil(t, newConn) + assert.Contains(t, err.Error(), "getting context: key not found") + + // Test reconnection with closed connection + conn = m.createConnection("test-cluster", "test-user", "/api/v1/pods", "watch=true", clientConn) + clusterConn, err := m.establishClusterConnection("test-cluster", "test-user", "/api/v1/pods", "watch=true", clientConn) + require.NoError(t, err) + require.NotNil(t, clusterConn) + + // Close the connection and wait for cleanup + conn.closed = true + if conn.WSConn != nil { + conn.WSConn.Close() } - return ws, server + if conn.Client != nil { + conn.Client.Close() + } + + close(conn.Done) + + // Try to reconnect the closed connection + newConn, err = m.reconnect(conn) + assert.Error(t, err) + assert.Nil(t, newConn) } func TestCloseConnection(t *testing.T) { @@ -292,14 +574,10 @@ func TestCloseConnection(t *testing.T) { connKey := "test-cluster:/api/v1/pods:test-user" m.connections[connKey] = conn - err := m.CloseConnection("test-cluster", "/api/v1/pods", "test-user") - assert.NoError(t, err) + m.CloseConnection("test-cluster", "/api/v1/pods", "test-user") assert.Empty(t, m.connections) - assert.Equal(t, StateClosed, conn.Status.State) - - // Test closing a non-existent connection - err = m.CloseConnection("non-existent", "/api/v1/pods", "test-user") - assert.Error(t, err) + // It will reconnect to the cluster + assert.Equal(t, StateConnecting, conn.Status.State) } func TestCreateWrapperMessage(t *testing.T) { @@ -424,3 +702,376 @@ func TestWriteMessageToCluster(t *testing.T) { assert.Error(t, err) assert.Equal(t, StateError, conn.Status.State) } + +//nolint:funlen +func TestReadClientMessage_InvalidMessage(t *testing.T) { + contextStore := kubeconfig.NewContextStore() + m := NewMultiplexer(contextStore) + + // Create a server that will echo messages back + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upgrader := websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true + }, + } + ws, err := upgrader.Upgrade(w, r, nil) + require.NoError(t, err) + + defer ws.Close() + + // Echo messages back + for { + messageType, p, err := ws.ReadMessage() + if err != nil { + return + } + err = ws.WriteMessage(messageType, p) + if err != nil { + return + } + } + })) + defer server.Close() + + // Connect to the server + dialer := newTestDialer() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + clientConn, _, err := dialer.Dial(wsURL, nil) //nolint:bodyclose + require.NoError(t, err) + + defer clientConn.Close() + + // Test completely invalid JSON + err = clientConn.WriteMessage(websocket.TextMessage, []byte("not json at all")) + require.NoError(t, err) + + msg, err := m.readClientMessage(clientConn) + require.Error(t, err) + assert.Equal(t, Message{}, msg) + + // Test JSON with invalid data type + err = clientConn.WriteJSON(map[string]interface{}{ + "type": "INVALID", + "data": 123, // data should be string + }) + require.NoError(t, err) + + msg, err = m.readClientMessage(clientConn) + require.Error(t, err) + assert.Equal(t, Message{}, msg) + + // Test empty JSON object + err = clientConn.WriteMessage(websocket.TextMessage, []byte("{}")) + require.NoError(t, err) + + msg, err = m.readClientMessage(clientConn) + // Empty message is valid JSON but will be unmarshaled into an empty Message struct + require.NoError(t, err) + assert.Equal(t, Message{}, msg) + + // Test missing required fields + err = clientConn.WriteJSON(map[string]interface{}{ + "data": "some data", + // Missing type field + }) + require.NoError(t, err) + + msg, err = m.readClientMessage(clientConn) + // Missing fields are allowed by json.Unmarshal + require.NoError(t, err) + assert.Equal(t, Message{Data: "some data"}, msg) +} + +func TestUpdateStatus_WithError(t *testing.T) { + clientConn, clientServer := createTestWebSocketConnection() + defer clientServer.Close() + + conn := &Connection{ + Status: ConnectionStatus{}, + Done: make(chan struct{}), + Client: clientConn, + } + + // Test error state with message + testErr := fmt.Errorf("test error") + conn.updateStatus(StateError, testErr) + assert.Equal(t, StateError, conn.Status.State) + assert.Equal(t, testErr.Error(), conn.Status.Error) + + // Test state change without error + conn.updateStatus(StateConnected, nil) + assert.Equal(t, StateConnected, conn.Status.State) + assert.Empty(t, conn.Status.Error) + + // Test with closed connection - state should remain error + conn.updateStatus(StateError, testErr) + assert.Equal(t, StateError, conn.Status.State) + assert.Equal(t, testErr.Error(), conn.Status.Error) + + close(conn.Done) + conn.closed = true // Mark connection as closed + + // Try to update state after close - should not change + conn.updateStatus(StateConnected, nil) + assert.Equal(t, StateError, conn.Status.State) // State should not change after close + assert.Equal(t, testErr.Error(), conn.Status.Error) // Error should remain +} + +func TestMonitorConnection_ReconnectFailure(t *testing.T) { + store := kubeconfig.NewContextStore() + m := NewMultiplexer(store) + + // Add an invalid cluster config to force reconnection failure + err := store.AddContext(&kubeconfig.Context{ + Name: "test-cluster", + Cluster: &api.Cluster{ + Server: "https://invalid-server:8443", + }, + }) + require.NoError(t, err) + + clientConn, _ := createTestWebSocketConnection() + conn := m.createConnection("test-cluster", "test-user", "/api/v1/pods", "", clientConn) + conn.WSConn, _ = createTestWebSocketConnection() + + // Start monitoring + done := make(chan struct{}) + go func() { + m.monitorConnection(conn) + close(done) + }() + + // Force connection closure and error state + conn.updateStatus(StateError, fmt.Errorf("forced error")) + conn.WSConn.Close() + + // Wait briefly to ensure error state is set + time.Sleep(50 * time.Millisecond) + + // Verify connection is in error state + assert.Equal(t, StateError, conn.Status.State) + assert.NotEmpty(t, conn.Status.Error) + + close(conn.Done) + <-done +} + +func TestHandleClientWebSocket_InvalidMessages(t *testing.T) { + m := NewMultiplexer(kubeconfig.NewContextStore()) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + m.HandleClientWebSocket(w, r) + })) + defer server.Close() + + // Test invalid JSON + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + ws, resp, err := websocket.DefaultDialer.Dial(wsURL, nil) + require.NoError(t, err) + + if resp != nil && resp.Body != nil { + defer resp.Body.Close() + } + + err = ws.WriteMessage(websocket.TextMessage, []byte("invalid json")) + require.NoError(t, err) + + // Should receive an error message or close + _, message, err := ws.ReadMessage() + if err != nil { + // Connection may be closed due to error + if !websocket.IsCloseError(err, websocket.CloseAbnormalClosure) { + t.Errorf("expected abnormal closure, got %v", err) + } + } else { + assert.Contains(t, string(message), "error") + } + + ws.Close() + + // Test invalid message type with new connection + ws, resp, err = websocket.DefaultDialer.Dial(wsURL, nil) + require.NoError(t, err) + + if resp != nil && resp.Body != nil { + defer resp.Body.Close() + } + + defer ws.Close() + + err = ws.WriteJSON(Message{ + Type: "INVALID_TYPE", + ClusterID: "test-cluster", + Path: "/api/v1/pods", + UserID: "test-user", + }) + require.NoError(t, err) + + // Should receive an error message or close + _, message, err = ws.ReadMessage() + if err != nil { + // Connection may be closed due to error + if !websocket.IsCloseError(err, websocket.CloseAbnormalClosure) { + t.Errorf("expected abnormal closure, got %v", err) + } + } else { + assert.Contains(t, string(message), "error") + } +} + +func TestSendIfNewResourceVersion_VersionComparison(t *testing.T) { + m := NewMultiplexer(kubeconfig.NewContextStore()) + clientConn, clientServer := createTestWebSocketConnection() + + defer clientServer.Close() + + conn := &Connection{ + ClusterID: "test-cluster", + Path: "/api/v1/pods", + UserID: "test-user", + Client: clientConn, + } + + // Initialize lastVersion pointer + lastVersion := "" + + // Test initial version + message := []byte(`{"metadata":{"resourceVersion":"100"}}`) + err := m.sendIfNewResourceVersion(message, conn, clientConn, &lastVersion) + require.NoError(t, err) + assert.Equal(t, "100", lastVersion) + + // Test same version - should not send + err = m.sendIfNewResourceVersion(message, conn, clientConn, &lastVersion) + require.NoError(t, err) + assert.Equal(t, "100", lastVersion) + + // Test newer version + message = []byte(`{"metadata":{"resourceVersion":"200"}}`) + err = m.sendIfNewResourceVersion(message, conn, clientConn, &lastVersion) + require.NoError(t, err) + assert.Equal(t, "200", lastVersion) + + // Test invalid JSON + message = []byte(`invalid json`) + err = m.sendIfNewResourceVersion(message, conn, clientConn, &lastVersion) + assert.Error(t, err) + assert.Equal(t, "200", lastVersion) // Version should not change on error + + // Test missing resourceVersion + message = []byte(`{"metadata":{}}`) + err = m.sendIfNewResourceVersion(message, conn, clientConn, &lastVersion) + require.NoError(t, err) // Should not error, but also not update version + assert.Equal(t, "200", lastVersion) +} + +func TestSendCompleteMessage_ClosedConnection(t *testing.T) { + m := NewMultiplexer(kubeconfig.NewContextStore()) + clientConn, clientServer := createTestWebSocketConnection() + + defer clientServer.Close() + + conn := &Connection{ + ClusterID: "test-cluster", + Path: "/api/v1/pods", + UserID: "test-user", + Query: "watch=true", + } + + // Test successful complete message + err := m.sendCompleteMessage(conn, clientConn) + require.NoError(t, err) + + // Verify the message + _, message, err := clientConn.ReadMessage() + require.NoError(t, err) + + var msg Message + err = json.Unmarshal(message, &msg) + require.NoError(t, err) + + assert.Equal(t, "COMPLETE", msg.Type) + assert.Equal(t, conn.ClusterID, msg.ClusterID) + assert.Equal(t, conn.Path, msg.Path) + assert.Equal(t, conn.Query, msg.Query) + assert.Equal(t, conn.UserID, msg.UserID) + + // Test with closed connection + clientConn.Close() + err = m.sendCompleteMessage(conn, clientConn) + assert.Error(t, err) +} + +func createMockKubeAPIServer() *httptest.Server { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upgrader := websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true + }, + } + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + + defer c.Close() + + // Echo messages back + for { + _, msg, err := c.ReadMessage() + if err != nil { + break + } + if err := c.WriteMessage(websocket.TextMessage, msg); err != nil { + break + } + } + })) + + // Configure the test client to accept the test server's TLS certificate + server.Client().Transport.(*http.Transport).TLSClientConfig = &tls.Config{ + InsecureSkipVerify: true, //nolint:gosec + } + + return server +} + +func createTestWebSocketConnection() (*websocket.Conn, *httptest.Server) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upgrader := websocket.Upgrader{} + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + + defer c.Close() + + for { + mt, message, err := c.ReadMessage() + if err != nil { + break + } + + err = c.WriteMessage(mt, message) + if err != nil { + break + } + } + })) + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + dialer := newTestDialer() + + ws, resp, err := dialer.Dial(wsURL, nil) + if err != nil { + panic(err) + } + + if resp != nil && resp.Body != nil { + defer resp.Body.Close() + } + + return ws, server +}