Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 33 additions & 17 deletions src/easyscience/global_object/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,17 @@ def __init__(self):
self.__type_dict = {}

def vertices(self) -> List[str]:
"""returns the vertices of a map"""
return list(self._store.keys())
"""Returns the vertices of a map.

Uses a retry loop to handle RuntimeError that can occur when the
WeakValueDictionary is modified during iteration (e.g., by garbage collection).
"""
while True:
try:
return list(self._store)
except RuntimeError:
# Dictionary changed size during iteration, retry
continue

def edges(self):
"""returns the edges of a map"""
Expand All @@ -103,43 +112,50 @@ def _nested_get(self, obj_type: str) -> List[str]:
return [key for key, item in self.__type_dict.items() if obj_type in item.type]

def get_item_by_key(self, item_id: str) -> object:
if item_id in self._store.keys():
if item_id in self._store:
return self._store[item_id]
raise ValueError('Item not in map.')

def is_known(self, vertex: object) -> bool:
# All objects should have a 'unique_name' attribute
return vertex.unique_name in self._store.keys()
"""Check if a vertex is known in the map.

All objects should have a 'unique_name' attribute.
"""
return vertex.unique_name in self._store

def find_type(self, vertex: object) -> List[str]:
if self.is_known(vertex):
return self.__type_dict[vertex.unique_name].type

def reset_type(self, obj, default_type: str):
if obj.unique_name in self.__type_dict.keys():
if obj.unique_name in self.__type_dict:
self.__type_dict[obj.unique_name].reset_type(default_type)

def change_type(self, obj, new_type: str):
if obj.unique_name in self.__type_dict.keys():
if obj.unique_name in self.__type_dict:
self.__type_dict[obj.unique_name].type = new_type

def add_vertex(self, obj: object, obj_type: str = None):
name = obj.unique_name
if name in self._store.keys():
if name in self._store:
raise ValueError(f'Object name {name} already exists in the graph.')
# Clean up stale entry in __type_dict if the weak reference was collected
# but the finalizer hasn't run yet
if name in self.__type_dict:
del self.__type_dict[name]
self._store[name] = obj
self.__type_dict[name] = _EntryList() # Add objects type to the list of types
self.__type_dict[name].finalizer = weakref.finalize(self._store[name], self.prune, name)
self.__type_dict[name].type = obj_type

def add_edge(self, start_obj: object, end_obj: object):
if start_obj.unique_name in self.__type_dict.keys():
if start_obj.unique_name in self.__type_dict:
self.__type_dict[start_obj.unique_name].append(end_obj.unique_name)
else:
raise AttributeError('Start object not in map.')

def get_edges(self, start_obj) -> List[str]:
if start_obj.unique_name in self.__type_dict.keys():
if start_obj.unique_name in self.__type_dict:
return list(self.__type_dict[start_obj.unique_name])
else:
raise AttributeError
Expand All @@ -163,13 +179,14 @@ def prune_vertex_from_edge(self, parent_obj, child_obj):
return
vertex2 = child_obj.unique_name

if vertex1 in self.__type_dict.keys() and vertex2 in self.__type_dict[vertex1]:
if vertex1 in self.__type_dict and vertex2 in self.__type_dict[vertex1]:
del self.__type_dict[vertex1][self.__type_dict[vertex1].index(vertex2)]

def prune(self, key: str):
if key in self.__type_dict.keys():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And again here.

if key in self.__type_dict:
del self.__type_dict[key]
del self._store[key]
if key in self._store:
del self._store[key]

def find_isolated_vertices(self) -> list:
"""returns a list of isolated vertices."""
Expand Down Expand Up @@ -247,7 +264,7 @@ def is_connected(self, vertices_encountered=None, start_vertex=None) -> bool:
if vertices_encountered is None:
vertices_encountered = set()
graph = self.__type_dict
vertices = list(graph.keys())
vertices = list(graph)
if not start_vertex:
# chose a vertex from graph as a starting point
start_vertex = vertices[0]
Expand All @@ -262,10 +279,9 @@ def is_connected(self, vertices_encountered=None, start_vertex=None) -> bool:

