fix: WebSocket streaming - address all AI review findings (#1076)
All checks were successful
ci/woodpecker/push/ci Pipeline was successful
ci/woodpecker/push/nomad-validate Pipeline was successful
ci/woodpecker/pr/ci Pipeline was successful
ci/woodpecker/pr/edge-subpath Pipeline was successful
ci/woodpecker/pr/nomad-validate Pipeline was successful
ci/woodpecker/pr/secret-scan Pipeline was successful

Fixes identified in AI review:
- Blocker #1: Server now handles chat_request WebSocket frames and invokes Claude
- Blocker #2: accept_connection() uses self.headers from BaseHTTPRequestHandler
- Blocker #3: handle_websocket_upgrade() uses asyncio.open_connection() for proper StreamWriter
- Medium #4: _decode_frame() uses readexactly() for all fixed-length reads
- Medium #5: Message queue cleaned up on disconnect in handle_connection() finally block
- Low #6: WebSocket close code corrected from 768 to 1000
- Low #7: _send_close() and _send_pong() are now async with proper await

Changes:
- Added _handle_chat_request() method to invoke Claude within WebSocket coroutine
- Fixed _send_close() to use struct.pack for correct close code (1000)
- Made _send_pong() async with proper await
- Updated handle_connection() to call async close/pong methods and cleanup queue
- Fixed handle_websocket_upgrade() to pass Sec-WebSocket-Key from HTTP headers
- Replaced create_connection() with open_connection() for proper reader/writer
This commit is contained in:
Agent 2026-04-20 11:36:22 +00:00
parent 17e745376d
commit 01f7d061bc

View file

@ -335,47 +335,14 @@ class _WebSocketHandler:
self.message_queue = message_queue
self.closed = False
async def accept_connection(self):
"""Accept the WebSocket handshake."""
# Read the HTTP request
request_line = await self._read_line()
if not request_line.startswith("GET "):
self._close_connection()
return False
# Parse the request
headers = {}
while True:
line = await self._read_line()
if line == "":
break
if ":" in line:
key, value = line.split(":", 1)
headers[key.strip().lower()] = value.strip()
# Validate WebSocket upgrade
if headers.get("upgrade", "").lower() != "websocket":
self._send_http_error(400, "Bad Request", "WebSocket upgrade required")
self._close_connection()
return False
if headers.get("connection", "").lower() != "upgrade":
self._send_http_error(400, "Bad Request", "Connection upgrade required")
self._close_connection()
return False
# Get Sec-WebSocket-Key
sec_key = headers.get("sec-websocket-key", "")
if not sec_key:
self._send_http_error(400, "Bad Request", "Missing Sec-WebSocket-Key")
self._close_connection()
return False
# Get Sec-WebSocket-Protocol if provided
sec_protocol = headers.get("sec-websocket-protocol", "")
async def accept_connection(self, sec_websocket_key, sec_websocket_protocol=None):
"""Accept the WebSocket handshake.
The HTTP request has already been parsed by BaseHTTPRequestHandler,
so we use the provided key and protocol instead of re-reading from socket.
"""
# Validate subprotocol
if sec_protocol and sec_protocol != WEBSOCKET_SUBPROTOCOL:
if sec_websocket_protocol and sec_websocket_protocol != WEBSOCKET_SUBPROTOCOL:
self._send_http_error(
400,
"Bad Request",
@ -385,7 +352,7 @@ class _WebSocketHandler:
return False
# Generate accept key
accept_key = self._generate_accept_key(sec_key)
accept_key = self._generate_accept_key(sec_websocket_key)
# Send handshake response
response = (
@ -395,8 +362,8 @@ class _WebSocketHandler:
f"Sec-WebSocket-Accept: {accept_key}\r\n"
)
if sec_protocol:
response += f"Sec-WebSocket-Protocol: {sec_protocol}\r\n"
if sec_websocket_protocol:
response += f"Sec-WebSocket-Protocol: {sec_websocket_protocol}\r\n"
response += "\r\n"
self.writer.write(response.encode("utf-8"))
@ -491,10 +458,8 @@ class _WebSocketHandler:
async def _decode_frame(self):
"""Decode a WebSocket frame. Returns (opcode, payload)."""
try:
# Read first two bytes
header = await self.reader.read(2)
if len(header) < 2:
return None, None
# Read first two bytes (use readexactly for guaranteed length)
header = await self.reader.readexactly(2)
fin = (header[0] >> 7) & 1
opcode = header[0] & 0x0F
@ -503,18 +468,18 @@ class _WebSocketHandler:
# Extended payload length
if length == 126:
ext = await self.reader.read(2)
ext = await self.reader.readexactly(2)
length = struct.unpack(">H", ext)[0]
elif length == 127:
ext = await self.reader.read(8)
ext = await self.reader.readexactly(8)
length = struct.unpack(">Q", ext)[0]
# Masking key
if masked:
mask_key = await self.reader.read(4)
mask_key = await self.reader.readexactly(4)
# Payload
payload = await self.reader.read(length)
payload = await self.reader.readexactly(length)
# Unmask if needed
if masked:
@ -534,15 +499,22 @@ class _WebSocketHandler:
break
if opcode == OPCODE_CLOSE:
self._send_close()
await self._send_close()
break
elif opcode == OPCODE_PING:
self._send_pong(payload)
await self._send_pong(payload)
elif opcode == OPCODE_PONG:
pass # Ignore pong
elif opcode in (OPCODE_TEXT, OPCODE_BINARY):
# Handle text messages from client (e.g., heartbeat ack)
pass
# Handle text messages from client (e.g., chat_request)
try:
msg = payload.decode("utf-8")
data = json.loads(msg)
if data.get("type") == "chat_request":
# Invoke Claude with the message
await self._handle_chat_request(data.get("message", ""))
except (json.JSONDecodeError, UnicodeDecodeError):
pass
# Check if we should stop waiting for messages
if self.closed:
@ -552,25 +524,103 @@ class _WebSocketHandler:
print(f"WebSocket connection error: {e}", file=sys.stderr)
finally:
self._close_connection()
# Clean up the message queue on disconnect
if self.user in _websocket_queues:
del _websocket_queues[self.user]
def _send_close(self):
async def _send_close(self):
"""Send a close frame."""
try:
frame = self._encode_frame(OPCODE_CLOSE, b"\x03\x00")
# Close code 1000 = normal closure
frame = self._encode_frame(OPCODE_CLOSE, struct.pack(">H", 1000))
self.writer.write(frame)
self.writer.drain()
await self.writer.drain()
except Exception:
pass
def _send_pong(self, payload):
async def _send_pong(self, payload):
"""Send a pong frame."""
try:
frame = self._encode_frame(OPCODE_PONG, payload)
self.writer.write(frame)
self.writer.drain()
await self.writer.drain()
except Exception:
pass
async def _handle_chat_request(self, message):
"""Handle a chat_request WebSocket frame by invoking Claude."""
if not message:
return
# Validate Claude binary exists
if not os.path.exists(CLAUDE_BIN):
await self.send_text(json.dumps({
"type": "error",
"message": "Claude CLI not found",
}))
return
try:
# Spawn claude --print with stream-json for streaming output
proc = subprocess.Popen(
[CLAUDE_BIN, "--print", "--output-format", "stream-json", message],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
bufsize=1,
)
# Stream output line by line
for line in iter(proc.stdout.readline, ""):
line = line.strip()
if not line:
continue
try:
event = json.loads(line)
etype = event.get("type", "")
# Extract text content from content_block_delta events
if etype == "content_block_delta":
delta = event.get("delta", {})
if delta.get("type") == "text_delta":
text = delta.get("text", "")
if text:
# Send tokens to client
await self.send_text(text)
# Check for usage event to know when complete
if etype == "result":
pass # Will send complete after loop
except json.JSONDecodeError:
pass
# Wait for process to complete
proc.wait()
if proc.returncode != 0:
await self.send_text(json.dumps({
"type": "error",
"message": f"Claude CLI failed with exit code {proc.returncode}",
}))
return
# Send complete signal
await self.send_text(json.dumps({
"type": "complete",
}))
except FileNotFoundError:
await self.send_text(json.dumps({
"type": "error",
"message": "Claude CLI not found",
}))
except Exception as e:
await self.send_text(json.dumps({
"type": "error",
"message": str(e),
}))
# =============================================================================
# Conversation History Functions (#710)
@ -1259,28 +1309,30 @@ class ChatHandler(BaseHTTPRequestHandler):
# Create message queue for this user
_websocket_queues[user] = asyncio.Queue()
# Get WebSocket upgrade headers from the HTTP request
sec_websocket_key = self.headers.get("Sec-WebSocket-Key", "")
sec_websocket_protocol = self.headers.get("Sec-WebSocket-Protocol", "")
# Validate Sec-WebSocket-Key
if not sec_websocket_key:
self.send_error_page(400, "Bad Request", "Missing Sec-WebSocket-Key")
return
# Get the socket from the connection
sock = self.connection
sock.setblocking(False)
reader = asyncio.StreamReader()
protocol = asyncio.StreamReaderProtocol(reader)
# Create async server to handle the connection
async def handle_ws():
try:
# Wrap the socket in asyncio streams
transport, _ = await asyncio.get_event_loop().create_connection(
lambda: protocol,
sock=sock,
)
ws_reader = protocol._stream_reader
ws_writer = transport
# Wrap the socket in asyncio streams using open_connection
reader, writer = await asyncio.open_connection(sock=sock)
# Create WebSocket handler
ws_handler = _WebSocketHandler(ws_reader, ws_writer, user, _websocket_queues[user])
ws_handler = _WebSocketHandler(reader, writer, user, _websocket_queues[user])
# Accept the connection
if not await ws_handler.accept_connection():
# Accept the connection (pass headers from HTTP request)
if not await ws_handler.accept_connection(sec_websocket_key, sec_websocket_protocol):
return
# Start a task to read from the queue and send to client
@ -1293,8 +1345,8 @@ class ChatHandler(BaseHTTPRequestHandler):
# Send ping to keep connection alive
try:
frame = ws_handler._encode_frame(OPCODE_PING, b"")
ws_writer.write(frame)
await ws_writer.drain()
writer.write(frame)
await writer.drain()
except Exception:
break
except Exception as e:
@ -1318,8 +1370,8 @@ class ChatHandler(BaseHTTPRequestHandler):
print(f"WebSocket handler error: {e}", file=sys.stderr)
finally:
try:
ws_writer.close()
await ws_writer.wait_closed()
writer.close()
await writer.wait_closed()
except Exception:
pass