fix: WebSocket streaming - address all AI review findings (#1076)
All checks were successful
ci/woodpecker/push/ci Pipeline was successful
ci/woodpecker/push/nomad-validate Pipeline was successful
ci/woodpecker/pr/ci Pipeline was successful
ci/woodpecker/pr/edge-subpath Pipeline was successful
ci/woodpecker/pr/nomad-validate Pipeline was successful
ci/woodpecker/pr/secret-scan Pipeline was successful
All checks were successful
ci/woodpecker/push/ci Pipeline was successful
ci/woodpecker/push/nomad-validate Pipeline was successful
ci/woodpecker/pr/ci Pipeline was successful
ci/woodpecker/pr/edge-subpath Pipeline was successful
ci/woodpecker/pr/nomad-validate Pipeline was successful
ci/woodpecker/pr/secret-scan Pipeline was successful
Fixes identified in AI review: - Blocker #1: Server now handles chat_request WebSocket frames and invokes Claude - Blocker #2: accept_connection() uses self.headers from BaseHTTPRequestHandler - Blocker #3: handle_websocket_upgrade() uses asyncio.open_connection() for proper StreamWriter - Medium #4: _decode_frame() uses readexactly() for all fixed-length reads - Medium #5: Message queue cleaned up on disconnect in handle_connection() finally block - Low #6: WebSocket close code corrected from 768 to 1000 - Low #7: _send_close() and _send_pong() are now async with proper await Changes: - Added _handle_chat_request() method to invoke Claude within WebSocket coroutine - Fixed _send_close() to use struct.pack for correct close code (1000) - Made _send_pong() async with proper await - Updated handle_connection() to call async close/pong methods and cleanup queue - Fixed handle_websocket_upgrade() to pass Sec-WebSocket-Key from HTTP headers - Replaced create_connection() with open_connection() for proper reader/writer
This commit is contained in:
parent
17e745376d
commit
01f7d061bc
1 changed files with 127 additions and 75 deletions
|
|
@ -335,47 +335,14 @@ class _WebSocketHandler:
|
|||
self.message_queue = message_queue
|
||||
self.closed = False
|
||||
|
||||
async def accept_connection(self):
|
||||
"""Accept the WebSocket handshake."""
|
||||
# Read the HTTP request
|
||||
request_line = await self._read_line()
|
||||
if not request_line.startswith("GET "):
|
||||
self._close_connection()
|
||||
return False
|
||||
|
||||
# Parse the request
|
||||
headers = {}
|
||||
while True:
|
||||
line = await self._read_line()
|
||||
if line == "":
|
||||
break
|
||||
if ":" in line:
|
||||
key, value = line.split(":", 1)
|
||||
headers[key.strip().lower()] = value.strip()
|
||||
|
||||
# Validate WebSocket upgrade
|
||||
if headers.get("upgrade", "").lower() != "websocket":
|
||||
self._send_http_error(400, "Bad Request", "WebSocket upgrade required")
|
||||
self._close_connection()
|
||||
return False
|
||||
|
||||
if headers.get("connection", "").lower() != "upgrade":
|
||||
self._send_http_error(400, "Bad Request", "Connection upgrade required")
|
||||
self._close_connection()
|
||||
return False
|
||||
|
||||
# Get Sec-WebSocket-Key
|
||||
sec_key = headers.get("sec-websocket-key", "")
|
||||
if not sec_key:
|
||||
self._send_http_error(400, "Bad Request", "Missing Sec-WebSocket-Key")
|
||||
self._close_connection()
|
||||
return False
|
||||
|
||||
# Get Sec-WebSocket-Protocol if provided
|
||||
sec_protocol = headers.get("sec-websocket-protocol", "")
|
||||
async def accept_connection(self, sec_websocket_key, sec_websocket_protocol=None):
|
||||
"""Accept the WebSocket handshake.
|
||||
|
||||
The HTTP request has already been parsed by BaseHTTPRequestHandler,
|
||||
so we use the provided key and protocol instead of re-reading from socket.
|
||||
"""
|
||||
# Validate subprotocol
|
||||
if sec_protocol and sec_protocol != WEBSOCKET_SUBPROTOCOL:
|
||||
if sec_websocket_protocol and sec_websocket_protocol != WEBSOCKET_SUBPROTOCOL:
|
||||
self._send_http_error(
|
||||
400,
|
||||
"Bad Request",
|
||||
|
|
@ -385,7 +352,7 @@ class _WebSocketHandler:
|
|||
return False
|
||||
|
||||
# Generate accept key
|
||||
accept_key = self._generate_accept_key(sec_key)
|
||||
accept_key = self._generate_accept_key(sec_websocket_key)
|
||||
|
||||
# Send handshake response
|
||||
response = (
|
||||
|
|
@ -395,8 +362,8 @@ class _WebSocketHandler:
|
|||
f"Sec-WebSocket-Accept: {accept_key}\r\n"
|
||||
)
|
||||
|
||||
if sec_protocol:
|
||||
response += f"Sec-WebSocket-Protocol: {sec_protocol}\r\n"
|
||||
if sec_websocket_protocol:
|
||||
response += f"Sec-WebSocket-Protocol: {sec_websocket_protocol}\r\n"
|
||||
|
||||
response += "\r\n"
|
||||
self.writer.write(response.encode("utf-8"))
|
||||
|
|
@ -491,10 +458,8 @@ class _WebSocketHandler:
|
|||
async def _decode_frame(self):
|
||||
"""Decode a WebSocket frame. Returns (opcode, payload)."""
|
||||
try:
|
||||
# Read first two bytes
|
||||
header = await self.reader.read(2)
|
||||
if len(header) < 2:
|
||||
return None, None
|
||||
# Read first two bytes (use readexactly for guaranteed length)
|
||||
header = await self.reader.readexactly(2)
|
||||
|
||||
fin = (header[0] >> 7) & 1
|
||||
opcode = header[0] & 0x0F
|
||||
|
|
@ -503,18 +468,18 @@ class _WebSocketHandler:
|
|||
|
||||
# Extended payload length
|
||||
if length == 126:
|
||||
ext = await self.reader.read(2)
|
||||
ext = await self.reader.readexactly(2)
|
||||
length = struct.unpack(">H", ext)[0]
|
||||
elif length == 127:
|
||||
ext = await self.reader.read(8)
|
||||
ext = await self.reader.readexactly(8)
|
||||
length = struct.unpack(">Q", ext)[0]
|
||||
|
||||
# Masking key
|
||||
if masked:
|
||||
mask_key = await self.reader.read(4)
|
||||
mask_key = await self.reader.readexactly(4)
|
||||
|
||||
# Payload
|
||||
payload = await self.reader.read(length)
|
||||
payload = await self.reader.readexactly(length)
|
||||
|
||||
# Unmask if needed
|
||||
if masked:
|
||||
|
|
@ -534,15 +499,22 @@ class _WebSocketHandler:
|
|||
break
|
||||
|
||||
if opcode == OPCODE_CLOSE:
|
||||
self._send_close()
|
||||
await self._send_close()
|
||||
break
|
||||
elif opcode == OPCODE_PING:
|
||||
self._send_pong(payload)
|
||||
await self._send_pong(payload)
|
||||
elif opcode == OPCODE_PONG:
|
||||
pass # Ignore pong
|
||||
elif opcode in (OPCODE_TEXT, OPCODE_BINARY):
|
||||
# Handle text messages from client (e.g., heartbeat ack)
|
||||
pass
|
||||
# Handle text messages from client (e.g., chat_request)
|
||||
try:
|
||||
msg = payload.decode("utf-8")
|
||||
data = json.loads(msg)
|
||||
if data.get("type") == "chat_request":
|
||||
# Invoke Claude with the message
|
||||
await self._handle_chat_request(data.get("message", ""))
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
pass
|
||||
|
||||
# Check if we should stop waiting for messages
|
||||
if self.closed:
|
||||
|
|
@ -552,25 +524,103 @@ class _WebSocketHandler:
|
|||
print(f"WebSocket connection error: {e}", file=sys.stderr)
|
||||
finally:
|
||||
self._close_connection()
|
||||
# Clean up the message queue on disconnect
|
||||
if self.user in _websocket_queues:
|
||||
del _websocket_queues[self.user]
|
||||
|
||||
def _send_close(self):
|
||||
async def _send_close(self):
|
||||
"""Send a close frame."""
|
||||
try:
|
||||
frame = self._encode_frame(OPCODE_CLOSE, b"\x03\x00")
|
||||
# Close code 1000 = normal closure
|
||||
frame = self._encode_frame(OPCODE_CLOSE, struct.pack(">H", 1000))
|
||||
self.writer.write(frame)
|
||||
self.writer.drain()
|
||||
await self.writer.drain()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _send_pong(self, payload):
|
||||
async def _send_pong(self, payload):
|
||||
"""Send a pong frame."""
|
||||
try:
|
||||
frame = self._encode_frame(OPCODE_PONG, payload)
|
||||
self.writer.write(frame)
|
||||
self.writer.drain()
|
||||
await self.writer.drain()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _handle_chat_request(self, message):
|
||||
"""Handle a chat_request WebSocket frame by invoking Claude."""
|
||||
if not message:
|
||||
return
|
||||
|
||||
# Validate Claude binary exists
|
||||
if not os.path.exists(CLAUDE_BIN):
|
||||
await self.send_text(json.dumps({
|
||||
"type": "error",
|
||||
"message": "Claude CLI not found",
|
||||
}))
|
||||
return
|
||||
|
||||
try:
|
||||
# Spawn claude --print with stream-json for streaming output
|
||||
proc = subprocess.Popen(
|
||||
[CLAUDE_BIN, "--print", "--output-format", "stream-json", message],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
)
|
||||
|
||||
# Stream output line by line
|
||||
for line in iter(proc.stdout.readline, ""):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
event = json.loads(line)
|
||||
etype = event.get("type", "")
|
||||
|
||||
# Extract text content from content_block_delta events
|
||||
if etype == "content_block_delta":
|
||||
delta = event.get("delta", {})
|
||||
if delta.get("type") == "text_delta":
|
||||
text = delta.get("text", "")
|
||||
if text:
|
||||
# Send tokens to client
|
||||
await self.send_text(text)
|
||||
|
||||
# Check for usage event to know when complete
|
||||
if etype == "result":
|
||||
pass # Will send complete after loop
|
||||
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Wait for process to complete
|
||||
proc.wait()
|
||||
|
||||
if proc.returncode != 0:
|
||||
await self.send_text(json.dumps({
|
||||
"type": "error",
|
||||
"message": f"Claude CLI failed with exit code {proc.returncode}",
|
||||
}))
|
||||
return
|
||||
|
||||
# Send complete signal
|
||||
await self.send_text(json.dumps({
|
||||
"type": "complete",
|
||||
}))
|
||||
|
||||
except FileNotFoundError:
|
||||
await self.send_text(json.dumps({
|
||||
"type": "error",
|
||||
"message": "Claude CLI not found",
|
||||
}))
|
||||
except Exception as e:
|
||||
await self.send_text(json.dumps({
|
||||
"type": "error",
|
||||
"message": str(e),
|
||||
}))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Conversation History Functions (#710)
|
||||
|
|
@ -1259,28 +1309,30 @@ class ChatHandler(BaseHTTPRequestHandler):
|
|||
# Create message queue for this user
|
||||
_websocket_queues[user] = asyncio.Queue()
|
||||
|
||||
# Get WebSocket upgrade headers from the HTTP request
|
||||
sec_websocket_key = self.headers.get("Sec-WebSocket-Key", "")
|
||||
sec_websocket_protocol = self.headers.get("Sec-WebSocket-Protocol", "")
|
||||
|
||||
# Validate Sec-WebSocket-Key
|
||||
if not sec_websocket_key:
|
||||
self.send_error_page(400, "Bad Request", "Missing Sec-WebSocket-Key")
|
||||
return
|
||||
|
||||
# Get the socket from the connection
|
||||
sock = self.connection
|
||||
sock.setblocking(False)
|
||||
reader = asyncio.StreamReader()
|
||||
protocol = asyncio.StreamReaderProtocol(reader)
|
||||
|
||||
# Create async server to handle the connection
|
||||
async def handle_ws():
|
||||
try:
|
||||
# Wrap the socket in asyncio streams
|
||||
transport, _ = await asyncio.get_event_loop().create_connection(
|
||||
lambda: protocol,
|
||||
sock=sock,
|
||||
)
|
||||
ws_reader = protocol._stream_reader
|
||||
ws_writer = transport
|
||||
# Wrap the socket in asyncio streams using open_connection
|
||||
reader, writer = await asyncio.open_connection(sock=sock)
|
||||
|
||||
# Create WebSocket handler
|
||||
ws_handler = _WebSocketHandler(ws_reader, ws_writer, user, _websocket_queues[user])
|
||||
ws_handler = _WebSocketHandler(reader, writer, user, _websocket_queues[user])
|
||||
|
||||
# Accept the connection
|
||||
if not await ws_handler.accept_connection():
|
||||
# Accept the connection (pass headers from HTTP request)
|
||||
if not await ws_handler.accept_connection(sec_websocket_key, sec_websocket_protocol):
|
||||
return
|
||||
|
||||
# Start a task to read from the queue and send to client
|
||||
|
|
@ -1293,8 +1345,8 @@ class ChatHandler(BaseHTTPRequestHandler):
|
|||
# Send ping to keep connection alive
|
||||
try:
|
||||
frame = ws_handler._encode_frame(OPCODE_PING, b"")
|
||||
ws_writer.write(frame)
|
||||
await ws_writer.drain()
|
||||
writer.write(frame)
|
||||
await writer.drain()
|
||||
except Exception:
|
||||
break
|
||||
except Exception as e:
|
||||
|
|
@ -1318,8 +1370,8 @@ class ChatHandler(BaseHTTPRequestHandler):
|
|||
print(f"WebSocket handler error: {e}", file=sys.stderr)
|
||||
finally:
|
||||
try:
|
||||
ws_writer.close()
|
||||
await ws_writer.wait_closed()
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue