Tuan-Dat Tran b2cc80bb09 refactor: Restructure Repository to add eta optimization
Signed-off-by: Tuan-Dat Tran <tuan-dat.tran@tudattr.dev>
2024-11-29 21:49:59 +01:00

307 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import simpy
import random
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from enum import Enum
# Types of cache
class CacheType(Enum):
LRU = 1
RANDOM_EVICTION = 2
# Constants
SEED = 42
DATABASE_OBJECTS = 100 # Number of objects in the database
ACCESS_COUNT_LIMIT = 10 # Total time to run the simulation
EXPORT_NAME = "./logs/export.csv"
ZIPF_CONSTANT = (
2 # Shape parameter for the Zipf distribution (controls skewness) Needs to be: 1<
)
# Set random seeds
random.seed(SEED)
np.random.seed(SEED)
# Initialize simulation environment
env = simpy.Environment()
CACHE_CAPACITY = DATABASE_OBJECTS # Maximum number of objects the cache can hold
# MAX_REFRESH_RATE is used as the maximum for a uniform
# distribution for mu.
# If MAX_REFRESH_RATE is 0, we do not do any refreshes.
MAX_REFRESH_RATE = 10
cache_type = CacheType.LRU
# CACHE_TTL is used to determin which TTL to set when an
# object is pulled into the cache
# If CACHE_TTL is set to 0, the TTL is infinite
CACHE_TTL = 5
configurations = {
"default": (DATABASE_OBJECTS, 10, CacheType.LRU, 5),
"No Refresh": (DATABASE_OBJECTS, 0, CacheType.LRU, 5),
"Infinite TTL": (int(DATABASE_OBJECTS / 2), 0, CacheType.LRU, 0),
"Random Eviction": (int(DATABASE_OBJECTS / 2), 10, CacheType.RANDOM_EVICTION, 5),
"RE without Refresh": (int(DATABASE_OBJECTS / 2), 0, CacheType.RANDOM_EVICTION, 5),
}
config = configurations["default"]
CACHE_CAPACITY = config[0]
MAX_REFRESH_RATE = config[1]
cache_type = config[2]
CACHE_TTL = config[3]
class Database:
def __init__(self):
# Each object now has a specific refresh rate 'mu'
self.data = {i: f"Object {i}" for i in range(1, DATABASE_OBJECTS + 1)}
self.lambda_values = {
i: np.random.zipf(ZIPF_CONSTANT) for i in range(1, DATABASE_OBJECTS + 1)
} # Request rate 'lambda' for each object
# Refresh rate 'mu' for each object
if MAX_REFRESH_RATE == 0:
self.mu_values = {i: 0 for i in range(1, DATABASE_OBJECTS + 1)}
else:
self.mu_values = {
i: np.random.uniform(1, MAX_REFRESH_RATE)
for i in range(1, DATABASE_OBJECTS + 1)
}
self.next_request = {
i: np.random.exponential(self.lambda_values[i])
for i in range(1, DATABASE_OBJECTS + 1)
}
def get_object(self, obj_id):
# print(f"[{env.now:.2f}] Database: Fetched {self.data.get(obj_id, 'Unknown')} for ID {obj_id}")
return self.data.get(obj_id, None)
class Cache:
def __init__(self, env, db, cache_type):
self.cache_type = cache_type
self.env = env
self.db = db
self.storage = {} # Dictionary to store cached objects
self.ttl = {} # Dictionary to store TTLs
self.age = {} # Dictionary to store age of each object
self.cache_size_over_time = [] # To record cache state at each interval
self.cache_next_request_over_time = []
self.request_log = {i: [] for i in range(1, DATABASE_OBJECTS + 1)}
self.hits = {
i: 0 for i in range(1, DATABASE_OBJECTS + 1)
} # Track hits per object
self.misses = {
i: 0 for i in range(1, DATABASE_OBJECTS + 1)
} # Track misses per object
self.cumulative_age = {
i: 0 for i in range(1, DATABASE_OBJECTS + 1)
} # Track cumulative age per object
self.access_count = {
i: 0 for i in range(1, DATABASE_OBJECTS + 1)
} # Track access count per object
self.next_refresh = {} # Track the next refresh time for each cached object
def get(self, obj_id):
if obj_id in self.storage and (self.ttl[obj_id] > env.now or CACHE_TTL == 0):
# Cache hit: increment hit count and update cumulative age
self.hits[obj_id] += 1
self.cumulative_age[obj_id] += self.age[obj_id]
self.access_count[obj_id] += 1
else:
# Cache miss: increment miss count
self.misses[obj_id] += 1
self.access_count[obj_id] += 1
# Fetch the object from the database if its not in cache
obj = self.db.get_object(obj_id)
# If the cache is full, evict the oldest object
if len(self.storage) >= CACHE_CAPACITY:
if self.cache_type == CacheType.LRU:
self.evict_oldest()
elif self.cache_type == CacheType.RANDOM_EVICTION:
self.evict_random()
# Add the object to cache, set TTL, reset age, and schedule next refresh
self.storage[obj_id] = obj
if CACHE_TTL != 0:
self.ttl[obj_id] = env.now + CACHE_TTL
else:
self.ttl[obj_id] = 0
self.age[obj_id] = 0
if MAX_REFRESH_RATE != 0:
self.next_refresh[obj_id] = env.now + np.random.exponential(
self.db.mu_values[obj_id]
) # Schedule refresh
def evict_oldest(self):
"""Remove the oldest item from the cache to make space."""
oldest_id = max(self.age, key=self.age.get) # Find the oldest item by age
print(
f"[{env.now:.2f}] Cache: Evicting oldest object {oldest_id} to make space at {self.ttl[oldest_id]:.2f}"
)
del self.storage[oldest_id]
del self.ttl[oldest_id]
del self.age[oldest_id]
def evict_random(self):
"""Remove a random item from the cache to make space."""
random_id = np.random.choice(
list(self.storage.keys())
) # Select a random key from the cache
print(
f"[{env.now:.2f}] Cache: Evicting random object {random_id} to make space at {self.ttl[random_id]:.2f}"
)
del self.storage[random_id]
del self.ttl[random_id]
del self.age[random_id]
def refresh_object(self, obj_id):
"""Refresh the object from the database to keep it up-to-date. TTL is increased on refresh."""
obj = self.db.get_object(obj_id)
self.storage[obj_id] = obj
if CACHE_TTL != 0:
self.ttl[obj_id] = env.now + CACHE_TTL
else:
self.ttl[obj_id] = 0
self.age[obj_id] = 0
# print(f"[{env.now:.2f}] Cache: Refreshed object {obj_id}")
def age_objects(self):
"""Increment age of each cached object."""
for obj_id in list(self.age.keys()):
if CACHE_TTL != 0:
if self.ttl[obj_id] > env.now:
self.age[obj_id] += 1
# print(f"[{env.now:.2f}] Cache: Object {obj_id} aged to {self.age[obj_id]}")
else:
# Remove object if its TTL expired
# print(f"[{env.now:.2f}] Cache: Object {obj_id} expired")
del self.storage[obj_id]
del self.ttl[obj_id]
del self.age[obj_id]
else:
self.age[obj_id] += 1
def record_cache_state(self):
"""Record the current cache state (number of objects in cache) over time."""
self.cache_size_over_time.append((env.now, len(self.storage)))
self.cache_next_request_over_time.append((env.now, self.db.next_request.copy()))
def age_cache_process(env, cache):
"""Process that ages cache objects over time, removes expired items, and refreshes based on object-specific intervals."""
while True:
cache.age_objects() # Age objects and remove expired ones
if MAX_REFRESH_RATE != 0:
# Refresh objects based on their individual refresh intervals
for obj_id in list(cache.storage.keys()):
# Check if it's time to refresh this object based on next_refresh
if env.now >= cache.next_refresh[obj_id]:
cache.refresh_object(obj_id)
# Schedule the next refresh based on the object's mu
cache.next_refresh[obj_id] = env.now + np.random.exponential(
cache.db.mu_values[obj_id]
)
cache.record_cache_state() # Record cache state at each time step
yield env.timeout(1) # Run every second
def client_request_process(env, cache, event):
"""Client process that makes requests for objects from the cache."""
lowest_lambda_object = max(cache.db.lambda_values.items(), key=lambda x: x[1])
lowest_lambda_object = (
[lowest_lambda_object]
if isinstance(lowest_lambda_object, int)
else lowest_lambda_object
)
while True:
obj_id, next_request = min(cache.db.next_request.items(), key=lambda x: x[1])
yield env.timeout(next_request - env.now)
if env.now >= next_request:
# print(f"[{env.now:.2f}] Client: Requesting object {obj_id}")
cache.get(obj_id)
# print(f"[{env.now:.2f}] Client: Schedule next request for {obj_id}")
next_request = env.now + np.random.exponential(
cache.db.lambda_values[obj_id]
)
cache.request_log[obj_id].append(next_request)
cache.db.next_request[obj_id] = next_request
if all(
cache.access_count[obj] >= ACCESS_COUNT_LIMIT
for obj in lowest_lambda_object
):
event.succeed()
# Instantiate components
db = Database()
cache = Cache(env, db, cache_type)
stop_event = env.event()
# Start processes
env.process(age_cache_process(env, cache))
env.process(client_request_process(env, cache, stop_event))
# Run the simulation
env.run(until=stop_event)
# Calculate and print hit rate and average age for each object
for obj_id in range(1, DATABASE_OBJECTS + 1):
if cache.access_count[obj_id] != 0:
hit_rate = cache.hits[obj_id] / max(
1, cache.access_count[obj_id]
) # Avoid division by zero
avg_age = cache.cumulative_age[obj_id] / max(
1, cache.hits[obj_id]
) # Only average over hits
print(
f"Object {obj_id}: Hit Rate = {hit_rate:.2f}, Average Age = {avg_age:.2f}"
)
# Extract recorded data for plotting
times, cache_sizes = zip(*cache.cache_size_over_time)
# Plot the cache size over time
plt.figure(figsize=(30, 5))
plt.plot(times, cache_sizes, label="Objects in Cache")
plt.xlabel("Time (s)")
plt.ylabel("Number of Cached Objects")
plt.title("Number of Objects in Cache Over Time")
plt.legend()
plt.grid(True)
plt.savefig("./graphs/objects_in_cache_over_time.pdf")
access_count = pd.DataFrame.from_dict(
cache.access_count, orient="index", columns=["access_count"]
)
hits = pd.DataFrame.from_dict(cache.hits, orient="index", columns=["hits"])
misses = pd.DataFrame.from_dict(cache.misses, orient="index", columns=["misses"])
mu = pd.DataFrame.from_dict(db.mu_values, orient="index", columns=["mu"])
lmbda = pd.DataFrame.from_dict(db.lambda_values, orient="index", columns=["lambda"])
hit_rate = pd.DataFrame(
np.round((hits.to_numpy() / access_count.to_numpy()) * 100, 2), columns=["hit_rate"]
)
merged = (
access_count.merge(hits, left_index=True, right_index=True)
.merge(misses, left_index=True, right_index=True)
.merge(mu, left_index=True, right_index=True)
.merge(lmbda, left_index=True, right_index=True)
.merge(hit_rate, left_index=True, right_index=True)
)
merged.to_csv(EXPORT_NAME)