Source code for lmcache.storage_backend.connector.lm_connector

import socket
import threading
from typing import List, Optional

from lmcache.logging import init_logger
from lmcache.protocol import ClientMetaMessage, Constants, ServerMetaMessage
from lmcache.storage_backend.connector.base_connector import \
    RemoteBytesConnector
from lmcache.utils import _lmcache_nvtx_annotate

logger = init_logger(__name__)


# TODO: performance optimization for this class, consider using C/C++/Rust
# for communication + deserialization
[docs] class LMCServerConnector(RemoteBytesConnector): def __init__(self, host, port): self.client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.client_socket.connect((host, port)) self.socket_lock = threading.Lock()
[docs] def receive_all(self, n): received = 0 buffer = bytearray(n) view = memoryview(buffer) while received < n: num_bytes = self.client_socket.recv_into(view[received:], n - received) if num_bytes == 0: return None received += num_bytes return buffer
[docs] def send_all(self, data): """ Thread-safe function to send the data """ with self.socket_lock: self.client_socket.sendall(data)
[docs] def exists(self, key: str) -> bool: logger.debug("Call to exists()!") self.send_all( ClientMetaMessage(Constants.CLIENT_EXIST, key, 0).serialize()) response = self.client_socket.recv(ServerMetaMessage.packlength()) return (ServerMetaMessage.deserialize(response).code == Constants.SERVER_SUCCESS)
[docs] def set(self, key: str, obj: bytes): # type: ignore[override] logger.debug("Call to set()!") self.send_all( ClientMetaMessage(Constants.CLIENT_PUT, key, len(obj)).serialize()) self.send_all(obj)
# response = self.client_socket.recv(ServerMetaMessage.packlength()) # if ServerMetaMessage.deserialize(response).code # != Constants.SERVER_SUCCESS: # raise RuntimeError(f"Failed to set key: # {ServerMetaMessage.deserialize(response).code}")
[docs] @_lmcache_nvtx_annotate def get(self, key: str) -> Optional[bytes]: self.send_all( ClientMetaMessage(Constants.CLIENT_GET, key, 0).serialize()) data = self.client_socket.recv(ServerMetaMessage.packlength()) meta = ServerMetaMessage.deserialize(data) if meta.code != Constants.SERVER_SUCCESS: return None length = meta.length data = self.receive_all(length) return data if data is None else bytes(data)
[docs] def list(self) -> List[str]: self.send_all( ClientMetaMessage(Constants.CLIENT_LIST, "", 0).serialize()) data = self.client_socket.recv(ServerMetaMessage.packlength()) meta = ServerMetaMessage.deserialize(data) if meta.code != Constants.SERVER_SUCCESS: logger.error( "LMCServerConnector: Cannot list keys from the remote server!") return [] length = meta.length data = self.receive_all(length) return list(filter(lambda s: len(s) > 0, data.decode().split("\n")))
[docs] def close(self): self.client_socket.close() logger.info("Closed the lmserver connection")