Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions elastica/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,4 @@
)
from elastica.memory_block.memory_block_rigid_body import MemoryBlockRigidBody
from elastica.memory_block.memory_block_rod import MemoryBlockCosseratRod
from elastica.restart import save_state, load_state
13 changes: 9 additions & 4 deletions elastica/restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import os
from itertools import groupby
from .memory_block import MemoryBlockCosseratRod, MemoryBlockRigidBody


def all_equal(iterable):
Expand Down Expand Up @@ -41,6 +42,10 @@ def save_state(simulator, directory: str = "", time=0.0, verbose: bool = False):
"""
os.makedirs(directory, exist_ok=True)
for idx, rod in enumerate(simulator):
if isinstance(rod, MemoryBlockCosseratRod) or isinstance(
rod, MemoryBlockRigidBody
):
continue
path = os.path.join(directory, "system_{}.npz".format(idx))
np.savez(path, time=time, **rod.__dict__)

Expand Down Expand Up @@ -69,6 +74,10 @@ def load_state(simulator, directory: str = "", verbose: bool = False):
"""
time_list = [] # Simulation time of rods when they are saved.
for idx, rod in enumerate(simulator):
if isinstance(rod, MemoryBlockCosseratRod) or isinstance(
rod, MemoryBlockRigidBody
):
continue
path = os.path.join(directory, "system_{}.npz".format(idx))
data = np.load(path, allow_pickle=True)
for key, value in data.items():
Expand All @@ -88,10 +97,6 @@ def load_state(simulator, directory: str = "", verbose: bool = False):
"Restart time of loaded rods are different, check your inputs!"
)

# Apply boundary conditions, after loading the systems.
simulator.constrain_values(0.0)
simulator.constrain_rates(0.0)

if verbose:
print("Load complete: {}".format(directory))

Expand Down
93 changes: 93 additions & 0 deletions tests/test_restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
CallBacks,
)
from elastica.restart import save_state, load_state
import elastica as ea


class GenericSimulatorClass(
Expand Down Expand Up @@ -78,6 +79,98 @@ def test_restart_save_load(self, load_collection):

assert_allclose(test_value, correct_value)

def run_sim(self, final_time, load_from_restart, save_data_restart):
class BaseSimulatorClass(
BaseSystemCollection, Constraints, Forcing, Connections, CallBacks
):
pass

simulator_class = BaseSimulatorClass()

rod_list = []
for _ in range(5):
rod = ea.CosseratRod.straight_rod(
n_elements=10,
start=np.zeros((3)),
direction=np.array([0, 1, 0.0]),
normal=np.array([1, 0, 0.0]),
base_length=1,
base_radius=1,
density=1,
youngs_modulus=1,
)
# Bypass check, but its fine for testing
simulator_class._systems.append(rod)

# Also add rods to a separate list
rod_list.append(rod)

for rod in rod_list:
simulator_class.add_forcing_to(rod).using(
ea.EndpointForces,
start_force=np.zeros(
3,
),
end_force=np.array([0, 0.1, 0]),
ramp_up_time=0.1,
)

# Finalize simulator
simulator_class.finalize()

directory = "restart_test_data/"

time_step = 1e-4
total_steps = int(final_time / time_step)

if load_from_restart:
restart_time = ea.load_state(simulator_class, directory, True)

else:
restart_time = np.float64(0.0)

timestepper = ea.PositionVerlet()
time = ea.integrate(
timestepper,
simulator_class,
final_time,
total_steps,
restart_time=restart_time,
)

if save_data_restart:
ea.save_state(simulator_class, directory, time, True)

# Compute final time accelerations
recorded_list = np.zeros((len(rod_list), rod_list[0].n_elems + 1))
for i, rod in enumerate(rod_list):
recorded_list[i, :] = rod.acceleration_collection[1, :]

return recorded_list

@pytest.mark.parametrize("final_time", [0.2, 1.0])
def test_save_restart_run_sim(self, final_time):

# First half of simulation
_ = self.run_sim(
final_time / 2, load_from_restart=False, save_data_restart=True
)

# Second half of simulation
recorded_list = self.run_sim(
final_time / 2, load_from_restart=True, save_data_restart=False
)
recorded_list_second_half = recorded_list.copy()

# Full simulation
recorded_list = self.run_sim(
final_time, load_from_restart=False, save_data_restart=False
)
recorded_list_full_sim = recorded_list.copy()

# Compare final accelerations of rods
assert_allclose(recorded_list_second_half, recorded_list_full_sim)


class TestRestartFunctionsWithFeaturesUsingRigidBodies:
@pytest.fixture(scope="function")
Expand Down