diff --git a/stationapi/src/domain/repository/line_repository.rs b/stationapi/src/domain/repository/line_repository.rs index 15c7f16d..1716c41d 100644 --- a/stationapi/src/domain/repository/line_repository.rs +++ b/stationapi/src/domain/repository/line_repository.rs @@ -15,6 +15,10 @@ pub trait LineRepository: Send + Sync + 'static { &self, station_group_id_vec: &[u32], ) -> Result, DomainError>; + async fn get_by_station_group_id_vec_no_types( + &self, + station_group_id_vec: &[u32], + ) -> Result, DomainError>; async fn get_by_line_group_id(&self, line_group_id: u32) -> Result, DomainError>; async fn get_by_line_group_id_vec( &self, @@ -229,6 +233,13 @@ mod tests { Ok(result) } + async fn get_by_station_group_id_vec_no_types( + &self, + station_group_id_vec: &[u32], + ) -> Result, DomainError> { + self.get_by_station_group_id_vec(station_group_id_vec).await + } + async fn get_by_line_group_id(&self, line_group_id: u32) -> Result, DomainError> { Ok(self .lines_by_line_group_id diff --git a/stationapi/src/domain/repository/station_repository.rs b/stationapi/src/domain/repository/station_repository.rs index 40bbe19b..7b1a374d 100644 --- a/stationapi/src/domain/repository/station_repository.rs +++ b/stationapi/src/domain/repository/station_repository.rs @@ -47,6 +47,11 @@ pub trait StationRepository: Send + Sync + 'static { &self, line_group_ids: &[u32], ) -> Result, DomainError>; + async fn get_bus_stops_near_stations( + &self, + coords: &[(u32, f64, f64)], // (station_g_cd, lat, lon) + limit_per_station: u32, + ) -> Result, DomainError>; async fn get_route_stops( &self, from_station_id: u32, @@ -195,6 +200,23 @@ mod tests { Ok(result) } + async fn get_bus_stops_near_stations( + &self, + coords: &[(u32, f64, f64)], + limit_per_station: u32, + ) -> Result, DomainError> { + let mut result = Vec::new(); + for &(source_g_cd, lat, lon) in coords { + let stops = self + .get_by_coordinates(lat, lon, Some(limit_per_station), Some(TransportType::Bus)) + .await?; + for stop in stops { + result.push((source_g_cd, stop)); + } + } + Ok(result) + } + async fn get_by_name( &self, station_name: String, diff --git a/stationapi/src/infrastructure/line_repository.rs b/stationapi/src/infrastructure/line_repository.rs index 725eb42d..c01dcb43 100644 --- a/stationapi/src/infrastructure/line_repository.rs +++ b/stationapi/src/infrastructure/line_repository.rs @@ -124,6 +124,19 @@ impl LineRepository for MyLineRepository { let mut conn = self.pool.acquire().await?; InternalLineRepository::get_by_station_group_id_vec(&station_group_id_vec, &mut conn).await } + async fn get_by_station_group_id_vec_no_types( + &self, + station_group_id_vec: &[u32], + ) -> Result, DomainError> { + let station_group_id_vec: Vec = + station_group_id_vec.iter().map(|x| *x as i64).collect(); + let mut conn = self.pool.acquire().await?; + InternalLineRepository::get_by_station_group_id_vec_no_types( + &station_group_id_vec, + &mut conn, + ) + .await + } async fn get_by_line_group_id(&self, line_group_id: u32) -> Result, DomainError> { let line_group_id: i64 = line_group_id as i64; let mut conn = self.pool.acquire().await?; @@ -456,6 +469,82 @@ impl InternalLineRepository { Ok(lines) } + async fn get_by_station_group_id_vec_no_types( + station_group_id_vec: &[i64], + conn: &mut PgConnection, + ) -> Result, DomainError> { + if station_group_id_vec.is_empty() { + return Ok(vec![]); + } + + let params = (1..=station_group_id_vec.len()) + .map(|i| format!("${i}")) + .collect::>() + .join(", "); + let query_str = format!( + "SELECT DISTINCT ON (l.line_cd, s.station_g_cd) + l.line_cd, + l.company_cd, + l.line_type, + l.line_symbol1, + l.line_symbol2, + l.line_symbol3, + l.line_symbol4, + l.line_symbol1_color, + l.line_symbol2_color, + l.line_symbol3_color, + l.line_symbol4_color, + l.line_symbol1_shape, + l.line_symbol2_shape, + l.line_symbol3_shape, + l.line_symbol4_shape, + l.e_status, + l.e_sort, + COALESCE(l.average_distance, 0.0)::DOUBLE PRECISION AS average_distance, + s.station_cd, + s.station_g_cd, + NULL::int AS line_group_cd, + NULL::int AS type_cd, + COALESCE(a.line_name, l.line_name) AS line_name, + COALESCE(a.line_name_k, l.line_name_k) AS line_name_k, + COALESCE(a.line_name_h, l.line_name_h) AS line_name_h, + COALESCE(a.line_name_r, l.line_name_r) AS line_name_r, + COALESCE(a.line_name_zh, l.line_name_zh) AS line_name_zh, + COALESCE(a.line_name_ko, l.line_name_ko) AS line_name_ko, + COALESCE(a.line_color_c, l.line_color_c) AS line_color_c, + l.transport_type + FROM lines AS l + JOIN stations AS s ON s.station_g_cd IN ( {params} ) + AND s.e_status = 0 + AND s.line_cd = l.line_cd + LEFT JOIN line_aliases AS la ON la.station_cd = s.station_cd + LEFT JOIN aliases AS a ON la.alias_cd = a.id + WHERE l.e_status = 0 + AND NOT EXISTS ( + SELECT 1 FROM station_station_types sst + WHERE sst.station_cd = s.station_cd + AND sst.line_group_cd IS NOT NULL + AND sst.pass = 1 + AND NOT EXISTS ( + SELECT 1 FROM station_station_types sst2 + WHERE sst2.station_cd = s.station_cd + AND sst2.line_group_cd IS NOT NULL + AND sst2.pass <> 1 + ) + )" + ); + + let mut query = sqlx::query_as::<_, LineRow>(&query_str); + for id in station_group_id_vec { + query = query.bind(id); + } + + let rows = query.fetch_all(conn).await?; + let lines: Vec = rows.into_iter().map(|row| row.into()).collect(); + + Ok(lines) + } + async fn get_by_line_group_id( line_group_id: i64, conn: &mut PgConnection, diff --git a/stationapi/src/infrastructure/station_repository.rs b/stationapi/src/infrastructure/station_repository.rs index e64e1716..2435f908 100644 --- a/stationapi/src/infrastructure/station_repository.rs +++ b/stationapi/src/infrastructure/station_repository.rs @@ -81,6 +81,138 @@ struct StationRow { pub transport_type: Option, } +#[derive(sqlx::FromRow, Clone)] +struct BusStopWithSourceRow { + pub source_g_cd: i32, + pub station_cd: i32, + pub station_g_cd: i32, + pub station_name: String, + pub station_name_k: String, + pub station_name_r: Option, + #[allow(dead_code)] + pub station_name_rn: Option, + pub station_name_zh: Option, + pub station_name_ko: Option, + pub station_number1: Option, + pub station_number2: Option, + pub station_number3: Option, + pub station_number4: Option, + pub three_letter_code: Option, + pub line_cd: i32, + pub pref_cd: i32, + pub post: String, + pub address: String, + pub lon: f64, + pub lat: f64, + pub open_ymd: String, + pub close_ymd: String, + pub e_status: i32, + pub e_sort: i32, + pub company_cd: Option, + pub line_name: Option, + pub line_name_k: Option, + pub line_name_h: Option, + pub line_name_r: Option, + pub line_name_zh: Option, + pub line_name_ko: Option, + pub line_color_c: Option, + pub line_type: Option, + pub line_symbol1: Option, + pub line_symbol2: Option, + pub line_symbol3: Option, + pub line_symbol4: Option, + pub line_symbol1_color: Option, + pub line_symbol2_color: Option, + pub line_symbol3_color: Option, + pub line_symbol4_color: Option, + pub line_symbol1_shape: Option, + pub line_symbol2_shape: Option, + pub line_symbol3_shape: Option, + pub line_symbol4_shape: Option, + pub average_distance: Option, + pub type_id: Option, + pub sst_id: Option, + pub type_cd: Option, + pub line_group_cd: Option, + pub pass: Option, + pub type_name: Option, + pub type_name_k: Option, + pub type_name_r: Option, + pub type_name_zh: Option, + pub type_name_ko: Option, + pub color: Option, + pub direction: Option, + pub kind: Option, + pub transport_type: Option, +} + +impl From for Station { + fn from(row: BusStopWithSourceRow) -> Self { + let station_row = StationRow { + station_cd: row.station_cd, + station_g_cd: row.station_g_cd, + station_name: row.station_name, + station_name_k: row.station_name_k, + station_name_r: row.station_name_r, + station_name_rn: row.station_name_rn, + station_name_zh: row.station_name_zh, + station_name_ko: row.station_name_ko, + station_number1: row.station_number1, + station_number2: row.station_number2, + station_number3: row.station_number3, + station_number4: row.station_number4, + three_letter_code: row.three_letter_code, + line_cd: row.line_cd, + pref_cd: row.pref_cd, + post: row.post, + address: row.address, + lon: row.lon, + lat: row.lat, + open_ymd: row.open_ymd, + close_ymd: row.close_ymd, + e_status: row.e_status, + e_sort: row.e_sort, + company_cd: row.company_cd, + line_name: row.line_name, + line_name_k: row.line_name_k, + line_name_h: row.line_name_h, + line_name_r: row.line_name_r, + line_name_zh: row.line_name_zh, + line_name_ko: row.line_name_ko, + line_color_c: row.line_color_c, + line_type: row.line_type, + line_symbol1: row.line_symbol1, + line_symbol2: row.line_symbol2, + line_symbol3: row.line_symbol3, + line_symbol4: row.line_symbol4, + line_symbol1_color: row.line_symbol1_color, + line_symbol2_color: row.line_symbol2_color, + line_symbol3_color: row.line_symbol3_color, + line_symbol4_color: row.line_symbol4_color, + line_symbol1_shape: row.line_symbol1_shape, + line_symbol2_shape: row.line_symbol2_shape, + line_symbol3_shape: row.line_symbol3_shape, + line_symbol4_shape: row.line_symbol4_shape, + average_distance: row.average_distance, + type_id: row.type_id, + sst_id: row.sst_id, + type_cd: row.type_cd, + line_group_cd: row.line_group_cd, + pass: row.pass, + type_name: row.type_name, + type_name_k: row.type_name_k, + type_name_r: row.type_name_r, + type_name_zh: row.type_name_zh, + type_name_ko: row.type_name_ko, + color: row.color, + direction: row.direction, + kind: row.kind, + transport_type: row.transport_type, + }; + station_row.into() + } +} + impl From for Station { fn from(row: StationRow) -> Self { let stop_condition = match row.pass.unwrap_or(0) { @@ -291,6 +423,19 @@ impl StationRepository for MyStationRepository { InternalStationRepository::get_by_line_group_id_vec(line_group_ids, &mut conn).await } + async fn get_bus_stops_near_stations( + &self, + coords: &[(u32, f64, f64)], + limit_per_station: u32, + ) -> Result, DomainError> { + if coords.is_empty() { + return Ok(vec![]); + } + let mut conn = self.pool.acquire().await?; + InternalStationRepository::get_bus_stops_near_stations(coords, limit_per_station, &mut conn) + .await + } + async fn get_route_stops( &self, from_station_id: u32, @@ -1128,6 +1273,119 @@ impl InternalStationRepository { Ok(stations) } + async fn get_bus_stops_near_stations( + coords: &[(u32, f64, f64)], + limit_per_station: u32, + conn: &mut PgConnection, + ) -> Result, DomainError> { + if coords.is_empty() { + return Ok(vec![]); + } + + let query_str = r#"SELECT + ic.source_g_cd, + s.station_cd, + s.station_g_cd, + s.station_name, + s.station_name_k, + s.station_name_r, + s.station_name_rn, + s.station_name_zh, + s.station_name_ko, + s.station_number1, + s.station_number2, + s.station_number3, + s.station_number4, + s.three_letter_code, + s.line_cd, + s.pref_cd, + s.post, + s.address, + s.lon, + s.lat, + s.open_ymd, + s.close_ymd, + s.e_status, + s.e_sort, + l.company_cd, + COALESCE(NULLIF(COALESCE(a.line_name, l.line_name), ''), NULL) AS line_name, + COALESCE(NULLIF(COALESCE(a.line_name_k, l.line_name_k), ''), NULL) AS line_name_k, + COALESCE(NULLIF(COALESCE(a.line_name_h, l.line_name_h), ''), NULL) AS line_name_h, + COALESCE(NULLIF(COALESCE(a.line_name_r, l.line_name_r), ''), NULL) AS line_name_r, + COALESCE(NULLIF(COALESCE(a.line_name_zh, l.line_name_zh), ''), NULL) AS line_name_zh, + COALESCE(NULLIF(COALESCE(a.line_name_ko, l.line_name_ko), ''), NULL) AS line_name_ko, + COALESCE(NULLIF(COALESCE(a.line_color_c, l.line_color_c), ''), NULL) AS line_color_c, + l.line_type, + l.line_symbol1, + l.line_symbol2, + l.line_symbol3, + l.line_symbol4, + l.line_symbol1_color, + l.line_symbol2_color, + l.line_symbol3_color, + l.line_symbol4_color, + l.line_symbol1_shape, + l.line_symbol2_shape, + l.line_symbol3_shape, + l.line_symbol4_shape, + COALESCE(l.average_distance, 0.0)::DOUBLE PRECISION AS average_distance, + NULL::int AS type_id, + NULL::int AS sst_id, + NULL::int AS type_cd, + NULL::int AS line_group_cd, + NULL::int AS pass, + NULL::text AS type_name, + NULL::text AS type_name_k, + NULL::text AS type_name_r, + NULL::text AS type_name_zh, + NULL::text AS type_name_ko, + NULL::text AS color, + NULL::int AS direction, + NULL::int AS kind, + s.transport_type + FROM ( + SELECT unnest($1::int[]) AS source_g_cd, + unnest($2::float8[]) AS lat, + unnest($3::float8[]) AS lon + ) ic, + LATERAL ( + SELECT s.* + FROM stations s + WHERE s.e_status = 0 + AND COALESCE(s.transport_type, 0) = $5 + ORDER BY point(s.lat, s.lon) <-> point(ic.lat, ic.lon) + LIMIT $4 + ) s + JOIN lines AS l ON s.line_cd = l.line_cd AND l.e_status = 0 + LEFT JOIN line_aliases AS la ON la.station_cd = s.station_cd + LEFT JOIN aliases AS a ON a.id = la.alias_cd + ORDER BY ic.source_g_cd, point(s.lat, s.lon) <-> point(ic.lat, ic.lon)"#; + + let source_g_cds: Vec = coords.iter().map(|(g, _, _)| *g as i32).collect(); + let lats: Vec = coords.iter().map(|(_, lat, _)| *lat).collect(); + let lons: Vec = coords.iter().map(|(_, _, lon)| *lon).collect(); + + let rows = sqlx::query_as::<_, BusStopWithSourceRow>(query_str) + .bind(&source_g_cds) + .bind(&lats) + .bind(&lons) + .bind(limit_per_station as i32) + .bind(TransportType::Bus as i32) + .fetch_all(&mut *conn) + .await?; + + let result: Vec<(u32, Station)> = rows + .into_iter() + .map(|row| { + let source_g_cd = row.source_g_cd as u32; + let station: Station = row.into(); + (source_g_cd, station) + }) + .collect(); + + Ok(result) + } + async fn get_by_coordinates( latitude: f64, longitude: f64, diff --git a/stationapi/src/use_case/interactor/query.rs b/stationapi/src/use_case/interactor/query.rs index 84109c4a..4a310f9d 100644 --- a/stationapi/src/use_case/interactor/query.rs +++ b/stationapi/src/use_case/interactor/query.rs @@ -959,6 +959,31 @@ where Ok(stations) } + async fn get_lines_by_station_group_id_vec_no_types( + &self, + station_group_id_vec: &[u32], + ) -> Result, UseCaseError> { + let lines = self + .line_repository + .get_by_station_group_id_vec_no_types(station_group_id_vec) + .await?; + + Ok(lines) + } + + async fn get_bus_stops_near_stations( + &self, + coords: &[(u32, f64, f64)], + limit_per_station: u32, + ) -> Result, UseCaseError> { + let result = self + .station_repository + .get_bus_stops_near_stations(coords, limit_per_station) + .await?; + + Ok(result) + } + async fn update_station_vec_with_attributes( &self, mut stations: Vec, @@ -973,41 +998,109 @@ where station_group_ids.sort_unstable(); station_group_ids.dedup(); + // Determine if bus enrichment is needed + let should_include_bus_routes = transport_type == TransportTypeFilter::RailAndBus; + + // Collect unique coordinates for batch bus stop lookup + let unique_bus_coords: Vec<(u32, f64, f64)> = if should_include_bus_routes { + let mut seen = std::collections::HashSet::new(); + stations + .iter() + .filter(|s| s.transport_type == TransportType::Rail && seen.insert(s.station_g_cd)) + .map(|s| (s.station_g_cd as u32, s.lat, s.lon)) + .collect() + } else { + vec![] + }; + // Phase 1: independent queries in parallel // When skip_types_join is true, skip the expensive JOINs to // station_station_types and types tables (used by GetStationsByLineIdList) - let (stations_by_group_ids, lines) = if skip_types_join { + // Also batch-fetch bus stop candidates in parallel + let (stations_by_group_ids, lines, bus_candidates_flat) = if skip_types_join { tokio::try_join!( self.get_stations_by_group_id_vec_no_types(&station_group_ids), - self.get_lines_by_station_group_id_vec(&station_group_ids), + self.get_lines_by_station_group_id_vec_no_types(&station_group_ids), + self.get_bus_stops_near_stations(&unique_bus_coords, 50), )? } else { - tokio::try_join!( + let (s, l) = tokio::try_join!( self.get_stations_by_group_id_vec(&station_group_ids), self.get_lines_by_station_group_id_vec(&station_group_ids), - )? + )?; + let bus = self + .get_bus_stops_near_stations(&unique_bus_coords, 50) + .await?; + (s, l, bus) }; + // Build bus candidate cache from batch results + let mut bus_candidate_cache: std::collections::HashMap> = + std::collections::HashMap::new(); + for (source_g_cd, station) in bus_candidates_flat { + bus_candidate_cache + .entry(source_g_cd as i32) + .or_default() + .push(station); + } + + // Collect all bus station group IDs for batch bus lines fetch + let mut all_bus_station_group_ids: Vec = bus_candidate_cache + .values() + .flat_map(|stops| stops.iter().map(|s| s.station_g_cd as u32)) + .collect::>() + .into_iter() + .collect(); + all_bus_station_group_ids.sort_unstable(); + let station_ids = stations_by_group_ids .iter() .map(|station| station.station_cd as u32) .collect::>(); + // Collect company IDs from rail lines let mut company_ids: Vec = lines.iter().map(|l| l.company_cd as u32).collect(); - company_ids.sort_unstable(); - company_ids.dedup(); // Phase 2: dependent queries in parallel - let (companies, train_types) = tokio::try_join!( + // Fetch companies, train types, and all bus lines in one batch + let (companies, train_types, all_bus_lines) = tokio::try_join!( self.find_company_by_id_vec(&company_ids), self.get_train_types_by_station_id_vec(&station_ids, line_group_id), + self.get_lines_by_station_group_id_vec(&all_bus_station_group_ids), )?; - // Build HashMap for O(1) company lookup instead of O(n) linear search - // Owns the values so we can add bus companies later + // Collect bus company IDs and fetch any missing companies + let bus_company_ids: Vec = all_bus_lines.iter().map(|l| l.company_cd as u32).collect(); + company_ids.extend(bus_company_ids); + company_ids.sort_unstable(); + company_ids.dedup(); + + // Build HashMap for O(1) company lookup let mut company_map: std::collections::HashMap = companies.into_iter().map(|c| (c.company_cd, c)).collect(); + // Fetch any bus companies not already in the map + let missing_company_ids: Vec = company_ids + .iter() + .filter(|id| !company_map.contains_key(&(**id as i32))) + .copied() + .collect(); + if !missing_company_ids.is_empty() { + let extra_companies = self.find_company_by_id_vec(&missing_company_ids).await?; + for c in extra_companies { + company_map.insert(c.company_cd, c); + } + } + + // Pre-index bus lines by station_g_cd for O(1) lookup + let mut bus_lines_by_g_cd: std::collections::HashMap> = + std::collections::HashMap::new(); + for bus_line in &all_bus_lines { + if let Some(g_cd) = bus_line.station_g_cd { + bus_lines_by_g_cd.entry(g_cd).or_default().push(bus_line); + } + } + // Build HashMap for O(1) train_type lookup by station_cd let train_type_map: std::collections::HashMap = train_types .iter() @@ -1029,16 +1122,6 @@ where } } - // Cache nearby bus stop candidates by station_g_cd. - // Stations with the same station_g_cd are at the same physical location, - // so they share identical bus stop candidates. - let mut bus_candidate_cache: std::collections::HashMap> = - std::collections::HashMap::new(); - // Cache bus lines by station_group_ids to avoid repeated DB queries - // for the same set of bus stop groups. - let mut bus_lines_cache: std::collections::HashMap, Vec> = - std::collections::HashMap::new(); - for station in stations.iter_mut() { let mut line = self.extract_line_from_station(station); line.line_symbols = self.get_line_symbols(&line); @@ -1078,98 +1161,68 @@ where }) .unwrap_or_default(); - // For rail stations, add nearby bus routes to lines array - // Only add bus routes if transport_type is RailAndBus - let should_include_bus_routes = transport_type == TransportTypeFilter::RailAndBus; + // For rail stations, add nearby bus routes from pre-fetched data if station.transport_type == TransportType::Rail && should_include_bus_routes { - let cache_key = station.station_g_cd; - let candidates = if let Some(cached) = bus_candidate_cache.get(&cache_key) { - cached.clone() - } else { - let result = self - .station_repository - .get_by_coordinates( - station.lat, - station.lon, - Some(50), - Some(TransportType::Bus), - ) - .await?; - bus_candidate_cache.insert(cache_key, result.clone()); - result - }; - - // Apply 300m filter from this station's exact coordinates - let nearby_bus_stops: Vec<&Station> = candidates - .iter() - .filter(|bus_stop| { - haversine_distance(station.lat, station.lon, bus_stop.lat, bus_stop.lon) - <= NEARBY_BUS_STOP_RADIUS_METERS - }) - .collect(); - - if !nearby_bus_stops.is_empty() { - let mut bus_station_group_ids: Vec = nearby_bus_stops + if let Some(candidates) = bus_candidate_cache.get(&station.station_g_cd) { + // Apply 300m filter from this station's exact coordinates + let nearby_bus_stops: Vec<&Station> = candidates .iter() - .map(|s| s.station_g_cd as u32) + .filter(|bus_stop| { + haversine_distance(station.lat, station.lon, bus_stop.lat, bus_stop.lon) + <= NEARBY_BUS_STOP_RADIUS_METERS + }) .collect(); - bus_station_group_ids.sort_unstable(); - bus_station_group_ids.dedup(); - - let mut bus_lines = - if let Some(cached) = bus_lines_cache.get(&bus_station_group_ids) { - cached.clone() - } else { - let result = self - .line_repository - .get_by_station_group_id_vec(&bus_station_group_ids) - .await?; - bus_lines_cache.insert(bus_station_group_ids, result.clone()); - result - }; - let mut seen_bus_line_cds = std::collections::HashSet::new(); - bus_lines.retain(|line| { - line.transport_type == TransportType::Bus - && seen_bus_line_cds.insert(line.line_cd) - }); - - let bus_stop_by_line_cd: std::collections::HashMap = - nearby_bus_stops - .iter() - .filter(|s| seen_bus_line_cds.contains(&s.line_cd)) - .map(|s| (s.line_cd, *s)) - .collect(); - - for bus_line in bus_lines.iter_mut() { - bus_line.line_symbols = self.get_line_symbols(bus_line); - if let Some(&bus_stop) = bus_stop_by_line_cd.get(&bus_line.line_cd) { - let mut station_copy = bus_stop.clone(); - station_copy.station_numbers = self.get_station_numbers(&station_copy); - bus_line.station = Some(station_copy); + if !nearby_bus_stops.is_empty() { + // Collect bus lines from pre-fetched data, iterating + // nearby_bus_stops in order (distance-sorted) to preserve + // deterministic ordering. + let mut seen_bus_g_cds = std::collections::HashSet::new(); + let mut seen_bus_line_cds = std::collections::HashSet::new(); + let mut bus_lines: Vec = Vec::new(); + for bus_stop in &nearby_bus_stops { + if seen_bus_g_cds.insert(bus_stop.station_g_cd) { + if let Some(lines_for_g_cd) = + bus_lines_by_g_cd.get(&bus_stop.station_g_cd) + { + for &bus_line in lines_for_g_cd { + if bus_line.transport_type == TransportType::Bus + && seen_bus_line_cds.insert(bus_line.line_cd) + { + bus_lines.push(bus_line.clone()); + } + } + } + } } - } - for bus_line in bus_lines { - if seen_line_cds.insert(bus_line.line_cd) { - station_lines.push(bus_line); + // Build stop-per-line map; first (closest) stop wins + let mut bus_stop_by_line_cd: std::collections::HashMap = + std::collections::HashMap::new(); + for bus_stop in &nearby_bus_stops { + if seen_bus_line_cds.contains(&bus_stop.line_cd) { + bus_stop_by_line_cd + .entry(bus_stop.line_cd) + .or_insert(bus_stop); + } } - } - } - } - // Fetch any missing companies (e.g., bus-only operators not in initial lines) - let missing_company_ids: Vec = station_lines - .iter() - .filter(|l| !company_map.contains_key(&l.company_cd)) - .map(|l| l.company_cd as u32) - .collect::>() - .into_iter() - .collect(); - if !missing_company_ids.is_empty() { - let extra_companies = self.find_company_by_id_vec(&missing_company_ids).await?; - for c in extra_companies { - company_map.insert(c.company_cd, c); + for bus_line in bus_lines.iter_mut() { + bus_line.line_symbols = self.get_line_symbols(bus_line); + if let Some(&bus_stop) = bus_stop_by_line_cd.get(&bus_line.line_cd) { + let mut station_copy = bus_stop.clone(); + station_copy.station_numbers = + self.get_station_numbers(&station_copy); + bus_line.station = Some(station_copy); + } + } + + for bus_line in bus_lines { + if seen_line_cds.insert(bus_line.line_cd) { + station_lines.push(bus_line); + } + } + } } } @@ -1533,6 +1586,13 @@ mod tests { ) -> Result, DomainError> { Ok(vec![]) } + async fn get_bus_stops_near_stations( + &self, + _: &[(u32, f64, f64)], + _: u32, + ) -> Result, DomainError> { + Ok(vec![]) + } async fn get_by_name( &self, _: String, @@ -1600,6 +1660,12 @@ mod tests { ) -> Result, DomainError> { Ok(vec![]) } + async fn get_by_station_group_id_vec_no_types( + &self, + _: &[u32], + ) -> Result, DomainError> { + Ok(vec![]) + } } #[async_trait::async_trait] @@ -1973,6 +2039,27 @@ mod tests { Ok(vec![]) } } + async fn get_bus_stops_near_stations( + &self, + coords: &[(u32, f64, f64)], + limit_per_station: u32, + ) -> Result, DomainError> { + let mut result = Vec::new(); + for &(source_g_cd, lat, lon) in coords { + let stops = self + .get_by_coordinates( + lat, + lon, + Some(limit_per_station), + Some(TransportType::Bus), + ) + .await?; + for stop in stops { + result.push((source_g_cd, stop)); + } + } + Ok(result) + } async fn get_by_name( &self, _: String, @@ -2053,6 +2140,12 @@ mod tests { ) -> Result, DomainError> { Ok(self.lines_by_station_group.clone()) } + async fn get_by_station_group_id_vec_no_types( + &self, + station_group_id_vec: &[u32], + ) -> Result, DomainError> { + self.get_by_station_group_id_vec(station_group_id_vec).await + } } /// Configurable mock train type repository for testing