Merge pull request 'fix: vision(#623): WebSocket streaming for chat UI to replace one-shot claude --print (#1026)' (#1076) from fix/issue-1026 into main
This commit is contained in:
commit
abca547dcc
3 changed files with 603 additions and 7 deletions
|
|
@ -22,6 +22,7 @@ OAuth flow:
|
||||||
The claude binary is expected to be mounted from the host at /usr/local/bin/claude.
|
The claude binary is expected to be mounted from the host at /usr/local/bin/claude.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
@ -30,8 +31,14 @@ import secrets
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
import threading
|
||||||
from http.server import HTTPServer, BaseHTTPRequestHandler
|
from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||||
|
from socketserver import ThreadingMixIn
|
||||||
from urllib.parse import urlparse, parse_qs, urlencode
|
from urllib.parse import urlparse, parse_qs, urlencode
|
||||||
|
import socket
|
||||||
|
import struct
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
|
||||||
# Configuration
|
# Configuration
|
||||||
HOST = os.environ.get("CHAT_HOST", "0.0.0.0")
|
HOST = os.environ.get("CHAT_HOST", "0.0.0.0")
|
||||||
|
|
@ -89,6 +96,10 @@ _request_log = {}
|
||||||
# user -> {"tokens": int, "date": "YYYY-MM-DD"}
|
# user -> {"tokens": int, "date": "YYYY-MM-DD"}
|
||||||
_daily_tokens = {}
|
_daily_tokens = {}
|
||||||
|
|
||||||
|
# WebSocket message queues per user
|
||||||
|
# user -> asyncio.Queue (for streaming messages to connected clients)
|
||||||
|
_websocket_queues = {}
|
||||||
|
|
||||||
# MIME types for static files
|
# MIME types for static files
|
||||||
MIME_TYPES = {
|
MIME_TYPES = {
|
||||||
".html": "text/html; charset=utf-8",
|
".html": "text/html; charset=utf-8",
|
||||||
|
|
@ -101,6 +112,17 @@ MIME_TYPES = {
|
||||||
".ico": "image/x-icon",
|
".ico": "image/x-icon",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# WebSocket subprotocol for chat streaming
|
||||||
|
WEBSOCKET_SUBPROTOCOL = "chat-stream-v1"
|
||||||
|
|
||||||
|
# WebSocket opcodes
|
||||||
|
OPCODE_CONTINUATION = 0x0
|
||||||
|
OPCODE_TEXT = 0x1
|
||||||
|
OPCODE_BINARY = 0x2
|
||||||
|
OPCODE_CLOSE = 0x8
|
||||||
|
OPCODE_PING = 0x9
|
||||||
|
OPCODE_PONG = 0xA
|
||||||
|
|
||||||
|
|
||||||
def _build_callback_uri():
|
def _build_callback_uri():
|
||||||
"""Build the OAuth callback URI based on tunnel configuration."""
|
"""Build the OAuth callback URI based on tunnel configuration."""
|
||||||
|
|
@ -299,6 +321,307 @@ def _parse_stream_json(output):
|
||||||
return "".join(text_parts), total_tokens
|
return "".join(text_parts), total_tokens
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# WebSocket Handler Class
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
class _WebSocketHandler:
|
||||||
|
"""Handle WebSocket connections for chat streaming."""
|
||||||
|
|
||||||
|
def __init__(self, reader, writer, user, message_queue):
|
||||||
|
self.reader = reader
|
||||||
|
self.writer = writer
|
||||||
|
self.user = user
|
||||||
|
self.message_queue = message_queue
|
||||||
|
self.closed = False
|
||||||
|
|
||||||
|
async def accept_connection(self, sec_websocket_key, sec_websocket_protocol=None):
|
||||||
|
"""Accept the WebSocket handshake.
|
||||||
|
|
||||||
|
The HTTP request has already been parsed by BaseHTTPRequestHandler,
|
||||||
|
so we use the provided key and protocol instead of re-reading from socket.
|
||||||
|
"""
|
||||||
|
# Validate subprotocol
|
||||||
|
if sec_websocket_protocol and sec_websocket_protocol != WEBSOCKET_SUBPROTOCOL:
|
||||||
|
self._send_http_error(
|
||||||
|
400,
|
||||||
|
"Bad Request",
|
||||||
|
f"Unsupported subprotocol. Expected: {WEBSOCKET_SUBPROTOCOL}",
|
||||||
|
)
|
||||||
|
self._close_connection()
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Generate accept key
|
||||||
|
accept_key = self._generate_accept_key(sec_websocket_key)
|
||||||
|
|
||||||
|
# Send handshake response
|
||||||
|
response = (
|
||||||
|
"HTTP/1.1 101 Switching Protocols\r\n"
|
||||||
|
"Upgrade: websocket\r\n"
|
||||||
|
"Connection: Upgrade\r\n"
|
||||||
|
f"Sec-WebSocket-Accept: {accept_key}\r\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
if sec_websocket_protocol:
|
||||||
|
response += f"Sec-WebSocket-Protocol: {sec_websocket_protocol}\r\n"
|
||||||
|
|
||||||
|
response += "\r\n"
|
||||||
|
self.writer.write(response.encode("utf-8"))
|
||||||
|
await self.writer.drain()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _generate_accept_key(self, sec_key):
|
||||||
|
"""Generate the Sec-WebSocket-Accept key."""
|
||||||
|
GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
||||||
|
combined = sec_key + GUID
|
||||||
|
sha1 = hashlib.sha1(combined.encode("utf-8"))
|
||||||
|
return base64.b64encode(sha1.digest()).decode("utf-8")
|
||||||
|
|
||||||
|
async def _read_line(self):
|
||||||
|
"""Read a line from the socket."""
|
||||||
|
data = await self.reader.read(1)
|
||||||
|
line = ""
|
||||||
|
while data:
|
||||||
|
if data == b"\r":
|
||||||
|
data = await self.reader.read(1)
|
||||||
|
continue
|
||||||
|
if data == b"\n":
|
||||||
|
return line
|
||||||
|
line += data.decode("utf-8", errors="replace")
|
||||||
|
data = await self.reader.read(1)
|
||||||
|
return line
|
||||||
|
|
||||||
|
def _send_http_error(self, code, title, message):
|
||||||
|
"""Send an HTTP error response."""
|
||||||
|
response = (
|
||||||
|
f"HTTP/1.1 {code} {title}\r\n"
|
||||||
|
"Content-Type: text/plain; charset=utf-8\r\n"
|
||||||
|
"Content-Length: " + str(len(message)) + "\r\n"
|
||||||
|
"\r\n"
|
||||||
|
+ message
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
self.writer.write(response.encode("utf-8"))
|
||||||
|
self.writer.drain()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _close_connection(self):
|
||||||
|
"""Close the connection."""
|
||||||
|
try:
|
||||||
|
self.writer.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def send_text(self, data):
|
||||||
|
"""Send a text frame."""
|
||||||
|
if self.closed:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
frame = self._encode_frame(OPCODE_TEXT, data.encode("utf-8"))
|
||||||
|
self.writer.write(frame)
|
||||||
|
await self.writer.drain()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"WebSocket send error: {e}", file=sys.stderr)
|
||||||
|
|
||||||
|
async def send_binary(self, data):
|
||||||
|
"""Send a binary frame."""
|
||||||
|
if self.closed:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
if isinstance(data, str):
|
||||||
|
data = data.encode("utf-8")
|
||||||
|
frame = self._encode_frame(OPCODE_BINARY, data)
|
||||||
|
self.writer.write(frame)
|
||||||
|
await self.writer.drain()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"WebSocket send error: {e}", file=sys.stderr)
|
||||||
|
|
||||||
|
def _encode_frame(self, opcode, payload):
|
||||||
|
"""Encode a WebSocket frame."""
|
||||||
|
frame = bytearray()
|
||||||
|
frame.append(0x80 | opcode) # FIN + opcode
|
||||||
|
|
||||||
|
length = len(payload)
|
||||||
|
if length < 126:
|
||||||
|
frame.append(length)
|
||||||
|
elif length < 65536:
|
||||||
|
frame.append(126)
|
||||||
|
frame.extend(struct.pack(">H", length))
|
||||||
|
else:
|
||||||
|
frame.append(127)
|
||||||
|
frame.extend(struct.pack(">Q", length))
|
||||||
|
|
||||||
|
frame.extend(payload)
|
||||||
|
return bytes(frame)
|
||||||
|
|
||||||
|
async def _decode_frame(self):
|
||||||
|
"""Decode a WebSocket frame. Returns (opcode, payload)."""
|
||||||
|
try:
|
||||||
|
# Read first two bytes (use readexactly for guaranteed length)
|
||||||
|
header = await self.reader.readexactly(2)
|
||||||
|
|
||||||
|
fin = (header[0] >> 7) & 1
|
||||||
|
opcode = header[0] & 0x0F
|
||||||
|
masked = (header[1] >> 7) & 1
|
||||||
|
length = header[1] & 0x7F
|
||||||
|
|
||||||
|
# Extended payload length
|
||||||
|
if length == 126:
|
||||||
|
ext = await self.reader.readexactly(2)
|
||||||
|
length = struct.unpack(">H", ext)[0]
|
||||||
|
elif length == 127:
|
||||||
|
ext = await self.reader.readexactly(8)
|
||||||
|
length = struct.unpack(">Q", ext)[0]
|
||||||
|
|
||||||
|
# Masking key
|
||||||
|
if masked:
|
||||||
|
mask_key = await self.reader.readexactly(4)
|
||||||
|
|
||||||
|
# Payload
|
||||||
|
payload = await self.reader.readexactly(length)
|
||||||
|
|
||||||
|
# Unmask if needed
|
||||||
|
if masked:
|
||||||
|
payload = bytes(b ^ mask_key[i % 4] for i, b in enumerate(payload))
|
||||||
|
|
||||||
|
return opcode, payload
|
||||||
|
except Exception as e:
|
||||||
|
print(f"WebSocket decode error: {e}", file=sys.stderr)
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
async def handle_connection(self):
|
||||||
|
"""Handle the WebSocket connection loop."""
|
||||||
|
try:
|
||||||
|
while not self.closed:
|
||||||
|
opcode, payload = await self._decode_frame()
|
||||||
|
if opcode is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
if opcode == OPCODE_CLOSE:
|
||||||
|
await self._send_close()
|
||||||
|
break
|
||||||
|
elif opcode == OPCODE_PING:
|
||||||
|
await self._send_pong(payload)
|
||||||
|
elif opcode == OPCODE_PONG:
|
||||||
|
pass # Ignore pong
|
||||||
|
elif opcode in (OPCODE_TEXT, OPCODE_BINARY):
|
||||||
|
# Handle text messages from client (e.g., chat_request)
|
||||||
|
try:
|
||||||
|
msg = payload.decode("utf-8")
|
||||||
|
data = json.loads(msg)
|
||||||
|
if data.get("type") == "chat_request":
|
||||||
|
# Invoke Claude with the message
|
||||||
|
await self._handle_chat_request(data.get("message", ""))
|
||||||
|
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Check if we should stop waiting for messages
|
||||||
|
if self.closed:
|
||||||
|
break
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"WebSocket connection error: {e}", file=sys.stderr)
|
||||||
|
finally:
|
||||||
|
self._close_connection()
|
||||||
|
# Clean up the message queue on disconnect
|
||||||
|
if self.user in _websocket_queues:
|
||||||
|
del _websocket_queues[self.user]
|
||||||
|
|
||||||
|
async def _send_close(self):
|
||||||
|
"""Send a close frame."""
|
||||||
|
try:
|
||||||
|
# Close code 1000 = normal closure
|
||||||
|
frame = self._encode_frame(OPCODE_CLOSE, struct.pack(">H", 1000))
|
||||||
|
self.writer.write(frame)
|
||||||
|
await self.writer.drain()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _send_pong(self, payload):
|
||||||
|
"""Send a pong frame."""
|
||||||
|
try:
|
||||||
|
frame = self._encode_frame(OPCODE_PONG, payload)
|
||||||
|
self.writer.write(frame)
|
||||||
|
await self.writer.drain()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _handle_chat_request(self, message):
|
||||||
|
"""Handle a chat_request WebSocket frame by invoking Claude."""
|
||||||
|
if not message:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Validate Claude binary exists
|
||||||
|
if not os.path.exists(CLAUDE_BIN):
|
||||||
|
await self.send_text(json.dumps({
|
||||||
|
"type": "error",
|
||||||
|
"message": "Claude CLI not found",
|
||||||
|
}))
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Spawn claude --print with stream-json for streaming output
|
||||||
|
proc = subprocess.Popen(
|
||||||
|
[CLAUDE_BIN, "--print", "--output-format", "stream-json", message],
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
text=True,
|
||||||
|
bufsize=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stream output line by line
|
||||||
|
for line in iter(proc.stdout.readline, ""):
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
event = json.loads(line)
|
||||||
|
etype = event.get("type", "")
|
||||||
|
|
||||||
|
# Extract text content from content_block_delta events
|
||||||
|
if etype == "content_block_delta":
|
||||||
|
delta = event.get("delta", {})
|
||||||
|
if delta.get("type") == "text_delta":
|
||||||
|
text = delta.get("text", "")
|
||||||
|
if text:
|
||||||
|
# Send tokens to client
|
||||||
|
await self.send_text(text)
|
||||||
|
|
||||||
|
# Check for usage event to know when complete
|
||||||
|
if etype == "result":
|
||||||
|
pass # Will send complete after loop
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Wait for process to complete
|
||||||
|
proc.wait()
|
||||||
|
|
||||||
|
if proc.returncode != 0:
|
||||||
|
await self.send_text(json.dumps({
|
||||||
|
"type": "error",
|
||||||
|
"message": f"Claude CLI failed with exit code {proc.returncode}",
|
||||||
|
}))
|
||||||
|
return
|
||||||
|
|
||||||
|
# Send complete signal
|
||||||
|
await self.send_text(json.dumps({
|
||||||
|
"type": "complete",
|
||||||
|
}))
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
await self.send_text(json.dumps({
|
||||||
|
"type": "error",
|
||||||
|
"message": "Claude CLI not found",
|
||||||
|
}))
|
||||||
|
except Exception as e:
|
||||||
|
await self.send_text(json.dumps({
|
||||||
|
"type": "error",
|
||||||
|
"message": str(e),
|
||||||
|
}))
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Conversation History Functions (#710)
|
# Conversation History Functions (#710)
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
@ -548,9 +871,9 @@ class ChatHandler(BaseHTTPRequestHandler):
|
||||||
self.serve_static(path)
|
self.serve_static(path)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Reserved WebSocket endpoint (future use)
|
# WebSocket upgrade endpoint
|
||||||
if path == "/ws" or path.startswith("/ws"):
|
if path == "/chat/ws" or path == "/ws" or path.startswith("/ws"):
|
||||||
self.send_error_page(501, "WebSocket upgrade not yet implemented")
|
self.handle_websocket_upgrade()
|
||||||
return
|
return
|
||||||
|
|
||||||
# 404 for unknown paths
|
# 404 for unknown paths
|
||||||
|
|
@ -759,6 +1082,7 @@ class ChatHandler(BaseHTTPRequestHandler):
|
||||||
"""
|
"""
|
||||||
Handle chat requests by spawning `claude --print` with the user message.
|
Handle chat requests by spawning `claude --print` with the user message.
|
||||||
Enforces per-user rate limits and tracks token usage (#711).
|
Enforces per-user rate limits and tracks token usage (#711).
|
||||||
|
Streams tokens over WebSocket if connected.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Check rate limits before processing (#711)
|
# Check rate limits before processing (#711)
|
||||||
|
|
@ -816,10 +1140,47 @@ class ChatHandler(BaseHTTPRequestHandler):
|
||||||
stdout=subprocess.PIPE,
|
stdout=subprocess.PIPE,
|
||||||
stderr=subprocess.PIPE,
|
stderr=subprocess.PIPE,
|
||||||
text=True,
|
text=True,
|
||||||
|
bufsize=1, # Line buffered
|
||||||
)
|
)
|
||||||
|
|
||||||
raw_output = proc.stdout.read()
|
# Stream output line by line
|
||||||
|
response_parts = []
|
||||||
|
total_tokens = 0
|
||||||
|
for line in iter(proc.stdout.readline, ""):
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
event = json.loads(line)
|
||||||
|
etype = event.get("type", "")
|
||||||
|
|
||||||
|
# Extract text content from content_block_delta events
|
||||||
|
if etype == "content_block_delta":
|
||||||
|
delta = event.get("delta", {})
|
||||||
|
if delta.get("type") == "text_delta":
|
||||||
|
text = delta.get("text", "")
|
||||||
|
if text:
|
||||||
|
response_parts.append(text)
|
||||||
|
# Stream to WebSocket if connected
|
||||||
|
if user in _websocket_queues:
|
||||||
|
try:
|
||||||
|
_websocket_queues[user].put_nowait(text)
|
||||||
|
except Exception:
|
||||||
|
pass # Client disconnected
|
||||||
|
|
||||||
|
# Parse usage from result event
|
||||||
|
if etype == "result":
|
||||||
|
usage = event.get("usage", {})
|
||||||
|
total_tokens = usage.get("input_tokens", 0) + usage.get("output_tokens", 0)
|
||||||
|
elif "usage" in event:
|
||||||
|
usage = event["usage"]
|
||||||
|
if isinstance(usage, dict):
|
||||||
|
total_tokens = usage.get("input_tokens", 0) + usage.get("output_tokens", 0)
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Wait for process to complete
|
||||||
error_output = proc.stderr.read()
|
error_output = proc.stderr.read()
|
||||||
if error_output:
|
if error_output:
|
||||||
print(f"Claude stderr: {error_output}", file=sys.stderr)
|
print(f"Claude stderr: {error_output}", file=sys.stderr)
|
||||||
|
|
@ -830,8 +1191,8 @@ class ChatHandler(BaseHTTPRequestHandler):
|
||||||
self.send_error_page(500, f"Claude CLI failed with exit code {proc.returncode}")
|
self.send_error_page(500, f"Claude CLI failed with exit code {proc.returncode}")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Parse stream-json for text and token usage (#711)
|
# Combine response parts
|
||||||
response, total_tokens = _parse_stream_json(raw_output)
|
response = "".join(response_parts)
|
||||||
|
|
||||||
# Track token usage - does not block *this* request (#711)
|
# Track token usage - does not block *this* request (#711)
|
||||||
if total_tokens > 0:
|
if total_tokens > 0:
|
||||||
|
|
@ -843,7 +1204,7 @@ class ChatHandler(BaseHTTPRequestHandler):
|
||||||
|
|
||||||
# Fall back to raw output if stream-json parsing yielded no text
|
# Fall back to raw output if stream-json parsing yielded no text
|
||||||
if not response:
|
if not response:
|
||||||
response = raw_output
|
response = proc.stdout.getvalue() if hasattr(proc.stdout, 'getvalue') else ""
|
||||||
|
|
||||||
# Save assistant response to history
|
# Save assistant response to history
|
||||||
_write_message(user, conv_id, "assistant", response)
|
_write_message(user, conv_id, "assistant", response)
|
||||||
|
|
@ -913,6 +1274,118 @@ class ChatHandler(BaseHTTPRequestHandler):
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
self.wfile.write(json.dumps({"conversation_id": conv_id}, ensure_ascii=False).encode("utf-8"))
|
self.wfile.write(json.dumps({"conversation_id": conv_id}, ensure_ascii=False).encode("utf-8"))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def push_to_websocket(user, message):
|
||||||
|
"""Push a message to a WebSocket connection for a user.
|
||||||
|
|
||||||
|
This is called from the chat handler to stream tokens to connected clients.
|
||||||
|
The message is added to the user's WebSocket message queue.
|
||||||
|
"""
|
||||||
|
# Get the message queue from the WebSocket handler's queue
|
||||||
|
# We store the queue in a global dict keyed by user
|
||||||
|
if user in _websocket_queues:
|
||||||
|
_websocket_queues[user].put_nowait(message)
|
||||||
|
|
||||||
|
def handle_websocket_upgrade(self):
|
||||||
|
"""Handle WebSocket upgrade request for chat streaming."""
|
||||||
|
# Check session cookie
|
||||||
|
user = _validate_session(self.headers.get("Cookie"))
|
||||||
|
if not user:
|
||||||
|
self.send_error_page(401, "Unauthorized: no valid session")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check rate limits before allowing WebSocket connection
|
||||||
|
allowed, retry_after, reason = _check_rate_limit(user)
|
||||||
|
if not allowed:
|
||||||
|
self.send_error_page(
|
||||||
|
429,
|
||||||
|
f"Rate limit exceeded: {reason}. Retry after {retry_after}s",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Record request for rate limiting
|
||||||
|
_record_request(user)
|
||||||
|
|
||||||
|
# Create message queue for this user
|
||||||
|
_websocket_queues[user] = asyncio.Queue()
|
||||||
|
|
||||||
|
# Get WebSocket upgrade headers from the HTTP request
|
||||||
|
sec_websocket_key = self.headers.get("Sec-WebSocket-Key", "")
|
||||||
|
sec_websocket_protocol = self.headers.get("Sec-WebSocket-Protocol", "")
|
||||||
|
|
||||||
|
# Validate Sec-WebSocket-Key
|
||||||
|
if not sec_websocket_key:
|
||||||
|
self.send_error_page(400, "Bad Request", "Missing Sec-WebSocket-Key")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get the socket from the connection
|
||||||
|
sock = self.connection
|
||||||
|
sock.setblocking(False)
|
||||||
|
|
||||||
|
# Create async server to handle the connection
|
||||||
|
async def handle_ws():
|
||||||
|
try:
|
||||||
|
# Wrap the socket in asyncio streams using open_connection
|
||||||
|
reader, writer = await asyncio.open_connection(sock=sock)
|
||||||
|
|
||||||
|
# Create WebSocket handler
|
||||||
|
ws_handler = _WebSocketHandler(reader, writer, user, _websocket_queues[user])
|
||||||
|
|
||||||
|
# Accept the connection (pass headers from HTTP request)
|
||||||
|
if not await ws_handler.accept_connection(sec_websocket_key, sec_websocket_protocol):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Start a task to read from the queue and send to client
|
||||||
|
async def send_stream():
|
||||||
|
while not ws_handler.closed:
|
||||||
|
try:
|
||||||
|
data = await asyncio.wait_for(ws_handler.message_queue.get(), timeout=1.0)
|
||||||
|
await ws_handler.send_text(data)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# Send ping to keep connection alive
|
||||||
|
try:
|
||||||
|
frame = ws_handler._encode_frame(OPCODE_PING, b"")
|
||||||
|
writer.write(frame)
|
||||||
|
await writer.drain()
|
||||||
|
except Exception:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Send stream error: {e}", file=sys.stderr)
|
||||||
|
break
|
||||||
|
|
||||||
|
# Start sending task
|
||||||
|
send_task = asyncio.create_task(send_stream())
|
||||||
|
|
||||||
|
# Handle incoming WebSocket frames
|
||||||
|
await ws_handler.handle_connection()
|
||||||
|
|
||||||
|
# Cancel send task
|
||||||
|
send_task.cancel()
|
||||||
|
try:
|
||||||
|
await send_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"WebSocket handler error: {e}", file=sys.stderr)
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
writer.close()
|
||||||
|
await writer.wait_closed()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Run the async handler in a thread
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
try:
|
||||||
|
loop.run_until_complete(handle_ws())
|
||||||
|
except Exception as e:
|
||||||
|
print(f"WebSocket error: {e}", file=sys.stderr)
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
sock.close()
|
||||||
|
|
||||||
def do_DELETE(self):
|
def do_DELETE(self):
|
||||||
"""Handle DELETE requests."""
|
"""Handle DELETE requests."""
|
||||||
parsed = urlparse(self.path)
|
parsed = urlparse(self.path)
|
||||||
|
|
|
||||||
|
|
@ -430,6 +430,10 @@
|
||||||
return div.innerHTML.replace(/\n/g, '<br>');
|
return div.innerHTML.replace(/\n/g, '<br>');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WebSocket connection for streaming
|
||||||
|
let ws = null;
|
||||||
|
let wsMessageId = null;
|
||||||
|
|
||||||
// Send message handler
|
// Send message handler
|
||||||
async function sendMessage() {
|
async function sendMessage() {
|
||||||
const message = textarea.value.trim();
|
const message = textarea.value.trim();
|
||||||
|
|
@ -449,6 +453,14 @@
|
||||||
await createNewConversation();
|
await createNewConversation();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Try WebSocket streaming first, fall back to fetch
|
||||||
|
if (window.location.protocol === 'https:' || window.location.hostname === 'localhost') {
|
||||||
|
if (tryWebSocketSend(message)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to fetch
|
||||||
try {
|
try {
|
||||||
// Use fetch with URLSearchParams for application/x-www-form-urlencoded
|
// Use fetch with URLSearchParams for application/x-www-form-urlencoded
|
||||||
const params = new URLSearchParams();
|
const params = new URLSearchParams();
|
||||||
|
|
@ -485,6 +497,111 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Try to send message via WebSocket streaming
|
||||||
|
function tryWebSocketSend(message) {
|
||||||
|
try {
|
||||||
|
// Generate a unique message ID for this request
|
||||||
|
wsMessageId = Date.now().toString(36) + Math.random().toString(36).substr(2);
|
||||||
|
|
||||||
|
// Connect to WebSocket
|
||||||
|
const wsUrl = window.location.protocol === 'https:'
|
||||||
|
? `wss://${window.location.host}/chat/ws`
|
||||||
|
: `ws://${window.location.host}/chat/ws`;
|
||||||
|
|
||||||
|
ws = new WebSocket(wsUrl);
|
||||||
|
|
||||||
|
ws.onopen = function() {
|
||||||
|
// Send the message as JSON with message ID
|
||||||
|
const data = {
|
||||||
|
type: 'chat_request',
|
||||||
|
message_id: wsMessageId,
|
||||||
|
message: message,
|
||||||
|
conversation_id: currentConversationId
|
||||||
|
};
|
||||||
|
ws.send(JSON.stringify(data));
|
||||||
|
};
|
||||||
|
|
||||||
|
ws.onmessage = function(event) {
|
||||||
|
try {
|
||||||
|
const data = JSON.parse(event.data);
|
||||||
|
|
||||||
|
if (data.type === 'token') {
|
||||||
|
// Stream a token to the UI
|
||||||
|
addTokenToLastMessage(data.token);
|
||||||
|
} else if (data.type === 'complete') {
|
||||||
|
// Streaming complete
|
||||||
|
closeWebSocket();
|
||||||
|
textarea.disabled = false;
|
||||||
|
sendBtn.disabled = false;
|
||||||
|
sendBtn.textContent = 'Send';
|
||||||
|
textarea.focus();
|
||||||
|
messagesDiv.scrollTop = messagesDiv.scrollHeight;
|
||||||
|
loadConversations();
|
||||||
|
} else if (data.type === 'error') {
|
||||||
|
addSystemMessage(`Error: ${data.message}`);
|
||||||
|
closeWebSocket();
|
||||||
|
textarea.disabled = false;
|
||||||
|
sendBtn.disabled = false;
|
||||||
|
sendBtn.textContent = 'Send';
|
||||||
|
textarea.focus();
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
console.error('Failed to parse WebSocket message:', e);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
ws.onerror = function(error) {
|
||||||
|
console.error('WebSocket error:', error);
|
||||||
|
addSystemMessage('WebSocket connection error. Falling back to regular chat.');
|
||||||
|
closeWebSocket();
|
||||||
|
sendMessage(); // Retry with fetch
|
||||||
|
};
|
||||||
|
|
||||||
|
ws.onclose = function() {
|
||||||
|
wsMessageId = null;
|
||||||
|
};
|
||||||
|
|
||||||
|
return true; // WebSocket attempt started
|
||||||
|
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Failed to create WebSocket:', error);
|
||||||
|
return false; // Fall back to fetch
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add a token to the last assistant message (for streaming)
|
||||||
|
function addTokenToLastMessage(token) {
|
||||||
|
const messages = messagesDiv.querySelectorAll('.message.assistant');
|
||||||
|
if (messages.length === 0) {
|
||||||
|
// No assistant message yet, create one
|
||||||
|
const msgDiv = document.createElement('div');
|
||||||
|
msgDiv.className = 'message assistant';
|
||||||
|
msgDiv.innerHTML = `
|
||||||
|
<div class="role">assistant</div>
|
||||||
|
<div class="content streaming"></div>
|
||||||
|
`;
|
||||||
|
messagesDiv.appendChild(msgDiv);
|
||||||
|
}
|
||||||
|
|
||||||
|
const lastMsg = messagesDiv.querySelector('.message.assistant .content.streaming');
|
||||||
|
if (lastMsg) {
|
||||||
|
lastMsg.textContent += token;
|
||||||
|
messagesDiv.scrollTop = messagesDiv.scrollHeight;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close WebSocket connection
|
||||||
|
function closeWebSocket() {
|
||||||
|
if (ws) {
|
||||||
|
ws.onopen = null;
|
||||||
|
ws.onmessage = null;
|
||||||
|
ws.onerror = null;
|
||||||
|
ws.onclose = null;
|
||||||
|
ws.close();
|
||||||
|
ws = null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Event listeners
|
// Event listeners
|
||||||
sendBtn.addEventListener('click', sendMessage);
|
sendBtn.addEventListener('click', sendMessage);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -172,6 +172,12 @@ EOT
|
||||||
handle /chat/oauth/callback {
|
handle /chat/oauth/callback {
|
||||||
reverse_proxy 127.0.0.1:8080
|
reverse_proxy 127.0.0.1:8080
|
||||||
}
|
}
|
||||||
|
# WebSocket endpoint for streaming (#1026)
|
||||||
|
handle /chat/ws {
|
||||||
|
header_up Upgrade $http.upgrade
|
||||||
|
header_up Connection $http.connection
|
||||||
|
reverse_proxy 127.0.0.1:8080
|
||||||
|
}
|
||||||
# Defense-in-depth: forward_auth stamps X-Forwarded-User from session (#709)
|
# Defense-in-depth: forward_auth stamps X-Forwarded-User from session (#709)
|
||||||
handle /chat/* {
|
handle /chat/* {
|
||||||
forward_auth 127.0.0.1:8080 {
|
forward_auth 127.0.0.1:8080 {
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue