diff --git a/docker/chat/server.py b/docker/chat/server.py
index ef37fb1..0623955 100644
--- a/docker/chat/server.py
+++ b/docker/chat/server.py
@@ -22,6 +22,7 @@ OAuth flow:
The claude binary is expected to be mounted from the host at /usr/local/bin/claude.
"""
+import asyncio
import datetime
import json
import os
@@ -30,8 +31,14 @@ import secrets
import subprocess
import sys
import time
+import threading
from http.server import HTTPServer, BaseHTTPRequestHandler
+from socketserver import ThreadingMixIn
from urllib.parse import urlparse, parse_qs, urlencode
+import socket
+import struct
+import base64
+import hashlib
# Configuration
HOST = os.environ.get("CHAT_HOST", "0.0.0.0")
@@ -89,6 +96,10 @@ _request_log = {}
# user -> {"tokens": int, "date": "YYYY-MM-DD"}
_daily_tokens = {}
+# WebSocket message queues per user
+# user -> asyncio.Queue (for streaming messages to connected clients)
+_websocket_queues = {}
+
# MIME types for static files
MIME_TYPES = {
".html": "text/html; charset=utf-8",
@@ -101,6 +112,17 @@ MIME_TYPES = {
".ico": "image/x-icon",
}
+# WebSocket subprotocol for chat streaming
+WEBSOCKET_SUBPROTOCOL = "chat-stream-v1"
+
+# WebSocket opcodes
+OPCODE_CONTINUATION = 0x0
+OPCODE_TEXT = 0x1
+OPCODE_BINARY = 0x2
+OPCODE_CLOSE = 0x8
+OPCODE_PING = 0x9
+OPCODE_PONG = 0xA
+
def _build_callback_uri():
"""Build the OAuth callback URI based on tunnel configuration."""
@@ -299,6 +321,307 @@ def _parse_stream_json(output):
return "".join(text_parts), total_tokens
+# =============================================================================
+# WebSocket Handler Class
+# =============================================================================
+
+class _WebSocketHandler:
+ """Handle WebSocket connections for chat streaming."""
+
+ def __init__(self, reader, writer, user, message_queue):
+ self.reader = reader
+ self.writer = writer
+ self.user = user
+ self.message_queue = message_queue
+ self.closed = False
+
+ async def accept_connection(self, sec_websocket_key, sec_websocket_protocol=None):
+ """Accept the WebSocket handshake.
+
+ The HTTP request has already been parsed by BaseHTTPRequestHandler,
+ so we use the provided key and protocol instead of re-reading from socket.
+ """
+ # Validate subprotocol
+ if sec_websocket_protocol and sec_websocket_protocol != WEBSOCKET_SUBPROTOCOL:
+ self._send_http_error(
+ 400,
+ "Bad Request",
+ f"Unsupported subprotocol. Expected: {WEBSOCKET_SUBPROTOCOL}",
+ )
+ self._close_connection()
+ return False
+
+ # Generate accept key
+ accept_key = self._generate_accept_key(sec_websocket_key)
+
+ # Send handshake response
+ response = (
+ "HTTP/1.1 101 Switching Protocols\r\n"
+ "Upgrade: websocket\r\n"
+ "Connection: Upgrade\r\n"
+ f"Sec-WebSocket-Accept: {accept_key}\r\n"
+ )
+
+ if sec_websocket_protocol:
+ response += f"Sec-WebSocket-Protocol: {sec_websocket_protocol}\r\n"
+
+ response += "\r\n"
+ self.writer.write(response.encode("utf-8"))
+ await self.writer.drain()
+ return True
+
+ def _generate_accept_key(self, sec_key):
+ """Generate the Sec-WebSocket-Accept key."""
+ GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
+ combined = sec_key + GUID
+ sha1 = hashlib.sha1(combined.encode("utf-8"))
+ return base64.b64encode(sha1.digest()).decode("utf-8")
+
+ async def _read_line(self):
+ """Read a line from the socket."""
+ data = await self.reader.read(1)
+ line = ""
+ while data:
+ if data == b"\r":
+ data = await self.reader.read(1)
+ continue
+ if data == b"\n":
+ return line
+ line += data.decode("utf-8", errors="replace")
+ data = await self.reader.read(1)
+ return line
+
+ def _send_http_error(self, code, title, message):
+ """Send an HTTP error response."""
+ response = (
+ f"HTTP/1.1 {code} {title}\r\n"
+ "Content-Type: text/plain; charset=utf-8\r\n"
+ "Content-Length: " + str(len(message)) + "\r\n"
+ "\r\n"
+ + message
+ )
+ try:
+ self.writer.write(response.encode("utf-8"))
+ self.writer.drain()
+ except Exception:
+ pass
+
+ def _close_connection(self):
+ """Close the connection."""
+ try:
+ self.writer.close()
+ except Exception:
+ pass
+
+ async def send_text(self, data):
+ """Send a text frame."""
+ if self.closed:
+ return
+ try:
+ frame = self._encode_frame(OPCODE_TEXT, data.encode("utf-8"))
+ self.writer.write(frame)
+ await self.writer.drain()
+ except Exception as e:
+ print(f"WebSocket send error: {e}", file=sys.stderr)
+
+ async def send_binary(self, data):
+ """Send a binary frame."""
+ if self.closed:
+ return
+ try:
+ if isinstance(data, str):
+ data = data.encode("utf-8")
+ frame = self._encode_frame(OPCODE_BINARY, data)
+ self.writer.write(frame)
+ await self.writer.drain()
+ except Exception as e:
+ print(f"WebSocket send error: {e}", file=sys.stderr)
+
+ def _encode_frame(self, opcode, payload):
+ """Encode a WebSocket frame."""
+ frame = bytearray()
+ frame.append(0x80 | opcode) # FIN + opcode
+
+ length = len(payload)
+ if length < 126:
+ frame.append(length)
+ elif length < 65536:
+ frame.append(126)
+ frame.extend(struct.pack(">H", length))
+ else:
+ frame.append(127)
+ frame.extend(struct.pack(">Q", length))
+
+ frame.extend(payload)
+ return bytes(frame)
+
+ async def _decode_frame(self):
+ """Decode a WebSocket frame. Returns (opcode, payload)."""
+ try:
+ # Read first two bytes (use readexactly for guaranteed length)
+ header = await self.reader.readexactly(2)
+
+ fin = (header[0] >> 7) & 1
+ opcode = header[0] & 0x0F
+ masked = (header[1] >> 7) & 1
+ length = header[1] & 0x7F
+
+ # Extended payload length
+ if length == 126:
+ ext = await self.reader.readexactly(2)
+ length = struct.unpack(">H", ext)[0]
+ elif length == 127:
+ ext = await self.reader.readexactly(8)
+ length = struct.unpack(">Q", ext)[0]
+
+ # Masking key
+ if masked:
+ mask_key = await self.reader.readexactly(4)
+
+ # Payload
+ payload = await self.reader.readexactly(length)
+
+ # Unmask if needed
+ if masked:
+ payload = bytes(b ^ mask_key[i % 4] for i, b in enumerate(payload))
+
+ return opcode, payload
+ except Exception as e:
+ print(f"WebSocket decode error: {e}", file=sys.stderr)
+ return None, None
+
+ async def handle_connection(self):
+ """Handle the WebSocket connection loop."""
+ try:
+ while not self.closed:
+ opcode, payload = await self._decode_frame()
+ if opcode is None:
+ break
+
+ if opcode == OPCODE_CLOSE:
+ await self._send_close()
+ break
+ elif opcode == OPCODE_PING:
+ await self._send_pong(payload)
+ elif opcode == OPCODE_PONG:
+ pass # Ignore pong
+ elif opcode in (OPCODE_TEXT, OPCODE_BINARY):
+ # Handle text messages from client (e.g., chat_request)
+ try:
+ msg = payload.decode("utf-8")
+ data = json.loads(msg)
+ if data.get("type") == "chat_request":
+ # Invoke Claude with the message
+ await self._handle_chat_request(data.get("message", ""))
+ except (json.JSONDecodeError, UnicodeDecodeError):
+ pass
+
+ # Check if we should stop waiting for messages
+ if self.closed:
+ break
+
+ except Exception as e:
+ print(f"WebSocket connection error: {e}", file=sys.stderr)
+ finally:
+ self._close_connection()
+ # Clean up the message queue on disconnect
+ if self.user in _websocket_queues:
+ del _websocket_queues[self.user]
+
+ async def _send_close(self):
+ """Send a close frame."""
+ try:
+ # Close code 1000 = normal closure
+ frame = self._encode_frame(OPCODE_CLOSE, struct.pack(">H", 1000))
+ self.writer.write(frame)
+ await self.writer.drain()
+ except Exception:
+ pass
+
+ async def _send_pong(self, payload):
+ """Send a pong frame."""
+ try:
+ frame = self._encode_frame(OPCODE_PONG, payload)
+ self.writer.write(frame)
+ await self.writer.drain()
+ except Exception:
+ pass
+
+ async def _handle_chat_request(self, message):
+ """Handle a chat_request WebSocket frame by invoking Claude."""
+ if not message:
+ return
+
+ # Validate Claude binary exists
+ if not os.path.exists(CLAUDE_BIN):
+ await self.send_text(json.dumps({
+ "type": "error",
+ "message": "Claude CLI not found",
+ }))
+ return
+
+ try:
+ # Spawn claude --print with stream-json for streaming output
+ proc = subprocess.Popen(
+ [CLAUDE_BIN, "--print", "--output-format", "stream-json", message],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ text=True,
+ bufsize=1,
+ )
+
+ # Stream output line by line
+ for line in iter(proc.stdout.readline, ""):
+ line = line.strip()
+ if not line:
+ continue
+ try:
+ event = json.loads(line)
+ etype = event.get("type", "")
+
+ # Extract text content from content_block_delta events
+ if etype == "content_block_delta":
+ delta = event.get("delta", {})
+ if delta.get("type") == "text_delta":
+ text = delta.get("text", "")
+ if text:
+ # Send tokens to client
+ await self.send_text(text)
+
+ # Check for usage event to know when complete
+ if etype == "result":
+ pass # Will send complete after loop
+
+ except json.JSONDecodeError:
+ pass
+
+ # Wait for process to complete
+ proc.wait()
+
+ if proc.returncode != 0:
+ await self.send_text(json.dumps({
+ "type": "error",
+ "message": f"Claude CLI failed with exit code {proc.returncode}",
+ }))
+ return
+
+ # Send complete signal
+ await self.send_text(json.dumps({
+ "type": "complete",
+ }))
+
+ except FileNotFoundError:
+ await self.send_text(json.dumps({
+ "type": "error",
+ "message": "Claude CLI not found",
+ }))
+ except Exception as e:
+ await self.send_text(json.dumps({
+ "type": "error",
+ "message": str(e),
+ }))
+
+
# =============================================================================
# Conversation History Functions (#710)
# =============================================================================
@@ -548,9 +871,9 @@ class ChatHandler(BaseHTTPRequestHandler):
self.serve_static(path)
return
- # Reserved WebSocket endpoint (future use)
- if path == "/ws" or path.startswith("/ws"):
- self.send_error_page(501, "WebSocket upgrade not yet implemented")
+ # WebSocket upgrade endpoint
+ if path == "/chat/ws" or path == "/ws" or path.startswith("/ws"):
+ self.handle_websocket_upgrade()
return
# 404 for unknown paths
@@ -759,6 +1082,7 @@ class ChatHandler(BaseHTTPRequestHandler):
"""
Handle chat requests by spawning `claude --print` with the user message.
Enforces per-user rate limits and tracks token usage (#711).
+ Streams tokens over WebSocket if connected.
"""
# Check rate limits before processing (#711)
@@ -816,10 +1140,47 @@ class ChatHandler(BaseHTTPRequestHandler):
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
+ bufsize=1, # Line buffered
)
- raw_output = proc.stdout.read()
+ # Stream output line by line
+ response_parts = []
+ total_tokens = 0
+ for line in iter(proc.stdout.readline, ""):
+ line = line.strip()
+ if not line:
+ continue
+ try:
+ event = json.loads(line)
+ etype = event.get("type", "")
+ # Extract text content from content_block_delta events
+ if etype == "content_block_delta":
+ delta = event.get("delta", {})
+ if delta.get("type") == "text_delta":
+ text = delta.get("text", "")
+ if text:
+ response_parts.append(text)
+ # Stream to WebSocket if connected
+ if user in _websocket_queues:
+ try:
+ _websocket_queues[user].put_nowait(text)
+ except Exception:
+ pass # Client disconnected
+
+ # Parse usage from result event
+ if etype == "result":
+ usage = event.get("usage", {})
+ total_tokens = usage.get("input_tokens", 0) + usage.get("output_tokens", 0)
+ elif "usage" in event:
+ usage = event["usage"]
+ if isinstance(usage, dict):
+ total_tokens = usage.get("input_tokens", 0) + usage.get("output_tokens", 0)
+
+ except json.JSONDecodeError:
+ pass
+
+ # Wait for process to complete
error_output = proc.stderr.read()
if error_output:
print(f"Claude stderr: {error_output}", file=sys.stderr)
@@ -830,8 +1191,8 @@ class ChatHandler(BaseHTTPRequestHandler):
self.send_error_page(500, f"Claude CLI failed with exit code {proc.returncode}")
return
- # Parse stream-json for text and token usage (#711)
- response, total_tokens = _parse_stream_json(raw_output)
+ # Combine response parts
+ response = "".join(response_parts)
# Track token usage - does not block *this* request (#711)
if total_tokens > 0:
@@ -843,7 +1204,7 @@ class ChatHandler(BaseHTTPRequestHandler):
# Fall back to raw output if stream-json parsing yielded no text
if not response:
- response = raw_output
+ response = proc.stdout.getvalue() if hasattr(proc.stdout, 'getvalue') else ""
# Save assistant response to history
_write_message(user, conv_id, "assistant", response)
@@ -913,6 +1274,118 @@ class ChatHandler(BaseHTTPRequestHandler):
self.end_headers()
self.wfile.write(json.dumps({"conversation_id": conv_id}, ensure_ascii=False).encode("utf-8"))
+ @staticmethod
+ def push_to_websocket(user, message):
+ """Push a message to a WebSocket connection for a user.
+
+ This is called from the chat handler to stream tokens to connected clients.
+ The message is added to the user's WebSocket message queue.
+ """
+ # Get the message queue from the WebSocket handler's queue
+ # We store the queue in a global dict keyed by user
+ if user in _websocket_queues:
+ _websocket_queues[user].put_nowait(message)
+
+ def handle_websocket_upgrade(self):
+ """Handle WebSocket upgrade request for chat streaming."""
+ # Check session cookie
+ user = _validate_session(self.headers.get("Cookie"))
+ if not user:
+ self.send_error_page(401, "Unauthorized: no valid session")
+ return
+
+ # Check rate limits before allowing WebSocket connection
+ allowed, retry_after, reason = _check_rate_limit(user)
+ if not allowed:
+ self.send_error_page(
+ 429,
+ f"Rate limit exceeded: {reason}. Retry after {retry_after}s",
+ )
+ return
+
+ # Record request for rate limiting
+ _record_request(user)
+
+ # Create message queue for this user
+ _websocket_queues[user] = asyncio.Queue()
+
+ # Get WebSocket upgrade headers from the HTTP request
+ sec_websocket_key = self.headers.get("Sec-WebSocket-Key", "")
+ sec_websocket_protocol = self.headers.get("Sec-WebSocket-Protocol", "")
+
+ # Validate Sec-WebSocket-Key
+ if not sec_websocket_key:
+ self.send_error_page(400, "Bad Request", "Missing Sec-WebSocket-Key")
+ return
+
+ # Get the socket from the connection
+ sock = self.connection
+ sock.setblocking(False)
+
+ # Create async server to handle the connection
+ async def handle_ws():
+ try:
+ # Wrap the socket in asyncio streams using open_connection
+ reader, writer = await asyncio.open_connection(sock=sock)
+
+ # Create WebSocket handler
+ ws_handler = _WebSocketHandler(reader, writer, user, _websocket_queues[user])
+
+ # Accept the connection (pass headers from HTTP request)
+ if not await ws_handler.accept_connection(sec_websocket_key, sec_websocket_protocol):
+ return
+
+ # Start a task to read from the queue and send to client
+ async def send_stream():
+ while not ws_handler.closed:
+ try:
+ data = await asyncio.wait_for(ws_handler.message_queue.get(), timeout=1.0)
+ await ws_handler.send_text(data)
+ except asyncio.TimeoutError:
+ # Send ping to keep connection alive
+ try:
+ frame = ws_handler._encode_frame(OPCODE_PING, b"")
+ writer.write(frame)
+ await writer.drain()
+ except Exception:
+ break
+ except Exception as e:
+ print(f"Send stream error: {e}", file=sys.stderr)
+ break
+
+ # Start sending task
+ send_task = asyncio.create_task(send_stream())
+
+ # Handle incoming WebSocket frames
+ await ws_handler.handle_connection()
+
+ # Cancel send task
+ send_task.cancel()
+ try:
+ await send_task
+ except asyncio.CancelledError:
+ pass
+
+ except Exception as e:
+ print(f"WebSocket handler error: {e}", file=sys.stderr)
+ finally:
+ try:
+ writer.close()
+ await writer.wait_closed()
+ except Exception:
+ pass
+
+ # Run the async handler in a thread
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+ try:
+ loop.run_until_complete(handle_ws())
+ except Exception as e:
+ print(f"WebSocket error: {e}", file=sys.stderr)
+ finally:
+ loop.close()
+ sock.close()
+
def do_DELETE(self):
"""Handle DELETE requests."""
parsed = urlparse(self.path)
diff --git a/docker/chat/ui/index.html b/docker/chat/ui/index.html
index bd920f9..b045873 100644
--- a/docker/chat/ui/index.html
+++ b/docker/chat/ui/index.html
@@ -430,6 +430,10 @@
return div.innerHTML.replace(/\n/g, '
');
}
+ // WebSocket connection for streaming
+ let ws = null;
+ let wsMessageId = null;
+
// Send message handler
async function sendMessage() {
const message = textarea.value.trim();
@@ -449,6 +453,14 @@
await createNewConversation();
}
+ // Try WebSocket streaming first, fall back to fetch
+ if (window.location.protocol === 'https:' || window.location.hostname === 'localhost') {
+ if (tryWebSocketSend(message)) {
+ return;
+ }
+ }
+
+ // Fallback to fetch
try {
// Use fetch with URLSearchParams for application/x-www-form-urlencoded
const params = new URLSearchParams();
@@ -485,6 +497,111 @@
}
}
+ // Try to send message via WebSocket streaming
+ function tryWebSocketSend(message) {
+ try {
+ // Generate a unique message ID for this request
+ wsMessageId = Date.now().toString(36) + Math.random().toString(36).substr(2);
+
+ // Connect to WebSocket
+ const wsUrl = window.location.protocol === 'https:'
+ ? `wss://${window.location.host}/chat/ws`
+ : `ws://${window.location.host}/chat/ws`;
+
+ ws = new WebSocket(wsUrl);
+
+ ws.onopen = function() {
+ // Send the message as JSON with message ID
+ const data = {
+ type: 'chat_request',
+ message_id: wsMessageId,
+ message: message,
+ conversation_id: currentConversationId
+ };
+ ws.send(JSON.stringify(data));
+ };
+
+ ws.onmessage = function(event) {
+ try {
+ const data = JSON.parse(event.data);
+
+ if (data.type === 'token') {
+ // Stream a token to the UI
+ addTokenToLastMessage(data.token);
+ } else if (data.type === 'complete') {
+ // Streaming complete
+ closeWebSocket();
+ textarea.disabled = false;
+ sendBtn.disabled = false;
+ sendBtn.textContent = 'Send';
+ textarea.focus();
+ messagesDiv.scrollTop = messagesDiv.scrollHeight;
+ loadConversations();
+ } else if (data.type === 'error') {
+ addSystemMessage(`Error: ${data.message}`);
+ closeWebSocket();
+ textarea.disabled = false;
+ sendBtn.disabled = false;
+ sendBtn.textContent = 'Send';
+ textarea.focus();
+ }
+ } catch (e) {
+ console.error('Failed to parse WebSocket message:', e);
+ }
+ };
+
+ ws.onerror = function(error) {
+ console.error('WebSocket error:', error);
+ addSystemMessage('WebSocket connection error. Falling back to regular chat.');
+ closeWebSocket();
+ sendMessage(); // Retry with fetch
+ };
+
+ ws.onclose = function() {
+ wsMessageId = null;
+ };
+
+ return true; // WebSocket attempt started
+
+ } catch (error) {
+ console.error('Failed to create WebSocket:', error);
+ return false; // Fall back to fetch
+ }
+ }
+
+ // Add a token to the last assistant message (for streaming)
+ function addTokenToLastMessage(token) {
+ const messages = messagesDiv.querySelectorAll('.message.assistant');
+ if (messages.length === 0) {
+ // No assistant message yet, create one
+ const msgDiv = document.createElement('div');
+ msgDiv.className = 'message assistant';
+ msgDiv.innerHTML = `
+