1 changed files with 127 additions and 75 deletions
|
|
@ -335,47 +335,14 @@ class _WebSocketHandler:
|
||||||
self.message_queue = message_queue
|
self.message_queue = message_queue
|
||||||
self.closed = False
|
self.closed = False
|
||||||
|
|
||||||
async def accept_connection(self):
|
async def accept_connection(self, sec_websocket_key, sec_websocket_protocol=None):
|
||||||
"""Accept the WebSocket handshake."""
|
"""Accept the WebSocket handshake.
|
||||||
# Read the HTTP request
|
|
||||||
request_line = await self._read_line()
|
|
||||||
if not request_line.startswith("GET "):
|
|
||||||
self._close_connection()
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Parse the request
|
|
||||||
headers = {}
|
|
||||||
while True:
|
|
||||||
line = await self._read_line()
|
|
||||||
if line == "":
|
|
||||||
break
|
|
||||||
if ":" in line:
|
|
||||||
key, value = line.split(":", 1)
|
|
||||||
headers[key.strip().lower()] = value.strip()
|
|
||||||
|
|
||||||
# Validate WebSocket upgrade
|
|
||||||
if headers.get("upgrade", "").lower() != "websocket":
|
|
||||||
self._send_http_error(400, "Bad Request", "WebSocket upgrade required")
|
|
||||||
self._close_connection()
|
|
||||||
return False
|
|
||||||
|
|
||||||
if headers.get("connection", "").lower() != "upgrade":
|
|
||||||
self._send_http_error(400, "Bad Request", "Connection upgrade required")
|
|
||||||
self._close_connection()
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Get Sec-WebSocket-Key
|
|
||||||
sec_key = headers.get("sec-websocket-key", "")
|
|
||||||
if not sec_key:
|
|
||||||
self._send_http_error(400, "Bad Request", "Missing Sec-WebSocket-Key")
|
|
||||||
self._close_connection()
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Get Sec-WebSocket-Protocol if provided
|
|
||||||
sec_protocol = headers.get("sec-websocket-protocol", "")
|
|
||||||
|
|
||||||
|
The HTTP request has already been parsed by BaseHTTPRequestHandler,
|
||||||
|
so we use the provided key and protocol instead of re-reading from socket.
|
||||||
|
"""
|
||||||
# Validate subprotocol
|
# Validate subprotocol
|
||||||
if sec_protocol and sec_protocol != WEBSOCKET_SUBPROTOCOL:
|
if sec_websocket_protocol and sec_websocket_protocol != WEBSOCKET_SUBPROTOCOL:
|
||||||
self._send_http_error(
|
self._send_http_error(
|
||||||
400,
|
400,
|
||||||
"Bad Request",
|
"Bad Request",
|
||||||
|
|
@ -385,7 +352,7 @@ class _WebSocketHandler:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Generate accept key
|
# Generate accept key
|
||||||
accept_key = self._generate_accept_key(sec_key)
|
accept_key = self._generate_accept_key(sec_websocket_key)
|
||||||
|
|
||||||
# Send handshake response
|
# Send handshake response
|
||||||
response = (
|
response = (
|
||||||
|
|
@ -395,8 +362,8 @@ class _WebSocketHandler:
|
||||||
f"Sec-WebSocket-Accept: {accept_key}\r\n"
|
f"Sec-WebSocket-Accept: {accept_key}\r\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
if sec_protocol:
|
if sec_websocket_protocol:
|
||||||
response += f"Sec-WebSocket-Protocol: {sec_protocol}\r\n"
|
response += f"Sec-WebSocket-Protocol: {sec_websocket_protocol}\r\n"
|
||||||
|
|
||||||
response += "\r\n"
|
response += "\r\n"
|
||||||
self.writer.write(response.encode("utf-8"))
|
self.writer.write(response.encode("utf-8"))
|
||||||
|
|
@ -491,10 +458,8 @@ class _WebSocketHandler:
|
||||||
async def _decode_frame(self):
|
async def _decode_frame(self):
|
||||||
"""Decode a WebSocket frame. Returns (opcode, payload)."""
|
"""Decode a WebSocket frame. Returns (opcode, payload)."""
|
||||||
try:
|
try:
|
||||||
# Read first two bytes
|
# Read first two bytes (use readexactly for guaranteed length)
|
||||||
header = await self.reader.read(2)
|
header = await self.reader.readexactly(2)
|
||||||
if len(header) < 2:
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
fin = (header[0] >> 7) & 1
|
fin = (header[0] >> 7) & 1
|
||||||
opcode = header[0] & 0x0F
|
opcode = header[0] & 0x0F
|
||||||
|
|
@ -503,18 +468,18 @@ class _WebSocketHandler:
|
||||||
|
|
||||||
# Extended payload length
|
# Extended payload length
|
||||||
if length == 126:
|
if length == 126:
|
||||||
ext = await self.reader.read(2)
|
ext = await self.reader.readexactly(2)
|
||||||
length = struct.unpack(">H", ext)[0]
|
length = struct.unpack(">H", ext)[0]
|
||||||
elif length == 127:
|
elif length == 127:
|
||||||
ext = await self.reader.read(8)
|
ext = await self.reader.readexactly(8)
|
||||||
length = struct.unpack(">Q", ext)[0]
|
length = struct.unpack(">Q", ext)[0]
|
||||||
|
|
||||||
# Masking key
|
# Masking key
|
||||||
if masked:
|
if masked:
|
||||||
mask_key = await self.reader.read(4)
|
mask_key = await self.reader.readexactly(4)
|
||||||
|
|
||||||
# Payload
|
# Payload
|
||||||
payload = await self.reader.read(length)
|
payload = await self.reader.readexactly(length)
|
||||||
|
|
||||||
# Unmask if needed
|
# Unmask if needed
|
||||||
if masked:
|
if masked:
|
||||||
|
|
@ -534,14 +499,21 @@ class _WebSocketHandler:
|
||||||
break
|
break
|
||||||
|
|
||||||
if opcode == OPCODE_CLOSE:
|
if opcode == OPCODE_CLOSE:
|
||||||
self._send_close()
|
await self._send_close()
|
||||||
break
|
break
|
||||||
elif opcode == OPCODE_PING:
|
elif opcode == OPCODE_PING:
|
||||||
self._send_pong(payload)
|
await self._send_pong(payload)
|
||||||
elif opcode == OPCODE_PONG:
|
elif opcode == OPCODE_PONG:
|
||||||
pass # Ignore pong
|
pass # Ignore pong
|
||||||
elif opcode in (OPCODE_TEXT, OPCODE_BINARY):
|
elif opcode in (OPCODE_TEXT, OPCODE_BINARY):
|
||||||
# Handle text messages from client (e.g., heartbeat ack)
|
# Handle text messages from client (e.g., chat_request)
|
||||||
|
try:
|
||||||
|
msg = payload.decode("utf-8")
|
||||||
|
data = json.loads(msg)
|
||||||
|
if data.get("type") == "chat_request":
|
||||||
|
# Invoke Claude with the message
|
||||||
|
await self._handle_chat_request(data.get("message", ""))
|
||||||
|
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Check if we should stop waiting for messages
|
# Check if we should stop waiting for messages
|
||||||
|
|
@ -552,25 +524,103 @@ class _WebSocketHandler:
|
||||||
print(f"WebSocket connection error: {e}", file=sys.stderr)
|
print(f"WebSocket connection error: {e}", file=sys.stderr)
|
||||||
finally:
|
finally:
|
||||||
self._close_connection()
|
self._close_connection()
|
||||||
|
# Clean up the message queue on disconnect
|
||||||
|
if self.user in _websocket_queues:
|
||||||
|
del _websocket_queues[self.user]
|
||||||
|
|
||||||
def _send_close(self):
|
async def _send_close(self):
|
||||||
"""Send a close frame."""
|
"""Send a close frame."""
|
||||||
try:
|
try:
|
||||||
frame = self._encode_frame(OPCODE_CLOSE, b"\x03\x00")
|
# Close code 1000 = normal closure
|
||||||
|
frame = self._encode_frame(OPCODE_CLOSE, struct.pack(">H", 1000))
|
||||||
self.writer.write(frame)
|
self.writer.write(frame)
|
||||||
self.writer.drain()
|
await self.writer.drain()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _send_pong(self, payload):
|
async def _send_pong(self, payload):
|
||||||
"""Send a pong frame."""
|
"""Send a pong frame."""
|
||||||
try:
|
try:
|
||||||
frame = self._encode_frame(OPCODE_PONG, payload)
|
frame = self._encode_frame(OPCODE_PONG, payload)
|
||||||
self.writer.write(frame)
|
self.writer.write(frame)
|
||||||
self.writer.drain()
|
await self.writer.drain()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def _handle_chat_request(self, message):
|
||||||
|
"""Handle a chat_request WebSocket frame by invoking Claude."""
|
||||||
|
if not message:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Validate Claude binary exists
|
||||||
|
if not os.path.exists(CLAUDE_BIN):
|
||||||
|
await self.send_text(json.dumps({
|
||||||
|
"type": "error",
|
||||||
|
"message": "Claude CLI not found",
|
||||||
|
}))
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Spawn claude --print with stream-json for streaming output
|
||||||
|
proc = subprocess.Popen(
|
||||||
|
[CLAUDE_BIN, "--print", "--output-format", "stream-json", message],
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
text=True,
|
||||||
|
bufsize=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stream output line by line
|
||||||
|
for line in iter(proc.stdout.readline, ""):
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
event = json.loads(line)
|
||||||
|
etype = event.get("type", "")
|
||||||
|
|
||||||
|
# Extract text content from content_block_delta events
|
||||||
|
if etype == "content_block_delta":
|
||||||
|
delta = event.get("delta", {})
|
||||||
|
if delta.get("type") == "text_delta":
|
||||||
|
text = delta.get("text", "")
|
||||||
|
if text:
|
||||||
|
# Send tokens to client
|
||||||
|
await self.send_text(text)
|
||||||
|
|
||||||
|
# Check for usage event to know when complete
|
||||||
|
if etype == "result":
|
||||||
|
pass # Will send complete after loop
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Wait for process to complete
|
||||||
|
proc.wait()
|
||||||
|
|
||||||
|
if proc.returncode != 0:
|
||||||
|
await self.send_text(json.dumps({
|
||||||
|
"type": "error",
|
||||||
|
"message": f"Claude CLI failed with exit code {proc.returncode}",
|
||||||
|
}))
|
||||||
|
return
|
||||||
|
|
||||||
|
# Send complete signal
|
||||||
|
await self.send_text(json.dumps({
|
||||||
|
"type": "complete",
|
||||||
|
}))
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
await self.send_text(json.dumps({
|
||||||
|
"type": "error",
|
||||||
|
"message": "Claude CLI not found",
|
||||||
|
}))
|
||||||
|
except Exception as e:
|
||||||
|
await self.send_text(json.dumps({
|
||||||
|
"type": "error",
|
||||||
|
"message": str(e),
|
||||||
|
}))
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Conversation History Functions (#710)
|
# Conversation History Functions (#710)
|
||||||
|
|
@ -1259,28 +1309,30 @@ class ChatHandler(BaseHTTPRequestHandler):
|
||||||
# Create message queue for this user
|
# Create message queue for this user
|
||||||
_websocket_queues[user] = asyncio.Queue()
|
_websocket_queues[user] = asyncio.Queue()
|
||||||
|
|
||||||
|
# Get WebSocket upgrade headers from the HTTP request
|
||||||
|
sec_websocket_key = self.headers.get("Sec-WebSocket-Key", "")
|
||||||
|
sec_websocket_protocol = self.headers.get("Sec-WebSocket-Protocol", "")
|
||||||
|
|
||||||
|
# Validate Sec-WebSocket-Key
|
||||||
|
if not sec_websocket_key:
|
||||||
|
self.send_error_page(400, "Bad Request", "Missing Sec-WebSocket-Key")
|
||||||
|
return
|
||||||
|
|
||||||
# Get the socket from the connection
|
# Get the socket from the connection
|
||||||
sock = self.connection
|
sock = self.connection
|
||||||
sock.setblocking(False)
|
sock.setblocking(False)
|
||||||
reader = asyncio.StreamReader()
|
|
||||||
protocol = asyncio.StreamReaderProtocol(reader)
|
|
||||||
|
|
||||||
# Create async server to handle the connection
|
# Create async server to handle the connection
|
||||||
async def handle_ws():
|
async def handle_ws():
|
||||||
try:
|
try:
|
||||||
# Wrap the socket in asyncio streams
|
# Wrap the socket in asyncio streams using open_connection
|
||||||
transport, _ = await asyncio.get_event_loop().create_connection(
|
reader, writer = await asyncio.open_connection(sock=sock)
|
||||||
lambda: protocol,
|
|
||||||
sock=sock,
|
|
||||||
)
|
|
||||||
ws_reader = protocol._stream_reader
|
|
||||||
ws_writer = transport
|
|
||||||
|
|
||||||
# Create WebSocket handler
|
# Create WebSocket handler
|
||||||
ws_handler = _WebSocketHandler(ws_reader, ws_writer, user, _websocket_queues[user])
|
ws_handler = _WebSocketHandler(reader, writer, user, _websocket_queues[user])
|
||||||
|
|
||||||
# Accept the connection
|
# Accept the connection (pass headers from HTTP request)
|
||||||
if not await ws_handler.accept_connection():
|
if not await ws_handler.accept_connection(sec_websocket_key, sec_websocket_protocol):
|
||||||
return
|
return
|
||||||
|
|
||||||
# Start a task to read from the queue and send to client
|
# Start a task to read from the queue and send to client
|
||||||
|
|
@ -1293,8 +1345,8 @@ class ChatHandler(BaseHTTPRequestHandler):
|
||||||
# Send ping to keep connection alive
|
# Send ping to keep connection alive
|
||||||
try:
|
try:
|
||||||
frame = ws_handler._encode_frame(OPCODE_PING, b"")
|
frame = ws_handler._encode_frame(OPCODE_PING, b"")
|
||||||
ws_writer.write(frame)
|
writer.write(frame)
|
||||||
await ws_writer.drain()
|
await writer.drain()
|
||||||
except Exception:
|
except Exception:
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -1318,8 +1370,8 @@ class ChatHandler(BaseHTTPRequestHandler):
|
||||||
print(f"WebSocket handler error: {e}", file=sys.stderr)
|
print(f"WebSocket handler error: {e}", file=sys.stderr)
|
||||||
finally:
|
finally:
|
||||||
try:
|
try:
|
||||||
ws_writer.close()
|
writer.close()
|
||||||
await ws_writer.wait_closed()
|
await writer.wait_closed()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue