From 01f7d061bc9a74e25b94362a5b95721d70ad93df Mon Sep 17 00:00:00 2001 From: Agent Date: Mon, 20 Apr 2026 11:36:22 +0000 Subject: [PATCH] fix: WebSocket streaming - address all AI review findings (#1076) Fixes identified in AI review: - Blocker #1: Server now handles chat_request WebSocket frames and invokes Claude - Blocker #2: accept_connection() uses self.headers from BaseHTTPRequestHandler - Blocker #3: handle_websocket_upgrade() uses asyncio.open_connection() for proper StreamWriter - Medium #4: _decode_frame() uses readexactly() for all fixed-length reads - Medium #5: Message queue cleaned up on disconnect in handle_connection() finally block - Low #6: WebSocket close code corrected from 768 to 1000 - Low #7: _send_close() and _send_pong() are now async with proper await Changes: - Added _handle_chat_request() method to invoke Claude within WebSocket coroutine - Fixed _send_close() to use struct.pack for correct close code (1000) - Made _send_pong() async with proper await - Updated handle_connection() to call async close/pong methods and cleanup queue - Fixed handle_websocket_upgrade() to pass Sec-WebSocket-Key from HTTP headers - Replaced create_connection() with open_connection() for proper reader/writer --- docker/chat/server.py | 202 ++++++++++++++++++++++++++---------------- 1 file changed, 127 insertions(+), 75 deletions(-) diff --git a/docker/chat/server.py b/docker/chat/server.py index 85834f5..0623955 100644 --- a/docker/chat/server.py +++ b/docker/chat/server.py @@ -335,47 +335,14 @@ class _WebSocketHandler: self.message_queue = message_queue self.closed = False - async def accept_connection(self): - """Accept the WebSocket handshake.""" - # Read the HTTP request - request_line = await self._read_line() - if not request_line.startswith("GET "): - self._close_connection() - return False - - # Parse the request - headers = {} - while True: - line = await self._read_line() - if line == "": - break - if ":" in line: - key, value = line.split(":", 1) - headers[key.strip().lower()] = value.strip() - - # Validate WebSocket upgrade - if headers.get("upgrade", "").lower() != "websocket": - self._send_http_error(400, "Bad Request", "WebSocket upgrade required") - self._close_connection() - return False - - if headers.get("connection", "").lower() != "upgrade": - self._send_http_error(400, "Bad Request", "Connection upgrade required") - self._close_connection() - return False - - # Get Sec-WebSocket-Key - sec_key = headers.get("sec-websocket-key", "") - if not sec_key: - self._send_http_error(400, "Bad Request", "Missing Sec-WebSocket-Key") - self._close_connection() - return False - - # Get Sec-WebSocket-Protocol if provided - sec_protocol = headers.get("sec-websocket-protocol", "") + async def accept_connection(self, sec_websocket_key, sec_websocket_protocol=None): + """Accept the WebSocket handshake. + The HTTP request has already been parsed by BaseHTTPRequestHandler, + so we use the provided key and protocol instead of re-reading from socket. + """ # Validate subprotocol - if sec_protocol and sec_protocol != WEBSOCKET_SUBPROTOCOL: + if sec_websocket_protocol and sec_websocket_protocol != WEBSOCKET_SUBPROTOCOL: self._send_http_error( 400, "Bad Request", @@ -385,7 +352,7 @@ class _WebSocketHandler: return False # Generate accept key - accept_key = self._generate_accept_key(sec_key) + accept_key = self._generate_accept_key(sec_websocket_key) # Send handshake response response = ( @@ -395,8 +362,8 @@ class _WebSocketHandler: f"Sec-WebSocket-Accept: {accept_key}\r\n" ) - if sec_protocol: - response += f"Sec-WebSocket-Protocol: {sec_protocol}\r\n" + if sec_websocket_protocol: + response += f"Sec-WebSocket-Protocol: {sec_websocket_protocol}\r\n" response += "\r\n" self.writer.write(response.encode("utf-8")) @@ -491,10 +458,8 @@ class _WebSocketHandler: async def _decode_frame(self): """Decode a WebSocket frame. Returns (opcode, payload).""" try: - # Read first two bytes - header = await self.reader.read(2) - if len(header) < 2: - return None, None + # Read first two bytes (use readexactly for guaranteed length) + header = await self.reader.readexactly(2) fin = (header[0] >> 7) & 1 opcode = header[0] & 0x0F @@ -503,18 +468,18 @@ class _WebSocketHandler: # Extended payload length if length == 126: - ext = await self.reader.read(2) + ext = await self.reader.readexactly(2) length = struct.unpack(">H", ext)[0] elif length == 127: - ext = await self.reader.read(8) + ext = await self.reader.readexactly(8) length = struct.unpack(">Q", ext)[0] # Masking key if masked: - mask_key = await self.reader.read(4) + mask_key = await self.reader.readexactly(4) # Payload - payload = await self.reader.read(length) + payload = await self.reader.readexactly(length) # Unmask if needed if masked: @@ -534,15 +499,22 @@ class _WebSocketHandler: break if opcode == OPCODE_CLOSE: - self._send_close() + await self._send_close() break elif opcode == OPCODE_PING: - self._send_pong(payload) + await self._send_pong(payload) elif opcode == OPCODE_PONG: pass # Ignore pong elif opcode in (OPCODE_TEXT, OPCODE_BINARY): - # Handle text messages from client (e.g., heartbeat ack) - pass + # Handle text messages from client (e.g., chat_request) + try: + msg = payload.decode("utf-8") + data = json.loads(msg) + if data.get("type") == "chat_request": + # Invoke Claude with the message + await self._handle_chat_request(data.get("message", "")) + except (json.JSONDecodeError, UnicodeDecodeError): + pass # Check if we should stop waiting for messages if self.closed: @@ -552,25 +524,103 @@ class _WebSocketHandler: print(f"WebSocket connection error: {e}", file=sys.stderr) finally: self._close_connection() + # Clean up the message queue on disconnect + if self.user in _websocket_queues: + del _websocket_queues[self.user] - def _send_close(self): + async def _send_close(self): """Send a close frame.""" try: - frame = self._encode_frame(OPCODE_CLOSE, b"\x03\x00") + # Close code 1000 = normal closure + frame = self._encode_frame(OPCODE_CLOSE, struct.pack(">H", 1000)) self.writer.write(frame) - self.writer.drain() + await self.writer.drain() except Exception: pass - def _send_pong(self, payload): + async def _send_pong(self, payload): """Send a pong frame.""" try: frame = self._encode_frame(OPCODE_PONG, payload) self.writer.write(frame) - self.writer.drain() + await self.writer.drain() except Exception: pass + async def _handle_chat_request(self, message): + """Handle a chat_request WebSocket frame by invoking Claude.""" + if not message: + return + + # Validate Claude binary exists + if not os.path.exists(CLAUDE_BIN): + await self.send_text(json.dumps({ + "type": "error", + "message": "Claude CLI not found", + })) + return + + try: + # Spawn claude --print with stream-json for streaming output + proc = subprocess.Popen( + [CLAUDE_BIN, "--print", "--output-format", "stream-json", message], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + bufsize=1, + ) + + # Stream output line by line + for line in iter(proc.stdout.readline, ""): + line = line.strip() + if not line: + continue + try: + event = json.loads(line) + etype = event.get("type", "") + + # Extract text content from content_block_delta events + if etype == "content_block_delta": + delta = event.get("delta", {}) + if delta.get("type") == "text_delta": + text = delta.get("text", "") + if text: + # Send tokens to client + await self.send_text(text) + + # Check for usage event to know when complete + if etype == "result": + pass # Will send complete after loop + + except json.JSONDecodeError: + pass + + # Wait for process to complete + proc.wait() + + if proc.returncode != 0: + await self.send_text(json.dumps({ + "type": "error", + "message": f"Claude CLI failed with exit code {proc.returncode}", + })) + return + + # Send complete signal + await self.send_text(json.dumps({ + "type": "complete", + })) + + except FileNotFoundError: + await self.send_text(json.dumps({ + "type": "error", + "message": "Claude CLI not found", + })) + except Exception as e: + await self.send_text(json.dumps({ + "type": "error", + "message": str(e), + })) + # ============================================================================= # Conversation History Functions (#710) @@ -1259,28 +1309,30 @@ class ChatHandler(BaseHTTPRequestHandler): # Create message queue for this user _websocket_queues[user] = asyncio.Queue() + # Get WebSocket upgrade headers from the HTTP request + sec_websocket_key = self.headers.get("Sec-WebSocket-Key", "") + sec_websocket_protocol = self.headers.get("Sec-WebSocket-Protocol", "") + + # Validate Sec-WebSocket-Key + if not sec_websocket_key: + self.send_error_page(400, "Bad Request", "Missing Sec-WebSocket-Key") + return + # Get the socket from the connection sock = self.connection sock.setblocking(False) - reader = asyncio.StreamReader() - protocol = asyncio.StreamReaderProtocol(reader) # Create async server to handle the connection async def handle_ws(): try: - # Wrap the socket in asyncio streams - transport, _ = await asyncio.get_event_loop().create_connection( - lambda: protocol, - sock=sock, - ) - ws_reader = protocol._stream_reader - ws_writer = transport + # Wrap the socket in asyncio streams using open_connection + reader, writer = await asyncio.open_connection(sock=sock) # Create WebSocket handler - ws_handler = _WebSocketHandler(ws_reader, ws_writer, user, _websocket_queues[user]) + ws_handler = _WebSocketHandler(reader, writer, user, _websocket_queues[user]) - # Accept the connection - if not await ws_handler.accept_connection(): + # Accept the connection (pass headers from HTTP request) + if not await ws_handler.accept_connection(sec_websocket_key, sec_websocket_protocol): return # Start a task to read from the queue and send to client @@ -1293,8 +1345,8 @@ class ChatHandler(BaseHTTPRequestHandler): # Send ping to keep connection alive try: frame = ws_handler._encode_frame(OPCODE_PING, b"") - ws_writer.write(frame) - await ws_writer.drain() + writer.write(frame) + await writer.drain() except Exception: break except Exception as e: @@ -1318,8 +1370,8 @@ class ChatHandler(BaseHTTPRequestHandler): print(f"WebSocket handler error: {e}", file=sys.stderr) finally: try: - ws_writer.close() - await ws_writer.wait_closed() + writer.close() + await writer.wait_closed() except Exception: pass