diff --git a/abel/__init__.py b/abel/__init__.py index b039e05f..d9efb147 100644 --- a/abel/__init__.py +++ b/abel/__init__.py @@ -51,6 +51,9 @@ from .classes.source.impl.source_flattop import SourceFlatTop from .classes.source.impl.source_capsule import SourceCapsule from .classes.stage.impl.stage_basic import StageBasic + +from .classes.stage.impl.stage_spin import StageSpin + from .classes.stage.impl.stage_hipace import StageHipace from .classes.stage.impl.stage_wake_t import StageWakeT from .classes.stage.impl.stage_quasistatic_2d import StageQuasistatic2d diff --git a/abel/classes/source/source.py b/abel/classes/source/source.py index d4d4829c..e50a573c 100644 --- a/abel/classes/source/source.py +++ b/abel/classes/source/source.py @@ -74,7 +74,7 @@ class Source(Trackable, CostModeled): #TODO: Why is accel_gradient needed?? @abstractmethod - def __init__(self, length=0, charge=None, energy=None, accel_gradient=None, wallplug_efficiency=1, x_offset=0, y_offset=0, x_angle=0, y_angle=0, norm_jitter_emittance_x=None, norm_jitter_emittance_y=None, waist_shift_x=0, waist_shift_y=0, rep_rate_trains=None, num_bunches_in_train=None, bunch_separation=None): + def __init__(self, length=0, charge=None, energy=None, accel_gradient=None, wallplug_efficiency=1, x_offset=0, y_offset=0, x_angle=0, y_angle=0, norm_jitter_emittance_x=None, norm_jitter_emittance_y=None, waist_shift_x=0, waist_shift_y=0, rep_rate_trains=None, num_bunches_in_train=None, bunch_separation=None, spin_polarization=0.0, spin_polarization_direction='z'): super().__init__(num_bunches_in_train=num_bunches_in_train, bunch_separation=bunch_separation, rep_rate_trains=rep_rate_trains) @@ -83,6 +83,9 @@ def __init__(self, length=0, charge=None, energy=None, accel_gradient=None, wall self.charge = charge self.accel_gradient = accel_gradient self.wallplug_efficiency = wallplug_efficiency + + self.spin_polarization = spin_polarization + self.spin_polarization_direction = spin_polarization_direction self.x_offset = x_offset self.y_offset = y_offset @@ -104,10 +107,7 @@ def __init__(self, length=0, charge=None, energy=None, accel_gradient=None, wall self.jitter.yp = 0 self.jitter.E = 0 - self.spin_polarization = 0 - self.spin_polarization_direction = 'z' - - self.is_polarized = False #TODO shouldn't this rather be a function that checks whether spin_polarization > 0? + self.is_polarized = False # TODO write a function instead to check this. @abstractmethod diff --git a/abel/classes/stage/impl/stage_spin.py b/abel/classes/stage/impl/stage_spin.py new file mode 100644 index 00000000..059eeb6e --- /dev/null +++ b/abel/classes/stage/impl/stage_spin.py @@ -0,0 +1,429 @@ +from abel.classes.stage.stage import Stage +from abel.classes.source.source import Source +from abel.physics_models.spin_tracking import * +from abel.utilities.plasma_physics import k_p +from abel.utilities.relativity import energy2proper_velocity, proper_velocity2energy, momentum2proper_velocity, proper_velocity2momentum, proper_velocity2gamma, energy2gamma, gamma2momentum +import numpy as np +import scipy.constants as SI +import copy +import warnings +from scipy.signal import hilbert + +SI.r_e = SI.physical_constants['classical electron radius'][0] + + +class StageSpin(Stage): + + def __init__(self, nom_accel_gradient=None, nom_energy_gain=None, plasma_density=None, driver_source=None, ramp_beta_mag=None, transformer_ratio=1, calc_evolution=False): + + super().__init__(nom_accel_gradient=nom_accel_gradient, nom_energy_gain=nom_energy_gain, plasma_density=plasma_density, driver_source=driver_source, ramp_beta_mag=ramp_beta_mag) + + self.transformer_ratio = transformer_ratio + self.calc_evolution = calc_evolution + + + + + def track(self, beam_incoming, savedepth=0, runnable=None, verbose=False): + + driver_incoming = self.driver_source.track() + + if self.plasma_density is None: + self.optimize_plasma_density(beam_incoming) + + beam0 = beam_incoming # TODO: check this later... + + self._prepare_ramps() + + if self.upramp is not None: + beam0, driver0 = self.track_upramp(beam_incoming, driver_incoming) + else: + beam0 = copy.deepcopy(beam_incoming) + driver0 = copy.deepcopy(driver_incoming) + if self.ramp_beta_mag is not None: + beam0.magnify_beta_function(1/self.ramp_beta_mag, axis_defining_beam=driver_incoming) + driver0.magnify_beta_function(1/self.ramp_beta_mag, axis_defining_beam=driver_incoming) + + beam = copy.deepcopy(beam0) + driver = copy.deepcopy(driver0) + + # compute plasma quantities + omega_p = np.sqrt(self.plasma_density * SI.e**2 / (SI.epsilon_0 * SI.m_e)) + kp = omega_p / SI.c + + # Initial particle properties + sx = beam.spxs() + sy = beam.spys() + sz = beam.spzs() + x0 = beam.xs() + y0 = beam.ys() + ux0 = beam.uxs() + uy0 = beam.uys() + E0 = np.array(beam.Es()) + gamma0s = energy2gamma(E0) + + # set initial radius and length using paper normalized units + #r0_norm = 0.1 # 0.1 * c/omega_p + #L_norm = 2.4e4 # 2.4e4 * c/omega_p + + #r0_m = r0_norm * SI.c / omega_p # meters (use this for initial x,y spread) + #L_m = L_norm * SI.c / omega_p # meters (use this for stage length) + + L = self.length_flattop #0.1275 + n_particles = len(beam) + + + final_spins = np.zeros((n_particles, 3), dtype=float) + all_spins = [] + all_spin_norms = [] + all_ss = [] + + # Uniform energy gain for all particles + #deltaEs = np.full(len(E0), self.nom_energy_gain) +#np.full(len(E0), self.nom_energy_gain_flattop) + #Ef = E0 + deltaEs + deltaE_eV = np.full_like(E0, self.nom_energy_gain) + Ef = E0 + deltaE_eV + gammaf = energy2gamma(Ef) + dgamma_ds = (gammaf - gamma0s) / L #(gammaf - gamma0s) / L + + last_ss = None + last_gamma = None + + for i in range(n_particles): #loop over all particles in the beam + + # Initial spin vector + S0 = np.array([sx[i], sy[i], sz[i]], dtype=float) + S0 /= np.linalg.norm(S0) if np.linalg.norm(S0) > 0 else 1.0 + + # Transverse motion evolution using Hill's equation + x, ux, gamma_arr = evolve_hills_equation_analytic_evolution(x0[i], ux0[i], L, gamma0s[i], dgamma_ds[i], k_p(self.plasma_density)) + + y, uy, _ = evolve_hills_equation_analytic_evolution(y0[i], uy0[i], L, gamma0s[i], dgamma_ds[i], k_p(self.plasma_density)) + + + N = len(x) + ss = np.linspace(0, L, N) # Different positions along the particle's path through the stage + ds = np.diff(ss) + + gamma = np.ravel(gamma_arr) if gamma_arr is not None else (gamma0s[i] + dgamma_ds[i] * ss) + ux = np.ravel(ux) + uy = np.ravel(uy) + beta_x = ux / gamma + beta_y = uy / gamma + beta_z = np.sqrt(np.clip(1.0 - beta_x**2 - beta_y**2, 1e-16, 1.0)) + + """ + "" + #Time steps + ds = np.diff(ss) + dt_steps = ds / (beta_z[:-1] * SI.c) # N-1 values? + dt_arr = np.concatenate(([dt_steps[0]], dt_steps)) #N-values? + + #Compute fields along trajectory + #Es = np.empty((N,3), dtype=float) + #Bs = np.empty((N,3), dtype=float) + #for j in range(N): + #r_vec = np.array([x[j], y[j], ss[j]], dtype=float) + #Es[j] = plasma_E_field(r_vec, j, k_p(self.plasma_density)) + #Bs[j] = plasma_B_field(r_vec, j, k_p(self.plasma_density)) + + #Track spin along the trajectory + S_hist = np.empty((N, 3), dtype=float) + S_hist[0,:] = S0 + spin_norms = [] + + + for j in range(1, N): + beta_vec = np.array([beta_x[j - 1], beta_y[j - 1], beta_z[j - 1]], dtype=float) + dt = dt_arr[j] #time spent between step j-1 and j + gamma_val = gamma[j - 1] + #E = Es[j - 1] + #B = Bs[j - 1] + r_vec = np.array([x[j], y[j], ss[j]], dtype=float) + E = plasma_E_field(r_vec, j, k_p(self.plasma_density), Ez0=Ez0) + B = plasma_B_field(r_vec, j, k_p(self.plasma_density), beta_z=beta_z[j - 1]) + + S_next = tbmt_boris_spin_update(S_hist[j -1], E, B, beta_vec, gamma_val, dt) + + if not np.all(np.isfinite(S_next)): + # revert to previous normalized spin + S_next = S_hist[j - 1].copy() + nrm = np.linalg.norm(S_next) + + if nrm < 1e-12 or not np.isfinite(nrm): + S_next = S_hist[j -1].copy() + else: + S_next /= nrm + + spin_norms.append(np.linalg.norm(S_next)) + + S_hist[j, :] = S_next + """ + + + #dt_steps = ds / (beta_z[:-1] * SI.c) + #dt_steps[~np.isfinite(dt_steps)] = 1e-20 + + # Ensure dt_steps has the correct length + #if len(dt_steps) != N - 1: + #dt_steps = np.full( + #N - 1, + #np.mean(dt_steps[np.isfinite(dt_steps)]) if np.any(np.isfinite(dt_steps)) else 1e-6 + #) + + S_hist = np.empty((N,3)) + S_hist[0, :] = S0 + a_e = 0.00115965218128 + tiny = 1e-20 + theta_max = 0.5 + + for j in range(1, N): + gamma_t = max(gamma[j-1], tiny) + x_t = x[j-1] # meters (ensure evolve_hills returns meters) + y_t = y[j-1] + + alpha_E = (SI.m_e * omega_p**2) / (2.0 * -SI.e) + #pref_omega_s = (e / m_e) * (a_e + 1.0 / gamma_t) * (alpha_E / c*gamma[j-1]) + #alpha_loc = (kp**2) / (2.0 * gamma_t) #0.5 + pref_omega_s = (SI.e / SI.m_e) * (a_e + 1.0 / gamma_t) * (alpha_E / SI.c) + # pref dimensionless * 1/m^2 => 1/m^2; multiplied by x (m) -> 1/m + #pref = alpha_loc * (a_e + 1.0 / gamma_t) + + Omega = np.array([-pref_omega_s * x_t, pref_omega_s * y_t, 0.0]) + Omega_mag = np.linalg.norm(Omega) + + ds_j = max(ds[j-1], 0.0) + dt = ds_j / (beta_z[j-1] * SI.c) + + if Omega_mag < 1e-18 or ds_j <= 0: + S_hist[j,:] = S_hist[j-1, :] + continue + + delta = Omega_mag * dt + d = (Omega / Omega_mag) * np.tan(0.5 * delta) # dimensionless + s_prev = S_hist[j-1,:] + s_prime = s_prev + np.cross(s_prev, d) + s_next = s_prev + 2.0*np.cross(s_prime, d)/(1.0 + np.dot(d,d)) + + s_next /= np.linalg.norm(s_next) + S_hist[j,:] = s_next / np.linalg.norm(s_next) + + # after loop + # after finishing the j-loop for particle i + spin_norms = np.linalg.norm(S_hist, axis=1) + all_spin_norms.append(spin_norms) + all_spins.append(S_hist.copy()) + S_last = S_hist[-1, :].copy() # last spin for this particle + final_spins[i, :] = S_last # assign into row i + + last_ss = ss + last_gamma = gamma + + + + print(f"Final spins stds: sx_std, sy_std, sz_std =", np.std(final_spins[:, 0]), np.std(final_spins[:, 1]), np.std(final_spins[:, 2])) + # ensure shape and finiteness + print("alpha_E [V/m]:", alpha_E) + print("pref_omega_s [1/s/m]:", pref_omega_s) + print(f"beta {beta_z}") + assert final_spins.shape == (n_particles, 3) + if not np.all(np.isfinite(final_spins)): + warnings.warn("NaN/Inf found in final_spins", RuntimeWarning) + + # optional plotting (first particle only to inspect) + if len(all_spins) > 0 and last_ss is not None and last_gamma is not None: + try: + plot_spin_tracking(all_spins, last_ss) + plot_spin_tracking_gamma(all_spins, last_gamma) + except Exception: + warnings.warn("Spin plots could not be generated.", RuntimeWarning) + + plt.figure() + for norms in all_spin_norms: + plt.plot(ss/1e3, norms, alpha=0.3) # one curve per particle + plt.title("Spin vector norm along stage") + plt.xlabel("s [m]") + plt.ylabel("|S|") + + + min_steps = min(sp.shape[0] for sp in all_spins) + S_arr = np.array([sp[:min_steps] for sp in all_spins]).transpose(1, 0, 2) + P_vec = np.mean(S_arr, axis=1) + P_evol = np.linalg.norm(P_vec, axis=1) + Pz_evol = np.mean(S_arr[:, :, 2], axis=1) + P0_vec = np.mean(S_arr[0, :, :], axis=0) + P0_mag = np.linalg.norm(P0_vec) + P0z = P0_vec[2] + Dz = ((P0z - Pz_evol) / max(abs(P0z), tiny)) * 100.0 + D = ((P0_mag - P_evol) / max(P0_mag, tiny)) * 100.0 + final_D = (P0_mag - P_evol[-1]) / P0_mag * 100 + + analytic_signal = hilbert(D) + envelope = np.abs(analytic_signal) + analytic_signal_z = hilbert(Dz) + envelope_z = np.abs(analytic_signal_z) + ss_plot = last_ss[:min_steps] + plt.figure() + #plt.plot(ss_plot, Dz, label=r'$\Delta P_z/P_{z0}$') + #plt.plot(ss_plot, D, label=r'$\Delta P/P_0$') + plt.plot(last_gamma/1e5, Dz, label=r'$\Delta P_z/P_{z0}$') + plt.plot(last_gamma/1e5, envelope, label=r'Envelope $|\Delta P/P_0|$', linewidth=2) + + plt.plot(last_gamma/1e5, D, label=r'$\Delta P/P_0$') + plt.plot(last_gamma/1e5, envelope_z, label=r'Envelope $|\Delta P_z/P_{z0}|$', linewidth=2, linestyle='--') + + plt.xlabel("Stage length [m]") + plt.xlabel("Gamma [10^5]") + plt.ylabel("Depolarization [%]") + #plt.ylim(min(D) + 0.25*min(D), 0) + plt.legend() + plt.grid(True) + + + data = D + #np.savetxt("depolarization_vs_gradient.txt", data, header="plasma_density[m^-3] depolarization[%]", fmt="%.5e") + #print("✅ Saved results to depolarization_vs_density.txt") + # --- prepare arrays (safe for variable step counts) --- + min_steps = min(sp.shape[0] for sp in all_spins) + ss_cut = ss[:min_steps] # positions + S_arr = np.array([sp[:min_steps] for sp in all_spins]) # shape (n_particles, min_steps, 3) + S_arr = S_arr.transpose(1,0,2) # shape (min_steps, n_particles, 3) + n_steps, n_part, _ = S_arr.shape + + # --- polarization vector evolution (ensemble mean) --- + P_vec = np.mean(S_arr, axis=1) # shape (n_steps, 3) + P_mag = np.linalg.norm(P_vec, axis=1) # ensemble polarization magnitude vs s + + # --- per-particle norms and deviations --- + norms = np.linalg.norm(S_arr, axis=2) # shape (n_steps, n_particles) + # mean norm deviation from 1.0 (absolute), and std across ensemble + mean_norm = np.mean(norms, axis=1) + std_norm = np.std(norms, axis=1) + dev_from_one = mean_norm - 1.0 + + # --- envelope of ensemble polarization decay (optional) --- + analytic = hilbert(P_mag - P_mag.mean()) # remove mean for envelope + envelope = np.abs(analytic) + + # --- Depolarization metrics --- + P0_vec = np.mean(S_arr[0,:,:], axis=0) + P0_mag = np.linalg.norm(P0_vec) + final_P_vec = P_vec[-1,:] + final_P_mag = np.linalg.norm(final_P_vec) + final_depol_percent = (P0_mag - final_P_mag) / (P0_mag + 1e-30) * 100.0 + self.final_depolarization = final_D + + print("P0_mag:", P0_mag) + print("final_P_mag:", final_P_mag) + print("final depolarization [%]:", final_depol_percent) + + # --- Plot 1: mean norm deviation in ppm (clear visualization) --- + plt.figure(figsize=(8,4)) + plt.plot(ss_cut, (dev_from_one)*1e6, label='mean Δ|S| (ppm)') + plt.fill_between(ss_cut, (dev_from_one - std_norm)*1e6, (dev_from_one + std_norm)*1e6, + color='C0', alpha=0.2, label='±1σ') + plt.xlabel("s [m]") + plt.ylabel("Δ|S| [ppm]") + plt.title("Spin norm deviation (mean ± σ) along stage") + plt.grid(True) + plt.legend() + plt.tight_layout() + + # --- Plot 2: Ensemble polarization magnitude + envelope and percent depol --- + plt.figure(figsize=(8,4)) + plt.plot(ss_cut, P_mag, label='|⟨S⟩|(s)') + plt.plot(ss_cut, envelope + P_mag.mean(), '--', label='Envelope (abs)') + plt.xlabel("s [m]") + plt.ylabel("Polarization magnitude |⟨S⟩|") + plt.title(f"Ensemble polarization (final depol = {final_depol_percent:.3e} %)") + plt.grid(True) + plt.legend() + plt.tight_layout() + + # --- Plot 3: fractional depolarization vs s in ppm --- + frac_depol = (P0_mag - P_mag) / (P0_mag + 1e-30) # fraction + plt.figure(figsize=(8,4)) + plt.plot(ss_cut, frac_depol*1e6, label='ΔP/P0 [ppm]') + plt.xlabel("s [m]") + plt.ylabel("Depolarization [ppm]") + plt.title("Fractional depolarization along stage") + plt.grid(True) + plt.legend() + plt.tight_layout() + min_steps = min(sp.shape[0] for sp in all_spins) + S_init = np.array([sp[0,:] for sp in all_spins]) # (n_part, 3) + S_final = np.array([sp[min_steps-1,:] for sp in all_spins]) + + P0_vec = np.mean(S_init, axis=0) + Pf_vec = np.mean(S_final, axis=0) + P0_mag = np.linalg.norm(P0_vec) + Pf_mag = np.linalg.norm(Pf_vec) + delta = Pf_mag - P0_mag + delta_percent = (P0_mag - Pf_mag)/ (P0_mag + 1e-30) * 100.0 + + print("P0_mag:", P0_mag) + print("Pf_mag:", Pf_mag) + print("Delta (Pf - P0) :", delta) + print("Delta percent (signed):", delta_percent, "%") + + # Bootstrap uncertainty on delta_percent + npart = S_init.shape[0] + nboot = 2000 + rng = np.random.default_rng(12345) + boot_vals = np.empty(nboot) + idx = np.arange(npart) + for k in range(nboot): + sel = rng.choice(idx, size=npart, replace=True) + P0_b = np.linalg.norm(np.mean(S_init[sel], axis=0)) + Pf_b = np.linalg.norm(np.mean(S_final[sel], axis=0)) + boot_vals[k] = (P0_b - Pf_b) / (P0_b + 1e-30) * 100.0 + + mean_boot = np.mean(boot_vals) + ci_low, ci_high = np.percentile(boot_vals, [2.5, 97.5]) + print(f"Bootstrap mean depol [%]: {mean_boot:.6e}, 95% CI = [{ci_low:.6e}, {ci_high:.6e}]") + return final_D + + beam.plot_spins() + beam.set_spxs(final_spins[:, 0]) + beam.set_spys(final_spins[:, 1]) + beam.set_spzs(final_spins[:, 2]) + beam.plot_spins() + # ========== Betatron oscillations ========== + deltaEs = np.full(len(beam.Es()), self.nom_energy_gain_flattop) + if self.calc_evolution: + _, evol = beam.apply_betatron_motion(self.length_flattop, self.plasma_density, deltaEs, x0_driver=driver0.x_offset(), y0_driver=driver0.y_offset(), calc_evolution=self.calc_evolution) + self.evolution.beam = evol + else: + beam.apply_betatron_motion(self.length_flattop, self.plasma_density, deltaEs, x0_driver=driver0.x_offset(), y0_driver=driver0.y_offset()) + + beam.set_Es(beam.Es() + self.nom_energy_gain_flattop) + + if isinstance(self.driver_source, Source) and (self.driver_source.jitter.xp != 0 or self.driver_source.x_angle != 0 or self.driver_source.jitter.yp != 0 or self.driver_source.y_angle != 0): + self._rotate_beams_back_to_original(beam, driver0) + + if self.downramp is not None: + beam_outgoing, driver_outgoing = self.track_downramp(beam, driver) + else: + beam_outgoing = copy.deepcopy(beam) + driver_outgoing = copy.deepcopy(driver) + if self.ramp_beta_mag is not None: + beam_outgoing.magnify_beta_function(self.ramp_beta_mag, axis_defining_beam=driver) + driver_outgoing.magnify_beta_function(self.ramp_beta_mag, axis_defining_beam=driver) + + if self._return_tracked_driver: + return super().track(beam_outgoing, savedepth, runnable, verbose), driver_outgoing + else: + return super().track(beam_outgoing, savedepth, runnable, verbose), all_spins + + def optimize_plasma_density(self, source): + extraction_efficiency = (self.transformer_ratio / 0.75) * abs(source.abs_charge() / self.driver_source.get_charge()) + + + energy_density_z_extracted = abs(source.get_charge()*self.nom_accel_gradient) + energy_density_z_wake = energy_density_z_extracted/extraction_efficiency + norm_blowout_radius = ((32*SI.r_e/(SI.m_e*SI.c**2))*energy_density_z_wake)**(1/4) + norm_accel_gradient = 1/3 * (norm_blowout_radius)**1.15 + wavebreaking_field = self.nom_accel_gradient / norm_accel_gradient + plasma_wavenumber = wavebreaking_field/(SI.m_e*SI.c**2/SI.e) + self.plasma_density = plasma_wavenumber**2*SI.m_e*SI.c**2*SI.epsilon_0/SI.e**2 diff --git a/abel/physics_models/spin_tracking.py b/abel/physics_models/spin_tracking.py new file mode 100644 index 00000000..a29d45d1 --- /dev/null +++ b/abel/physics_models/spin_tracking.py @@ -0,0 +1,161 @@ +from abel import * +import numpy as np +import scipy.constants as SI +import matplotlib.pyplot as plt +from scipy.integrate import odeint +from abel.physics_models import * +from abel.utilities.relativity import * +import scipy.special as scispec + +SI.r_e = SI.physical_constants['classical electron radius'][0] + +""" +Simulates particle motion and tracks spin motion using the T-BMT equation +""" + + +def evolve_hills_equation_analytic_evolution(x0, ux0, L, gamma0, dgamma_ds, kp=None, g=None, N=500): + s = np.linspace(0, L, N) + + xp0 = ux0 / gamma2proper_velocity(gamma0) + + if dgamma_ds == 0: + if g is None: + g = kp ** 2 * SI.m_e * SI.c / (2 * SI.e) + + k = g * SI.c / gamma2energy(gamma0) + + if k == 0: + x = x0 + xp0 * s + xp = np.full_like(s, xp0) + else: + x = np.real( + x0 * np.cos(k * s) + (xp0 / k) * np.sin(k * s) + ) + xp = np.real( + xp0 * np.cos(k * s) - x0 * k * np.sin(k * s) + ) + + gamma = np.full_like(s, gamma0) + + else: + if kp is None: + kp = np.sqrt(2 * g * SI.e / (SI.m_e * SI.c)) + + gamma = gamma0 + dgamma_ds * s + C = np.sqrt(2) * kp / dgamma_ds + A0 = C * np.sqrt(gamma0) + A = C * np.sqrt(gamma) + + Di = (kp**2 * x0 * scispec.iv(1, A0 * 1j) + A0 * 1j * dgamma_ds * xp0 * scispec.iv(0, A0 * 1j)) + Dk = (kp**2 * x0 * scispec.kv(1, A0 * 1j) - A0 * 1j * dgamma_ds * xp0 * scispec.kv(0, A0 * 1j)) + E = kp**2 * (scispec.iv(1, A0 * 1j) * scispec.kv(0, A0 * 1j) + + scispec.iv(0, A0 * 1j) * scispec.kv(1, A0 * 1j)) * 1j + + x = np.real(1j * (Di * scispec.kv(0, A * 1j) + Dk * scispec.iv(0, A * 1j)) / E) + xp = -np.real((dgamma_ds * C**2 / (2 * A * E)) * + (Dk * scispec.iv(1, A * 1j) - Di * scispec.kv(1, A * 1j))) + + ux = xp * gamma2proper_velocity(gamma) + + return x, ux, gamma + + + +def plasma_E_field(r_vec, t, k_p, Ez0=0.0): + x, y, _ = r_vec + Er_coeff = SI.m_e * (k_p**2) * SI.c**2 / (2*SI.e) + Ex = Er_coeff * x + Ey = Er_coeff * y + Ez = Ez0 # set from the accelerating gradient below + return np.array([Ex, Ey, Ez], dtype=float) + +def plasma_B_field(r_vec, t, k_p, beta_z=1.0): + x, y, _ = r_vec + Br_coeff = (SI.m_e * k_p**2 * SI.c / (2*SI.e)) * beta_z + Bx = -Br_coeff * y + By = Br_coeff * x + Bz = 0.0 + return np.array([Bx, By, Bz], dtype=float) + +def tbmt_boris_spin_update(S, E, B, beta_vec, gamma, dt): + q = - SI.e + m = SI.m_e + a = 0.00115965218128 + + # Ensure all inputs are proper scalars or 1D arrays + dt = float(np.ravel(dt)[0]) + + S = np.asarray(S, dtype=np.float64).reshape(3) + E = np.asarray(E, dtype=np.float64).reshape(3) + B = np.asarray(B, dtype=np.float64).reshape(3) + beta_vec = np.asarray(beta_vec, dtype=np.float64).reshape(3) + gamma = float(gamma) + + term1 = (a + 1.0 / gamma)*SI.c * B + term2 = - (a + 1.0/(gamma+1.0)) * np.cross(beta_vec, E) + term3 = - a*(gamma / (gamma + 1.0)) * (np.dot(beta_vec*SI.c , B) * beta_vec) + Omega = - (q / m) * (term1 + term2 + term3)/SI.c + + w = np.linalg.norm(Omega) + if w == 0 or dt == 0: + return S.copy() + + n = Omega / w + theta = w * dt + c, s = np.cos(theta), np.sin(theta) + + # Rodrigues rotation of S about axis n by angle theta + S_rot = S * c + np.cross(n, S) * s + n * np.dot(n, S) * (1 - c) + return S_rot + + + #return S+np.cross(Omega, S) * dt +""" +def plot_spin_tracking(all_spins, ss): + plt.figure(figsize=(10, 6)) + for i, spins in enumerate(all_spins): + plt.plot(ss, spins[:, 0], label=f"Spin X (Particle {i+1})") + plt.plot(ss, spins[:, 1], label=f"Spin Y (Particle {i+1})") + plt.plot(ss, spins[:, 2], label=f"Spin Z (Particle {i+1})") + break + + plt.title("Spin Tracking of Particles") + plt.xlabel("Stage Length (m)") + plt.ylabel("Spin Components") + plt.legend(loc="upper right") + plt.grid(True) + plt.show() +""" +def plot_spin_tracking(all_spins, ss): + if len(all_spins) == 0: + return + plt.figure(figsize=(10, 6)) + spins = all_spins[10] + #plt.plot(ss, spins[:, 0], label="Spin X") + #plt.plot(ss, spins[:, 1], label="Spin Y") + plt.plot(ss, spins[:, 2], label="Spin Z") + plt.title("Spin Tracking (first particle)") + plt.xlabel("Stage Length (m)") + plt.ylabel("Spin Components") + plt.legend(loc="upper right") + plt.grid(True) + plt.show() + +def plot_spin_tracking_gamma(all_spins, gammas): + plt.figure(figsize=(10, 6)) + for i, spins in enumerate(all_spins): + #plt.plot(gammas, spins[:, 0], label=f"Spin X (Particle {i+1})") + #plt.plot(gammas, spins[:, 1], label=f"Spin Y (Particle {i+1})") + plt.plot(gammas, spins[:, 2], label=f"Spin Z (Particle {i+1})") + break + + plt.title("Spin Tracking of Particles Over Gamma") + plt.xlabel("Lorentz Factor (γ)") + plt.ylabel("Spin Components") + plt.legend(loc="upper right") + plt.grid(True) + plt.show() + + + \ No newline at end of file