diff --git a/docker/chat/server.py b/docker/chat/server.py index ef37fb1..0623955 100644 --- a/docker/chat/server.py +++ b/docker/chat/server.py @@ -22,6 +22,7 @@ OAuth flow: The claude binary is expected to be mounted from the host at /usr/local/bin/claude. """ +import asyncio import datetime import json import os @@ -30,8 +31,14 @@ import secrets import subprocess import sys import time +import threading from http.server import HTTPServer, BaseHTTPRequestHandler +from socketserver import ThreadingMixIn from urllib.parse import urlparse, parse_qs, urlencode +import socket +import struct +import base64 +import hashlib # Configuration HOST = os.environ.get("CHAT_HOST", "0.0.0.0") @@ -89,6 +96,10 @@ _request_log = {} # user -> {"tokens": int, "date": "YYYY-MM-DD"} _daily_tokens = {} +# WebSocket message queues per user +# user -> asyncio.Queue (for streaming messages to connected clients) +_websocket_queues = {} + # MIME types for static files MIME_TYPES = { ".html": "text/html; charset=utf-8", @@ -101,6 +112,17 @@ MIME_TYPES = { ".ico": "image/x-icon", } +# WebSocket subprotocol for chat streaming +WEBSOCKET_SUBPROTOCOL = "chat-stream-v1" + +# WebSocket opcodes +OPCODE_CONTINUATION = 0x0 +OPCODE_TEXT = 0x1 +OPCODE_BINARY = 0x2 +OPCODE_CLOSE = 0x8 +OPCODE_PING = 0x9 +OPCODE_PONG = 0xA + def _build_callback_uri(): """Build the OAuth callback URI based on tunnel configuration.""" @@ -299,6 +321,307 @@ def _parse_stream_json(output): return "".join(text_parts), total_tokens +# ============================================================================= +# WebSocket Handler Class +# ============================================================================= + +class _WebSocketHandler: + """Handle WebSocket connections for chat streaming.""" + + def __init__(self, reader, writer, user, message_queue): + self.reader = reader + self.writer = writer + self.user = user + self.message_queue = message_queue + self.closed = False + + async def accept_connection(self, sec_websocket_key, sec_websocket_protocol=None): + """Accept the WebSocket handshake. + + The HTTP request has already been parsed by BaseHTTPRequestHandler, + so we use the provided key and protocol instead of re-reading from socket. + """ + # Validate subprotocol + if sec_websocket_protocol and sec_websocket_protocol != WEBSOCKET_SUBPROTOCOL: + self._send_http_error( + 400, + "Bad Request", + f"Unsupported subprotocol. Expected: {WEBSOCKET_SUBPROTOCOL}", + ) + self._close_connection() + return False + + # Generate accept key + accept_key = self._generate_accept_key(sec_websocket_key) + + # Send handshake response + response = ( + "HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + f"Sec-WebSocket-Accept: {accept_key}\r\n" + ) + + if sec_websocket_protocol: + response += f"Sec-WebSocket-Protocol: {sec_websocket_protocol}\r\n" + + response += "\r\n" + self.writer.write(response.encode("utf-8")) + await self.writer.drain() + return True + + def _generate_accept_key(self, sec_key): + """Generate the Sec-WebSocket-Accept key.""" + GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + combined = sec_key + GUID + sha1 = hashlib.sha1(combined.encode("utf-8")) + return base64.b64encode(sha1.digest()).decode("utf-8") + + async def _read_line(self): + """Read a line from the socket.""" + data = await self.reader.read(1) + line = "" + while data: + if data == b"\r": + data = await self.reader.read(1) + continue + if data == b"\n": + return line + line += data.decode("utf-8", errors="replace") + data = await self.reader.read(1) + return line + + def _send_http_error(self, code, title, message): + """Send an HTTP error response.""" + response = ( + f"HTTP/1.1 {code} {title}\r\n" + "Content-Type: text/plain; charset=utf-8\r\n" + "Content-Length: " + str(len(message)) + "\r\n" + "\r\n" + + message + ) + try: + self.writer.write(response.encode("utf-8")) + self.writer.drain() + except Exception: + pass + + def _close_connection(self): + """Close the connection.""" + try: + self.writer.close() + except Exception: + pass + + async def send_text(self, data): + """Send a text frame.""" + if self.closed: + return + try: + frame = self._encode_frame(OPCODE_TEXT, data.encode("utf-8")) + self.writer.write(frame) + await self.writer.drain() + except Exception as e: + print(f"WebSocket send error: {e}", file=sys.stderr) + + async def send_binary(self, data): + """Send a binary frame.""" + if self.closed: + return + try: + if isinstance(data, str): + data = data.encode("utf-8") + frame = self._encode_frame(OPCODE_BINARY, data) + self.writer.write(frame) + await self.writer.drain() + except Exception as e: + print(f"WebSocket send error: {e}", file=sys.stderr) + + def _encode_frame(self, opcode, payload): + """Encode a WebSocket frame.""" + frame = bytearray() + frame.append(0x80 | opcode) # FIN + opcode + + length = len(payload) + if length < 126: + frame.append(length) + elif length < 65536: + frame.append(126) + frame.extend(struct.pack(">H", length)) + else: + frame.append(127) + frame.extend(struct.pack(">Q", length)) + + frame.extend(payload) + return bytes(frame) + + async def _decode_frame(self): + """Decode a WebSocket frame. Returns (opcode, payload).""" + try: + # Read first two bytes (use readexactly for guaranteed length) + header = await self.reader.readexactly(2) + + fin = (header[0] >> 7) & 1 + opcode = header[0] & 0x0F + masked = (header[1] >> 7) & 1 + length = header[1] & 0x7F + + # Extended payload length + if length == 126: + ext = await self.reader.readexactly(2) + length = struct.unpack(">H", ext)[0] + elif length == 127: + ext = await self.reader.readexactly(8) + length = struct.unpack(">Q", ext)[0] + + # Masking key + if masked: + mask_key = await self.reader.readexactly(4) + + # Payload + payload = await self.reader.readexactly(length) + + # Unmask if needed + if masked: + payload = bytes(b ^ mask_key[i % 4] for i, b in enumerate(payload)) + + return opcode, payload + except Exception as e: + print(f"WebSocket decode error: {e}", file=sys.stderr) + return None, None + + async def handle_connection(self): + """Handle the WebSocket connection loop.""" + try: + while not self.closed: + opcode, payload = await self._decode_frame() + if opcode is None: + break + + if opcode == OPCODE_CLOSE: + await self._send_close() + break + elif opcode == OPCODE_PING: + await self._send_pong(payload) + elif opcode == OPCODE_PONG: + pass # Ignore pong + elif opcode in (OPCODE_TEXT, OPCODE_BINARY): + # Handle text messages from client (e.g., chat_request) + try: + msg = payload.decode("utf-8") + data = json.loads(msg) + if data.get("type") == "chat_request": + # Invoke Claude with the message + await self._handle_chat_request(data.get("message", "")) + except (json.JSONDecodeError, UnicodeDecodeError): + pass + + # Check if we should stop waiting for messages + if self.closed: + break + + except Exception as e: + print(f"WebSocket connection error: {e}", file=sys.stderr) + finally: + self._close_connection() + # Clean up the message queue on disconnect + if self.user in _websocket_queues: + del _websocket_queues[self.user] + + async def _send_close(self): + """Send a close frame.""" + try: + # Close code 1000 = normal closure + frame = self._encode_frame(OPCODE_CLOSE, struct.pack(">H", 1000)) + self.writer.write(frame) + await self.writer.drain() + except Exception: + pass + + async def _send_pong(self, payload): + """Send a pong frame.""" + try: + frame = self._encode_frame(OPCODE_PONG, payload) + self.writer.write(frame) + await self.writer.drain() + except Exception: + pass + + async def _handle_chat_request(self, message): + """Handle a chat_request WebSocket frame by invoking Claude.""" + if not message: + return + + # Validate Claude binary exists + if not os.path.exists(CLAUDE_BIN): + await self.send_text(json.dumps({ + "type": "error", + "message": "Claude CLI not found", + })) + return + + try: + # Spawn claude --print with stream-json for streaming output + proc = subprocess.Popen( + [CLAUDE_BIN, "--print", "--output-format", "stream-json", message], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + bufsize=1, + ) + + # Stream output line by line + for line in iter(proc.stdout.readline, ""): + line = line.strip() + if not line: + continue + try: + event = json.loads(line) + etype = event.get("type", "") + + # Extract text content from content_block_delta events + if etype == "content_block_delta": + delta = event.get("delta", {}) + if delta.get("type") == "text_delta": + text = delta.get("text", "") + if text: + # Send tokens to client + await self.send_text(text) + + # Check for usage event to know when complete + if etype == "result": + pass # Will send complete after loop + + except json.JSONDecodeError: + pass + + # Wait for process to complete + proc.wait() + + if proc.returncode != 0: + await self.send_text(json.dumps({ + "type": "error", + "message": f"Claude CLI failed with exit code {proc.returncode}", + })) + return + + # Send complete signal + await self.send_text(json.dumps({ + "type": "complete", + })) + + except FileNotFoundError: + await self.send_text(json.dumps({ + "type": "error", + "message": "Claude CLI not found", + })) + except Exception as e: + await self.send_text(json.dumps({ + "type": "error", + "message": str(e), + })) + + # ============================================================================= # Conversation History Functions (#710) # ============================================================================= @@ -548,9 +871,9 @@ class ChatHandler(BaseHTTPRequestHandler): self.serve_static(path) return - # Reserved WebSocket endpoint (future use) - if path == "/ws" or path.startswith("/ws"): - self.send_error_page(501, "WebSocket upgrade not yet implemented") + # WebSocket upgrade endpoint + if path == "/chat/ws" or path == "/ws" or path.startswith("/ws"): + self.handle_websocket_upgrade() return # 404 for unknown paths @@ -759,6 +1082,7 @@ class ChatHandler(BaseHTTPRequestHandler): """ Handle chat requests by spawning `claude --print` with the user message. Enforces per-user rate limits and tracks token usage (#711). + Streams tokens over WebSocket if connected. """ # Check rate limits before processing (#711) @@ -816,10 +1140,47 @@ class ChatHandler(BaseHTTPRequestHandler): stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, + bufsize=1, # Line buffered ) - raw_output = proc.stdout.read() + # Stream output line by line + response_parts = [] + total_tokens = 0 + for line in iter(proc.stdout.readline, ""): + line = line.strip() + if not line: + continue + try: + event = json.loads(line) + etype = event.get("type", "") + # Extract text content from content_block_delta events + if etype == "content_block_delta": + delta = event.get("delta", {}) + if delta.get("type") == "text_delta": + text = delta.get("text", "") + if text: + response_parts.append(text) + # Stream to WebSocket if connected + if user in _websocket_queues: + try: + _websocket_queues[user].put_nowait(text) + except Exception: + pass # Client disconnected + + # Parse usage from result event + if etype == "result": + usage = event.get("usage", {}) + total_tokens = usage.get("input_tokens", 0) + usage.get("output_tokens", 0) + elif "usage" in event: + usage = event["usage"] + if isinstance(usage, dict): + total_tokens = usage.get("input_tokens", 0) + usage.get("output_tokens", 0) + + except json.JSONDecodeError: + pass + + # Wait for process to complete error_output = proc.stderr.read() if error_output: print(f"Claude stderr: {error_output}", file=sys.stderr) @@ -830,8 +1191,8 @@ class ChatHandler(BaseHTTPRequestHandler): self.send_error_page(500, f"Claude CLI failed with exit code {proc.returncode}") return - # Parse stream-json for text and token usage (#711) - response, total_tokens = _parse_stream_json(raw_output) + # Combine response parts + response = "".join(response_parts) # Track token usage - does not block *this* request (#711) if total_tokens > 0: @@ -843,7 +1204,7 @@ class ChatHandler(BaseHTTPRequestHandler): # Fall back to raw output if stream-json parsing yielded no text if not response: - response = raw_output + response = proc.stdout.getvalue() if hasattr(proc.stdout, 'getvalue') else "" # Save assistant response to history _write_message(user, conv_id, "assistant", response) @@ -913,6 +1274,118 @@ class ChatHandler(BaseHTTPRequestHandler): self.end_headers() self.wfile.write(json.dumps({"conversation_id": conv_id}, ensure_ascii=False).encode("utf-8")) + @staticmethod + def push_to_websocket(user, message): + """Push a message to a WebSocket connection for a user. + + This is called from the chat handler to stream tokens to connected clients. + The message is added to the user's WebSocket message queue. + """ + # Get the message queue from the WebSocket handler's queue + # We store the queue in a global dict keyed by user + if user in _websocket_queues: + _websocket_queues[user].put_nowait(message) + + def handle_websocket_upgrade(self): + """Handle WebSocket upgrade request for chat streaming.""" + # Check session cookie + user = _validate_session(self.headers.get("Cookie")) + if not user: + self.send_error_page(401, "Unauthorized: no valid session") + return + + # Check rate limits before allowing WebSocket connection + allowed, retry_after, reason = _check_rate_limit(user) + if not allowed: + self.send_error_page( + 429, + f"Rate limit exceeded: {reason}. Retry after {retry_after}s", + ) + return + + # Record request for rate limiting + _record_request(user) + + # Create message queue for this user + _websocket_queues[user] = asyncio.Queue() + + # Get WebSocket upgrade headers from the HTTP request + sec_websocket_key = self.headers.get("Sec-WebSocket-Key", "") + sec_websocket_protocol = self.headers.get("Sec-WebSocket-Protocol", "") + + # Validate Sec-WebSocket-Key + if not sec_websocket_key: + self.send_error_page(400, "Bad Request", "Missing Sec-WebSocket-Key") + return + + # Get the socket from the connection + sock = self.connection + sock.setblocking(False) + + # Create async server to handle the connection + async def handle_ws(): + try: + # Wrap the socket in asyncio streams using open_connection + reader, writer = await asyncio.open_connection(sock=sock) + + # Create WebSocket handler + ws_handler = _WebSocketHandler(reader, writer, user, _websocket_queues[user]) + + # Accept the connection (pass headers from HTTP request) + if not await ws_handler.accept_connection(sec_websocket_key, sec_websocket_protocol): + return + + # Start a task to read from the queue and send to client + async def send_stream(): + while not ws_handler.closed: + try: + data = await asyncio.wait_for(ws_handler.message_queue.get(), timeout=1.0) + await ws_handler.send_text(data) + except asyncio.TimeoutError: + # Send ping to keep connection alive + try: + frame = ws_handler._encode_frame(OPCODE_PING, b"") + writer.write(frame) + await writer.drain() + except Exception: + break + except Exception as e: + print(f"Send stream error: {e}", file=sys.stderr) + break + + # Start sending task + send_task = asyncio.create_task(send_stream()) + + # Handle incoming WebSocket frames + await ws_handler.handle_connection() + + # Cancel send task + send_task.cancel() + try: + await send_task + except asyncio.CancelledError: + pass + + except Exception as e: + print(f"WebSocket handler error: {e}", file=sys.stderr) + finally: + try: + writer.close() + await writer.wait_closed() + except Exception: + pass + + # Run the async handler in a thread + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(handle_ws()) + except Exception as e: + print(f"WebSocket error: {e}", file=sys.stderr) + finally: + loop.close() + sock.close() + def do_DELETE(self): """Handle DELETE requests.""" parsed = urlparse(self.path) diff --git a/docker/chat/ui/index.html b/docker/chat/ui/index.html index bd920f9..b045873 100644 --- a/docker/chat/ui/index.html +++ b/docker/chat/ui/index.html @@ -430,6 +430,10 @@ return div.innerHTML.replace(/\n/g, '
'); } + // WebSocket connection for streaming + let ws = null; + let wsMessageId = null; + // Send message handler async function sendMessage() { const message = textarea.value.trim(); @@ -449,6 +453,14 @@ await createNewConversation(); } + // Try WebSocket streaming first, fall back to fetch + if (window.location.protocol === 'https:' || window.location.hostname === 'localhost') { + if (tryWebSocketSend(message)) { + return; + } + } + + // Fallback to fetch try { // Use fetch with URLSearchParams for application/x-www-form-urlencoded const params = new URLSearchParams(); @@ -485,6 +497,111 @@ } } + // Try to send message via WebSocket streaming + function tryWebSocketSend(message) { + try { + // Generate a unique message ID for this request + wsMessageId = Date.now().toString(36) + Math.random().toString(36).substr(2); + + // Connect to WebSocket + const wsUrl = window.location.protocol === 'https:' + ? `wss://${window.location.host}/chat/ws` + : `ws://${window.location.host}/chat/ws`; + + ws = new WebSocket(wsUrl); + + ws.onopen = function() { + // Send the message as JSON with message ID + const data = { + type: 'chat_request', + message_id: wsMessageId, + message: message, + conversation_id: currentConversationId + }; + ws.send(JSON.stringify(data)); + }; + + ws.onmessage = function(event) { + try { + const data = JSON.parse(event.data); + + if (data.type === 'token') { + // Stream a token to the UI + addTokenToLastMessage(data.token); + } else if (data.type === 'complete') { + // Streaming complete + closeWebSocket(); + textarea.disabled = false; + sendBtn.disabled = false; + sendBtn.textContent = 'Send'; + textarea.focus(); + messagesDiv.scrollTop = messagesDiv.scrollHeight; + loadConversations(); + } else if (data.type === 'error') { + addSystemMessage(`Error: ${data.message}`); + closeWebSocket(); + textarea.disabled = false; + sendBtn.disabled = false; + sendBtn.textContent = 'Send'; + textarea.focus(); + } + } catch (e) { + console.error('Failed to parse WebSocket message:', e); + } + }; + + ws.onerror = function(error) { + console.error('WebSocket error:', error); + addSystemMessage('WebSocket connection error. Falling back to regular chat.'); + closeWebSocket(); + sendMessage(); // Retry with fetch + }; + + ws.onclose = function() { + wsMessageId = null; + }; + + return true; // WebSocket attempt started + + } catch (error) { + console.error('Failed to create WebSocket:', error); + return false; // Fall back to fetch + } + } + + // Add a token to the last assistant message (for streaming) + function addTokenToLastMessage(token) { + const messages = messagesDiv.querySelectorAll('.message.assistant'); + if (messages.length === 0) { + // No assistant message yet, create one + const msgDiv = document.createElement('div'); + msgDiv.className = 'message assistant'; + msgDiv.innerHTML = ` +
assistant
+
+ `; + messagesDiv.appendChild(msgDiv); + } + + const lastMsg = messagesDiv.querySelector('.message.assistant .content.streaming'); + if (lastMsg) { + lastMsg.textContent += token; + messagesDiv.scrollTop = messagesDiv.scrollHeight; + } + } + + // Close WebSocket connection + function closeWebSocket() { + if (ws) { + ws.onopen = null; + ws.onmessage = null; + ws.onerror = null; + ws.onclose = null; + ws.close(); + ws = null; + } + } + // Event listeners sendBtn.addEventListener('click', sendMessage); diff --git a/nomad/jobs/edge.hcl b/nomad/jobs/edge.hcl index bf82b3d..afc57c3 100644 --- a/nomad/jobs/edge.hcl +++ b/nomad/jobs/edge.hcl @@ -172,6 +172,12 @@ EOT handle /chat/oauth/callback { reverse_proxy 127.0.0.1:8080 } + # WebSocket endpoint for streaming (#1026) + handle /chat/ws { + header_up Upgrade $http.upgrade + header_up Connection $http.connection + reverse_proxy 127.0.0.1:8080 + } # Defense-in-depth: forward_auth stamps X-Forwarded-User from session (#709) handle /chat/* { forward_auth 127.0.0.1:8080 {