def _clear(self):
"""Reset the map to an empty state. Only to be used for testing"""
for vertex in self.vertices():
self.prune(vertex)
self._store.clear()
self.__type_dict.clear()
gc.collect()
self.__type_dict = {}

def __repr__(self) -> str:
return f'Map object of {len(self._store)} vertices.'
245 changes: 245 additions & 0 deletions tests/unit_tests/global_object/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,3 +484,248 @@ def test_map_initialization(self):
assert len(test_map.vertices()) == 0
assert test_map.edges() == []

def test_vertices_retry_on_runtime_error(self, clear):
"""Test that vertices() retries when RuntimeError occurs during iteration.

This tests the thread-safety fix for WeakValueDictionary modification
during iteration (e.g., by garbage collection).
"""
# Given
test_map = Map()

# Create a mock _store that raises RuntimeError on first iteration attempt
call_count = 0
original_store = test_map._store

class MockWeakValueDict:
def __init__(self):
self.data = {}
self.iteration_count = 0

def __iter__(self):
self.iteration_count += 1
if self.iteration_count == 1:
# First iteration raises RuntimeError (simulates GC interference)
raise RuntimeError("dictionary changed size during iteration")
# Subsequent iterations succeed
return iter(self.data)

def __len__(self):
return len(self.data)

mock_store = MockWeakValueDict()
test_map._store = mock_store

# When
vertices = test_map.vertices()

# Then
assert vertices == []
assert mock_store.iteration_count == 2 # Should have retried once

def test_add_vertex_cleans_stale_type_dict_entry(self, clear):
"""Test that add_vertex cleans up stale __type_dict entries.

This can happen when a weak reference was collected but the finalizer
hasn't run yet, and a new object is created with the same unique_name.
"""
# Given
test_map = Map()

# Manually add a stale entry to __type_dict (simulating GC collected but finalizer not run)
stale_name = "StaleObject_0"
test_map._Map__type_dict[stale_name] = _EntryList()

# Create a mock object with the same unique_name
mock_obj = MagicMock()
mock_obj.unique_name = stale_name

# When - Adding the object should clean up the stale entry first
test_map.add_vertex(mock_obj, 'created')

# Then - Object should be added successfully
assert stale_name in test_map._store
assert stale_name in test_map._Map__type_dict
assert test_map._Map__type_dict[stale_name].type == ['created']

def test_prune_key_not_in_store(self, clear):
"""Test that prune handles case when key is not in _store.

This defensive check prevents KeyError when the weak reference has
already been garbage collected but __type_dict entry remains.
"""
# Given
test_map = Map()

# Manually add entry to __type_dict without corresponding _store entry
orphan_key = "OrphanObject_0"
test_map._Map__type_dict[orphan_key] = _EntryList()

# When - Pruning should not raise error
test_map.prune(orphan_key)

# Then - Entry should be removed from __type_dict
assert orphan_key not in test_map._Map__type_dict

def test_prune_key_in_both_dicts(self, clear, base_object):
"""Test that prune removes key from both _store and __type_dict."""
# Given
unique_name = base_object.unique_name
assert unique_name in global_object.map._store
assert unique_name in global_object.map._Map__type_dict

# When
global_object.map.prune(unique_name)

# Then
assert unique_name not in global_object.map._Map__type_dict
# Note: _store entry may or may not exist depending on weak ref state

def test_prune_nonexistent_key(self, clear):
"""Test that prune handles nonexistent key gracefully."""
# When/Then - Should not raise error
global_object.map.prune("nonexistent_key")

def test_reset_type_unknown_object(self, clear):
"""Test reset_type with object not in map."""
# Given
unknown_obj = MagicMock()
unknown_obj.unique_name = "unknown"

