Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
3e76a38
typing: contact related modules
skim0119 Jun 15, 2024
8a056a6
remove AllowedContactType union alias to distinguish cases
skim0119 Jun 15, 2024
7484502
typing: add basic properties for surface_base definition
skim0119 Jun 15, 2024
64e39a6
typing: include interaction module
skim0119 Jun 15, 2024
32130fb
typing: other minor update for different operators
skim0119 Jun 16, 2024
ee57acc
Merge branch 'typing/rods' into typing/contact
skim0119 Jun 16, 2024
3a06dc9
remove deprecated joint functions
skim0119 Jun 16, 2024
0b38d36
refactor: mv memory related methods into separate file
skim0119 Jun 16, 2024
7fb82a9
Merge branch 'update/mypy' into typing/contact
skim0119 Jun 16, 2024
8f5af01
resolve circular import
skim0119 Jun 16, 2024
4df6150
test: remove unittest for ExternalContact and SelfContact
skim0119 Jun 16, 2024
0e1c9ab
typing: collision utils
skim0119 Jun 16, 2024
83ec4fe
wip: typing symplectic stepper
skim0119 Jun 16, 2024
267bca2
fix tests with refactoring symplectic system
skim0119 Jun 16, 2024
e4d9a13
remove deprecated functions in interaction
skim0119 Jun 16, 2024
090ac1e
remove deprecated anisotropic friction msg
skim0119 Jun 16, 2024
f6738e9
wip: fixing type disagreements
skim0119 Jun 16, 2024
5553400
resolve liskov error
skim0119 Jun 16, 2024
77952bb
wip: finalizing typing system and system collection relation
skim0119 Jun 17, 2024
6b39aea
refactor memory block related utility functions
skim0119 Jun 17, 2024
cdce69e
type: finish organizing prototype
skim0119 Jun 17, 2024
1a00700
tests: use tmp_path to save test files
skim0119 Jun 17, 2024
c1d4bdc
fix: formatting
skim0119 Jun 17, 2024
c323e56
fix: directly access number of elements
skim0119 Jun 17, 2024
6c20c3a
remove indexing feature for connection
skim0119 Jun 17, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions elastica/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,8 @@
)
from elastica.joint import (
FreeJoint,
ExternalContact,
FixedJoint,
HingeJoint,
SelfContact,
)
from elastica.contact_forces import (
NoContact,
Expand Down
6 changes: 3 additions & 3 deletions elastica/_synchronize_periodic_boundary.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
from numpy.typing import NDArray
from elastica.boundary_conditions import ConstraintBase
from elastica.typing import SystemType
from elastica.typing import RodType


@njit(cache=True) # type: ignore
Expand Down Expand Up @@ -92,15 +92,15 @@ class _ConstrainPeriodicBoundaries(ConstraintBase):
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

def constrain_values(self, rod: SystemType, time: np.floating) -> None:
def constrain_values(self, rod: RodType, time: np.floating) -> None:
_synchronize_periodic_boundary_of_vector_collection(
rod.position_collection, rod.periodic_boundary_nodes_idx
)
_synchronize_periodic_boundary_of_matrix_collection(
rod.director_collection, rod.periodic_boundary_elems_idx
)

def constrain_rates(self, rod: SystemType, time: np.floating) -> None:
def constrain_rates(self, rod: RodType, time: np.floating) -> None:
_synchronize_periodic_boundary_of_vector_collection(
rod.velocity_collection, rod.periodic_boundary_nodes_idx
)
Expand Down
61 changes: 42 additions & 19 deletions elastica/boundary_conditions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
__doc__ = """ Built-in boundary condition implementationss """

import warnings
from typing import Any, Optional
from typing import Any, Optional, TypeVar, Generic

import numpy as np
from numpy.typing import NDArray
Expand All @@ -12,10 +12,13 @@

from elastica._linalg import _batch_matvec, _batch_matrix_transpose
from elastica._rotations import _get_rotation_matrix
from elastica.typing import SystemType, RodType
from elastica.typing import SystemType, RodType, RigidBodyType


class ConstraintBase(ABC):
S = TypeVar("S")


class ConstraintBase(ABC, Generic[S]):
"""Base class for constraint and displacement boundary condition implementation.

Notes
Expand All @@ -31,7 +34,7 @@ class ConstraintBase(ABC):

"""

_system: SystemType
_system: S
_constrained_position_idx: np.ndarray
_constrained_director_idx: np.ndarray

Expand All @@ -51,24 +54,24 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
)

