diff --git a/elastica/__init__.py b/elastica/__init__.py index 7622bf260..f60bc902b 100644 --- a/elastica/__init__.py +++ b/elastica/__init__.py @@ -71,3 +71,4 @@ ) from elastica.memory_block.memory_block_rigid_body import MemoryBlockRigidBody from elastica.memory_block.memory_block_rod import MemoryBlockCosseratRod +from elastica.restart import save_state, load_state diff --git a/elastica/restart.py b/elastica/restart.py index 7e97cb83c..1b5fab267 100644 --- a/elastica/restart.py +++ b/elastica/restart.py @@ -3,6 +3,7 @@ import numpy as np import os from itertools import groupby +from .memory_block import MemoryBlockCosseratRod, MemoryBlockRigidBody def all_equal(iterable): @@ -41,6 +42,10 @@ def save_state(simulator, directory: str = "", time=0.0, verbose: bool = False): """ os.makedirs(directory, exist_ok=True) for idx, rod in enumerate(simulator): + if isinstance(rod, MemoryBlockCosseratRod) or isinstance( + rod, MemoryBlockRigidBody + ): + continue path = os.path.join(directory, "system_{}.npz".format(idx)) np.savez(path, time=time, **rod.__dict__) @@ -69,6 +74,10 @@ def load_state(simulator, directory: str = "", verbose: bool = False): """ time_list = [] # Simulation time of rods when they are saved. for idx, rod in enumerate(simulator): + if isinstance(rod, MemoryBlockCosseratRod) or isinstance( + rod, MemoryBlockRigidBody + ): + continue path = os.path.join(directory, "system_{}.npz".format(idx)) data = np.load(path, allow_pickle=True) for key, value in data.items(): @@ -88,10 +97,6 @@ def load_state(simulator, directory: str = "", verbose: bool = False): "Restart time of loaded rods are different, check your inputs!" ) - # Apply boundary conditions, after loading the systems. - simulator.constrain_values(0.0) - simulator.constrain_rates(0.0) - if verbose: print("Load complete: {}".format(directory)) diff --git a/tests/test_restart.py b/tests/test_restart.py index 0ba52df5d..0f814b21a 100644 --- a/tests/test_restart.py +++ b/tests/test_restart.py @@ -12,6 +12,7 @@ CallBacks, ) from elastica.restart import save_state, load_state +import elastica as ea class GenericSimulatorClass( @@ -78,6 +79,98 @@ def test_restart_save_load(self, load_collection): assert_allclose(test_value, correct_value) + def run_sim(self, final_time, load_from_restart, save_data_restart): + class BaseSimulatorClass( + BaseSystemCollection, Constraints, Forcing, Connections, CallBacks + ): + pass + + simulator_class = BaseSimulatorClass() + + rod_list = [] + for _ in range(5): + rod = ea.CosseratRod.straight_rod( + n_elements=10, + start=np.zeros((3)), + direction=np.array([0, 1, 0.0]), + normal=np.array([1, 0, 0.0]), + base_length=1, + base_radius=1, + density=1, + youngs_modulus=1, + ) + # Bypass check, but its fine for testing + simulator_class._systems.append(rod) + + # Also add rods to a separate list + rod_list.append(rod) + + for rod in rod_list: + simulator_class.add_forcing_to(rod).using( + ea.EndpointForces, + start_force=np.zeros( + 3, + ), + end_force=np.array([0, 0.1, 0]), + ramp_up_time=0.1, + ) + + # Finalize simulator + simulator_class.finalize() + + directory = "restart_test_data/" + + time_step = 1e-4 + total_steps = int(final_time / time_step) + + if load_from_restart: + restart_time = ea.load_state(simulator_class, directory, True) + + else: + restart_time = np.float64(0.0) + + timestepper = ea.PositionVerlet() + time = ea.integrate( + timestepper, + simulator_class, + final_time, + total_steps, + restart_time=restart_time, + ) + + if save_data_restart: + ea.save_state(simulator_class, directory, time, True) + + # Compute final time accelerations + recorded_list = np.zeros((len(rod_list), rod_list[0].n_elems + 1)) + for i, rod in enumerate(rod_list): + recorded_list[i, :] = rod.acceleration_collection[1, :] + + return recorded_list + + @pytest.mark.parametrize("final_time", [0.2, 1.0]) + def test_save_restart_run_sim(self, final_time): + + # First half of simulation + _ = self.run_sim( + final_time / 2, load_from_restart=False, save_data_restart=True + ) + + # Second half of simulation + recorded_list = self.run_sim( + final_time / 2, load_from_restart=True, save_data_restart=False + ) + recorded_list_second_half = recorded_list.copy() + + # Full simulation + recorded_list = self.run_sim( + final_time, load_from_restart=False, save_data_restart=False + ) + recorded_list_full_sim = recorded_list.copy() + + # Compare final accelerations of rods + assert_allclose(recorded_list_second_half, recorded_list_full_sim) + class TestRestartFunctionsWithFeaturesUsingRigidBodies: @pytest.fixture(scope="function")