# When/Then - Should not raise error
global_object.map.reset_type(unknown_obj, 'argument')

def test_change_type_unknown_object(self, clear):
"""Test change_type with object not in map."""
# Given
unknown_obj = MagicMock()
unknown_obj.unique_name = "unknown"

# When/Then - Should not raise error
global_object.map.change_type(unknown_obj, 'argument')

def test_find_path_start_not_in_graph(self, clear):
"""Test find_path when start vertex is not in graph."""
# When
path = global_object.map.find_path("nonexistent", "also_nonexistent")

# Then
assert path == []

def test_find_all_paths_start_not_in_graph(self, clear):
"""Test find_all_paths when start vertex is not in graph."""
# When
paths = global_object.map.find_all_paths("nonexistent", "also_nonexistent")

# Then
assert paths == []

def test_find_isolated_vertices(self, clear, base_object, parameter_object):
"""Test finding isolated vertices (vertices with no outgoing edges)."""
# Given - No edges added, both objects are isolated

# When
isolated = global_object.map.find_isolated_vertices()

# Then
assert base_object.unique_name in isolated
assert parameter_object.unique_name in isolated

def test_find_isolated_vertices_with_edges(self, clear, base_object, parameter_object):
"""Test finding isolated vertices when some have edges."""
# Given
global_object.map.add_edge(base_object, parameter_object)

# When
isolated = global_object.map.find_isolated_vertices()

# Then
# base_object has an edge, so it's not isolated
assert base_object.unique_name not in isolated
# parameter_object has no outgoing edges, so it's isolated
assert parameter_object.unique_name in isolated

def test_prune_vertex_from_edge_edge_not_exists(self, clear, base_object, parameter_object):
"""Test pruning edge that doesn't exist."""
# Given - No edge added between objects

# When/Then - Should not raise error
global_object.map.prune_vertex_from_edge(base_object, parameter_object)

def test_prune_vertex_from_edge_parent_not_in_map(self, clear, parameter_object):
"""Test pruning edge when parent is not in map."""
# Given
unknown_obj = MagicMock()
unknown_obj.unique_name = "unknown"

# When/Then - Should not raise error (vertex1 not in type_dict)
global_object.map.prune_vertex_from_edge(unknown_obj, parameter_object)

def test_created_internal_property(self, clear):
"""Test created_internal property."""
# Given
obj = ObjBase(name="internal_obj")
global_object.map.change_type(obj, 'created_internal')

# When
internal_objs = global_object.map.created_internal

# Then
assert obj.unique_name in internal_objs

def test_clear_empties_both_dicts(self, clear, base_object, parameter_object):
"""Test that _clear() properly empties both _store and __type_dict."""
# Given
assert len(global_object.map._store) == 2
assert len(global_object.map._Map__type_dict) == 2

# When
global_object.map._clear()

# Then
assert len(global_object.map._store) == 0
assert len(global_object.map._Map__type_dict) == 0

def test_entry_list_delitem(self):
"""Test _EntryList __delitem__ method."""
# Given
entry = _EntryList()
entry.append("item1")
entry.append("item2")
entry.append("item3")

# When
del entry[1]

# Then
assert len(entry) == 2
assert "item2" not in entry
assert "item1" in entry
assert "item3" in entry

def test_entry_list_repr_with_finalizer(self):
"""Test _EntryList repr when finalizer is set."""
# Given
entry = _EntryList()
entry.type = 'created'
entry.finalizer = MagicMock() # Non-None finalizer

# When
repr_str = str(entry)

# Then
assert 'created' in repr_str
assert 'With a finalizer' in repr_str

def test_entry_list_remove_type_unknown(self):
"""Test removing a type that's not in known types."""
# Given
entry = _EntryList()
entry.type = 'created'

# When - Try to remove unknown type
entry.remove_type('unknown_type')

# Then - Should not change anything
assert 'created' in entry.type

Loading