@property
def system(self) -> SystemType:
def system(self) -> S:
"""get system (rod or rigid body) reference"""
return self._system

@property
def constrained_position_idx(self) -> Optional[np.ndarray]:
def constrained_position_idx(self) -> np.ndarray:
"""get position-indices passed to "using" """
# TODO: This should be immutable somehow
return self._constrained_position_idx

@property
def constrained_director_idx(self) -> Optional[np.ndarray]:
def constrained_director_idx(self) -> np.ndarray:
"""get director-indices passed to "using" """
# TODO: This should be immutable somehow
return self._constrained_director_idx

@abstractmethod
def constrain_values(self, system: SystemType, time: np.floating) -> None:
def constrain_values(self, system: S, time: np.floating) -> None:
# TODO: In the future, we can remove rod and use self.system
"""
Constrain values (position and/or directors) of a rod object.
Expand All @@ -83,7 +86,7 @@ def constrain_values(self, system: SystemType, time: np.floating) -> None:
pass

@abstractmethod
def constrain_rates(self, system: SystemType, time: np.floating) -> None:
def constrain_rates(self, system: S, time: np.floating) -> None:
# TODO: In the future, we can remove rod and use self.system
"""
Constrain rates (velocity and/or omega) of a rod object.
Expand All @@ -107,11 +110,15 @@ class FreeBC(ConstraintBase):
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

def constrain_values(self, system: SystemType, time: np.floating) -> None:
def constrain_values(
self, system: "RodType | RigidBodyType", time: np.floating
) -> None:
"""In FreeBC, this routine simply passes."""
pass

def constrain_rates(self, system: SystemType, time: np.floating) -> None:
def constrain_rates(
self, system: "RodType | RigidBodyType", time: np.floating
) -> None:
"""In FreeBC, this routine simply passes."""
pass

