diff --git a/Cargo.lock b/Cargo.lock index 69e8299..1631af8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2548,6 +2548,7 @@ version = "0.1.0" dependencies = [ "anyhow", "blake3", + "clap", "colored", "futures", "migration", diff --git a/config_template.toml b/config_template.toml new file mode 100644 index 0000000..9c95eea --- /dev/null +++ b/config_template.toml @@ -0,0 +1,15 @@ +[db] +url = "postgres://srdown:srdown@localhost:5432/srdown" +schema = "public" +max_connections = 10 +sqlx_logging = false + +[sync] +max_timeout = 1.0 +start_id = 76859 + +[sync.fast] +start_id = 76859 +end_id = 1321469 +worker_count = 10 +worker_size = 10 diff --git a/sr_download/Cargo.toml b/sr_download/Cargo.toml index 55662d1..7149a7f 100644 --- a/sr_download/Cargo.toml +++ b/sr_download/Cargo.toml @@ -2,6 +2,7 @@ name = "sr_download" version = "0.1.0" edition = "2021" +default-run = "sr_download" [[bin]] name = "get-save" @@ -26,3 +27,4 @@ toml = "0.8.15" blake3 = "1.5.3" futures = "0.3.30" colored = "2.1.0" +clap = { version = "4.5.9", features = ["derive"] } diff --git a/sr_download/src/config.rs b/sr_download/src/config.rs index 833ed0c..b55ba2d 100644 --- a/sr_download/src/config.rs +++ b/sr_download/src/config.rs @@ -1,36 +1,71 @@ +use migration::SaveId; use std::path::Path; +use colored::Colorize; use serde::{Deserialize, Serialize}; +use tracing::{event, Level}; -use migration::SaveId; - -#[derive(Serialize, Deserialize)] -pub struct ConfigFile { - pub db_url: String, - pub db_schema: String, +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename = "db")] +pub struct DbConfig { + pub url: String, + pub schema: String, pub max_connections: u32, pub sqlx_logging: bool, - pub worker_count: u32, - pub worker_size: u32, - pub start_id: SaveId, - pub max_timeout: f32, } -impl Default for ConfigFile { +impl Default for DbConfig { fn default() -> Self { Self { - db_url: "postgres://srdown:srdown@192.168.3.22:10001/srdown".to_string(), - db_schema: "public".to_string(), + url: "postgres://srdown:srdown@localhost:5432/srdown".to_string(), + schema: "public".to_string(), max_connections: 10, sqlx_logging: false, - worker_count: 10, - worker_size: 10, - start_id: 173860, - max_timeout: 1.0, } } } +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename = "fast-sync")] +pub struct FastSyncConfig { + pub start_id: SaveId, + pub end_id: SaveId, + pub worker_count: u32, + pub worker_size: u32, +} + +impl Default for FastSyncConfig { + fn default() -> Self { + Self { + start_id: 76859, + end_id: 1321469, + worker_count: 10, + worker_size: 10, + } + } +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename = "sync")] +pub struct SyncConfig { + pub max_timeout: f32, + pub fast: FastSyncConfig, +} +impl Default for SyncConfig { + fn default() -> Self { + Self { + max_timeout: 1.0, + fast: FastSyncConfig::default(), + } + } +} + +#[derive(Serialize, Deserialize, Debug, Clone, Default)] +pub struct ConfigFile { + pub db: DbConfig, + pub sync: SyncConfig, +} + impl ConfigFile { pub fn read_from_file(file_path: &Path) -> anyhow::Result { let data = std::fs::read_to_string(file_path)?; @@ -46,6 +81,26 @@ impl ConfigFile { } pub fn timeout_as_duration(&self) -> std::time::Duration { - std::time::Duration::from_secs_f32(self.max_timeout) + std::time::Duration::from_secs_f32(self.sync.max_timeout) + } + + pub fn read_or_panic() -> Self { + match Self::read_from_file(Path::new("config.toml")) { + Ok(conf) => conf, + Err(e) => { + let _ = tracing_subscriber::fmt::try_init(); + event!(Level::ERROR, "{}", "Please Fix the config.toml file".red()); + event!(Level::ERROR, "Error: {:?}", e); + if let Err(e) = Self::write_default_to_file(Path::new("config_template.toml")) { + event!(Level::ERROR, "Error while writing file: {:?}", e); + event!( + Level::ERROR, + "template file like this: {}", + toml::to_string(&Self::default()).unwrap() + ); + }; + panic!("Please Fix the config.toml file"); + } + } } } diff --git a/sr_download/src/db_part.rs b/sr_download/src/db_part.rs index 36d5726..a9ba53e 100644 --- a/sr_download/src/db_part.rs +++ b/sr_download/src/db_part.rs @@ -76,10 +76,10 @@ impl DbData { } pub async fn connect(conf: &ConfigFile) -> anyhow::Result { - let mut opt = ConnectOptions::new(conf.db_url.clone()); - opt.max_connections(conf.max_connections) - .set_schema_search_path(conf.db_schema.clone()) - .sqlx_logging(conf.sqlx_logging); + let mut opt = ConnectOptions::new(conf.db.url.clone()); + opt.max_connections(conf.db.max_connections) + .set_schema_search_path(conf.db.schema.clone()) + .sqlx_logging(conf.db.sqlx_logging); event!(Level::INFO, "Connecting to database"); let db: DatabaseConnection = Database::connect(opt).await?; db.ping().await?; @@ -116,17 +116,6 @@ pub async fn find_max_id(db: &DatabaseConnection) -> SaveId { } } -#[allow(non_snake_case)] -pub fn SaveType_from_str(str: String) -> Option { - match str.to_lowercase().as_str() { - "save" => Some(SaveType::Save), - "ship" => Some(SaveType::Ship), - "none" => Some(SaveType::None), - "unknown" => Some(SaveType::Unknown), - _ => None - } -} - /// 直接从数据库中查询数据, 这里数据库已经准备好了根据长度区分过的数据 /// 可以从 full view 里直接选数据 pub async fn get_raw_data(save_id: SaveId, db: &DatabaseConnection) -> Option { @@ -142,7 +131,7 @@ pub async fn get_raw_data(save_id: SaveId, db: &DatabaseConnection) -> Option anyhow::Result<()> { - let config = config::ConfigFile::read_from_file(Path::new("config.toml")).unwrap(); + let config = config::ConfigFile::read_or_panic(); let db = db_part::connect(&config).await.unwrap(); - - let want_get_id = std::env::args().nth(1).ok_or(anyhow::anyhow!("No input"))?.parse::()?; - - let data = db_part::get_raw_data(want_get_id, &db).await.ok_or(anyhow::anyhow!("No data"))?; + + let want_get_id = std::env::args() + .nth(1) + .ok_or(anyhow::anyhow!("No input"))? + .parse::()?; + + let data = db_part::get_raw_data(want_get_id, &db) + .await + .ok_or(anyhow::anyhow!("No data"))?; println!("{}", data.text.ok_or(anyhow::anyhow!("No text"))?); diff --git a/sr_download/src/main.rs b/sr_download/src/main.rs index 621e2c6..69b933b 100644 --- a/sr_download/src/main.rs +++ b/sr_download/src/main.rs @@ -1,6 +1,6 @@ use colored::Colorize; use futures::future::select_all; -use std::{ops::Range, path::Path}; +use std::ops::Range; use tokio::sync::oneshot::Receiver; use tracing::{event, Level}; @@ -78,13 +78,7 @@ async fn big_worker( } async fn main_works(mut stop_receiver: Receiver<()>) -> anyhow::Result<()> { - let conf = match config::ConfigFile::read_from_file(Path::new("config.toml")) { - Ok(conf) => conf, - Err(_) => { - config::ConfigFile::write_default_to_file(Path::new("config.toml"))?; - config::ConfigFile::default() - } - }; + let conf = config::ConfigFile::read_or_panic(); let db_connect = db_part::connect(&conf).await?; db_part::migrate(&db_connect).await?; @@ -98,18 +92,27 @@ async fn main_works(mut stop_receiver: Receiver<()>) -> anyhow::Result<()> { tokio::time::sleep(std::time::Duration::from_secs(3)).await; - // 1321469 end - let end_id: SaveId = 1321469; + if stop_receiver.try_recv().is_ok() { + event!(Level::INFO, "{}", "Stop download".red()); + // 结束 db + db_connect.close().await?; + return Ok(()); + } - let mut current_id = conf.start_id; - - let batch_size = conf.worker_size; - // 10 works - let mut works = Vec::with_capacity(conf.worker_count as usize); - let max_works = conf.worker_count as usize; + let end_id: SaveId = conf.sync.fast.end_id; + let worker_size = conf.sync.fast.worker_size; + let mut current_id = conf.sync.fast.start_id; + let mut works = Vec::with_capacity(conf.sync.fast.worker_count as usize); + let max_works = conf.sync.fast.worker_count as usize; for _ in 0..works.len() { + if stop_receiver.try_recv().is_ok() { + event!(Level::INFO, "{}", "Stop download".red()); + // 结束 db + db_connect.close().await?; + return Ok(()); + } let client = net::Downloader::new(conf.timeout_as_duration()); - let end = current_id + batch_size; + let end = current_id + worker_size; works.push(tokio::spawn(big_worker( db_connect.clone(), client, @@ -119,9 +122,15 @@ async fn main_works(mut stop_receiver: Receiver<()>) -> anyhow::Result<()> { } while current_id < end_id || !works.is_empty() { + if stop_receiver.try_recv().is_ok() { + event!(Level::INFO, "{}", "Stop download".red()); + // 结束 db + db_connect.close().await?; + return Ok(()); + } while current_id < end_id && works.len() < max_works { let client = net::Downloader::new(conf.timeout_as_duration()); - let end = current_id + batch_size; + let end = current_id + worker_size; works.push(tokio::spawn(big_worker( db_connect.clone(), client, @@ -134,12 +143,6 @@ async fn main_works(mut stop_receiver: Receiver<()>) -> anyhow::Result<()> { let (_result, _index, remain) = select_all(works).await; works = remain; } - if stop_receiver.try_recv().is_ok() { - event!(Level::INFO, "{}", "Stop download".red()); - // 结束 db - db_connect.close().await?; - break; - } } Ok(()) }