b2cc80bb09
Signed-off-by: Tuan-Dat Tran <tuan-dat.tran@tudattr.dev>
307 lines
11 KiB
Python
307 lines
11 KiB
Python
import simpy
|
||
import random
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
import pandas as pd
|
||
from enum import Enum
|
||
|
||
|
||
# Types of cache
|
||
class CacheType(Enum):
|
||
LRU = 1
|
||
RANDOM_EVICTION = 2
|
||
|
||
|
||
# Constants
|
||
SEED = 42
|
||
DATABASE_OBJECTS = 100 # Number of objects in the database
|
||
ACCESS_COUNT_LIMIT = 10 # Total time to run the simulation
|
||
EXPORT_NAME = "./logs/export.csv"
|
||
|
||
ZIPF_CONSTANT = (
|
||
2 # Shape parameter for the Zipf distribution (controls skewness) Needs to be: 1<
|
||
)
|
||
|
||
# Set random seeds
|
||
random.seed(SEED)
|
||
np.random.seed(SEED)
|
||
|
||
# Initialize simulation environment
|
||
env = simpy.Environment()
|
||
|
||
|
||
CACHE_CAPACITY = DATABASE_OBJECTS # Maximum number of objects the cache can hold
|
||
|
||
# MAX_REFRESH_RATE is used as the maximum for a uniform
|
||
# distribution for mu.
|
||
# If MAX_REFRESH_RATE is 0, we do not do any refreshes.
|
||
MAX_REFRESH_RATE = 10
|
||
|
||
cache_type = CacheType.LRU
|
||
|
||
# CACHE_TTL is used to determin which TTL to set when an
|
||
# object is pulled into the cache
|
||
# If CACHE_TTL is set to 0, the TTL is infinite
|
||
CACHE_TTL = 5
|
||
|
||
|
||
configurations = {
|
||
"default": (DATABASE_OBJECTS, 10, CacheType.LRU, 5),
|
||
"No Refresh": (DATABASE_OBJECTS, 0, CacheType.LRU, 5),
|
||
"Infinite TTL": (int(DATABASE_OBJECTS / 2), 0, CacheType.LRU, 0),
|
||
"Random Eviction": (int(DATABASE_OBJECTS / 2), 10, CacheType.RANDOM_EVICTION, 5),
|
||
"RE without Refresh": (int(DATABASE_OBJECTS / 2), 0, CacheType.RANDOM_EVICTION, 5),
|
||
}
|
||
|
||
config = configurations["default"]
|
||
|
||
CACHE_CAPACITY = config[0]
|
||
MAX_REFRESH_RATE = config[1]
|
||
cache_type = config[2]
|
||
CACHE_TTL = config[3]
|
||
|
||
|
||
class Database:
|
||
def __init__(self):
|
||
# Each object now has a specific refresh rate 'mu'
|
||
self.data = {i: f"Object {i}" for i in range(1, DATABASE_OBJECTS + 1)}
|
||
self.lambda_values = {
|
||
i: np.random.zipf(ZIPF_CONSTANT) for i in range(1, DATABASE_OBJECTS + 1)
|
||
} # Request rate 'lambda' for each object
|
||
# Refresh rate 'mu' for each object
|
||
if MAX_REFRESH_RATE == 0:
|
||
self.mu_values = {i: 0 for i in range(1, DATABASE_OBJECTS + 1)}
|
||
else:
|
||
self.mu_values = {
|
||
i: np.random.uniform(1, MAX_REFRESH_RATE)
|
||
for i in range(1, DATABASE_OBJECTS + 1)
|
||
}
|
||
self.next_request = {
|
||
i: np.random.exponential(self.lambda_values[i])
|
||
for i in range(1, DATABASE_OBJECTS + 1)
|
||
}
|
||
|
||
def get_object(self, obj_id):
|
||
# print(f"[{env.now:.2f}] Database: Fetched {self.data.get(obj_id, 'Unknown')} for ID {obj_id}")
|
||
return self.data.get(obj_id, None)
|
||
|
||
|
||
class Cache:
|
||
def __init__(self, env, db, cache_type):
|
||
self.cache_type = cache_type
|
||
self.env = env
|
||
self.db = db
|
||
self.storage = {} # Dictionary to store cached objects
|
||
self.ttl = {} # Dictionary to store TTLs
|
||
self.age = {} # Dictionary to store age of each object
|
||
self.cache_size_over_time = [] # To record cache state at each interval
|
||
self.cache_next_request_over_time = []
|
||
self.request_log = {i: [] for i in range(1, DATABASE_OBJECTS + 1)}
|
||
self.hits = {
|
||
i: 0 for i in range(1, DATABASE_OBJECTS + 1)
|
||
} # Track hits per object
|
||
self.misses = {
|
||
i: 0 for i in range(1, DATABASE_OBJECTS + 1)
|
||
} # Track misses per object
|
||
self.cumulative_age = {
|
||
i: 0 for i in range(1, DATABASE_OBJECTS + 1)
|
||
} # Track cumulative age per object
|
||
self.access_count = {
|
||
i: 0 for i in range(1, DATABASE_OBJECTS + 1)
|
||
} # Track access count per object
|
||
self.next_refresh = {} # Track the next refresh time for each cached object
|
||
|
||
def get(self, obj_id):
|
||
if obj_id in self.storage and (self.ttl[obj_id] > env.now or CACHE_TTL == 0):
|
||
# Cache hit: increment hit count and update cumulative age
|
||
self.hits[obj_id] += 1
|
||
self.cumulative_age[obj_id] += self.age[obj_id]
|
||
self.access_count[obj_id] += 1
|
||
else:
|
||
# Cache miss: increment miss count
|
||
self.misses[obj_id] += 1
|
||
self.access_count[obj_id] += 1
|
||
|
||
# Fetch the object from the database if it’s not in cache
|
||
obj = self.db.get_object(obj_id)
|
||
|
||
# If the cache is full, evict the oldest object
|
||
if len(self.storage) >= CACHE_CAPACITY:
|
||
if self.cache_type == CacheType.LRU:
|
||
self.evict_oldest()
|
||
elif self.cache_type == CacheType.RANDOM_EVICTION:
|
||
self.evict_random()
|
||
|
||
# Add the object to cache, set TTL, reset age, and schedule next refresh
|
||
self.storage[obj_id] = obj
|
||
if CACHE_TTL != 0:
|
||
self.ttl[obj_id] = env.now + CACHE_TTL
|
||
else:
|
||
self.ttl[obj_id] = 0
|
||
self.age[obj_id] = 0
|
||
if MAX_REFRESH_RATE != 0:
|
||
self.next_refresh[obj_id] = env.now + np.random.exponential(
|
||
self.db.mu_values[obj_id]
|
||
) # Schedule refresh
|
||
|
||
def evict_oldest(self):
|
||
"""Remove the oldest item from the cache to make space."""
|
||
oldest_id = max(self.age, key=self.age.get) # Find the oldest item by age
|
||
print(
|
||
f"[{env.now:.2f}] Cache: Evicting oldest object {oldest_id} to make space at {self.ttl[oldest_id]:.2f}"
|
||
)
|
||
del self.storage[oldest_id]
|
||
del self.ttl[oldest_id]
|
||
del self.age[oldest_id]
|
||
|
||
def evict_random(self):
|
||
"""Remove a random item from the cache to make space."""
|
||
random_id = np.random.choice(
|
||
list(self.storage.keys())
|
||
) # Select a random key from the cache
|
||
print(
|
||
f"[{env.now:.2f}] Cache: Evicting random object {random_id} to make space at {self.ttl[random_id]:.2f}"
|
||
)
|
||
del self.storage[random_id]
|
||
del self.ttl[random_id]
|
||
del self.age[random_id]
|
||
|
||
def refresh_object(self, obj_id):
|
||
"""Refresh the object from the database to keep it up-to-date. TTL is increased on refresh."""
|
||
obj = self.db.get_object(obj_id)
|
||
self.storage[obj_id] = obj
|
||
if CACHE_TTL != 0:
|
||
self.ttl[obj_id] = env.now + CACHE_TTL
|
||
else:
|
||
self.ttl[obj_id] = 0
|
||
self.age[obj_id] = 0
|
||
# print(f"[{env.now:.2f}] Cache: Refreshed object {obj_id}")
|
||
|
||
def age_objects(self):
|
||
"""Increment age of each cached object."""
|
||
for obj_id in list(self.age.keys()):
|
||
if CACHE_TTL != 0:
|
||
if self.ttl[obj_id] > env.now:
|
||
self.age[obj_id] += 1
|
||
# print(f"[{env.now:.2f}] Cache: Object {obj_id} aged to {self.age[obj_id]}")
|
||
else:
|
||
# Remove object if its TTL expired
|
||
# print(f"[{env.now:.2f}] Cache: Object {obj_id} expired")
|
||
del self.storage[obj_id]
|
||
del self.ttl[obj_id]
|
||
del self.age[obj_id]
|
||
else:
|
||
self.age[obj_id] += 1
|
||
|
||
def record_cache_state(self):
|
||
"""Record the current cache state (number of objects in cache) over time."""
|
||
self.cache_size_over_time.append((env.now, len(self.storage)))
|
||
self.cache_next_request_over_time.append((env.now, self.db.next_request.copy()))
|
||
|
||
|
||
def age_cache_process(env, cache):
|
||
"""Process that ages cache objects over time, removes expired items, and refreshes based on object-specific intervals."""
|
||
while True:
|
||
cache.age_objects() # Age objects and remove expired ones
|
||
|
||
if MAX_REFRESH_RATE != 0:
|
||
# Refresh objects based on their individual refresh intervals
|
||
for obj_id in list(cache.storage.keys()):
|
||
# Check if it's time to refresh this object based on next_refresh
|
||
if env.now >= cache.next_refresh[obj_id]:
|
||
cache.refresh_object(obj_id)
|
||
# Schedule the next refresh based on the object's mu
|
||
cache.next_refresh[obj_id] = env.now + np.random.exponential(
|
||
cache.db.mu_values[obj_id]
|
||
)
|
||
|
||
cache.record_cache_state() # Record cache state at each time step
|
||
yield env.timeout(1) # Run every second
|
||
|
||
|
||
def client_request_process(env, cache, event):
|
||
"""Client process that makes requests for objects from the cache."""
|
||
lowest_lambda_object = max(cache.db.lambda_values.items(), key=lambda x: x[1])
|
||
lowest_lambda_object = (
|
||
[lowest_lambda_object]
|
||
if isinstance(lowest_lambda_object, int)
|
||
else lowest_lambda_object
|
||
)
|
||
while True:
|
||
obj_id, next_request = min(cache.db.next_request.items(), key=lambda x: x[1])
|
||
yield env.timeout(next_request - env.now)
|
||
if env.now >= next_request:
|
||
# print(f"[{env.now:.2f}] Client: Requesting object {obj_id}")
|
||
cache.get(obj_id)
|
||
|
||
# print(f"[{env.now:.2f}] Client: Schedule next request for {obj_id}")
|
||
next_request = env.now + np.random.exponential(
|
||
cache.db.lambda_values[obj_id]
|
||
)
|
||
cache.request_log[obj_id].append(next_request)
|
||
cache.db.next_request[obj_id] = next_request
|
||
if all(
|
||
cache.access_count[obj] >= ACCESS_COUNT_LIMIT
|
||
for obj in lowest_lambda_object
|
||
):
|
||
event.succeed()
|
||
|
||
|
||
# Instantiate components
|
||
db = Database()
|
||
cache = Cache(env, db, cache_type)
|
||
stop_event = env.event()
|
||
|
||
# Start processes
|
||
env.process(age_cache_process(env, cache))
|
||
env.process(client_request_process(env, cache, stop_event))
|
||
|
||
# Run the simulation
|
||
env.run(until=stop_event)
|
||
|
||
# Calculate and print hit rate and average age for each object
|
||
for obj_id in range(1, DATABASE_OBJECTS + 1):
|
||
if cache.access_count[obj_id] != 0:
|
||
hit_rate = cache.hits[obj_id] / max(
|
||
1, cache.access_count[obj_id]
|
||
) # Avoid division by zero
|
||
avg_age = cache.cumulative_age[obj_id] / max(
|
||
1, cache.hits[obj_id]
|
||
) # Only average over hits
|
||
print(
|
||
f"Object {obj_id}: Hit Rate = {hit_rate:.2f}, Average Age = {avg_age:.2f}"
|
||
)
|
||
|
||
# Extract recorded data for plotting
|
||
times, cache_sizes = zip(*cache.cache_size_over_time)
|
||
|
||
# Plot the cache size over time
|
||
plt.figure(figsize=(30, 5))
|
||
plt.plot(times, cache_sizes, label="Objects in Cache")
|
||
plt.xlabel("Time (s)")
|
||
plt.ylabel("Number of Cached Objects")
|
||
plt.title("Number of Objects in Cache Over Time")
|
||
plt.legend()
|
||
plt.grid(True)
|
||
plt.savefig("./graphs/objects_in_cache_over_time.pdf")
|
||
|
||
access_count = pd.DataFrame.from_dict(
|
||
cache.access_count, orient="index", columns=["access_count"]
|
||
)
|
||
hits = pd.DataFrame.from_dict(cache.hits, orient="index", columns=["hits"])
|
||
misses = pd.DataFrame.from_dict(cache.misses, orient="index", columns=["misses"])
|
||
mu = pd.DataFrame.from_dict(db.mu_values, orient="index", columns=["mu"])
|
||
lmbda = pd.DataFrame.from_dict(db.lambda_values, orient="index", columns=["lambda"])
|
||
hit_rate = pd.DataFrame(
|
||
np.round((hits.to_numpy() / access_count.to_numpy()) * 100, 2), columns=["hit_rate"]
|
||
)
|
||
|
||
merged = (
|
||
access_count.merge(hits, left_index=True, right_index=True)
|
||
.merge(misses, left_index=True, right_index=True)
|
||
.merge(mu, left_index=True, right_index=True)
|
||
.merge(lmbda, left_index=True, right_index=True)
|
||
.merge(hit_rate, left_index=True, right_index=True)
|
||
)
|
||
merged.to_csv(EXPORT_NAME)
|