Expand Down Expand Up @@ -165,7 +172,9 @@ def __init__(
self.fixed_position_collection = np.array(fixed_position)
self.fixed_directors_collection = np.array(fixed_directors)

def constrain_values(self, system: SystemType, time: np.floating) -> None:
def constrain_values(
self, system: "RodType | RigidBodyType", time: np.floating
) -> None:
# system.position_collection[..., 0] = self.fixed_position
# system.director_collection[..., 0] = self.fixed_directors
self.compute_constrain_values(
Expand All @@ -175,7 +184,9 @@ def constrain_values(self, system: SystemType, time: np.floating) -> None:
self.fixed_directors_collection,
)

def constrain_rates(self, system: SystemType, time: np.floating) -> None:
def constrain_rates(
self, system: "RodType | RigidBodyType", time: np.floating
) -> None:
# system.velocity_collection[..., 0] = 0.0
# system.omega_collection[..., 0] = 0.0
self.compute_constrain_rates(
Expand Down Expand Up @@ -340,7 +351,9 @@ def __init__(
)
self.rotational_constraint_selector = rotational_constraint_selector.astype(int)

def constrain_values(self, system: SystemType, time: np.floating) -> None:
def constrain_values(
self, system: "RodType | RigidBodyType", time: np.floating
) -> None:
if self.constrained_position_idx.size:
self.nb_constrain_translational_values(
system.position_collection,
Expand All @@ -349,7 +362,9 @@ def constrain_values(self, system: SystemType, time: np.floating) -> None:
self.translational_constraint_selector,
)

def constrain_rates(self, system: SystemType, time: np.floating) -> None:
def constrain_rates(
self, system: "RodType | RigidBodyType", time: np.floating
) -> None:
if self.constrained_position_idx.size:
self.nb_constrain_translational_rates(
system.velocity_collection,
Expand Down Expand Up @@ -525,7 +540,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
**kwargs,
)

def constrain_values(self, system: SystemType, time: np.floating) -> None:
def constrain_values(
self, system: "RodType | RigidBodyType", time: np.floating
) -> None:
if self.constrained_position_idx.size:
self.nb_constrain_translational_values(
system.position_collection,
Expand All @@ -539,7 +556,9 @@ def constrain_values(self, system: SystemType, time: np.floating) -> None:
self.constrained_director_idx,
)

def constrain_rates(self, system: SystemType, time: np.floating) -> None:
def constrain_rates(
self, system: "RodType | RigidBodyType", time: np.floating
) -> None:
if self.constrained_position_idx.size:
self.nb_constrain_translational_rates(
system.velocity_collection,
Expand Down Expand Up @@ -743,15 +762,19 @@ def __init__(
@ director_end
) # rotation_matrix wants vectors 3,1

def constrain_values(self, rod: SystemType, time: np.floating) -> None:
def constrain_values(
self, rod: "RodType | RigidBodyType", time: np.floating
) -> None:
if time > self.twisting_time:
rod.position_collection[..., 0] = self.final_start_position
rod.position_collection[..., -1] = self.final_end_position

rod.director_collection[..., 0] = self.final_start_directors
rod.director_collection[..., -1] = self.final_end_directors

def constrain_rates(self, rod: SystemType, time: np.floating) -> None:
def constrain_rates(
self, rod: "RodType | RigidBodyType", time: np.floating
) -> None:
if time > self.twisting_time:
rod.velocity_collection[..., 0] = 0.0
rod.omega_collection[..., 0] = 0.0
Expand Down
22 changes: 12 additions & 10 deletions elastica/callback_functions.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
__doc__ = """ Module contains callback classes to save simulation data for rod-like objects """
from elastica.typing import SystemType
from typing import Any, Optional, TypeVar, Generic
from elastica.typing import RodType, RigidBodyType, SystemType

import os
import sys
import numpy as np
from numpy.typing import NDArray
import logging
from typing import Any, Optional


from collections import defaultdict

from elastica.typing import RodType, SystemType

T = TypeVar("T")

class CallBackBaseClass:

class CallBackBaseClass(Generic[T]):
"""
This is the base class for callbacks for rod-like objects.

Expand All @@ -30,9 +32,7 @@ def __init__(self) -> None:
"""
pass

def make_callback(
self, system: SystemType, time: np.floating, current_step: int
) -> None:
def make_callback(self, system: T, time: np.floating, current_step: int) -> None:
"""
This method is called every time step. Users can define
which parameters are called back and recorded. Also users
Expand Down Expand Up @@ -81,7 +81,7 @@ def __init__(self, step_skip: int, callback_params: dict) -> None:
self.callback_params = callback_params

def make_callback(
self, system: SystemType, time: np.floating, current_step: int
self, system: "RodType | RigidBodyType", time: np.floating, current_step: int
) -> None:

if current_step % self.sample_every == 0:
Expand Down Expand Up @@ -176,7 +176,9 @@ def __init__(
self.file_save_interval = file_save_interval

# Data collector
self.buffer = defaultdict(list)
self.buffer: dict[str, list[NDArray[np.floating] | np.floating | int]] = (
defaultdict(list)
)
self.buffer_size = 0

# Module
Expand All @@ -199,7 +201,7 @@ def __init__(
self._ext = "pkl"

def make_callback(
self, system: SystemType, time: np.floating, current_step: int
self, system: "RodType | RigidBodyType", time: np.floating, current_step: int
) -> None:
"""

Expand Down
Loading