LMCache Engine#
The LMCache Engine has two mean function : store
and retrieve
.
The store
Function :
Breaks input tokens and KV caches into manageable chunks.
Store the chunks into dictionary, which is managed by the LMCache Backend
Uses prefix token hashes (Currently sha256) to index the chunks, combined with other arguments (e.g., format) to form the key.
Efficiently stores the KV caches while avoiding redundancy (if
skip_existing=True
).The store operation can be blocking or non-blocking depending on the
blocking
argument.
The retrieve
Function:
Retrieves KV caches for the input tokens, using the same chunking and hashing mechanism as
store
.Supports partial retrieval via the
mask
parameter, allowing retrieval of suffixes or specific portions of the token sequence.Concatenates the retrieved KV cache chunks into a usable format for model inference.
The details of LMCacheEngine
class are listed below.
- class lmcache.cache_engine.LMCacheEngine(config: LMCacheEngineConfig, metadata: LMCacheEngineMetadata)[source]#
- _blob_to_tuple_kv(blob: Tensor) Tuple[Tuple[Tensor, Tensor], ...] [source]#
Convert a single big tensor to the nested tuple of kv tensors
- _chunk_kv(kv_tensors: Tensor, fmt: str) Iterable[Tensor] [source]#
Chunk the kv cache into chunks of size self.chunk_size.
- Parameters:
tokens – the input tokens, with shape [seq_len]
kv_tensors – the kv cache of the tokens, in the format of nested tuples
fmt – either ‘huggingface’ or ‘vllm’
- Returns:
a generator of tuples, each tuple is a chunk of tokens and the corresponding kv cache.
- _chunk_tokens(tokens: Tensor) Iterable[Tensor] [source]#
Chunk the tokens into chunks of size self.chunk_size.
- Parameters:
tokens – the input tokens, with shape [seq_len] device: the target device after chunking
- Returns:
a generator of chunks of tokens, each with shape [chunk_size]
- _make_chunks(tokens: Tensor, kv_tensors: Tensor, fmt: str, num_skip_prefix_chunk=0, skip_existing=True) Iterable[Tuple[str, Tensor]] [source]#
Returns a generator of zipped (chunk_hash, chunk_kv) tuples
- _make_chunks_skip_existing(tokens: Tensor, kv_tensors: Tensor, fmt: str, num_skip_prefix_chunk=0) Iterable[Tuple[str, Tensor]] [source]#
Skip the existing chunks and return the rest of the chunks
- _make_key(chunk_hash: str, fmt: str) CacheEngineKey [source]#
- _slice_kv_at(start_idx: int, kv_tensors: Tensor, fmt: str) List[Tensor] [source]#
vllm format: [num_layer, 2, num_tokens, num_kv_head, head_size] huggingface format: [num_layer, 2, num_kv_head, num_tokens, head_size]
- _tuple_kv_to_blob(kv_tensors: Tuple[Tuple[Tensor, Tensor], ...]) Tensor [source]#
Convert the nested tuple of kv tensors to a single big tensor with 2 extra dimensions
- lookup(tokens: Tensor) int [source]#
Checks the existence of KV cache of the tokens from the cache engine.
- Parameters:
tokens – the input tokens, with shape [seq_len]
- Returns:
An int indicating how many prefix tokens are cached.
- retrieve(tokens: Tensor, mask: Tensor | None = None, return_tuple: bool = True) Tuple[Tuple[Tuple[Tensor, Tensor], ...] | Tensor, Tensor] [source]#
Retrieve the KV cache of the tokens from the cache engine. The retrieved KV cache should be a prefix of the input tokens.
The KV cache of the tokens, in the format of nested tuples or a single tensor with shape [num_layers, 2, hidden_dim, num_tokens] (huggingface) or [num_layers, 2, num_tokens, hidden_dim] (vllm).
Will be an empty tuple if no kv cache is retrieved (no matter return_tuple is True or not).
- Parameters:
tokens – the input tokens, with shape [seq_len]
mask – a boolean mask of tokens indicating which tokens’ KV Cache should be retrieved. Currently, only support suffix mask.
return_tuple – whether to return the kv cache as a tuple or a single tensor
- Returns:
Tuple[ kv_tensors , ret_mask] indicate which tokens are retrieved
- store(tokens: Tensor, kv_tensors_raw: Tuple[Tuple[Tensor, Tensor], ...], kv_tensors_mask: Tensor | None = None, skip_existing=True, blocking=True) None [source]#
Store the KV cache of the tokens into the cache engine. Format: either ‘huggingface’ or ‘vllm’
For huggingface, it should have the shape of [num_heads, num_tokens, head_size]
For vllm, it should have the shape of [num_tokens, num_heads, head_size]
- Parameters:
tokens – the input tokens, with shape [seq_len]
kv_tensors_raw – the kv cache of the tokens, in the format of nested tuples. The number of tokens in the kv_tensors_raw should be the same as trues in kv_tensors_mask if mask is not None. Otherwise, it should be the same as the input tokens.
kv_tensors_mask – a boolean mask of tokens indicating which tokens’ KV Cache should be stored. Only support suffix mask. None is taken as trues for all tokens. len(kv_tensors_mask) should be the same as len(tokens) number of true should be the same as kv_tensors_raw token number.
skip_existing – whether to skip the existing chunks
blocking – whether to wait for the store operation to finish
- Returns:
None
Note
The KV cache should NOT have the “batch” dimension.