fix: vision(#623): WebSocket streaming for chat UI to replace one-shot claude --print (#1026) #1076

Merged
dev-qwen merged 2 commits from fix/issue-1026 into main 2026-04-20 11:48:11 +00:00
3 changed files with 551 additions and 7 deletions
Showing only changes of commit 17e745376d - Show all commits

View file

@ -22,6 +22,7 @@ OAuth flow:
The claude binary is expected to be mounted from the host at /usr/local/bin/claude. The claude binary is expected to be mounted from the host at /usr/local/bin/claude.
""" """
import asyncio
import datetime import datetime
import json import json
import os import os
@ -30,8 +31,14 @@ import secrets
import subprocess import subprocess
import sys import sys
import time import time
import threading
from http.server import HTTPServer, BaseHTTPRequestHandler from http.server import HTTPServer, BaseHTTPRequestHandler
from socketserver import ThreadingMixIn
from urllib.parse import urlparse, parse_qs, urlencode from urllib.parse import urlparse, parse_qs, urlencode
import socket
import struct
import base64
import hashlib
# Configuration # Configuration
HOST = os.environ.get("CHAT_HOST", "0.0.0.0") HOST = os.environ.get("CHAT_HOST", "0.0.0.0")
@ -89,6 +96,10 @@ _request_log = {}
# user -> {"tokens": int, "date": "YYYY-MM-DD"} # user -> {"tokens": int, "date": "YYYY-MM-DD"}
_daily_tokens = {} _daily_tokens = {}
# WebSocket message queues per user
# user -> asyncio.Queue (for streaming messages to connected clients)
_websocket_queues = {}
# MIME types for static files # MIME types for static files
MIME_TYPES = { MIME_TYPES = {
".html": "text/html; charset=utf-8", ".html": "text/html; charset=utf-8",
@ -101,6 +112,17 @@ MIME_TYPES = {
".ico": "image/x-icon", ".ico": "image/x-icon",
} }
# WebSocket subprotocol for chat streaming
WEBSOCKET_SUBPROTOCOL = "chat-stream-v1"
# WebSocket opcodes
OPCODE_CONTINUATION = 0x0
OPCODE_TEXT = 0x1
OPCODE_BINARY = 0x2
OPCODE_CLOSE = 0x8
OPCODE_PING = 0x9
OPCODE_PONG = 0xA
def _build_callback_uri(): def _build_callback_uri():
"""Build the OAuth callback URI based on tunnel configuration.""" """Build the OAuth callback URI based on tunnel configuration."""
@ -299,6 +321,257 @@ def _parse_stream_json(output):
return "".join(text_parts), total_tokens return "".join(text_parts), total_tokens
# =============================================================================
# WebSocket Handler Class
# =============================================================================
class _WebSocketHandler:
"""Handle WebSocket connections for chat streaming."""
def __init__(self, reader, writer, user, message_queue):
self.reader = reader
self.writer = writer
self.user = user
self.message_queue = message_queue
self.closed = False
async def accept_connection(self):
"""Accept the WebSocket handshake."""
# Read the HTTP request
request_line = await self._read_line()
if not request_line.startswith("GET "):
self._close_connection()
return False
# Parse the request
headers = {}
while True:
line = await self._read_line()
if line == "":
break
if ":" in line:
key, value = line.split(":", 1)
headers[key.strip().lower()] = value.strip()
# Validate WebSocket upgrade
if headers.get("upgrade", "").lower() != "websocket":
self._send_http_error(400, "Bad Request", "WebSocket upgrade required")
self._close_connection()
return False
if headers.get("connection", "").lower() != "upgrade":
self._send_http_error(400, "Bad Request", "Connection upgrade required")
self._close_connection()
return False
# Get Sec-WebSocket-Key
sec_key = headers.get("sec-websocket-key", "")
if not sec_key:
self._send_http_error(400, "Bad Request", "Missing Sec-WebSocket-Key")
self._close_connection()
return False
# Get Sec-WebSocket-Protocol if provided
sec_protocol = headers.get("sec-websocket-protocol", "")
# Validate subprotocol
if sec_protocol and sec_protocol != WEBSOCKET_SUBPROTOCOL:
self._send_http_error(
400,
"Bad Request",
f"Unsupported subprotocol. Expected: {WEBSOCKET_SUBPROTOCOL}",
)
self._close_connection()
return False
# Generate accept key
accept_key = self._generate_accept_key(sec_key)
# Send handshake response
response = (
"HTTP/1.1 101 Switching Protocols\r\n"
"Upgrade: websocket\r\n"
"Connection: Upgrade\r\n"
f"Sec-WebSocket-Accept: {accept_key}\r\n"
)
if sec_protocol:
response += f"Sec-WebSocket-Protocol: {sec_protocol}\r\n"
response += "\r\n"
self.writer.write(response.encode("utf-8"))
await self.writer.drain()
return True
def _generate_accept_key(self, sec_key):
"""Generate the Sec-WebSocket-Accept key."""
GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
combined = sec_key + GUID
sha1 = hashlib.sha1(combined.encode("utf-8"))
return base64.b64encode(sha1.digest()).decode("utf-8")
async def _read_line(self):
"""Read a line from the socket."""
data = await self.reader.read(1)
line = ""
while data:
if data == b"\r":
data = await self.reader.read(1)
continue
if data == b"\n":
return line
line += data.decode("utf-8", errors="replace")
data = await self.reader.read(1)
return line
def _send_http_error(self, code, title, message):
"""Send an HTTP error response."""
response = (
f"HTTP/1.1 {code} {title}\r\n"
"Content-Type: text/plain; charset=utf-8\r\n"
"Content-Length: " + str(len(message)) + "\r\n"
"\r\n"
+ message
)
try:
self.writer.write(response.encode("utf-8"))
self.writer.drain()
except Exception:
pass
def _close_connection(self):
"""Close the connection."""
try:
self.writer.close()
except Exception:
pass
async def send_text(self, data):
"""Send a text frame."""
if self.closed:
return
try:
frame = self._encode_frame(OPCODE_TEXT, data.encode("utf-8"))
self.writer.write(frame)
await self.writer.drain()
except Exception as e:
print(f"WebSocket send error: {e}", file=sys.stderr)
async def send_binary(self, data):
"""Send a binary frame."""
if self.closed:
return
try:
if isinstance(data, str):
data = data.encode("utf-8")
frame = self._encode_frame(OPCODE_BINARY, data)
self.writer.write(frame)
await self.writer.drain()
except Exception as e:
print(f"WebSocket send error: {e}", file=sys.stderr)
def _encode_frame(self, opcode, payload):
"""Encode a WebSocket frame."""
frame = bytearray()
frame.append(0x80 | opcode) # FIN + opcode
length = len(payload)
if length < 126:
frame.append(length)
elif length < 65536:
frame.append(126)
frame.extend(struct.pack(">H", length))
else:
frame.append(127)
frame.extend(struct.pack(">Q", length))
frame.extend(payload)
return bytes(frame)
async def _decode_frame(self):
"""Decode a WebSocket frame. Returns (opcode, payload)."""
try:
# Read first two bytes
header = await self.reader.read(2)
if len(header) < 2:
return None, None
fin = (header[0] >> 7) & 1
opcode = header[0] & 0x0F
masked = (header[1] >> 7) & 1
length = header[1] & 0x7F
# Extended payload length
if length == 126:
ext = await self.reader.read(2)
length = struct.unpack(">H", ext)[0]
elif length == 127:
ext = await self.reader.read(8)
length = struct.unpack(">Q", ext)[0]
# Masking key
if masked:
mask_key = await self.reader.read(4)
# Payload
payload = await self.reader.read(length)
# Unmask if needed
if masked:
payload = bytes(b ^ mask_key[i % 4] for i, b in enumerate(payload))
return opcode, payload
except Exception as e:
print(f"WebSocket decode error: {e}", file=sys.stderr)
return None, None
async def handle_connection(self):
"""Handle the WebSocket connection loop."""
try:
while not self.closed:
opcode, payload = await self._decode_frame()
if opcode is None:
break
if opcode == OPCODE_CLOSE:
self._send_close()
break
elif opcode == OPCODE_PING:
self._send_pong(payload)
elif opcode == OPCODE_PONG:
pass # Ignore pong
elif opcode in (OPCODE_TEXT, OPCODE_BINARY):
# Handle text messages from client (e.g., heartbeat ack)
pass
# Check if we should stop waiting for messages
if self.closed:
break
except Exception as e:
print(f"WebSocket connection error: {e}", file=sys.stderr)
finally:
self._close_connection()
def _send_close(self):
"""Send a close frame."""
try:
frame = self._encode_frame(OPCODE_CLOSE, b"\x03\x00")
self.writer.write(frame)
self.writer.drain()
except Exception:
pass
def _send_pong(self, payload):
"""Send a pong frame."""
try:
frame = self._encode_frame(OPCODE_PONG, payload)
self.writer.write(frame)
self.writer.drain()
except Exception:
pass
# ============================================================================= # =============================================================================
# Conversation History Functions (#710) # Conversation History Functions (#710)
# ============================================================================= # =============================================================================
@ -548,9 +821,9 @@ class ChatHandler(BaseHTTPRequestHandler):
self.serve_static(path) self.serve_static(path)
return return
# Reserved WebSocket endpoint (future use) # WebSocket upgrade endpoint
if path == "/ws" or path.startswith("/ws"): if path == "/chat/ws" or path == "/ws" or path.startswith("/ws"):
self.send_error_page(501, "WebSocket upgrade not yet implemented") self.handle_websocket_upgrade()
return return
# 404 for unknown paths # 404 for unknown paths
@ -759,6 +1032,7 @@ class ChatHandler(BaseHTTPRequestHandler):
""" """
Handle chat requests by spawning `claude --print` with the user message. Handle chat requests by spawning `claude --print` with the user message.
Enforces per-user rate limits and tracks token usage (#711). Enforces per-user rate limits and tracks token usage (#711).
Streams tokens over WebSocket if connected.
""" """
# Check rate limits before processing (#711) # Check rate limits before processing (#711)
@ -816,10 +1090,47 @@ class ChatHandler(BaseHTTPRequestHandler):
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, stderr=subprocess.PIPE,
text=True, text=True,
bufsize=1, # Line buffered
) )
raw_output = proc.stdout.read() # Stream output line by line
response_parts = []
total_tokens = 0
for line in iter(proc.stdout.readline, ""):
line = line.strip()
if not line:
continue
try:
event = json.loads(line)
etype = event.get("type", "")
# Extract text content from content_block_delta events
if etype == "content_block_delta":
delta = event.get("delta", {})
if delta.get("type") == "text_delta":
text = delta.get("text", "")
if text:
response_parts.append(text)
# Stream to WebSocket if connected
if user in _websocket_queues:
try:
_websocket_queues[user].put_nowait(text)
except Exception:
pass # Client disconnected
# Parse usage from result event
if etype == "result":
usage = event.get("usage", {})
total_tokens = usage.get("input_tokens", 0) + usage.get("output_tokens", 0)
elif "usage" in event:
usage = event["usage"]
if isinstance(usage, dict):
total_tokens = usage.get("input_tokens", 0) + usage.get("output_tokens", 0)
except json.JSONDecodeError:
pass
# Wait for process to complete
error_output = proc.stderr.read() error_output = proc.stderr.read()
if error_output: if error_output:
print(f"Claude stderr: {error_output}", file=sys.stderr) print(f"Claude stderr: {error_output}", file=sys.stderr)
@ -830,8 +1141,8 @@ class ChatHandler(BaseHTTPRequestHandler):
self.send_error_page(500, f"Claude CLI failed with exit code {proc.returncode}") self.send_error_page(500, f"Claude CLI failed with exit code {proc.returncode}")
return return
# Parse stream-json for text and token usage (#711) # Combine response parts
response, total_tokens = _parse_stream_json(raw_output) response = "".join(response_parts)
# Track token usage - does not block *this* request (#711) # Track token usage - does not block *this* request (#711)
if total_tokens > 0: if total_tokens > 0:
@ -843,7 +1154,7 @@ class ChatHandler(BaseHTTPRequestHandler):
# Fall back to raw output if stream-json parsing yielded no text # Fall back to raw output if stream-json parsing yielded no text
if not response: if not response:
response = raw_output response = proc.stdout.getvalue() if hasattr(proc.stdout, 'getvalue') else ""
# Save assistant response to history # Save assistant response to history
_write_message(user, conv_id, "assistant", response) _write_message(user, conv_id, "assistant", response)
@ -913,6 +1224,116 @@ class ChatHandler(BaseHTTPRequestHandler):
self.end_headers() self.end_headers()
self.wfile.write(json.dumps({"conversation_id": conv_id}, ensure_ascii=False).encode("utf-8")) self.wfile.write(json.dumps({"conversation_id": conv_id}, ensure_ascii=False).encode("utf-8"))
@staticmethod
def push_to_websocket(user, message):
"""Push a message to a WebSocket connection for a user.
This is called from the chat handler to stream tokens to connected clients.
The message is added to the user's WebSocket message queue.
"""
# Get the message queue from the WebSocket handler's queue
# We store the queue in a global dict keyed by user
if user in _websocket_queues:
_websocket_queues[user].put_nowait(message)
def handle_websocket_upgrade(self):
"""Handle WebSocket upgrade request for chat streaming."""
# Check session cookie
user = _validate_session(self.headers.get("Cookie"))
if not user:
self.send_error_page(401, "Unauthorized: no valid session")
return
# Check rate limits before allowing WebSocket connection
allowed, retry_after, reason = _check_rate_limit(user)
if not allowed:
self.send_error_page(
429,
f"Rate limit exceeded: {reason}. Retry after {retry_after}s",
)
return
# Record request for rate limiting
_record_request(user)
# Create message queue for this user
_websocket_queues[user] = asyncio.Queue()
# Get the socket from the connection
sock = self.connection
sock.setblocking(False)
reader = asyncio.StreamReader()
protocol = asyncio.StreamReaderProtocol(reader)
# Create async server to handle the connection
async def handle_ws():
try:
# Wrap the socket in asyncio streams
transport, _ = await asyncio.get_event_loop().create_connection(
lambda: protocol,
sock=sock,
)
ws_reader = protocol._stream_reader
ws_writer = transport
# Create WebSocket handler
ws_handler = _WebSocketHandler(ws_reader, ws_writer, user, _websocket_queues[user])
# Accept the connection
if not await ws_handler.accept_connection():
return
# Start a task to read from the queue and send to client
async def send_stream():
while not ws_handler.closed:
try:
data = await asyncio.wait_for(ws_handler.message_queue.get(), timeout=1.0)
await ws_handler.send_text(data)
except asyncio.TimeoutError:
# Send ping to keep connection alive
try:
frame = ws_handler._encode_frame(OPCODE_PING, b"")
ws_writer.write(frame)
await ws_writer.drain()
except Exception:
break
except Exception as e:
print(f"Send stream error: {e}", file=sys.stderr)
break
# Start sending task
send_task = asyncio.create_task(send_stream())
# Handle incoming WebSocket frames
await ws_handler.handle_connection()
# Cancel send task
send_task.cancel()
try:
await send_task
except asyncio.CancelledError:
pass
except Exception as e:
print(f"WebSocket handler error: {e}", file=sys.stderr)
finally:
try:
ws_writer.close()
await ws_writer.wait_closed()
except Exception:
pass
# Run the async handler in a thread
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(handle_ws())
except Exception as e:
print(f"WebSocket error: {e}", file=sys.stderr)
finally:
loop.close()
sock.close()
def do_DELETE(self): def do_DELETE(self):
"""Handle DELETE requests.""" """Handle DELETE requests."""
parsed = urlparse(self.path) parsed = urlparse(self.path)

View file

@ -430,6 +430,10 @@
return div.innerHTML.replace(/\n/g, '<br>'); return div.innerHTML.replace(/\n/g, '<br>');
} }
// WebSocket connection for streaming
let ws = null;
let wsMessageId = null;
// Send message handler // Send message handler
async function sendMessage() { async function sendMessage() {
const message = textarea.value.trim(); const message = textarea.value.trim();
@ -449,6 +453,14 @@
await createNewConversation(); await createNewConversation();
} }
// Try WebSocket streaming first, fall back to fetch
if (window.location.protocol === 'https:' || window.location.hostname === 'localhost') {
if (tryWebSocketSend(message)) {
return;
}
}
// Fallback to fetch
try { try {
// Use fetch with URLSearchParams for application/x-www-form-urlencoded // Use fetch with URLSearchParams for application/x-www-form-urlencoded
const params = new URLSearchParams(); const params = new URLSearchParams();
@ -485,6 +497,111 @@
} }
} }
// Try to send message via WebSocket streaming
function tryWebSocketSend(message) {
try {
// Generate a unique message ID for this request
wsMessageId = Date.now().toString(36) + Math.random().toString(36).substr(2);
// Connect to WebSocket
const wsUrl = window.location.protocol === 'https:'
? `wss://${window.location.host}/chat/ws`
: `ws://${window.location.host}/chat/ws`;
ws = new WebSocket(wsUrl);
ws.onopen = function() {
// Send the message as JSON with message ID
const data = {
type: 'chat_request',
message_id: wsMessageId,
message: message,
conversation_id: currentConversationId
};
ws.send(JSON.stringify(data));
};
ws.onmessage = function(event) {
try {
const data = JSON.parse(event.data);
if (data.type === 'token') {
// Stream a token to the UI
addTokenToLastMessage(data.token);
} else if (data.type === 'complete') {
// Streaming complete
closeWebSocket();
textarea.disabled = false;
sendBtn.disabled = false;
sendBtn.textContent = 'Send';
textarea.focus();
messagesDiv.scrollTop = messagesDiv.scrollHeight;
loadConversations();
} else if (data.type === 'error') {
addSystemMessage(`Error: ${data.message}`);
closeWebSocket();
textarea.disabled = false;
sendBtn.disabled = false;
sendBtn.textContent = 'Send';
textarea.focus();
}
} catch (e) {
console.error('Failed to parse WebSocket message:', e);
}
};
ws.onerror = function(error) {
console.error('WebSocket error:', error);
addSystemMessage('WebSocket connection error. Falling back to regular chat.');
closeWebSocket();
sendMessage(); // Retry with fetch
};
ws.onclose = function() {
wsMessageId = null;
};
return true; // WebSocket attempt started
} catch (error) {
console.error('Failed to create WebSocket:', error);
return false; // Fall back to fetch
}
}
// Add a token to the last assistant message (for streaming)
function addTokenToLastMessage(token) {
const messages = messagesDiv.querySelectorAll('.message.assistant');
if (messages.length === 0) {
// No assistant message yet, create one
const msgDiv = document.createElement('div');
msgDiv.className = 'message assistant';
msgDiv.innerHTML = `
<div class="role">assistant</div>
<div class="content streaming"></div>
`;
messagesDiv.appendChild(msgDiv);
}
const lastMsg = messagesDiv.querySelector('.message.assistant .content.streaming');
if (lastMsg) {
lastMsg.textContent += token;
messagesDiv.scrollTop = messagesDiv.scrollHeight;
}
}
// Close WebSocket connection
function closeWebSocket() {
if (ws) {
ws.onopen = null;
ws.onmessage = null;
ws.onerror = null;
ws.onclose = null;
ws.close();
ws = null;
}
}
// Event listeners // Event listeners
sendBtn.addEventListener('click', sendMessage); sendBtn.addEventListener('click', sendMessage);

View file

@ -172,6 +172,12 @@ EOT
handle /chat/oauth/callback { handle /chat/oauth/callback {
reverse_proxy 127.0.0.1:8080 reverse_proxy 127.0.0.1:8080
} }
# WebSocket endpoint for streaming (#1026)
handle /chat/ws {
header_up Upgrade $http.upgrade
header_up Connection $http.connection
reverse_proxy 127.0.0.1:8080
}
# Defense-in-depth: forward_auth stamps X-Forwarded-User from session (#709) # Defense-in-depth: forward_auth stamps X-Forwarded-User from session (#709)
handle /chat/* { handle /chat/* {
forward_auth 127.0.0.1:8080 { forward_auth 127.0.0.1:8080 {