Module directory_server.server
Main directory server implementation using asyncio.
Implements Open/Closed Principle: extensible without modification.
Classes
class DirectoryServer (settings: Settings)-
Expand source code
class DirectoryServer: def __init__(self, settings: Settings): self.settings = settings self.network = NetworkType(settings.network) self.peer_registry = PeerRegistry(max_peers=settings.max_peers) self.connections = ConnectionPool(max_connections=settings.max_peers) self.peer_key_to_conn_id: dict[str, str] = {} self.message_router = MessageRouter( peer_registry=self.peer_registry, send_callback=self._send_to_peer, broadcast_batch_size=settings.broadcast_batch_size, on_send_failed=self._handle_send_failed, ) self.handshake_handler = HandshakeHandler( network=self.network, server_nick=f"directory-{settings.network}", motd=settings.motd ) # Rate limit by connection ID to prevent nick spoofing attacks. # A malicious peer could claim another's nick and spam to get them rate limited. # Using connection ID ensures each physical connection has its own bucket. self.rate_limiter = RateLimiter( rate_limit=settings.message_rate_limit, burst_limit=settings.message_burst_limit, disconnect_threshold=settings.rate_limit_disconnect_threshold if settings.rate_limit_disconnect_threshold > 0 else None, ) self.server: asyncio.Server | None = None self._shutdown = False self._start_time = datetime.now(UTC) self.health_server = HealthCheckServer( host=settings.health_check_host, port=settings.health_check_port ) async def start(self) -> None: self.server = await asyncio.start_server( self._handle_client, self.settings.host, self.settings.port, limit=self.settings.max_message_size, ) addr = self.server.sockets[0].getsockname() logger.info( f"Directory server started on {addr[0]}:{addr[1]} (network: {self.network.value})" ) self.health_server.start(self) async with self.server: await self.server.serve_forever() async def stop(self) -> None: logger.info("Shutting down directory server...") self._shutdown = True self.health_server.stop() if self.server: self.server.close() await self.server.wait_closed() await self.connections.close_all() logger.info("Directory server stopped") async def _handle_client( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter ) -> None: peer_addr = writer.get_extra_info("peername") conn_id = f"{peer_addr[0]}:{peer_addr[1]}" logger.trace(f"New connection from {conn_id}") transport = writer.transport # Set reasonable write buffer limits (64KB high, 16KB low) # This allows some buffering while preventing memory bloat transport.set_write_buffer_limits(high=65536, low=16384) # type: ignore[union-attr] sock = transport.get_extra_info("socket") if sock: import socket sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) connection = TCPConnection(reader, writer, self.settings.max_message_size) peer_key: str | None = None try: self.connections.add(conn_id, connection) peer_key = await self._perform_handshake(connection, conn_id) if not peer_key: return await self._handle_peer_messages(connection, conn_id, peer_key) except Exception as e: logger.error(f"Error handling client {conn_id}: {e}") finally: await self._cleanup_peer(connection, conn_id, peer_key) async def _perform_handshake(self, connection: TCPConnection, conn_id: str) -> str | None: try: logger.trace(f"[{conn_id}] Waiting for handshake message...") data = await asyncio.wait_for(connection.receive(), timeout=30.0) logger.trace(f"[{conn_id}] Received {len(data)} bytes: {data[:200]!r}...") envelope = MessageEnvelope.from_bytes( data, max_line_length=self.settings.max_line_length, max_json_nesting_depth=self.settings.max_json_nesting_depth, ) logger.trace( f"[{conn_id}] Parsed envelope: type={envelope.message_type}, payload_len={len(envelope.payload)}" ) if envelope.message_type != MessageType.HANDSHAKE: logger.warning(f"[{conn_id}] Expected handshake, got {envelope.message_type}") return None logger.trace(f"[{conn_id}] Processing handshake payload: {envelope.payload[:200]}") peer_info, response = self.handshake_handler.process_handshake( envelope.payload, conn_id ) logger.trace( f"[{conn_id}] Handshake processed: peer_nick={peer_info.nick}, location={peer_info.location_string}" ) response_envelope = MessageEnvelope( message_type=MessageType.DN_HANDSHAKE, payload=json.dumps(response) ) response_bytes = response_envelope.to_bytes() logger.trace(f"[{conn_id}] Sending handshake response: {len(response_bytes)} bytes") logger.trace(f"[{conn_id}] Response content: {response_bytes[:200]!r}...") try: await connection.send(response_bytes) logger.trace(f"[{conn_id}] Handshake response sent successfully") # Small delay to let client process the handshake response await asyncio.sleep(0.05) except Exception as e: logger.error(f"[{conn_id}] Failed to send handshake response: {e}") raise peer_location = peer_info.location_string self.peer_registry.register(peer_info) logger.trace(f"[{conn_id}] Peer registered in registry") peer_key = peer_info.nick if peer_location == "NOT-SERVING-ONION" else peer_location self.peer_registry.update_status(peer_key, PeerStatus.HANDSHAKED) self.peer_key_to_conn_id[peer_key] = conn_id logger.trace(f"[{conn_id}] Peer key mapped: {peer_key}") logger.trace(f"[{conn_id}] Handshake complete for {peer_key} (nick={peer_info.nick})") return peer_key except HandshakeError as e: logger.warning(f"[{conn_id}] Handshake failed: {e}") return None except TimeoutError: logger.warning(f"[{conn_id}] Handshake timeout (30s)") return None except Exception as e: logger.error(f"[{conn_id}] Handshake error: {e}", exc_info=True) return None async def _handle_peer_messages( self, connection: TCPConnection, conn_id: str, peer_key: str ) -> None: peer_info = self.peer_registry.get_by_key(peer_key) if not peer_info: return logger.info(f"Peer {peer_info.nick} connected from {peer_info.location_string}") while connection.is_connected() and not self._shutdown: try: data = await connection.receive() # Rate limiting by connection ID to prevent nick spoofing attacks. # A malicious peer could claim another's nick in handshake and spam # to get the legitimate peer rate-limited. Using conn_id ensures # each physical connection is rate-limited independently. action, delay = self.rate_limiter.check(conn_id) if action == RateLimitAction.DISCONNECT: violations = self.rate_limiter.get_violation_count(conn_id) logger.warning( f"Rate limit exceeded for {peer_info.nick} ({conn_id}): " f"{violations} violations, disconnecting" ) break elif action == RateLimitAction.DELAY: violations = self.rate_limiter.get_violation_count(conn_id) if violations % 50 == 1: # Log every 50th violation to avoid spam logger.debug( f"Rate limiting {peer_info.nick} ({conn_id}): " f"{violations} violations, delay={delay:.2f}s" ) # Drop message but stay connected - this is the "slowdown" approach continue envelope = MessageEnvelope.from_bytes( data, max_line_length=self.settings.max_line_length, max_json_nesting_depth=self.settings.max_json_nesting_depth, ) await self.message_router.route_message(envelope, peer_key) except asyncio.CancelledError: break except Exception as e: logger.error(f"Error processing message from {peer_info.nick}: {e}") break async def _cleanup_peer( self, connection: TCPConnection, conn_id: str, peer_key: str | None ) -> None: if peer_key: peer_info = self.peer_registry.get_by_key(peer_key) if peer_info: logger.info(f"Peer {peer_info.nick} disconnected") await self.message_router.broadcast_peer_disconnect( peer_info.location_string, peer_info.network ) self.peer_registry.unregister(peer_key) if peer_key in self.peer_key_to_conn_id: del self.peer_key_to_conn_id[peer_key] # Clean up offer tracking self.message_router.remove_peer_offers(peer_key) # Clean up rate limiter state (keyed by conn_id, not peer_key) self.rate_limiter.remove_peer(conn_id) self.connections.remove(conn_id) try: await connection.close() except Exception as e: logger.trace(f"Error closing connection: {e}") async def _send_to_peer(self, peer_location: str, data: bytes) -> None: peer_key = peer_location conn_id = self.peer_key_to_conn_id.get(peer_key) if not conn_id: raise ValueError(f"No connection for peer: {peer_location}") connection = self.connections.get(conn_id) if not connection: raise ValueError(f"No connection for conn_id: {conn_id}") await connection.send(data) async def _handle_send_failed(self, peer_key: str) -> None: """ Called when sending to a peer fails. Removes the peer from both the connection mapping and the registry to prevent further send attempts to this dead connection. """ if peer_key in self.peer_key_to_conn_id: logger.debug(f"Removing failed peer mapping: {peer_key}") del self.peer_key_to_conn_id[peer_key] # Also unregister from peer registry to prevent further routing attempts peer_info = self.peer_registry.get_by_key(peer_key) if peer_info: logger.debug(f"Unregistering failed peer: {peer_key}") self.peer_registry.unregister(peer_key) def is_healthy(self) -> bool: return ( self.server is not None and not self._shutdown and self.peer_registry.count() < self.settings.max_peers ) def get_stats(self) -> dict: return { "network": self.network.value, "connected_peers": self.peer_registry.count(), "max_peers": self.settings.max_peers, "active_connections": len(self.connections), "rate_limit_violations": self.rate_limiter.get_stats()["total_violations"], } def get_detailed_stats(self) -> dict: uptime = (datetime.now(UTC) - self._start_time).total_seconds() registry_stats = self.peer_registry.get_stats() connected_peers = self.peer_registry.get_all_connected() passive_peers = self.peer_registry.get_passive_peers() active_peers = self.peer_registry.get_active_peers() offer_stats = self.message_router.get_offer_stats() return { "network": self.network.value, "uptime_seconds": uptime, "server_status": "running" if not self._shutdown else "stopping", "max_peers": self.settings.max_peers, "stats": registry_stats, "rate_limiter": self.rate_limiter.get_stats(), "offers": offer_stats, "connected_peers": { "total": len(connected_peers), "nicks": [p.nick for p in connected_peers], }, "passive_peers": { "total": len(passive_peers), "nicks": [p.nick for p in passive_peers], }, "active_peers": { "total": len(active_peers), "nicks": [p.nick for p in active_peers], }, "active_connections": len(self.connections), } def log_status(self) -> None: stats = self.get_detailed_stats() logger.info("=== Directory Server Status ===") logger.info(f"Network: {stats['network']}") logger.info(f"Uptime: {stats['uptime_seconds']:.0f}s") logger.info(f"Status: {stats['server_status']}") logger.info(f"Connected peers: {stats['connected_peers']['total']}/{stats['max_peers']}") logger.info(f" Nicks: {', '.join(stats['connected_peers']['nicks'][:10])}") if len(stats["connected_peers"]["nicks"]) > 10: logger.info(f" ... and {len(stats['connected_peers']['nicks']) - 10} more") logger.info(f"Passive peers (orderbook watchers): {stats['passive_peers']['total']}") logger.info(f" Nicks: {', '.join(stats['passive_peers']['nicks'][:10])}") if len(stats["passive_peers"]["nicks"]) > 10: logger.info(f" ... and {len(stats['passive_peers']['nicks']) - 10} more") logger.info(f"Active peers (makers): {stats['active_peers']['total']}") logger.info(f" Nicks: {', '.join(stats['active_peers']['nicks'][:10])}") if len(stats["active_peers"]["nicks"]) > 10: logger.info(f" ... and {len(stats['active_peers']['nicks']) - 10} more") logger.info(f"Active connections: {stats['active_connections']}") logger.info("===============================")Methods
def get_detailed_stats(self) ‑> dict-
Expand source code
def get_detailed_stats(self) -> dict: uptime = (datetime.now(UTC) - self._start_time).total_seconds() registry_stats = self.peer_registry.get_stats() connected_peers = self.peer_registry.get_all_connected() passive_peers = self.peer_registry.get_passive_peers() active_peers = self.peer_registry.get_active_peers() offer_stats = self.message_router.get_offer_stats() return { "network": self.network.value, "uptime_seconds": uptime, "server_status": "running" if not self._shutdown else "stopping", "max_peers": self.settings.max_peers, "stats": registry_stats, "rate_limiter": self.rate_limiter.get_stats(), "offers": offer_stats, "connected_peers": { "total": len(connected_peers), "nicks": [p.nick for p in connected_peers], }, "passive_peers": { "total": len(passive_peers), "nicks": [p.nick for p in passive_peers], }, "active_peers": { "total": len(active_peers), "nicks": [p.nick for p in active_peers], }, "active_connections": len(self.connections), } def get_stats(self) ‑> dict-
Expand source code
def get_stats(self) -> dict: return { "network": self.network.value, "connected_peers": self.peer_registry.count(), "max_peers": self.settings.max_peers, "active_connections": len(self.connections), "rate_limit_violations": self.rate_limiter.get_stats()["total_violations"], } def is_healthy(self) ‑> bool-
Expand source code
def is_healthy(self) -> bool: return ( self.server is not None and not self._shutdown and self.peer_registry.count() < self.settings.max_peers ) def log_status(self) ‑> None-
Expand source code
def log_status(self) -> None: stats = self.get_detailed_stats() logger.info("=== Directory Server Status ===") logger.info(f"Network: {stats['network']}") logger.info(f"Uptime: {stats['uptime_seconds']:.0f}s") logger.info(f"Status: {stats['server_status']}") logger.info(f"Connected peers: {stats['connected_peers']['total']}/{stats['max_peers']}") logger.info(f" Nicks: {', '.join(stats['connected_peers']['nicks'][:10])}") if len(stats["connected_peers"]["nicks"]) > 10: logger.info(f" ... and {len(stats['connected_peers']['nicks']) - 10} more") logger.info(f"Passive peers (orderbook watchers): {stats['passive_peers']['total']}") logger.info(f" Nicks: {', '.join(stats['passive_peers']['nicks'][:10])}") if len(stats["passive_peers"]["nicks"]) > 10: logger.info(f" ... and {len(stats['passive_peers']['nicks']) - 10} more") logger.info(f"Active peers (makers): {stats['active_peers']['total']}") logger.info(f" Nicks: {', '.join(stats['active_peers']['nicks'][:10])}") if len(stats["active_peers"]["nicks"]) > 10: logger.info(f" ... and {len(stats['active_peers']['nicks']) - 10} more") logger.info(f"Active connections: {stats['active_connections']}") logger.info("===============================") async def start(self) ‑> None-
Expand source code
async def start(self) -> None: self.server = await asyncio.start_server( self._handle_client, self.settings.host, self.settings.port, limit=self.settings.max_message_size, ) addr = self.server.sockets[0].getsockname() logger.info( f"Directory server started on {addr[0]}:{addr[1]} (network: {self.network.value})" ) self.health_server.start(self) async with self.server: await self.server.serve_forever() async def stop(self) ‑> None-
Expand source code
async def stop(self) -> None: logger.info("Shutting down directory server...") self._shutdown = True self.health_server.stop() if self.server: self.server.close() await self.server.wait_closed() await self.connections.close_all() logger.info("Directory server stopped")