import asyncio
import inspect
import os
from typing import List, Optional, Tuple, Union, no_type_check
import redis
from lmcache.experimental.memory_management import (MemoryAllocatorInterface,
MemoryObj)
from lmcache.experimental.protocol import RedisMetadata
from lmcache.experimental.storage_backend.connector.base_connector import \
RemoteConnector
from lmcache.logging import init_logger
from lmcache.utils import CacheEngineKey
logger = init_logger(__name__)
# TODO(Jiayi): Use `redis.asyncio`
# NOTE(Jiayi): `redis-py` supports async operations, but data copy
# cannot be avoided. `hiredis` is more lower-level but asyncio is
# not supported.
[docs]
class RedisConnector(RemoteConnector):
"""
The remote url should start with "redis://" and only have one host-port pair
"""
def __init__(self, host: str, port: int, loop: asyncio.AbstractEventLoop,
memory_allocator: MemoryAllocatorInterface):
self.connection = redis.Redis(host=host,
port=port,
decode_responses=False)
self.memory_allocator = memory_allocator
self.loop = loop
[docs]
async def exists(self, key: CacheEngineKey) -> bool:
return bool(self.connection.exists(key.to_string() + "metadata"))
[docs]
async def get(self, key: CacheEngineKey) -> Optional[MemoryObj]:
key_str = key.to_string()
redis_metadata_bytes = self.connection.get(key_str + "metadata")
if redis_metadata_bytes is None:
return None
assert not inspect.isawaitable(redis_metadata_bytes)
redis_metadata = RedisMetadata.deserialize(
memoryview(redis_metadata_bytes))
memory_obj = self.memory_allocator.allocate(
redis_metadata.shape,
redis_metadata.dtype,
redis_metadata.fmt,
)
if memory_obj is None:
logger.warning("Failed to allocate memory during remote receive")
return None
# TODO(Jiayi): Find a way to do `get` inplace
kv_bytes = self.connection.get(key_str + "kv_bytes")
assert not inspect.isawaitable(kv_bytes)
if kv_bytes is None:
# TODO (Jiayi): We might need a way to better handle
# consistency issues.
# TODO (Jiayi): A better way is to aggregate metadata
# and kv cache in one key.
logger.warning("Key exists but KV cache does not exist."
"Might happen when the cache is evicted by redis.")
self.connection.delete(key_str + "metadata")
return None
view = memoryview(memory_obj.byte_array)
view[:redis_metadata.length] = kv_bytes
return memory_obj
[docs]
async def put(self, key: CacheEngineKey, memory_obj: MemoryObj):
# TODO(Jiayi): The following code is ugly.
# Please use a function like `memory_obj.to_meta()`.
kv_bytes = memory_obj.byte_array
kv_shape = memory_obj.get_shape()
kv_dtype = memory_obj.get_dtype()
memory_format = memory_obj.get_memory_format()
redis_metadata_bytes = RedisMetadata(len(kv_bytes), kv_shape, kv_dtype,
memory_format).serialize()
key_str = key.to_string()
self.connection.set(key_str + "metadata", redis_metadata_bytes)
self.connection.set(key_str + "kv_bytes", kv_bytes)
self.memory_allocator.ref_count_down(memory_obj)
# TODO
[docs]
@no_type_check
async def list(self) -> List[str]:
pass
[docs]
async def close(self):
self.connection.close()
logger.info("Closed the redis connection")
[docs]
class RedisSentinelConnector(RemoteConnector):
"""
Uses redis.Sentinel to connect to a Redis cluster.
The hosts are specified in the config file, started with "redis-sentinel://"
and separated by commas.
Example:
remote_url: "redis-sentinel://localhost:26379,localhost:26380,localhost:26381"
Extra environment variables:
- REDIS_SERVICE_NAME (required) -- service name for redis.
- REDIS_TIMEOUT (optional) -- Timeout in seconds, default is 1 if not set
"""
ENV_REDIS_TIMEOUT = "REDIS_TIMEOUT"
ENV_REDIS_SERVICE_NAME = "REDIS_SERVICE_NAME"
def __init__(self, hosts_and_ports: List[Tuple[str, Union[str, int]]],
loop: asyncio.AbstractEventLoop,
memory_allocator: MemoryAllocatorInterface):
# Get service name
match os.environ.get(self.ENV_REDIS_SERVICE_NAME):
case None:
logger.warning(
f"Environment variable {self.ENV_REDIS_SERVICE_NAME} is not"
f"found, using default value 'redismaster'")
service_name = "redismaster"
case value:
service_name = value
timeout: float = -1000.0
# Get timeout
match os.environ.get(self.ENV_REDIS_TIMEOUT):
case None:
timeout = 1
case value:
timeout = float(value)
logger.info(f"Host and ports: {hosts_and_ports}")
self.sentinel = redis.Sentinel(hosts_and_ports, timeout)
self.master = self.sentinel.master_for(service_name,
socket_timeout=timeout)
self.slave = self.sentinel.slave_for(service_name,
socket_timeout=timeout)
self.memory_allocator = memory_allocator
[docs]
async def exists(self, key: CacheEngineKey) -> bool:
return self.slave.exists(key.to_string() + "metadata")
[docs]
async def get(self, key: CacheEngineKey) -> Optional[MemoryObj]:
key_str = key.to_string()
redis_metadata_bytes = self.slave.get(key_str + "metadata")
if redis_metadata_bytes is None:
return None
assert not inspect.isawaitable(redis_metadata_bytes)
redis_metadata = RedisMetadata.deserialize(redis_metadata_bytes)
memory_obj = self.memory_allocator.allocate(
redis_metadata.shape,
redis_metadata.dtype,
redis_metadata.fmt,
)
if memory_obj is None:
logger.warning("Failed to allocate memory during remote receive")
return None
# TODO(Jiayi): Find a way to do `get` inplace
kv_bytes = self.slave.get(key_str + "kv_bytes")
assert not inspect.isawaitable(kv_bytes)
if kv_bytes is None:
# TODO (Jiayi): We might need a way to better handle
# consistency issues.
# TODO (Jiayi): A background sweeper might be better
# for the sake of performance.
logger.warning("Key exists but KV cache does not exist."
"Might happen when the cache is evicted by redis.")
self.master.delete(key_str + "metadata")
return None
view = memoryview(memory_obj.byte_array)
view[0:redis_metadata.length] = kv_bytes
return memory_obj
[docs]
async def put(self, key: CacheEngineKey, memory_obj: MemoryObj):
# TODO(Jiayi): The following code is ugly.
# Please use a function like `memory_obj.to_meta()`.
kv_bytes = memory_obj.byte_array
kv_shape = memory_obj.get_shape()
kv_dtype = memory_obj.get_dtype()
memory_format = memory_obj.get_memory_format()
redis_metadata_bytes = RedisMetadata(len(kv_bytes), kv_shape, kv_dtype,
memory_format).serialize()
key_str = key.to_string()
self.master.set(key_str + "metadata", redis_metadata_bytes)
self.master.set(key_str + "kv_bytes", kv_bytes)
self.memory_allocator.ref_count_down(memory_obj)
# TODO
[docs]
@no_type_check
async def list(self) -> List[str]:
pass
[docs]
async def close(self):
self.master.close()
self.slave.close()