diff --git a/myogen/__init__.py b/myogen/__init__.py index c71d3369..94795c8e 100644 --- a/myogen/__init__.py +++ b/myogen/__init__.py @@ -1,5 +1,9 @@ import warnings +from myogen._cuda_env import setup as _setup_cuda +_setup_cuda() +del _setup_cuda + import numpy as np from numpy.random import Generator diff --git a/myogen/_cuda_env.py b/myogen/_cuda_env.py new file mode 100644 index 00000000..56a05a4d --- /dev/null +++ b/myogen/_cuda_env.py @@ -0,0 +1,37 @@ +""" +Windows CUDA DLL discovery for pip-installed nvidia-* packages. + +CuPy 13 cannot locate DLLs shipped by ``pip install nvidia-cuda-nvrtc-cu12`` +and similar wheels because Python 3.8+ no longer searches PATH for DLLs. +This module registers those directories via ``os.add_dll_directory`` and +pre-loads the NVRTC builtins library required for JIT compilation. + +No-op on Linux / macOS and when the nvidia packages are absent. +""" + +import sys + + +def setup() -> None: + """Register CUDA DLL paths from pip-installed nvidia-* packages.""" + if sys.platform != "win32": + return + + import ctypes + import os + import pathlib + + for site_dir in [pathlib.Path(p) for p in sys.path if "site-packages" in p]: + # Register every nvidia/*/bin so cublas, cusolver, cufft etc. are found + for bin_dir in site_dir.glob("nvidia/*/bin"): + if bin_dir.is_dir(): + os.add_dll_directory(str(bin_dir)) + + # Pre-load nvrtc-builtins (required by CuPy for JIT kernel compilation) + for nvrtc_dll in site_dir.glob( + "nvidia/cuda_nvrtc/bin/nvrtc-builtins*.dll" + ): + try: + ctypes.WinDLL(str(nvrtc_dll)) + except OSError: + pass diff --git a/myogen/simulator/core/emg/intramuscular/bioelectric.py b/myogen/simulator/core/emg/intramuscular/bioelectric.py index 77f7f0a0..ba6edb68 100644 --- a/myogen/simulator/core/emg/intramuscular/bioelectric.py +++ b/myogen/simulator/core/emg/intramuscular/bioelectric.py @@ -47,7 +47,7 @@ def get_tm_current(z: np.ndarray, D1: float = 96.0, D2: float = -90.0) -> np.nda return Vm -def get_tm_current_dz(z: np.ndarray, D1: float = 96.0) -> np.ndarray: +def get_tm_current_dz(z: np.ndarray, D1: float = 96.0, xp=np) -> np.ndarray: """ Calculate first derivative of transmembrane current (Rosenfalck model). @@ -60,16 +60,18 @@ def get_tm_current_dz(z: np.ndarray, D1: float = 96.0) -> np.ndarray: Spatial coordinates along fiber in mm D1 : float, default=96.0 Current amplitude parameter in mV/mm³ + xp : module, default=np + Array backend (numpy or cupy) Returns ------- np.ndarray First derivative of transmembrane current """ - Vm = np.zeros_like(z, dtype=np.float64) + Vm = xp.zeros_like(z, dtype=xp.float64) pos_mask = z > 0 z_pos = z[pos_mask] - Vm[pos_mask] = D1 * (3 * z_pos**2 - z_pos**3) * np.exp(-z_pos) + Vm[pos_mask] = D1 * (3 * z_pos**2 - z_pos**3) * xp.exp(-z_pos) return Vm @@ -102,6 +104,7 @@ def get_elementary_current_response( r: np.ndarray, sigma_r: float = 63.0, # S/m sigma_z: float = 330.0, # S/m + xp=np, ) -> np.ndarray: """ Calculate elementary current response for volume conductor. @@ -122,6 +125,8 @@ def get_elementary_current_response( Radial conductivity in S/m (from Andreassen & Rosenfalck 1980) sigma_z : float, default=330.0 Longitudinal conductivity in S/m (from Andreassen & Rosenfalck 1980) + xp : module, default=np + Array backend (numpy or cupy) Returns ------- @@ -133,13 +138,17 @@ def get_elementary_current_response( sigma_r_S_per_mm = sigma_r / 1000.0 # CORRECTED: convert S/m → S/mm sigma_z_S_per_mm = sigma_z / 1000.0 # CORRECTED: convert S/m → S/mm - return np.divide( - 1 / 4 / np.pi / sigma_r_S_per_mm, - np.sqrt(sigma_z_S_per_mm / sigma_r_S_per_mm * r**2 + (z - z_electrode) ** 2), + # Normalize inputs to computation backend (prevents numpy/cupy mixing) + z = xp.asarray(z) + z_electrode = float(z_electrode) + + return xp.divide( + 1 / 4 / xp.pi / sigma_r_S_per_mm, + xp.sqrt(sigma_z_S_per_mm / sigma_r_S_per_mm * r**2 + (z - z_electrode) ** 2), ) -def shift_padding(vec, sh, axis): +def shift_padding(vec, sh, axis, xp=np): """ Circularly shifts 'vec' by 'sh' positions along the specified 'axis' and then pads the shifted region with zeros. @@ -152,21 +161,25 @@ def shift_padding(vec, sh, axis): Shift amount (positive means downward/rightward like MATLAB). axis : int Axis along which to shift. + xp : module, default=np + Array backend (numpy or cupy). Returns ------- ndarray Shifted and zero-padded array. """ - vec = np.roll(vec, sh, axis=axis) + vec = xp.roll(vec, sh, axis=axis) - n = len(vec) + n = vec.shape[0] # Equivalent of vec(1:sh) = 0 if sh > 0: vec[:sh] = 0 # Equivalent of vec(end+sh+1:end) = 0 + # Note: when sh > 0, both head AND tail are zeroed — this is the + # original MATLAB semantics (suppress wrap-around on both sides). if sh < 0: start = n + sh # because end+sh+1 in MATLAB is 1-based if start < n: @@ -215,7 +228,8 @@ def hr_shift_template(x, delay): def get_current_density( - t, z, zi, L1, L2, v, d=55e-3, suppress_endplate_density=True, endplate_width=0.5 + t, z, zi, L1, L2, v, d=55e-3, suppress_endplate_density=True, endplate_width=0.5, + xp=np, ): """ Model the individual action potential (IAP) or single fiber action potential (SFAP) in space and time. @@ -241,12 +255,25 @@ def get_current_density( Whether to suppress density at endplate region (default: True) endplate_width : float, optional Width around endplate where density is suppressed (mm) + xp : module, default=np + Array backend (numpy or cupy). Pass cupy for GPU acceleration. """ - dz = np.mean(np.diff(z, axis=0)) - z = np.concatenate([z, z[[-1]] + dz], axis=0) + # Normalize inputs to computation backend (prevents numpy/cupy mixing + # when callers pass numpy arrays with xp=cupy) + t = xp.asarray(t) + z = xp.asarray(z) + zi = float(zi) + L1 = float(L1) + L2 = float(L2) + v = float(v) + d = float(d) + + dz = xp.mean(xp.diff(z, axis=0)) + z = xp.concatenate([z, z[[-1]] + dz], axis=0) - T, Z = np.meshgrid(t, z) + # ravel() needed: t,z arrive as (N,1) column vectors; meshgrid expects 1-D + T, Z = xp.meshgrid(xp.ravel(t), xp.ravel(z)) # Tendon terminator function def tendon_terminator(z_inline, L_inline): @@ -254,20 +281,23 @@ def tendon_terminator(z_inline, L_inline): # Compute psi (transmembrane current derivative) if L1 >= L2: - psi = -4 * get_tm_current_dz(-2 * (Z - zi - v * T)) - longest_wave = np.diff(psi, axis=0) / dz + psi = -4 * get_tm_current_dz(-2 * (Z - zi - v * T), xp=xp) + longest_wave = xp.diff(psi, axis=0) / dz longest_wave *= tendon_terminator(Z[:-1, :] - zi - L1 / 2, L1) - longest_wave *= (Z[:-1, :] - zi) / v > 0 # negative time suppression + # Explicit bool→float64 cast required: CuPy does not support + # implicit multiplication of bool arrays with float arrays. + longest_wave *= ((Z[:-1, :] - zi) / v > 0).astype(xp.float64) else: - psi = 4 * get_tm_current_dz(-2 * (-Z + zi - v * T)) - longest_wave = np.diff(psi, axis=0) / dz + psi = 4 * get_tm_current_dz(-2 * (-Z + zi - v * T), xp=xp) + longest_wave = xp.diff(psi, axis=0) / dz longest_wave *= tendon_terminator(Z[:-1, :] - zi + L2 / 2, L2) - longest_wave *= (-Z[:-1, :] + zi) / v > 0 + longest_wave *= ((-Z[:-1, :] + zi) / v > 0).astype(xp.float64) # bool→float64 # Shortest wave (reversed) shortest_wave = longest_wave[::-1].copy() - shift_amount = int(np.round((L1 + L2 - max(z) + L2 - L1) / dz)) - shortest_wave = shift_padding(shortest_wave, shift_amount, 0) + # Use round(float(...)) to avoid device-to-host sync from int(xp.round(...)) + shift_amount = round(float((L1 + L2 - float(z.max()) + L2 - L1) / dz)) + shortest_wave = shift_padding(shortest_wave, shift_amount, axis=0, xp=xp) if L1 >= L2: shortest_wave *= tendon_terminator(Z[:-1, :] - zi + L2 / 2, L2) @@ -278,11 +308,8 @@ def tendon_terminator(z_inline, L_inline): # Suppress endplate density if required if suppress_endplate_density: - - def endplate_terminator(z_inline): - return (z_inline <= (zi - endplate_width)) | (z_inline >= (zi + endplate_width)) - - iap *= endplate_terminator(Z[:-1, :]) + iap *= ((Z[:-1, :] <= (zi - endplate_width)) | + (Z[:-1, :] >= (zi + endplate_width))).astype(xp.float64) # ---- FIXED UNIT CONVERSIONS ---- # Intracellular conductivity: 1.01 S/m → convert to S/mm @@ -291,7 +318,7 @@ def endplate_terminator(z_inline): # Fiber diameter is already in mm (default d=55e-3 mm = 55 um) # Compute cross-sectional area in mm² - area_mm2 = np.pi * (d / 2) ** 2 # CORRECTED: removed extra /4 + area_mm2 = xp.pi * (d / 2) ** 2 # CORRECTED: removed extra /4 # Scale current density by intracellular conductivity and fiber cross-section area iap *= sigma_i * area_mm2 diff --git a/myogen/simulator/core/emg/intramuscular/motor_unit_sim.py b/myogen/simulator/core/emg/intramuscular/motor_unit_sim.py index 387fa851..cf3bd6fb 100644 --- a/myogen/simulator/core/emg/intramuscular/motor_unit_sim.py +++ b/myogen/simulator/core/emg/intramuscular/motor_unit_sim.py @@ -8,6 +8,8 @@ Based on the MU_Sim class from the MATLAB iemg_simulator. """ +import os +from contextlib import nullcontext from typing import Optional, List import numpy as np @@ -15,6 +17,13 @@ from sklearn.cluster import KMeans from tqdm import tqdm +try: + import cupy as cp + + HAS_CUPY = cp.cuda.runtime.getDeviceCount() > 0 +except Exception: + HAS_CUPY = False + from myogen import derive_subseed, get_random_generator from myogen.utils.decorators import beartowertype from .bioelectric import ( @@ -240,6 +249,7 @@ def calc_sfaps( electrode_normals: Optional[np.ndarray] = None, min_radial_dist: Optional[float] = None, verbose: bool = True, + use_gpu: Optional[bool] = None, ): """ Calculate single fiber action potentials (SFAPs) for all fibers. @@ -258,6 +268,14 @@ def calc_sfaps( Minimum radial distance for stability (default: mean diameter * 1000) verbose : bool, default=True If True, display progress bars. Set to False to disable. + use_gpu : bool or None, default=None + GPU acceleration control: + - None → auto: use GPU if CuPy is available and MYOGEN_DISABLE_GPU + is not set. + - True → require GPU; raises RuntimeError if unavailable. + - False → force CPU execution. + Note: CuPy only supports NVIDIA GPUs (CUDA). AMD/ROCm is not + supported. """ self.dt = dt self.dz = dz @@ -295,69 +313,161 @@ def calc_sfaps( )[..., None] self.sfaps = np.zeros((len(t), self.Npt, self._number_of_muscle_fibers)) - for fiber_idx in tqdm( - range(self._number_of_muscle_fibers), - desc=f"MU {index}: Calculating SFAPs", - unit="fiber", - disable=not verbose, - ): - z_left = np.arange( - start=self._neuromuscular_z_coordinates__mm[fiber_idx], - step=-dz, - stop=self._muscle_fiber_left_ends__mm[fiber_idx] - dz, + # GPU acceleration: CUDA streams + direct xp dispatch for zero-overhead + # kernel execution. Model-agnostic: original functions called as black + # boxes with xp parameter, works with any upstream model change. + # + # Tri-state use_gpu logic: + # None → auto (GPU if available and env var not set) + # True → require GPU (raise if unavailable) + # False → force CPU + if use_gpu is True and not HAS_CUPY: + raise RuntimeError( + "use_gpu=True but CuPy is not available or no CUDA GPU detected. " + "Install with: pip install cupy-cuda12x" ) - z_right = np.arange( - start=self._neuromuscular_z_coordinates__mm[fiber_idx], - step=dz, - stop=self._muscle_fiber_right_ends__mm[fiber_idx] + dz, + if use_gpu is None: + _gpu_disabled_env = os.environ.get("MYOGEN_DISABLE_GPU", "").lower() in ( + "1", "true", "yes", ) - z = np.concatenate((z_left[::-1], z_right[1:]))[:, None] - mf_coord_3d = np.concatenate( - [ - np.matlib.repmat( - a=self.muscle_fiber_centers__mm[fiber_idx], m=len(z), n=1 - ), - z, - ], - axis=1, + _run_on_gpu = HAS_CUPY and not _gpu_disabled_env + else: + _run_on_gpu = use_gpu + if _run_on_gpu: + t_dev = cp.asarray(t) + _N_STREAMS = 4 + _streams = [cp.cuda.Stream(non_blocking=True) for _ in range(_N_STREAMS)] + sfaps_dev = cp.zeros( + (len(t), self.Npt, self._number_of_muscle_fibers), dtype=cp.float64 ) - current_density = get_current_density( - t, - z, - self._neuromuscular_z_coordinates__mm[fiber_idx], - self._muscle_fiber_right_ends__mm[fiber_idx] - - self._neuromuscular_z_coordinates__mm[fiber_idx], - self._neuromuscular_z_coordinates__mm[fiber_idx] - - self._muscle_fiber_left_ends__mm[fiber_idx], - self.muscle_fiber_conduction_velocity__mm_per_s[fiber_idx], - self.muscle_fiber_diameters__mm[fiber_idx], - ) + for fiber_idx in tqdm( + range(self._number_of_muscle_fibers), + desc=f"MU {index}: Calculating SFAPs (GPU)", + unit="fiber", + disable=not verbose, + ): + stream = _streams[fiber_idx % _N_STREAMS] + with stream: + z_left = np.arange( + start=self._neuromuscular_z_coordinates__mm[fiber_idx], + step=-dz, + stop=self._muscle_fiber_left_ends__mm[fiber_idx] - dz, + ) + z_right = np.arange( + start=self._neuromuscular_z_coordinates__mm[fiber_idx], + step=dz, + stop=self._muscle_fiber_right_ends__mm[fiber_idx] + dz, + ) + z = np.concatenate((z_left[::-1], z_right[1:]))[:, None] + + z_dev = cp.asarray(z) + + current_density = get_current_density( + t_dev, + z_dev, + self._neuromuscular_z_coordinates__mm[fiber_idx], + self._muscle_fiber_right_ends__mm[fiber_idx] + - self._neuromuscular_z_coordinates__mm[fiber_idx], + self._neuromuscular_z_coordinates__mm[fiber_idx] + - self._muscle_fiber_left_ends__mm[fiber_idx], + self.muscle_fiber_conduction_velocity__mm_per_s[fiber_idx], + self.muscle_fiber_diameters__mm[fiber_idx], + xp=cp, + ) - for electrode_idx in range(self.Npt): - # Calculate radial distance from fiber to electrode - radial_distance = np.sqrt( - np.sum( - ( - electrode_positions[electrode_idx, :2] - - self.muscle_fiber_centers__mm[fiber_idx] + ecr_list = [] + for electrode_idx in range(self.Npt): + radial_distance = np.sqrt( + np.sum( + ( + electrode_positions[electrode_idx, :2] + - self.muscle_fiber_centers__mm[fiber_idx] + ) + ** 2, + keepdims=True, + ) ) - ** 2, - keepdims=True, - ) + if radial_distance < min_radial_dist: + radial_distance = min_radial_dist + + r_dev = cp.asarray(radial_distance) + ecr_list.append(get_elementary_current_response( + z_dev, + electrode_positions[electrode_idx, 2], + r_dev, + xp=cp, + )) + + ecr_all = cp.column_stack(ecr_list) + sfaps_dev[:, :, fiber_idx] = current_density.T @ ecr_all + + cp.cuda.Device(0).synchronize() + self.sfaps = cp.asnumpy(sfaps_dev) + else: + t_dev = t + for fiber_idx in tqdm( + range(self._number_of_muscle_fibers), + desc=f"MU {index}: Calculating SFAPs", + unit="fiber", + disable=not verbose, + ): + z_left = np.arange( + start=self._neuromuscular_z_coordinates__mm[fiber_idx], + step=-dz, + stop=self._muscle_fiber_left_ends__mm[fiber_idx] - dz, + ) + z_right = np.arange( + start=self._neuromuscular_z_coordinates__mm[fiber_idx], + step=dz, + stop=self._muscle_fiber_right_ends__mm[fiber_idx] + dz, + ) + z = np.concatenate((z_left[::-1], z_right[1:]))[:, None] + mf_coord_3d = np.concatenate( + [ + np.matlib.repmat( + a=self.muscle_fiber_centers__mm[fiber_idx], m=len(z), n=1 + ), + z, + ], + axis=1, ) - if radial_distance < min_radial_dist: - radial_distance = min_radial_dist - response_to_elem_current = get_elementary_current_response( + current_density = get_current_density( + t_dev, z, - electrode_positions[electrode_idx, 2], - radial_distance, + self._neuromuscular_z_coordinates__mm[fiber_idx], + self._muscle_fiber_right_ends__mm[fiber_idx] + - self._neuromuscular_z_coordinates__mm[fiber_idx], + self._neuromuscular_z_coordinates__mm[fiber_idx] + - self._muscle_fiber_left_ends__mm[fiber_idx], + self.muscle_fiber_conduction_velocity__mm_per_s[fiber_idx], + self.muscle_fiber_diameters__mm[fiber_idx], ) - self.sfaps[:, electrode_idx, fiber_idx] = ( - current_density.T @ response_to_elem_current - )[:, 0] + ecr_list = [] + for electrode_idx in range(self.Npt): + radial_distance = np.sqrt( + np.sum( + ( + electrode_positions[electrode_idx, :2] + - self.muscle_fiber_centers__mm[fiber_idx] + ) + ** 2, + keepdims=True, + ) + ) + if radial_distance < min_radial_dist: + radial_distance = min_radial_dist + + ecr_list.append(get_elementary_current_response( + z, + electrode_positions[electrode_idx, 2], + radial_distance, + )) + + ecr_all = np.column_stack(ecr_list) + self.sfaps[:, :, fiber_idx] = current_density.T @ ecr_all self.shift_sfaps(dt) diff --git a/pyproject.toml b/pyproject.toml index 719f4b14..7fd04ddb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,9 @@ nwb = [ "pynwb>=2.8.0", "nwbinspector>=0.5.0", ] +gpu = [ + "cupy-cuda12x>=13.0,<14", +] [project.urls] Homepage = "https://nsquaredlab.github.io/MyoGen/" diff --git a/tests/test_gpu_parity.py b/tests/test_gpu_parity.py new file mode 100644 index 00000000..2335c401 --- /dev/null +++ b/tests/test_gpu_parity.py @@ -0,0 +1,145 @@ +""" +GPU parity test: verify CuPy backend produces bit-identical SFAPs to NumPy. + +Skipped automatically when CuPy is not installed or no CUDA GPU is available. +""" + +import numpy as np +import pytest + +from myogen.simulator.core.emg.intramuscular.bioelectric import ( + get_current_density, + get_elementary_current_response, + shift_padding, +) + +# Graceful skip: pytest.importorskip handles missing CuPy, then we +# additionally check that a real CUDA device is present (e.g. Docker +# images may have CuPy installed without a GPU). +cupy = pytest.importorskip("cupy", reason="CuPy not installed") + + +@pytest.fixture(autouse=True) +def _require_cuda(): + """Skip all tests in this module if no CUDA device is available.""" + if not cupy.cuda.is_available(): + pytest.skip("CUDA device not available") + + +# ── Helpers ────────────────────────────────────────────────────────────────── + +def _make_inputs(): + """Return canonical test inputs matching calc_sfaps conventions.""" + t = np.linspace(0, 0.01, 200)[:, None] # (Nt, 1) + z = np.linspace(0, 120, 500)[:, None] # (Nz, 1) + zi = 60.0 # neuromuscular junction position (mm) + L1 = 60.0 # right half-fiber length (mm) + L2 = 60.0 # left half-fiber length (mm) + v = 4000.0 # conduction velocity (mm/s) + d = 0.055 # fiber diameter (mm) + z_elec = 60.0 # electrode z-position (mm) + r = np.array([2.0]) # radial distance (mm) + return t, z, zi, L1, L2, v, d, z_elec, r + + +# ── GPU Parity Tests ───────────────────────────────────────────────────────── + +class TestGPUParity: + """Verify that CuPy-accelerated bioelectric kernels match CPU output.""" + + def test_current_density_parity(self): + """get_current_density: GPU vs CPU xcorr ≈ 1.0.""" + cp = cupy + + t, z, zi, L1, L2, v, d, _, _ = _make_inputs() + + cd_cpu = get_current_density(t, z, zi, L1, L2, v, d) + cd_gpu = get_current_density( + cp.asarray(t), cp.asarray(z), zi, L1, L2, v, d, xp=cp, + ) + cd_back = cp.asnumpy(cd_gpu) + + assert cd_cpu.shape == cd_back.shape + xcorr = np.corrcoef(cd_cpu.ravel(), cd_back.ravel())[0, 1] + rmse = float(np.sqrt(np.mean((cd_cpu - cd_back) ** 2))) + assert xcorr > 0.999999, f"xcorr={xcorr}" + assert rmse < 1e-12, f"RMSE={rmse}" + + def test_elementary_response_parity(self): + """get_elementary_current_response: GPU vs CPU.""" + cp = cupy + + _, z, _, _, _, _, _, z_elec, r = _make_inputs() + + ecr_cpu = get_elementary_current_response(z, z_elec, r) + ecr_gpu = get_elementary_current_response( + cp.asarray(z), z_elec, cp.asarray(r), xp=cp, + ) + ecr_back = cp.asnumpy(ecr_gpu) + + assert ecr_cpu.shape == ecr_back.shape + xcorr = np.corrcoef(ecr_cpu.ravel(), ecr_back.ravel())[0, 1] + assert xcorr > 0.999999, f"xcorr={xcorr}" + + def test_full_sfap_pipeline_parity(self): + """Full SFAP chain: current_density.T @ response → same on CPU and GPU.""" + cp = cupy + + t, z, zi, L1, L2, v, d, z_elec, r = _make_inputs() + + # CPU + cd_cpu = get_current_density(t, z, zi, L1, L2, v, d) + ecr_cpu = get_elementary_current_response(z, z_elec, r) + sfap_cpu = (cd_cpu.T @ ecr_cpu)[:, 0] + + # GPU + t_g, z_g, r_g = cp.asarray(t), cp.asarray(z), cp.asarray(r) + cd_gpu = get_current_density(t_g, z_g, zi, L1, L2, v, d, xp=cp) + ecr_gpu = get_elementary_current_response(z_g, z_elec, r_g, xp=cp) + sfap_gpu = cp.asnumpy((cd_gpu.T @ ecr_gpu)[:, 0]) + + xcorr = np.corrcoef(sfap_cpu, sfap_gpu)[0, 1] + rmse = float(np.sqrt(np.mean((sfap_cpu - sfap_gpu) ** 2))) + assert xcorr > 0.999999, f"xcorr={xcorr}" + assert rmse < 1e-12, f"RMSE={rmse}" + + +# ── shift_padding contract tests ───────────────────────────────────────────── + +class TestShiftPadding: + """Verify shift_padding semantics for positive, negative, and zero shifts.""" + + def test_positive_shift(self): + """Positive shift: head and tail zeroed (original MATLAB semantics).""" + vec = np.arange(1, 11, dtype=float) # [1..10] + result = shift_padding(vec.copy(), 3, axis=0) + # roll(3) wraps [8,9,10,1,2,3,4,5,6,7], then [:3]=0 and [-3:]=0 + assert np.all(result[:3] == 0), f"head not zeroed: {result[:3]}" + assert np.all(result[-3:] == 0), f"tail not zeroed: {result[-3:]}" + # middle values preserved + np.testing.assert_array_equal(result[3:7], [1, 2, 3, 4]) + + def test_negative_shift(self): + """Negative shift: only tail zeroed.""" + vec = np.arange(1, 11, dtype=float) + result = shift_padding(vec.copy(), -3, axis=0) + # roll(-3) wraps [4,5,6,7,8,9,10,1,2,3], then [7:]=0 + assert np.all(result[-3:] == 0), f"tail not zeroed: {result[-3:]}" + np.testing.assert_array_equal(result[:7], [4, 5, 6, 7, 8, 9, 10]) + + def test_zero_shift(self): + """Zero shift: array unchanged.""" + vec = np.arange(1, 11, dtype=float) + result = shift_padding(vec.copy(), 0, axis=0) + np.testing.assert_array_equal(result, vec) + + def test_gpu_parity(self): + """shift_padding: GPU vs CPU produce identical output.""" + cp = cupy + vec = np.random.default_rng(42).random(100) + for sh in [-5, 0, 5, 20]: + cpu_result = shift_padding(vec.copy(), sh, axis=0) + gpu_result = cp.asnumpy( + shift_padding(cp.asarray(vec.copy()), sh, axis=0, xp=cp) + ) + np.testing.assert_array_equal(cpu_result, gpu_result, err_msg=f"sh={sh}") diff --git a/uv.lock b/uv.lock index bb0cbc26..041c23a0 100644 --- a/uv.lock +++ b/uv.lock @@ -1479,7 +1479,7 @@ wheels = [ [[package]] name = "myogen" -version = "0.8.5" +version = "0.9.0" source = { editable = "." } dependencies = [ { name = "beartype" }, @@ -1514,10 +1514,12 @@ nwb = [ [package.dev-dependencies] dev = [ + { name = "elephant" }, { name = "pandas-stubs" }, { name = "poethepoet" }, { name = "pytest" }, { name = "scipy-stubs" }, + { name = "viziphant" }, ] docs = [ { name = "elephant" }, @@ -1569,10 +1571,12 @@ provides-extras = ["elephant", "nwb"] [package.metadata.requires-dev] dev = [ + { name = "elephant", specifier = ">=1.1.1" }, { name = "pandas-stubs", specifier = ">=2.3.0.250703" }, { name = "poethepoet", specifier = ">=0.37.0" }, { name = "pytest", specifier = ">=8.0" }, { name = "scipy-stubs", specifier = ">=1.16.1.0" }, + { name = "viziphant", specifier = ">=0.4.0" }, ] docs = [ { name = "elephant", specifier = ">=1.1.1" },