Skip to content

Commit

Permalink
Improve thread safety for TreeNode data structure and refactor relate…
Browse files Browse the repository at this point in the history
…d codes (#730)

* Improve thread safety for TreeNode data structure and refactor for consistency and readability

Signed-off-by: “Gangmuk <[email protected]>

* Updated minor comments based on the review

Signed-off-by: “Gangmuk <[email protected]>

---------

Signed-off-by: “Gangmuk <[email protected]>
  • Loading branch information
gangmuk authored Feb 25, 2025
1 parent fcc0896 commit 51651cf
Show file tree
Hide file tree
Showing 2 changed files with 204 additions and 100 deletions.
57 changes: 15 additions & 42 deletions pkg/plugins/gateway/algorithms/prefix_cache_and_load.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ func (h *SlidingWindowHistogram) getPrefillCost(node *prefixcacheindexer.TreeNod
attnQuad = calculateAttnQuadV100(numTokens, nil)
}
prefillTime := (baseTime + attnQuad) / 0.9
numPods := len(node.ModelToPods) // You might need to adjust this based on your actual GPU allocation tracking
numPods := node.GetModelToPodCount() // You might need to adjust this based on your actual GPU allocation tracking
klog.Infof("numTokens: %d, contextLength: %d, targetGPU: %s", numTokens, contextLength, targetGPU)
klog.Infof("prefillTime: %.2f = (Base time(%.2f) + attnQuad(%.2f)) / 0.9", prefillTime, baseTime, attnQuad)
totalPrefillCost := missRate * float64(h.nodeToCount[node]) * prefillTime / float64(numPods)
Expand Down Expand Up @@ -298,22 +298,18 @@ func (h *SlidingWindowHistogram) removeEvictedNodes(nodes []*prefixcacheindexer.
func (h *SlidingWindowHistogram) removeOldEntries(currentTime time.Time) {
h.mu.Lock()
defer h.mu.Unlock()

windowStart := currentTime.Add(-h.windowDuration)
newTimestamps := make([]histogramEntry, 0)

for _, entry := range h.timestamps {
if entry.timestamp.After(windowStart) {
newTimestamps = append(newTimestamps, entry)
} else {
node := entry.node
leafNode := entry.leafNode

h.histogram[node] -= leafNode.ContextLength()
h.nodeToCount[node]--
h.hitTokens[node] -= leafNode.ContextLength() - leafNode.NumTokens()
h.promptTokens[node] -= leafNode.ContextLength()

if h.histogram[node] <= 0 {
delete(h.histogram, node)
delete(h.nodeToCount, node)
Expand Down Expand Up @@ -368,7 +364,7 @@ func (h *SlidingWindowHistogram) getCurrentAllocationCostPerPod() map[string]flo
costs := make(map[string]float64)
for node := range h.histogram {
// Iterate through all models and their pods for this node
for _, modelPods := range node.ModelToPods {
for _, modelPods := range node.GetModelToPods() {
for podName := range modelPods {
costs[podName] += h.getNodeCost(node, podName)
}
Expand All @@ -387,21 +383,13 @@ func (p *prefixCacheAndLoadRouter) updatePodSet(readyPods []*v1.Pod) {
// Update cache structures
for _, node := range allNodes {
// 1. Update ModelToPods
for model, podMap := range node.ModelToPods {
for podName := range podMap {
if !currentPodSet[podName] {
delete(podMap, podName)
podsChanged = true
}
}
if len(podMap) == 0 {
delete(node.ModelToPods, model)
}
if node.RemovePodsNotInSet(currentPodSet) {
podsChanged = true
}
// 2. Update node's pod-specific data structures
node.EvictedPods = make(map[int]bool) // Reset as pod IDs might change
node.CachedPods = make(map[int]bool) // Reset as pod IDs might change
node.RefCounter = make([]int, len(currentPodSet)) // Resize for new pod count
node.ResetEvictedPods() // Reset as pod IDs might change
node.ResetCachedPods() // Reset as pod IDs might change
node.ResetRefCounter(len(currentPodSet)) // Resize for new pod count
}

// Update router and histogram if pods changed
Expand Down Expand Up @@ -440,26 +428,17 @@ func (p *prefixCacheAndLoadRouter) updatePodSet(readyPods []*v1.Pod) {
// Filter timestamps entries for nodes that still have valid pods
newTimestamps := make([]histogramEntry, 0)
for _, entry := range h.timestamps {
if hasValidPods(entry.node, currentPodSet) {
if entry.node == nil {
continue
}
if entry.node.HasValidPods(currentPodSet) {
newTimestamps = append(newTimestamps, entry)
}
}
h.timestamps = newTimestamps
}
}

// Helper function to check if a node has any valid pods
func hasValidPods(node *prefixcacheindexer.TreeNode, currentPodSet map[string]bool) bool {
for _, podMap := range node.ModelToPods {
for podName := range podMap {
if currentPodSet[podName] {
return true
}
}
}
return false
}

func (p *prefixCacheAndLoadRouter) Route(ctx context.Context, pods map[string]*v1.Pod, model, message string) (string, error) {
readyPods := utils.FilterReadyPods(pods)
if len(readyPods) == 0 {
Expand Down Expand Up @@ -489,7 +468,7 @@ func (p *prefixCacheAndLoadRouter) Route(ctx context.Context, pods map[string]*v
node, matchedTokens, _ := p.cache.AddPrefix(tokens, model, "")
var matchedPods []*v1.Pod
var matchedPodsNames []string
if modelPods, ok := node.ModelToPods[model]; ok {
if modelPods, ok := node.GetModelToPods()[model]; ok {
klog.Infof("node.ModelToPods[model]: %v", modelPods)
for podName := range modelPods {
for _, pod := range readyPods {
Expand All @@ -514,7 +493,7 @@ func (p *prefixCacheAndLoadRouter) Route(ctx context.Context, pods map[string]*v

currentNode := node
for currentNode != nil {
if modelPods, ok := currentNode.ModelToPods[model]; ok {
if modelPods, ok := currentNode.GetModelToPods()[model]; ok {
var nodePods []*v1.Pod
for podName := range modelPods {
for _, pod := range readyPods {
Expand Down Expand Up @@ -586,13 +565,7 @@ func (p *prefixCacheAndLoadRouter) Route(ctx context.Context, pods map[string]*v
// Update pod mapping in ALL nodes from matched node to root
currentNode := node
for currentNode != nil {
if modelPods, ok := currentNode.ModelToPods[model]; !ok {
currentNode.ModelToPods[model] = map[string]time.Time{
targetPod.Name: time.Now(),
}
} else {
modelPods[targetPod.Name] = time.Now()
}
currentNode.AddOrUpdatePodForModel(model, targetPod.Name, time.Now())
currentNode = currentNode.GetParent()
}

Expand All @@ -609,7 +582,7 @@ func (h *SlidingWindowHistogram) getPodLoad(pod *v1.Pod) int {
defer h.mu.RUnlock()
load := 0
for node, count := range h.nodeToCount {
for _, podMap := range node.ModelToPods {
for _, podMap := range node.GetModelToPods() {
if _, exists := podMap[pod.Name]; exists {
load += count
break // Found this pod in this node, no need to check other models
Expand Down
Loading

0 comments on commit 51651cf

Please sign in to comment.