Skip to content

Commit

Permalink
Cleanup cache_manager (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
FanhaiLu1 authored Apr 25, 2024
1 parent 8bd4986 commit 2abc73e
Showing 1 changed file with 27 additions and 3 deletions.
30 changes: 27 additions & 3 deletions jetstream_pt/cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
import torch


# pylint: disable-next=all
class CacheInterface:
"""Kv cache interface"""

# cache for ONE layer

def update(self, key, value):
Expand All @@ -33,6 +36,7 @@ def update(self, key, value):


class KVCachePrefill:
"""Prefill kv cache"""

def __init__(self, kv_quantize=False):
self.kv_quantize = kv_quantize
Expand All @@ -49,20 +53,23 @@ def update(self, key, value):
jnp.ones((bsz, 1, seq, 1), dtype=jnp.bfloat16)
)
return key, value, ones, ones
else:
return key, value

return key, value

def state(self):
"""Get prefill cache state"""
return self.cache_k, self.cache_v


# pylint: disable-next=all
def KVCachePrefill_flatten(cache):
return (
torch_xla2.tensor.unwrap((cache.cache_k, cache.cache_v)),
cache.kv_quantize,
)


# pylint: disable-next=all
def KVCachePrefill_unflatten(auxdata, data):
cache = KVCachePrefill(auxdata)
cache_k, cache_v = torch_xla2.tensor.wrap(data)
Expand All @@ -78,6 +85,7 @@ def KVCachePrefill_unflatten(auxdata, data):
# Refactor out cache management
# Easier to test for quantized kv cache
class KVCacheGenerate:
"""Kvache generator without quantization"""

def __init__(
self,
Expand All @@ -93,31 +101,38 @@ def __init__(
self.sharding = sharding

def update(self, key, value):
"""Update kv cache"""
keyj, valuej = torch_xla2.tensor.unwrap((key, value))
# pylint: disable-next=all
self.cache_k._elem = self.cache_k._elem.at[:, :, self.pos].set(keyj)
# pylint: disable-next=all
self.cache_v._elem = self.cache_v._elem.at[:, :, self.pos].set(valuej)
return self.cache_k, self.cache_v

def state(self):
"""Get kv cache state"""
# pylint: disable-next=all
return self.cache_k._elem, self.cache_v._elem

@classmethod
def empty(cls, shape, device, bf16_enable):
"""Create empty kv caches"""
default_dtype = jnp.bfloat16 if bf16_enable else jnp.float32
k = jnp.zeros(shape, device=device, dtype=default_dtype)
v = jnp.zeros(shape, device=device, dtype=default_dtype)
k, v = torch_xla2.tensor.wrap((k, v))
pos = jnp.array([0]) # replicated
return cls(k, v, 0, device)


# pylint: disable-next=all
def KVCacheGenerate_flatten(cache):
return torch_xla2.tensor.unwrap((cache.cache_k, cache.cache_v)), (
cache.pos,
cache.sharding,
)


# pylint: disable-next=all
def KVCacheGenerate_unflatten(auxdata, data):
position, sharding = auxdata
cache_k, cache_v = torch_xla2.tensor.wrap(data)
Expand All @@ -131,7 +146,9 @@ def KVCacheGenerate_unflatten(auxdata, data):


class Int8KVCacheGenerate:
"""Int8 quantized kvache with scalers"""

# pylint: disable-next=all
def __init__(
self,
cache_k,
Expand All @@ -147,15 +164,20 @@ def __init__(
self.k_scaler = cache_k_scaler
self.v_scaler = cache_v_scaler
self.input_pos = input_pos
self.sharding = sharding

def state(self):
"""Get kv cache state"""
return torch_xla2.tensor.unwrap((self.cache_k, self.cache_v))

def scalers(self):
"""Get kv cache scalers"""
return torch_xla2.tensor.unwrap((self.k_scaler, self.v_scaler))

@classmethod
# pylint: disable-next=all
def empty(cls, shape, device, bf16_enable):
"""Create empty kv caches"""
cache_k = jnp.zeros(shape, device=device, dtype=jnp.int8)
cache_v = jnp.zeros(shape, device=device, dtype=jnp.int8)
# bf16_enable is a placeholder parameter, it's not used in Int8KVCache
Expand All @@ -168,12 +190,14 @@ def empty(cls, shape, device, bf16_enable):
return cls(cache_k, cache_v, kscaler, vscaler, 0, device)

def quantize(self, val):
"""Quantize value"""
# val is (batch, heads, seqlen, dim)
scale = torch.amax(val.abs(), axis=(1, 3), keepdim=True)
scale = scale / 127
return (val / scale).to(torch.int8), scale

def update(self, xk, xv):
"""Update kv cache"""
k_quant, kscale = self.quantize(xk)
v_quant, vscale = self.quantize(xv)
self.cache_k[:, :, self.input_pos, :] = k_quant
Expand Down

0 comments on commit 2abc73e

Please sign in to comment.