Files
hashmap_concurrent/hashmap_concurrent.zig
2026-02-26 16:41:08 +01:00

1707 lines
63 KiB
Zig

//! Copyright (c) 2026 Pascal Zittlau, BSD 3-Clause
//! Version: 0.1.0
//!
//! A thread-safe, fixed-capacity, open-addressing Hash Table.
//!
//! It uses Robin Hood hashing to minimize the Variance of Probe Sequence Lengths (PSL).
//! Deletion is handled via Backward-Shift. Unlike tombstone-based deletion, this maintains the
//! Robin Hood invariant by shifting subsequent entries back to fill gaps, keeping the table compact
//! and preventing PSL inflation over time.
//!
//! Concurrency Model:
//! - **Writers:** Use Shard-Level Locking. The table is partitioned into contiguous "blocks"
//! (shards). A writer acquires a mutex for the initial shard and acquires subsequent shard locks
//! if a probe sequence crosses a boundary. It is possible to deadlock under very specific
//! circumstances if a writer wraps around while locking. More Information below.
//! - **Readers:** Use Sequence Locking (Optimistic Concurrency Control). Readers do not take
//! locks; they read a version counter, perform the probe, and validate the version. This
//! provides extremely high throughput for read-heavy workloads.
//!
//! This implementation is **NOT** wait-free. Writers can starve readers if they constantly update
//! the same shards.
//!
//! Design Constraints:
//! - POD keys. In particular no slices.
//! - The table does not resize.
//! - Always has a power of two number of entries.
//! - Uses block sharding. Entries are grouped into contiguous chunks protected by the same lock to
//! optimize cache locality during linear probing.
//! - Always uses a power of two number of shards.
//!
//! ## Pointers and Memory Reclamation
//! This table does not provide a memory reclamation scheme.
//! - If `K` is a pointer, `eql` should very likely **not** dereference it. If it does, a concurrent
//! `remove` + `free` by a writer will cause the reader to segfault during the equality check.
//! - If `V` is a pointer, the data it points to is not guaranteed to be valid after `get()`
//! returns, as another thread may have removed and freed it.
//!
//! ## Choosing the Number of Shards
//!
//! The `num_shards` parameter balances throughput, memory overhead, and deadlock risk. Here are a
//! few considerations for choosing an appropriate shard count:
//! - Higher shard counts reduce lock contention among writers.
//! - Higher shard counts reduce retry probabilities for readers.
//! - Each shard is aligned to a cache line(64-128 bytes) to prevent false sharing. 1024 shards
//! consume ~64-128KB which increases cache pressure.
//! - This implementation uses a fixed-size buffer of size `fetch_version_array_size = 8` to track
//! shard versions during optimistic reads. If a probe sequence spans more shards than this
//! buffer, `get()` will return `ProbeLimitExceeded`. Smaller shards make this more likely.
//! - Smaller shards increase the number of locks that may be required for writers. This is
//! additional overhead.
//! - If more threads are used than shards are available the deadlock risk is increased.
//! - Deadlocks only occur during array wrap-around. More information below.
//!
//! Usually **64-1024** shards with 4-16 shards per thread are the "sweet spot".
//! `const num_shards = @min(@max(64, num_threads * 16), 1024);`
//!
//! ## Deadlock safety
//!
//! A deadlock can only occur if a circular dependency is created between shards. Because writers
//! always acquire locks in increasing index order, a deadlock is only possible during a
//! wrap-around, where a probe starts at the end of the `entries` array and continues at index 0.
//!
//! Why a deadlock is extremely unlikely:
//! - Robin Hood hashing keeps PSLs very short. Even at 90% load, the average PSL is ~4.5. In
//! particular the expected value for the length of a PSL grows with O(ln(n)) where `n` is the
//! capacity of the table.
//! - For a deadlock to occur, one thread must be probing across the wrap-around boundary
//! (Shard[N-1] -> Shard[0]) while the other threads simultaneously bridge it from the other size.
//! - Given that shards typically contain hundreds or thousands of slots and there are usually
//! dozens or hundreds of shards, the probability of multiple simultaneous probe sequences being
//! long enough to bridge the whole table is practically zero for well-sized tables and shards.
//!
//! The following experimental results show PSLs for various `load_factors` for a table with a
//! capacity of 2**21(2 million entries).
//!
//! load_factor | max PSL | avg PSL | median PSL
//! 0.50 | 12 | 0.50 | 0
//! 0.70 | 23 | 1.17 | 1
//! 0.80 | 31 | 2.01 | 1
//! 0.90 | 65 | 4.56 | 3
//! 0.95 | 131 | 9.72 | 6
//! 0.99 | 371 | 78.72 | 34
//!
//! ## Benchmarks
//!
//! You can run the included benchmark suite yourself to evaluate throughput and scaling on your own
//! hardware. Simply run the `main` function of this file (e.g., using `zig run <filename>.zig -O
//! ReleaseFast`). The benchmark compares this concurrent implementation against a standard
//! `std.AutoHashMap` wrapped in a `std.Thread.Mutex` across varying read/write/delete workloads.
//!
//! ## License
//!
//! Copyright 2026 Pascal Zittlau
//!
//! Redistribution and use in source and binary forms, with or without modification, are permitted
//! provided that the following conditions are met:
//! 1. Redistributions of source code must retain the above copyright notice, this list of
//! conditions and the following disclaimer.
//! 2. Redistributions in binary form must reproduce the above copyright notice, this list of
//! conditions and the following disclaimer in the documentation and/or other materials provided
//! with the distribution.
//
//! 3. Neither the name of the copyright holder nor the names of its contributors may be used to
//! endorse or promote products derived from this software without specific prior written
//! permission.
//
//! THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR
//! IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
//! FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
//! CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
//! CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
//! SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
//! THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
//! OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
//! POSSIBILITY OF SUCH DAMAGE.
const std = @import("std");
const builtin = @import("builtin");
const atomic = std.atomic;
const math = std.math;
const mem = std.mem;
const testing = std.testing;
const assert = std.debug.assert;
const cache_line_size = atomic.cacheLineForCpu(builtin.cpu);
const Allocator = mem.Allocator;
pub fn getAutoHashFn(comptime K: type) (fn (K) u64) {
return struct {
fn hash(key: K) u64 {
if (std.meta.hasUniqueRepresentation(K)) {
return std.hash.Wyhash.hash(0, std.mem.asBytes(&key));
} else {
var hasher = std.hash.Wyhash.init(0);
std.hash.autoHash(&hasher, key);
return hasher.final();
}
}
}.hash;
}
pub fn getAutoEqlFn(comptime K: type) (fn (K, K) bool) {
return struct {
fn eql(a: K, b: K) bool {
return std.meta.eql(a, b);
}
}.eql;
}
pub fn AutoHashMapConcurrent(comptime K: type, comptime V: type) type {
if (!isSafe(K)) {
@compileError("AutoHashMapConcurrent: Key type '" ++ @typeName(K) ++
"' is potentially tearable (e.g. contains slices or is larger than a machine word). " ++
"Optimistic readers do not take locks and might see a partially updated key. " ++
"This causes automatically generated 'hash' and 'eql' functions to produce garbage " ++
"results or trigger memory safety violations (segfaults on torn slices). " ++
"To use this safely, you must provide your own robust 'hash' and 'eql' " ++
"functions and use 'HashMapConcurrent' directly.");
}
const hash = getAutoHashFn(K);
const eql = getAutoEqlFn(K);
return HashMapConcurrent(K, V, hash, eql, .{});
}
// By default mark all composite types as unsafe and assume that values smaller than `usize` are
// never torn.
fn isSafe(comptime T: type) bool {
switch (@typeInfo(T)) {
.bool => return true,
.int => |i| return i.bits <= @bitSizeOf(usize),
.float => |f| return f.bits <= @bitSizeOf(usize),
.@"enum" => |e| return @sizeOf(e.tag_type) <= @sizeOf(usize),
.pointer => |p| return p.size != .slice,
// TODO: add more safe ones as needed
else => return false,
}
}
pub const Config = struct {
/// The maximum load factor before the table is considered "at capacity". Default: 0.8
/// This is not really a strict maximum. The table still continues to perform okay up to about
/// to 95% fill. At 99% the performance degrades sharply. Even higher and the risk of deadlocks
/// increases. See the module documentation for more information or run a benchmark.
max_fill_factor: f64 = 0.8,
/// When fetching we need to save versions of the shards we visit. This is the maximum
/// number of shards we traverse in search for the entry. Visiting more shards than this
/// could happen under the following circumstances:
/// - the hash function is bad or
/// - the shards are small or
/// - the fill factor is close to 1
fetch_version_array_size: u64 = 8,
};
pub fn HashMapConcurrent(
/// The key has to be carefully chosen such that the `hash` and `eql` functions work correctly.
/// In particular the combination shouldn't be susceptible to torn reads.
K: type,
V: type,
hash: fn (K) u64,
eql: fn (K, K) bool,
config: Config,
) type {
return struct {
const Self = @This();
pub const Entry = struct {
key: K,
value: V,
};
const EMPTY_DIST = 0;
const SATURATED_DIST = 255;
/// How many entries can be saved in the table at most.
capacity: usize,
/// Number of used entries.
// This is heap allocated with a complete cache line to remove false sharing when other
// threads just try to access members of the HashTable struct.
count: *align(cache_line_size) atomic.Value(usize),
// NOTE: No slices because we don't want the bounds checks. All accesses are protected by
// masking of the excess bits with the masks below.
/// The entries saved in the table.
/// `entries.len == capacity`
entries: [*]align(cache_line_size) Entry,
/// The distances of the entries. `dists[i]` refers to the distance of `entries[i]`.
/// `dists.len == entries.len == capacity`
dists: [*]align(cache_line_size) u8,
/// Locks for the shards. `locks.len == num_shards`
locks: [*]align(cache_line_size) ShardLock,
num_shards: usize,
/// Maximum capacity before the table signals it is full in `atCapacity()`.
/// Pre-calculated to avoid math in hot paths.
max_count_threshold: usize,
entry_mask: usize,
shard_mask: usize,
/// shift **right** to get shard index.
shard_shift: u6,
/// Used to efficiently detect shard boundaries during linear probing.
/// `(entries_per_shard - 1)`.
partition_mask: usize,
/// Initializes the Hash Table.
///
/// `capacity`: The number of elements the table should hold. Must be a power of two.
/// If you just have a minimum capacity that should fit, use
/// `capacityForMin` to get value that respects the fill factor and is a power
/// of two.
///
/// `num_shards`: Number of mutexes. Must be a power of two and `<= capacity`.
/// See the module documentation for guidance on selecting this value.
///
/// bytes_used ≈
/// @sizeOf(ConcurrentHashTable) // structure
/// + capacity * @sizeOf(K) // keys
/// + capacity * @sizeOf(V) // values
/// + capacity // distances
/// + num_shards * cache_line_size // locks
/// + cache_line_size // count
pub fn init(allocator: Allocator, capacity: usize, num_shards: usize) !Self {
assert(math.isPowerOfTwo(capacity));
assert(math.isPowerOfTwo(num_shards));
assert(num_shards > 0);
assert(capacity >= num_shards);
const entries = try allocator.alignedAlloc(
Entry,
.fromByteUnits(cache_line_size),
capacity,
);
errdefer allocator.free(entries);
const dists = try allocator.alignedAlloc(
u8,
.fromByteUnits(cache_line_size),
capacity,
);
errdefer allocator.free(dists);
@memset(dists, EMPTY_DIST);
// Enforce that `count` has a complete cache line for itself to remove false sharing.
const count_line = try allocator.alignedAlloc(
u8,
.fromByteUnits(cache_line_size),
cache_line_size,
);
errdefer allocator.free(count_line);
const count = mem.bytesAsValue(atomic.Value(usize), count_line);
count.*.raw = 0;
const shards = try allocator.alignedAlloc(
ShardLock,
.fromByteUnits(cache_line_size),
num_shards,
);
errdefer allocator.free(shards);
@memset(shards, .{});
const entries_per_shard = capacity / num_shards;
assert(math.isPowerOfTwo(entries_per_shard));
const shift = math.log2_int(usize, entries_per_shard);
const capacity_f: f64 = @floatFromInt(capacity);
return .{
.entries = entries.ptr,
.dists = dists.ptr,
.locks = shards.ptr,
.capacity = capacity,
.num_shards = num_shards,
.count = count,
.entry_mask = capacity - 1,
.shard_mask = num_shards - 1,
.shard_shift = @intCast(shift),
.partition_mask = entries_per_shard - 1,
.max_count_threshold = @intFromFloat(capacity_f * config.max_fill_factor),
};
}
/// Calculates the required power-of-two capacity to hold `min_capacity` entries while
/// staying under the `max_fill_factor`.
pub fn capacityForMin(min_capacity: usize) usize {
const min_capacity_f: f64 = @floatFromInt(min_capacity);
const adj_min_capacity: u64 = @intFromFloat(min_capacity_f / config.max_fill_factor);
return math.ceilPowerOfTwoAssert(u64, adj_min_capacity);
}
/// **Not** threadsafe
pub fn deinit(ht: *Self, allocator: Allocator) void {
allocator.free(ht.entries[0..ht.capacity]);
allocator.free(ht.dists[0..ht.capacity]);
allocator.free(ht.locks[0..ht.num_shards]);
// Rebuild the used cache-line
var count_line_ptr: [*]align(cache_line_size) u8 = @ptrCast(ht.count);
allocator.free(count_line_ptr[0..cache_line_size]);
ht.* = undefined;
}
/// Thread-safely removes all entries from the table.
/// Acquires all shard locks to ensure consistency.
pub fn clearRetainingCapacity(ht: *Self) void {
for (0..ht.num_shards) |i| ht.locks[i].lock();
defer {
var i: usize = 0;
while (i < ht.num_shards) : (i += 1) ht.locks[i].unlock();
}
@memset(ht.dists[0..ht.capacity], EMPTY_DIST);
ht.count.store(0, .release);
}
/// Returns true if the current number of entries has reached the pre-calculated threshold
/// based on `max_fill_factor`.
pub fn atCapacity(ht: *Self) bool {
return ht.count.load(.monotonic) > ht.max_count_threshold;
}
/// Returns the current load factor of the table (0.0 to 1.0).
pub fn loadFactor(ht: *Self) f64 {
const count: f64 = @floatFromInt(ht.count.load(.monotonic));
const capacity: f64 = @floatFromInt(ht.entries.len);
return count / capacity;
}
inline fn getShardIndex(ht: *const Self, entry_idx: usize) usize {
return (entry_idx >> ht.shard_shift) & ht.shard_mask;
}
/// If the key exists, it **is overwritten**.
///
/// This function asserts that the table has enough physical capacity to perform the
/// insertion. Use `atCapacity()` to check load factors, or `tryPut()` if you need
/// to handle absolute capacity limits gracefully.
pub fn put(ht: *Self, key: K, value: V) void {
_ = ht.fetchPut(key, value);
}
/// Inserts or updates an entry. If the key already exists, the value is updated.
/// Returns the previous value if the key was present, otherwise null.
///
/// This function asserts that the table has enough physical capacity.
pub fn fetchPut(ht: *Self, key: K, value: V) ?V {
return ht.tryFetchPut(key, value) catch unreachable;
}
pub const TryPutError = error{TableFull};
/// Attempts to insert or update a key-value pair.
///
/// If the table is physically full (the probe sequence wraps around), it returns
/// `TryPutError.TableFull`.
pub fn tryPut(ht: *Self, key: K, value: V) TryPutError!void {
_ = try ht.tryFetchPut(key, value);
}
/// Attempts to insert or update an entry, returning the previous value if it existed.
///
/// This is the primary insertion logic for the table. It handles:
/// 1. Shard-level locking (acquiring additional locks if the probe crosses boundaries).
/// 2. Robin Hood swaps to minimize probe sequence length.
/// 3. Wrap-around detection to prevent infinite loops in a 100% full table.
///
/// Returns:
/// - `V`: The previous value if the key was already present.
/// - `null`: If the key was newly inserted.
/// - `TryPutError.TableFull`: If no space could be found.
pub fn tryFetchPut(ht: *Self, key: K, value: V) TryPutError!?V {
const incoming = Entry{ .key = key, .value = value };
var current = incoming;
var idx = hash(current.key) & ht.entry_mask;
// Distance from the ideal position for the `current` entry we are trying to insert.
var dist: u64 = 0;
const start_lock = ht.getShardIndex(idx);
var end_lock = start_lock;
ht.locks[start_lock].lock();
// Unlock all touched shards
defer {
var i = start_lock;
while (true) {
ht.locks[i].unlock();
if (i == end_lock) break;
i = (i + 1) & ht.shard_mask;
}
}
while (true) {
const stored_dist = ht.dists[idx];
if (stored_dist == EMPTY_DIST) {
ht.entries[idx] = current;
ht.dists[idx] = @intCast(@min(dist + 1, SATURATED_DIST));
_ = ht.count.fetchAdd(1, .monotonic);
return null;
} else if (eql(ht.entries[idx].key, current.key)) {
assert(eql(current.key, incoming.key));
const old_val = ht.entries[idx].value;
ht.entries[idx].value = current.value;
return old_val;
}
// Collision, so apply Robin Hood logic.
var existing_dist: u64 = stored_dist - 1;
if (stored_dist == SATURATED_DIST) {
@branchHint(.cold);
// If saturated, we must recompute the hash to find the real distance.
const existing_ideal = hash(ht.entries[idx].key) & ht.entry_mask;
existing_dist = (idx -% existing_ideal) & ht.entry_mask;
}
if (dist > existing_dist) {
mem.swap(Entry, &current, &ht.entries[idx]);
ht.dists[idx] = @intCast(@min(dist + 1, SATURATED_DIST));
dist = existing_dist;
}
idx = (idx + 1) & ht.entry_mask;
dist += 1;
if ((idx & ht.partition_mask) == 0) {
@branchHint(.unlikely);
// Since we move linearly, the new shard is simply the next one.
const next_lock = (end_lock + 1) & ht.shard_mask;
if (next_lock == start_lock) return TryPutError.TableFull;
ht.locks[next_lock].lock();
end_lock = next_lock;
}
}
}
/// Fetches a value from the table.
///
/// This asserts that the probe sequence does not span more shards than the internal
/// tracking buffer allows. For properly sized tables and shards, this is the standard
/// lookup method.
pub fn get(ht: *const Self, key: K) ?V {
return ht.getChecked(key) catch unreachable;
}
pub const GetError = error{ProbeLimitExceeded};
/// Performs an optimistic concurrent lookup.
///
/// Use this if you are operating at extreme load factors (>99%), with a small capacity, or
/// with a hash function that might produce very long clusters.
///
/// Returns:
/// - `null` if the key is not found.
/// - `GetError.ProbeLimitExceeded` if the probe sequence spans more than
/// `fetch_version_array_size` shards. This can happen with very poor hash distributions
/// or extremely high load factors. Use `getBuffer` to provide a larger stack-allocated
/// buffer if this occurs.
pub fn getChecked(ht: *const Self, key: K) GetError!?V {
var versions: [config.fetch_version_array_size]u64 = undefined;
return ht.getBuffer(key, &versions);
}
/// Low-level lookup allowing the caller to provide a version buffer.
///
/// Use this with a bigger buffer if `get()` returns `ProbeLimitExceeded` due to high shard
/// density.
///
/// This does not acquire any mutexes. It reads a version counter before and after reading
/// the data. If the version changed (meaning a writer touched the shard), the read is
/// retried.
pub fn getBuffer(ht: *const Self, key: K, versions: []u64) GetError!?V {
assert(versions.len > 0);
var shard_count: usize = 0;
var first_shard_idx: usize = undefined;
retry: while (true) {
var idx = hash(key) & ht.entry_mask;
var dist: u64 = 0;
var current_shard = ht.getShardIndex(idx);
first_shard_idx = current_shard;
versions[0] = ht.locks[current_shard].readBegin();
shard_count = 1;
const val: ?V = while (true) {
const stored_dist = ht.dists[idx];
if (stored_dist == EMPTY_DIST) break null;
// Robin Hood invariant check
if (stored_dist < (dist + 1) and stored_dist != SATURATED_DIST) break null;
const slot = &ht.entries[idx];
const k = slot.key;
const v = slot.value;
if (eql(k, key)) break v;
idx = (idx + 1) & ht.entry_mask;
dist += 1;
if ((idx & ht.partition_mask) == 0) {
@branchHint(.unlikely);
if (shard_count == versions.len) return GetError.ProbeLimitExceeded;
current_shard = (current_shard + 1) & ht.shard_mask;
versions[shard_count] = ht.locks[current_shard].readBegin();
shard_count += 1;
}
};
// We need a memory barrier here to ensure the data reads from the entries aren't
// reordered after the sequence version validations below.
loadFence();
// Validate all traversed shards
var check_idx = first_shard_idx;
for (0..shard_count) |i| {
if (!ht.locks[check_idx].readValid(versions[i])) {
@branchHint(.unlikely);
atomic.spinLoopHint();
continue :retry;
}
check_idx = (check_idx + 1) & ht.shard_mask;
}
return val;
}
}
/// Removes an entry from the table and performs a backward-shift to fill the gap.
pub fn remove(ht: *Self, key: K) void {
_ = ht.fetchRemove(key);
}
/// Removes an entry from the table and performs a backward-shift to fill the gap.
/// Returns the value of the removed entry, or null if the key was not found.
pub fn fetchRemove(ht: *Self, key: K) ?V {
var idx = hash(key) & ht.entry_mask;
var dist: u64 = 0;
const start_lock = ht.getShardIndex(idx);
var end_lock = start_lock;
ht.locks[start_lock].lock();
// Release all shards locked during probe and shift.
defer {
var i = start_lock;
while (true) {
ht.locks[i].unlock();
if (i == end_lock) break;
i = (i + 1) & ht.shard_mask;
}
}
// Find slot
while (true) {
const stored_dist = ht.dists[idx];
if (stored_dist == EMPTY_DIST) return null;
// Robin Hood invariant check
if (stored_dist < (dist + 1) and stored_dist != SATURATED_DIST) return null;
if (eql(ht.entries[idx].key, key)) break;
idx = (idx + 1) & ht.entry_mask;
dist += 1;
if ((idx & ht.partition_mask) == 0) {
@branchHint(.unlikely);
const next_lock = (end_lock + 1) & ht.shard_mask;
if (next_lock == start_lock) return null; // Wrap-around safety
ht.locks[next_lock].lock();
end_lock = next_lock;
}
}
const removed_value = ht.entries[idx].value;
_ = ht.count.fetchSub(1, .monotonic);
// Backward Shift Deletion
while (true) {
// Mark current slot as empty (temporarily, will be filled if shifting)
ht.dists[idx] = EMPTY_DIST;
const next_idx = (idx + 1) & ht.entry_mask;
if ((next_idx & ht.partition_mask) == 0) {
@branchHint(.unlikely);
const next_lock = (end_lock + 1) & ht.shard_mask;
if (next_lock != start_lock) {
@branchHint(.unlikely);
ht.locks[next_lock].lock();
end_lock = next_lock;
}
}
const next_dist_stored = ht.dists[next_idx];
// If the next element is empty (0) the item is at its ideal position, we can't
// shift it back.
if (next_dist_stored == EMPTY_DIST or next_dist_stored == 1) {
return removed_value;
}
// Shift back into the hole
ht.entries[idx] = ht.entries[next_idx];
var new_dist: u64 = next_dist_stored - 1;
if (next_dist_stored == SATURATED_DIST) {
@branchHint(.cold);
// Recompute real distance
const ideal = hash(ht.entries[next_idx].key) & ht.entry_mask;
const real_dist = (next_idx -% ideal) & ht.entry_mask;
new_dist = real_dist;
}
ht.dists[idx] = @intCast(@min(new_dist, SATURATED_DIST));
idx = next_idx;
}
}
/// Removes all entries from the table using lock-crabbing. This *is thread-safe* but
/// depending on the usage the table might never be fully empty.
///
/// It locks shards sequentially (holding at most two at a time) to avoid stalling the
/// entire table while preventing elements from shifting across boundaries. Concurrent
/// writers are still allowed, though depending on where the cleaning logic is, their entry
/// might be overwritten shortly after.
pub fn clear(ht: *Self) void {
ht.locks[0].lock();
var current_shard: usize = 0;
var elements_cleared: usize = 0;
for (0..ht.capacity) |i| {
// Check if we crossed a shard boundary
if (i > 0 and (i & ht.partition_mask) == 0) {
if (elements_cleared > 0) {
_ = ht.count.fetchSub(elements_cleared, .monotonic);
elements_cleared = 0;
}
const next_shard = current_shard + 1;
if (next_shard < ht.num_shards) {
ht.locks[next_shard].lock();
}
ht.locks[current_shard].unlock();
current_shard = next_shard;
}
if (ht.dists[i] != EMPTY_DIST) {
ht.dists[i] = EMPTY_DIST;
elements_cleared += 1;
}
}
// Flush remaining cleared elements and release the last lock
if (elements_cleared > 0) {
_ = ht.count.fetchSub(elements_cleared, .monotonic);
}
ht.locks[current_shard].unlock();
}
pub const LockingIterator = struct {
ht: *Self,
current_index: usize,
current_shard: usize,
/// Releases any shard locks held by the iterator.
/// This MUST be called if you stop iterating before `next()` returns `null`.
/// It is safe to call this even if the iterator is fully exhausted.
pub fn deinit(self: *LockingIterator) void {
if (self.current_shard < self.ht.num_shards) {
self.ht.locks[self.current_shard].unlock();
self.current_shard = self.ht.num_shards; // Prevent double-unlock
}
}
pub fn next(self: *LockingIterator) ?Entry {
while (self.current_index < self.ht.capacity) {
// Check if we crossed a shard boundary
if (self.current_index > 0 and (self.current_index & self.ht.partition_mask) == 0) {
const next_shard = self.current_shard + 1;
if (next_shard < self.ht.num_shards) {
self.ht.locks[next_shard].lock();
}
self.ht.locks[self.current_shard].unlock();
self.current_shard = next_shard;
}
const dist = self.ht.dists[self.current_index];
const entry = self.ht.entries[self.current_index];
self.current_index += 1;
if (dist != EMPTY_DIST) {
return entry;
}
}
// Reached the end, release the last lock
self.deinit();
return null;
}
};
/// Returns a thread-safe iterator using lock coupling. This will **not** provide a
/// consistent state because concurrent writers are still allowed. Though no elements will
/// be returned twice.
///
/// You MUST either exhaust the iterator (until `next()` returns `null`) or explicitly call
/// `it.deinit()`(or both) if you break early. Otherwise, a lock will be leaked and the
/// table will deadlock.
///
/// Do not call a locking function on the same table while an iterator is active in the same
/// thread.
pub fn lockingIterator(ht: *Self) LockingIterator {
ht.locks[0].lock();
return .{
.ht = ht,
.current_index = 0,
.current_shard = 0,
};
}
pub const ApproximateIterator = struct {
ht: *const Self,
current_index: usize,
pub fn next(self: *ApproximateIterator) ?Entry {
while (self.current_index < self.ht.capacity) {
defer self.current_index += 1;
const shard_idx = self.ht.getShardIndex(self.current_index);
var dist: u8 = undefined;
var entry: Entry = undefined;
// Optimistic read loop for this specific slot
while (true) {
const version = self.ht.locks[shard_idx].readBegin();
dist = self.ht.dists[self.current_index];
// Only copy the entry if the slot isn't empty
if (dist != EMPTY_DIST) {
entry = self.ht.entries[self.current_index];
}
loadFence();
if (self.ht.locks[shard_idx].readValid(version)) {
break; // Successfully read a consistent state
}
atomic.spinLoopHint();
}
if (dist != EMPTY_DIST) {
return entry;
}
}
return null;
}
};
/// Returns a non-locking, approximate iterator using optimistic concurrency control.
///
/// Because this iterator does not hold locks, concurrent `put` and `remove` operations can
/// shift elements backwards or forwards. As a result, this iterator may miss entries or
/// return the same entry multiple times. However, it strictly guarantees that any returned
/// `Entry` is internally consistent (no torn reads).
///
/// It is perfectly safe to break out of this iterator early, as no locks are held.
pub fn approximateIterator(ht: *const Self) ApproximateIterator {
return .{
.ht = ht,
.current_index = 0,
};
}
/// Collects health statistics about the table.
/// This is a slow, non-atomic operation. Use only for monitoring or debugging.
pub fn collectStatistics(ht: *const Self) Statistics {
var total_psl: usize = 0;
var max_psl: usize = 0;
const count: usize = ht.count.load(.acquire);
// Histogram for median calculation.
const hist_size = 1024; // If more are needed you likely have other worse problems
var psl_histogram = [_]usize{0} ** (hist_size);
var actual_count: usize = 0;
for (0..ht.capacity) |i| {
const slot = ht.entries[i];
const k = slot.key;
if (ht.dists[i] != EMPTY_DIST) {
const ideal = hash(k) & ht.entry_mask;
const psl = (i -% ideal) & ht.entry_mask;
total_psl += psl;
actual_count += 1;
if (psl > max_psl) max_psl = psl;
const bucket = @min(psl, psl_histogram.len - 1);
psl_histogram[bucket] += 1;
}
}
// Calculate Median from Histogram
var median_psl: usize = 0;
if (actual_count > 0) {
const target = actual_count / 2;
var accumulated: usize = 0;
for (psl_histogram, 0..) |freq, psl_val| {
accumulated += freq;
if (accumulated >= target) {
median_psl = psl_val;
break;
}
}
}
const count_f: f64 = @floatFromInt(count);
return .{
.capacity = ht.capacity,
.count = count,
.load_factor = count_f / @as(f64, @floatFromInt(ht.capacity)),
.max_psl = max_psl,
.avg_psl = if (count > 0) @as(f64, @floatFromInt(total_psl)) / count_f else 0,
.median_psl = median_psl,
.num_shards = ht.num_shards,
};
}
/// Exhaustively validates the internal state of the table.
/// This is slow, not threadsafe, and should only really be used in tests.
pub fn verifyIntegrity(ht: *const Self) !void {
assert(builtin.is_test);
var actual_total_count: usize = 0;
var count: usize = 0;
for (0..ht.capacity) |i| {
const entry = ht.entries[i];
if (ht.dists[i] == EMPTY_DIST) continue;
actual_total_count += 1;
count += 1;
// Ensure the key can actually be found by the get() logic
const found_val = try ht.getChecked(entry.key);
try testing.expectEqual(entry.value, found_val);
// Validate Robin Hood Invariant. A slot's PSL cannot be less than (next_slot.PSL - 1)
const next_idx = (i + 1) & ht.entry_mask;
const next_stored_dist = ht.dists[next_idx];
if (next_stored_dist != EMPTY_DIST) {
const current_ideal = hash(entry.key) & ht.entry_mask;
const next_ideal = hash(ht.entries[next_idx].key) & ht.entry_mask;
const current_psl = (i -% current_ideal) & ht.entry_mask;
const next_psl = (next_idx -% next_ideal) & ht.entry_mask;
try testing.expect(next_psl <= current_psl + 1);
}
}
try testing.expectEqual(ht.count.load(.acquire), count);
}
};
}
pub const Statistics = struct {
capacity: usize,
count: usize,
load_factor: f64,
max_psl: usize,
avg_psl: f64,
median_psl: usize,
num_shards: usize,
};
/// Emits an optimal architecture-specific LoadLoad barrier.
/// Required for the read-side of sequence locks to ensure the data reads are not reordered before
/// the first version read, or after the second version read.
inline fn loadFence() void {
switch (builtin.cpu.arch) {
.x86_64, .x86 => {
// x86 memory model is TSO. Hardware does not reorder loads with other loads.
// A compiler barrier is strictly sufficient.
asm volatile ("" ::: .{ .memory = true });
},
.aarch64, .aarch64_be => {
asm volatile ("dmb ishld" ::: .{ .memory = true });
},
.riscv64 => {
asm volatile ("fence r, r" ::: .{ .memory = true });
},
else => {
// Fallback: emulate a full sequence point using a dummy atomic RMW.
var dummy: u8 = 0;
_ = @cmpxchgWeak(u8, &dummy, 0, 0, .seq_cst, .seq_cst);
},
}
}
/// A hybrid Spinlock / Sequence Lock.
/// - Writers use `lock()` which utilizes a CAS-based spinlock and `unlock()`. Both increase the
/// version/timestamp.
/// - Readers use `readBegin()` and `readValid()` to perform lock-free reads.
const ShardLock = struct {
/// Even = Unlocked, Odd = Locked
version: atomic.Value(u64) = atomic.Value(u64).init(0),
/// To remove false_sharing.
/// OPTIM: If there are a lot of shards compared to the number of cores, we could probably drop
/// the padding and save some memory, while having similar performance.
padding: [cache_line_size - @sizeOf(u64)]u8 = undefined,
fn readBegin(self: *const ShardLock) u64 {
while (true) {
const current = self.version.load(.acquire);
if (current & 1 == 0) return current;
atomic.spinLoopHint();
}
}
fn readValid(self: *const ShardLock, ts: u64) bool {
return ts == self.version.load(.acquire);
}
fn lock(self: *ShardLock) void {
var current = self.version.load(.acquire);
while (true) {
// Wait for even
while (current & 1 != 0) {
atomic.spinLoopHint();
current = self.version.load(.monotonic);
}
// CAS to switch to odd
if (self.version.cmpxchgWeak(current, current + 1, .acquire, .monotonic)) |c| {
current = c;
} else {
return; // Locked successfully
}
}
}
fn unlock(self: *ShardLock) void {
const before = self.version.fetchAdd(1, .release);
assert(before & 1 == 1);
}
};
test "basic usage" {
var ht = try AutoHashMapConcurrent(u64, u64).init(testing.allocator, 4, 2);
defer ht.deinit(testing.allocator);
ht.put(1, 10);
ht.put(2, 20);
ht.put(3, 30);
try testing.expectEqual(10, ht.get(1));
try testing.expectEqual(20, ht.get(2));
try testing.expectEqual(30, ht.get(3));
try testing.expectEqual(null, ht.get(99));
ht.put(2, 22);
try testing.expectEqual(22, ht.get(2).?);
const val = ht.fetchRemove(2);
try testing.expectEqual(22, val);
try testing.expectEqual(null, ht.get(2));
try testing.expectEqual(10, ht.get(1));
try testing.expectEqual(30, ht.get(3));
}
test "collision and robin hood" {
var ht = try AutoHashMapConcurrent(u64, u64).init(testing.allocator, 4, 4);
defer ht.deinit(testing.allocator);
ht.put(10, 100);
ht.put(20, 200);
ht.put(30, 300);
ht.put(40, 400); // Full
try testing.expectEqual(100, ht.get(10));
try testing.expectEqual(400, ht.get(40));
ht.remove(10);
try testing.expectEqual(null, ht.get(10));
try testing.expectEqual(400, ht.get(40));
}
test "clear" {
var ht = try AutoHashMapConcurrent(u64, u64).init(testing.allocator, 16, 4);
defer ht.deinit(testing.allocator);
const num_entries = 8;
for (0..num_entries) |i| ht.put(i, i);
try testing.expectEqual(num_entries, ht.count.load(.monotonic));
for (0..num_entries) |i| try testing.expectEqual(i, ht.get(i));
ht.clear();
try testing.expectEqual(0, ht.count.load(.monotonic));
for (0..num_entries) |i| try testing.expectEqual(null, ht.get(i));
}
test "iterators basic" {
var ht = try AutoHashMapConcurrent(u64, u64).init(testing.allocator, 16, 4);
defer ht.deinit(testing.allocator);
const num_entries = 10;
for (0..num_entries) |i| ht.put(i, i);
// Test Locking Iterator
{
var it = ht.lockingIterator();
defer it.deinit();
var count: usize = 0;
while (it.next()) |entry| {
try testing.expectEqual(entry.key, entry.value);
count += 1;
}
try testing.expectEqual(num_entries, count);
}
// Test Optimistic Iterator
{
var it = ht.approximateIterator();
var count: usize = 0;
while (it.next()) |entry| {
try testing.expectEqual(entry.key, entry.value);
count += 1;
}
try testing.expectEqual(num_entries, count);
}
}
test "locking iterator early break" {
var ht = try AutoHashMapConcurrent(u64, u64).init(testing.allocator, 16, 4);
defer ht.deinit(testing.allocator);
for (0..10) |i| ht.put(i, i);
{
var it = ht.lockingIterator();
defer it.deinit();
_ = it.next();
}
// If deinit failed, this put will deadlock
ht.put(99, 99);
try testing.expectEqual(99, ht.get(99).?);
}
test "single threaded fuzz" {
const Seed = 42;
const capacity = 1024;
const shards = 32;
const iterations = 100_000;
var prng = std.Random.DefaultPrng.init(Seed);
const random = prng.random();
const allocator = testing.allocator;
var ht = try AutoHashMapConcurrent(u64, u64).init(allocator, capacity, shards);
defer ht.deinit(allocator);
var ref_map = std.AutoHashMapUnmanaged(u64, u64).empty;
try ref_map.ensureTotalCapacity(allocator, capacity * 2);
defer ref_map.deinit(allocator);
for (0..iterations) |i| {
const action = random.uintAtMostBiased(u8, 9);
const key = random.int(u64) & ((capacity * 8) - 1);
if (action <= 2) { // put
if (!ht.atCapacity()) {
const val = random.int(u64);
ht.put(key, val);
ref_map.putAssumeCapacity(key, val);
}
} else if (action <= 7) { // fetch
const ht_val = ht.getChecked(key);
const ref_val = ref_map.get(key);
try testing.expectEqual(ref_val, ht_val);
} else { // remove
ht.remove(key);
_ = ref_map.remove(key);
}
// verify integrity
if (i % 1000 == 0) {
var it = ref_map.iterator();
while (it.next()) |entry| {
const stored_val = ht.get(entry.key_ptr.*);
try testing.expect(stored_val != null);
try testing.expectEqual(entry.value_ptr.*, stored_val.?);
}
}
}
}
fn stressTest(
comptime capacity: usize,
comptime shards: u64,
comptime iterations: u64,
comptime keys_per_thread: u64,
comptime num_threads: u64,
) !void {
const allocator = testing.allocator;
var ht = try AutoHashMapConcurrent(u64, u64).init(allocator, capacity, shards);
defer ht.deinit(allocator);
const Context = struct {
id: u64,
allocator: Allocator,
ht: *AutoHashMapConcurrent(u64, u64),
iterations: u64,
keys_per_thread: u64,
fn run(
id: u64,
alloc: Allocator,
table: *AutoHashMapConcurrent(u64, u64),
iter: u64,
k_per_t: u64,
) !void {
var prng = std.Random.DefaultPrng.init(id);
const random = prng.random();
var ref_map = std.AutoHashMapUnmanaged(u64, u64).empty;
try ref_map.ensureTotalCapacity(alloc, capacity * 2);
defer ref_map.deinit(alloc);
const start = id * k_per_t;
const end = start + k_per_t;
for (0..iter) |_| {
const action = random.uintAtMostBiased(u8, 9);
const key = random.intRangeLessThan(u64, start, end);
if (action <= 3) { // put
if (!table.atCapacity()) {
const val = random.int(u64);
table.put(key, val);
ref_map.putAssumeCapacity(key, val);
}
} else if (action <= 4) { // fetch
const ht_val = table.getChecked(key);
const ref_val = ref_map.get(key);
try testing.expectEqual(ref_val, ht_val);
} else { // remove
table.remove(key);
_ = ref_map.remove(key);
}
}
}
};
var threads: [num_threads]std.Thread = undefined;
for (0..num_threads) |i| {
threads[i] = try std.Thread.spawn(
.{ .allocator = allocator },
Context.run,
.{ i, allocator, &ht, iterations, keys_per_thread },
);
}
for (threads) |t| {
t.join();
}
try ht.verifyIntegrity();
}
test "multithreaded fuzz" {
try stressTest(
1024,
32,
100_000,
512,
8,
);
}
test "multithreaded stress" {
if (true) {
return error.SkipZigTest;
}
try stressTest(
1024 * 1024,
1024,
1_000_000_000,
1024 * 1024,
8,
);
}
test "torn reads and value clobbering" {
const num_threads = 16;
const capacity = 64;
const shards = 4;
const num_keys = 8;
const keys = [_]u64{ 0, 1, 2, 3, 4, 5, 6, 7 };
const time_ns = 1 * std.time.ns_per_s;
const allocator = testing.allocator;
var ht = try AutoHashMapConcurrent(u64, u64).init(allocator, capacity, shards);
defer ht.deinit(allocator);
var stop = atomic.Value(bool).init(false);
const Context = struct {
fn run(
id: usize,
_ht: *AutoHashMapConcurrent(u64, u64),
_stop: *atomic.Value(bool),
) !void {
var prng = std.Random.DefaultPrng.init(id);
const random = prng.random();
// Each thread has a unique pattern like 0x0000000100000001
const thread_pattern: u32 = @intCast(id + 1);
const thread_val: u64 = (@as(u64, thread_pattern) << 32) | thread_pattern;
while (!_stop.load(.monotonic)) {
const key = keys[random.uintLessThan(usize, num_keys)];
// 50% Put, 50% Get
if (random.boolean()) {
_ht.put(key, thread_val);
} else {
const val = try _ht.getChecked(key);
if (val) |v| {
const high: u32 = @intCast(v >> 32);
const low: u32 = @intCast(v & 0xFFFFFFFF);
try testing.expectEqual(high, low); // torn read
try testing.expect(low != 0);
try testing.expect(low <= num_threads);
try testing.expect(high != 0);
try testing.expect(high <= num_threads);
}
}
}
}
};
var threads: [num_threads]std.Thread = undefined;
for (0..num_threads) |i| {
threads[i] = try std.Thread.spawn(
.{},
Context.run,
.{ i, &ht, &stop },
);
}
std.Thread.sleep(time_ns);
stop.store(true, .monotonic);
for (threads) |t| t.join();
try ht.verifyIntegrity();
}
test "structural integrity" {
const cap = 512;
const shards = 16;
const num_threads = 16;
const key_range = 1024;
const num_rounds = 50;
const time_per_round_ns = 25 * std.time.ns_per_ms;
const allocator = testing.allocator;
var ht = try AutoHashMapConcurrent(u64, u64).init(allocator, cap, shards);
defer ht.deinit(allocator);
var stop = atomic.Value(bool).init(false);
var exit = atomic.Value(bool).init(false);
var start_sem = std.Thread.Semaphore{};
var done_sem = std.Thread.Semaphore{};
const Context = struct {
fn run(
_ht: *AutoHashMapConcurrent(u64, u64),
_stop: *atomic.Value(bool),
_exit: *atomic.Value(bool),
_start: *std.Thread.Semaphore,
_done: *std.Thread.Semaphore,
seed: u64,
) !void {
var prng = std.Random.DefaultPrng.init(seed);
const random = prng.random();
while (true) {
_start.wait();
if (_exit.load(.monotonic)) break;
while (!_stop.load(.monotonic)) {
const key = random.uintLessThan(u64, key_range);
const action = random.uintLessThan(u8, 10);
if (action < 4) { // 40% Put
if (!_ht.atCapacity()) {
_ht.put(key, key);
}
} else if (action < 8) { // 40% Remove
_ht.remove(key);
} else { // 20% Get
_ = try _ht.getChecked(key);
}
// Yield occasionally to increase interleaving
if (random.uintLessThan(u8, 100) == 0) try std.Thread.yield();
}
_done.post();
}
}
};
var threads: [num_threads]std.Thread = undefined;
for (0..num_threads) |i| {
threads[i] = try std.Thread.spawn(
.{},
Context.run,
.{ &ht, &stop, &exit, &start_sem, &done_sem, i },
);
}
// We run multiple rounds of stress. In each rounds, threads cause chaos, then we stop them and
// verify the table isn't corrupted.
for (0..num_rounds) |_| {
stop.store(false, .monotonic);
for (0..num_threads) |_| start_sem.post();
std.Thread.sleep(time_per_round_ns);
// Stop the world
stop.store(true, .monotonic);
for (0..num_threads) |_| done_sem.wait();
try ht.verifyIntegrity();
}
exit.store(true, .monotonic);
for (0..num_threads) |_| start_sem.post();
for (threads) |t| t.join();
}
test "linearizability and reference matching" {
const cap = 1024;
const shards = 64;
const num_keys = 128;
const num_workers = 8;
const time_ns = 1 * std.time.ns_per_s;
const allocator = testing.allocator;
var ht = try AutoHashMapConcurrent(u64, u64).init(allocator, cap, shards);
defer ht.deinit(allocator);
var stop = atomic.Value(bool).init(false);
const Reference = struct {
values: [num_keys]atomic.Value(u64),
locks: [num_keys]std.Thread.Mutex,
fn init() @This() {
var self: @This() = undefined;
for (0..num_keys) |i| {
self.values[i] = atomic.Value(u64).init(0);
self.locks[i] = .{};
}
return self;
}
};
var ref = Reference.init();
const Worker = struct {
fn run(
table: *AutoHashMapConcurrent(u64, u64),
reference: *Reference,
is_stop: *atomic.Value(bool),
seed: u64,
) !void {
var prng = std.Random.DefaultPrng.init(seed);
const random = prng.random();
while (!is_stop.load(.monotonic)) {
const key = random.uintLessThan(usize, num_keys);
const val = random.intRangeAtMost(u64, 1, math.maxInt(u64) - 1);
// Lock the reference key to ensure the Map and the Reference are updated atomically
reference.locks[key].lock();
defer reference.locks[key].unlock();
table.put(key, val);
reference.values[key].store(val, .release);
if (random.uintLessThan(u8, 100) == 0) try std.Thread.yield();
}
}
};
const Observer = struct {
fn run(
table: *AutoHashMapConcurrent(u64, u64),
reference: *Reference,
is_stop: *atomic.Value(bool),
) !void {
while (!is_stop.load(.monotonic)) {
for (0..num_keys) |key| {
reference.locks[key].lock();
defer reference.locks[key].unlock();
const map_val = try table.getChecked(key) orelse 0; // 0 instead of null
const ref_val = reference.values[key].load(.acquire);
try testing.expectEqual(ref_val, map_val);
}
}
}
};
var workers: [num_workers]std.Thread = undefined;
for (0..num_workers) |i| {
workers[i] = try std.Thread.spawn(.{}, Worker.run, .{ &ht, &ref, &stop, i });
}
const observer_thread = try std.Thread.spawn(.{}, Observer.run, .{ &ht, &ref, &stop });
std.Thread.sleep(time_ns);
stop.store(true, .monotonic);
for (workers) |t| t.join();
observer_thread.join();
// Final check
for (0..num_keys) |i| {
const map_v = try ht.getChecked(i);
const ref_v = ref.values[i].load(.monotonic);
try testing.expectEqual(ref_v, map_v.?);
}
try ht.verifyIntegrity();
}
/// Wrapper to make standard HashMap thread-safe for comparison
fn MutexHashMap(K: type, V: type, max_fill_factor: f64) type {
const max_load_percentage: u64 = @intFromFloat(max_fill_factor * 100);
return struct {
const Self = @This();
mutex: std.Thread.Mutex = .{},
map: std.HashMapUnmanaged(K, V, std.hash_map.AutoContext(K), max_load_percentage) = .empty,
fn init(allocator: Allocator, capacity: u32, _: usize) !Self {
var self = Self{};
// Pre-allocate to be fair
try self.map.ensureTotalCapacity(allocator, capacity);
return self;
}
fn deinit(self: *Self, allocator: Allocator) void {
self.map.deinit(allocator);
}
fn atCapacity(self: *Self) bool {
return self.map.available == 0;
}
fn put(self: *Self, k: u64, v: u64) void {
self.mutex.lock();
defer self.mutex.unlock();
self.map.putAssumeCapacity(k, v);
}
fn get(self: *Self, k: u64) ?u64 {
self.mutex.lock();
defer self.mutex.unlock();
return self.map.get(k);
}
fn remove(self: *Self, k: u64) void {
self.mutex.lock();
defer self.mutex.unlock();
_ = self.map.remove(k);
}
};
}
const BenchOptions = struct {
name: []const u8,
size: u32,
num_threads: u64,
iterations: u64,
put_prob: u8,
remove_prob: u8,
baseline_ns: ?u64 = null,
};
fn benchWorker(
comptime MapType: type,
map: *MapType,
seed: u64,
options: BenchOptions,
) void {
var prng = std.Random.DefaultPrng.init(seed);
const random = prng.random();
const total_ops = options.iterations / options.num_threads;
const key_range = options.size * 2;
for (0..total_ops) |_| {
const action = random.uintLessThanBiased(u8, 100);
const key = random.uintLessThanBiased(u64, key_range);
if (action < options.put_prob and !map.atCapacity()) {
map.put(key, key);
continue;
}
if (action < options.put_prob + options.remove_prob) {
map.remove(key);
continue;
}
_ = map.get(key);
}
}
fn runBench(
allocator: Allocator,
comptime MapType: type,
options: BenchOptions,
) !u64 {
const num_shards = @min(@max(64, options.num_threads * 16), 1024);
var map = try MapType.init(allocator, options.size, num_shards);
defer map.deinit(allocator);
var threads = try allocator.alloc(std.Thread, options.num_threads);
defer allocator.free(threads);
// Pre fill to avoid empty bias
var k: u64 = 0;
while (!map.atCapacity()) : (k += 1) {
map.put(k, k);
}
const timer_start = std.time.nanoTimestamp();
for (0..options.num_threads) |i| {
threads[i] = try std.Thread.spawn(.{}, benchWorker, .{
MapType, &map, i, options,
});
}
for (threads) |t| t.join();
return @intCast(std.time.nanoTimestamp() - timer_start);
}
/// These are by no means statistically sound benchmarks. They are just to give a rough guidance on
/// how to choose parameters.
pub fn main() !void {
const size = 1024 * 1024;
const iterations = 10_000_000;
const load_factors = [_]f64{ 0.5, 0.8, 0.9, 0.95, 0.98 };
const configs = [_]struct { name: []const u8, p: u8, r: u8 }{
.{ .name = "Read-Heavy", .p = 3, .r = 2 },
.{ .name = "Balanced", .p = 25, .r = 25 },
.{ .name = "Write-Heavy", .p = 45, .r = 45 },
};
const allocator = std.heap.page_allocator;
var thread_counts = std.ArrayListUnmanaged(u64).empty;
defer thread_counts.deinit(allocator);
const cpu_count = try std.Thread.getCpuCount();
var t: u64 = 1;
while (t <= cpu_count) : (t *= 2) {
try thread_counts.append(allocator, t);
}
const csv_file = try std.fs.cwd().createFile("benchmark_results.csv", .{});
defer csv_file.close();
var csv_buffer: [1024]u8 = undefined;
var csv_file_writer = csv_file.writer(&csv_buffer);
const csv = &csv_file_writer.interface;
try csv.print("implementation,load_factor,workload,threads,time_ns,ops_per_sec,speedup\n", .{});
var stdout_buffer: [1024]u8 = undefined;
var stdout_file_writer = std.fs.File.stdout().writer(&stdout_buffer);
const stdout = &stdout_file_writer.interface;
// Header
try stdout.print(
"{s:<14} | {s:<4} | {s:<11}",
.{ "Implementation", "LF", "Workload" },
);
for (thread_counts.items) |threads| {
try stdout.print(" | {d:>3} Threads", .{threads});
}
try stdout.print("\n", .{});
// Separator
try stdout.print("{s:-<14}-+-{s:-<4}-+-{s:-<11}", .{ "", "", "" });
for (thread_counts.items) |_| try stdout.print("-+------------", .{});
try stdout.print("\n", .{});
try stdout.flush();
inline for (load_factors) |lf| {
for (configs) |cfg| {
const impls = [_][]const u8{ "Concurrent", "Mutex" };
inline for (impls) |impl_name| {
// Do not spam with unnecessary mutex benchmarks.
if (comptime mem.eql(u8, impl_name, "Mutex") and lf != 0.8) continue;
try stdout.print("{s:<14} | {d:<4.2} | {s:<11}", .{ impl_name, lf, cfg.name });
var baseline_ops: u64 = 0;
for (thread_counts.items) |threads| {
const options = BenchOptions{
.name = impl_name,
.size = size,
.num_threads = threads,
.iterations = iterations,
.put_prob = cfg.p,
.remove_prob = cfg.r,
};
const time_ns = if (mem.eql(u8, impl_name, "Concurrent"))
try runBench(
allocator,
HashMapConcurrent(
u64,
u64,
getAutoHashFn(u64),
getAutoEqlFn(u64),
.{ .max_fill_factor = lf },
),
options,
)
else
try runBench(allocator, MutexHashMap(u64, u64, lf), options);
try stdout.print(" | {D:>9} ", .{time_ns});
const total_ops: f64 = @floatFromInt(iterations);
const time_s: f64 = @as(f64, @floatFromInt(time_ns)) / 1e9;
const ops_sec: u64 = @intFromFloat(total_ops / time_s);
if (threads == 1) baseline_ops = ops_sec;
const speedup = if (baseline_ops > 0)
@as(f64, @floatFromInt(ops_sec)) / @as(f64, @floatFromInt(baseline_ops))
else
0;
try csv.print(
"{s},{d:.2},{s},{d},{d},{d},{d:.2}\n",
.{ impl_name, lf, cfg.name, threads, time_ns, ops_sec, speedup },
);
}
try stdout.print("\n", .{});
try stdout.flush();
}
}
// Small separator between Load Factors
try stdout.print("{s:-<14}-+-{s:-<4}-+-{s:-<11}", .{ "", "", "" });
for (thread_counts.items) |_| try stdout.print("-+------------", .{});
try stdout.print("\n", .{});
try stdout.flush();
}
try csv.flush();
try stdout.flush();
}