From 5a92ecdeb73d484041829ef8e1e1d83904a95797 Mon Sep 17 00:00:00 2001 From: glitchySid Date: Sun, 28 Dec 2025 19:15:53 +0530 Subject: [PATCH] feat: Add comprehensive improvements - async optimization, caching, and error handling - Add async optimization using futures for concurrent file processing - Implement intelligent caching system with SHA256 file change detection - Add comprehensive custom error handling for Gemini API responses - Fix critical error handling issues throughout the codebase - Replace fragile JSON parsing with proper struct-based deserialization - Add automatic retry logic for rate limits and network issues - Improve user experience with detailed error messages and progress feedback - Add cache persistence and automatic cleanup of old entries - Optimize performance for batch processing scenarios --- Cargo.lock | 206 +++++++++++++++++++++++++++++++++++++++ Cargo.toml | 7 ++ src/cache.rs | 168 ++++++++++++++++++++++++++++++++ src/files.rs | 175 +++++++++++++++++++++++++++------ src/gemini.rs | 194 ++++++++++++++++++++++++++++++++++--- src/gemini_errors.rs | 226 +++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 2 + src/main.rs | 154 ++++++++++++++++++++++++++--- 8 files changed, 1070 insertions(+), 62 deletions(-) create mode 100644 src/cache.rs create mode 100644 src/gemini_errors.rs diff --git a/Cargo.lock b/Cargo.lock index 7022238..ecab8f5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -20,6 +20,15 @@ version = "2.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3" +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + [[package]] name = "bumpalo" version = "3.19.1" @@ -48,6 +57,15 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" +[[package]] +name = "colored" +version = "3.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fde0e0ec90c9dfb3b4b1a0891a7dcd0e2bffde2f7efed5fe7c9bb00e5bfb915e" +dependencies = [ + "windows-sys 0.52.0", +] + [[package]] name = "core-foundation" version = "0.9.4" @@ -64,6 +82,35 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + +[[package]] +name = "crypto-common" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", +] + [[package]] name = "displaydoc" version = "0.2.5" @@ -81,6 +128,12 @@ version = "0.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77c90badedccf4105eca100756a0b1289e191f6fcbdadd3cee1d2f614f97da8f" +[[package]] +name = "either" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" + [[package]] name = "encoding_rs" version = "0.8.35" @@ -148,6 +201,21 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.31" @@ -155,6 +223,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" dependencies = [ "futures-core", + "futures-sink", ] [[package]] @@ -163,6 +232,34 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" +[[package]] +name = "futures-executor" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" + +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-sink" version = "0.3.31" @@ -181,10 +278,26 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ + "futures-channel", "futures-core", + "futures-io", + "futures-macro", + "futures-sink", "futures-task", + "memchr", "pin-project-lite", "pin-utils", + "slab", +] + +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", ] [[package]] @@ -235,6 +348,12 @@ version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + [[package]] name = "http" version = "1.4.0" @@ -482,6 +601,15 @@ dependencies = [ "serde", ] +[[package]] +name = "itertools" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.15" @@ -575,11 +703,18 @@ dependencies = [ name = "noentropy" version = "0.1.0" dependencies = [ + "colored", "dotenv", + "futures", + "hex", + "itertools", "reqwest", "serde", "serde_json", + "sha2", + "thiserror", "tokio", + "walkdir", ] [[package]] @@ -833,6 +968,15 @@ version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + [[package]] name = "schannel" version = "0.1.28" @@ -926,6 +1070,17 @@ dependencies = [ "serde", ] +[[package]] +name = "sha2" +version = "0.10.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "shlex" version = "1.3.0" @@ -1040,6 +1195,26 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "thiserror" +version = "2.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f63587ca0f12b72a0600bcba1d40081f830876000bb46dd2337a3051618f4fc8" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ff15c8ecd7de3849db632e14d18d2571fa09dfc5ed93479bc4485c7a517c913" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "tinystr" version = "0.8.2" @@ -1181,6 +1356,12 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "typenum" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" + [[package]] name = "unicode-ident" version = "1.0.22" @@ -1217,6 +1398,22 @@ version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + [[package]] name = "want" version = "0.3.1" @@ -1309,6 +1506,15 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "winapi-util" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" +dependencies = [ + "windows-sys 0.61.2", +] + [[package]] name = "windows-link" version = "0.2.1" diff --git a/Cargo.toml b/Cargo.toml index c5e9798..a5536b5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,8 +4,15 @@ version = "0.1.0" edition = "2024" [dependencies] +colored = "3.0.0" dotenv = "0.15.0" +futures = "0.3.31" +hex = "0.4.3" +itertools = "0.14.0" reqwest = { version = "0.12.26", features = ["json"] } serde = { version = "1.0.228", features = ["derive"] } serde_json = "1.0.145" +sha2 = "0.10.8" +thiserror = "2.0.11" tokio = { version = "1.48.0", features = ["full"] } +walkdir = "2.5.0" diff --git a/src/cache.rs b/src/cache.rs new file mode 100644 index 0000000..8164ce5 --- /dev/null +++ b/src/cache.rs @@ -0,0 +1,168 @@ +use crate::files::OrganizationPlan; +use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; +use std::collections::HashMap; +use std::fs; +use std::path::Path; +use std::time::{SystemTime, UNIX_EPOCH}; + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct CacheEntry { + pub response: OrganizationPlan, + pub timestamp: u64, + pub file_hashes: HashMap, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Cache { + entries: HashMap, +} + +impl Cache { + pub fn new() -> Self { + Self { + entries: HashMap::new(), + } + } + + pub fn load_or_create(cache_path: &Path) -> Self { + if cache_path.exists() { + match fs::read_to_string(cache_path) { + Ok(content) => { + match serde_json::from_str::(&content) { + Ok(cache) => { + println!("Loaded cache with {} entries", cache.entries.len()); + cache + } + Err(_) => { + println!("Cache corrupted, creating new cache"); + Self::new() + } + } + } + Err(_) => { + println!("Failed to read cache, creating new cache"); + Self::new() + } + } + } else { + println!("Creating new cache file"); + Self::new() + } + } + + pub fn save(&self, cache_path: &Path) -> Result<(), Box> { + if let Some(parent) = cache_path.parent() { + fs::create_dir_all(parent)?; + } + + let content = serde_json::to_string_pretty(self)?; + fs::write(cache_path, content)?; + Ok(()) + } + + pub fn get_cached_response(&self, filenames: &[String], base_path: &Path) -> Option { + let cache_key = self.generate_cache_key(filenames); + + if let Some(entry) = self.entries.get(&cache_key) { + // Check if files have changed by comparing hashes + let mut all_files_unchanged = true; + + for filename in filenames { + let file_path = base_path.join(filename); + if let Ok(current_hash) = Self::hash_file(&file_path) { + if let Some(cached_hash) = entry.file_hashes.get(filename) { + if current_hash != *cached_hash { + all_files_unchanged = false; + break; + } + } else { + all_files_unchanged = false; + break; + } + } else { + // File doesn't exist or can't be read + all_files_unchanged = false; + break; + } + } + + if all_files_unchanged { + println!("Using cached response (timestamp: {})", entry.timestamp); + return Some(entry.response.clone()); + } + } + + None + } + + pub fn cache_response(&mut self, filenames: &[String], response: OrganizationPlan, base_path: &Path) { + let cache_key = self.generate_cache_key(filenames); + let mut file_hashes = HashMap::new(); + + // Hash all files for future change detection + for filename in filenames { + let file_path = base_path.join(filename); + if let Ok(hash) = Self::hash_file(&file_path) { + file_hashes.insert(filename.clone(), hash); + } + } + + let timestamp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + + let entry = CacheEntry { + response, + timestamp, + file_hashes, + }; + + self.entries.insert(cache_key, entry); + println!("Cached response for {} files", filenames.len()); + } + + fn generate_cache_key(&self, filenames: &[String]) -> String { + let mut sorted_filenames = filenames.to_vec(); + sorted_filenames.sort(); + + let mut hasher = Sha256::new(); + for filename in &sorted_filenames { + hasher.update(filename.as_bytes()); + hasher.update(b"|"); + } + + hex::encode(hasher.finalize()) + } + + fn hash_file(file_path: &Path) -> Result> { + if !file_path.exists() { + return Err("File does not exist".into()); + } + + let mut hasher = Sha256::new(); + let content = fs::read(file_path)?; + hasher.update(content); + + Ok(hex::encode(hasher.finalize())) + } + + pub fn cleanup_old_entries(&mut self, max_age_seconds: u64) { + let current_time = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + + let initial_count = self.entries.len(); + + self.entries.retain(|_, entry| { + current_time - entry.timestamp < max_age_seconds + }); + + let removed_count = initial_count - self.entries.len(); + if removed_count > 0 { + println!("Cleaned up {} old cache entries", removed_count); + } + } +} \ No newline at end of file diff --git a/src/files.rs b/src/files.rs index 0aeccff..e9b7ed0 100644 --- a/src/files.rs +++ b/src/files.rs @@ -1,50 +1,42 @@ +use colored::*; use serde::{Deserialize, Serialize}; -use std::{fs, path::Path, path::PathBuf}; +use std::io; +use std::{ffi::OsStr, fs, path::Path, path::PathBuf}; +use walkdir::WalkDir; -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Clone)] pub struct FileCategory { pub filename: String, pub category: String, + pub sub_category: String, } - -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Clone)] pub struct OrganizationPlan { pub files: Vec, } - #[derive(Debug)] pub struct FileBatch { pub filenames: Vec, pub paths: Vec, } - impl FileBatch { /// Reads a directory path and populates lists of all files inside it. /// It skips sub-directories (does not read recursively). pub fn from_path(root_path: PathBuf) -> Self { let mut filenames = Vec::new(); let mut paths = Vec::new(); - - // Check if the path exists and is a directory - if root_path.is_dir() { - // fs::read_dir returns a Result, so we must handle it - if let Ok(read_dir) = fs::read_dir(&root_path) { - for child in read_dir { - if let Ok(child) = child { - // We only want to list FILES, not sub-folders, - // otherwise we might try to move a folder into a folder - if child.path().is_file() { - filenames.push(child.file_name().to_string_lossy().into_owned()); - paths.push(child.path()); - } - } - } + for entry in WalkDir::new(&root_path) + .into_iter() + .filter_map(|e| e.ok()) + .filter(|e| e.path().is_file()) + { + if let Ok(relative_path) = entry.path().strip_prefix(&root_path) { + filenames.push(relative_path.to_string_lossy().into_owned()); + paths.push(entry.path().to_path_buf()); } } - FileBatch { filenames, paths } } - /// Helper to get the number of files found pub fn count(&self) -> usize { self.filenames.len() @@ -52,25 +44,144 @@ impl FileBatch { } pub fn execute_move(base_path: &Path, plan: OrganizationPlan) { + // --------------------------------------------------------- + // PHASE 1: PREVIEW (Show the plan) + // --------------------------------------------------------- + println!("\n{}", "--- EXECUTION PLAN ---".bold().underline()); + + if plan.files.is_empty() { + println!("{}", "No files to organize.".yellow()); + return; + } + + // Iterate by reference (&) so we don't consume the data yet + for item in &plan.files { + let mut target_display = format!("{}", item.category.green()); + if !item.sub_category.is_empty() { + target_display = format!("{}/{}", target_display, item.sub_category.blue()); + } + + println!("Plan: {} -> {}/", item.filename, target_display); + } + + // --------------------------------------------------------- + // PHASE 2: PROMPT (Ask for permission) + // --------------------------------------------------------- + eprint!("\nDo you want to apply these changes? [y/N]: "); + + let mut input = String::new(); + if io::stdin() + .read_line(&mut input) + .is_err() + { + println!("\n{}", "Failed to read input. Operation cancelled.".red()); + return; + } + + let input = input.trim().to_lowercase(); + + // If input is not "y" or "yes", abort. + if input != "y" && input != "yes" { + println!("\n{}", "Operation cancelled.".red()); + return; + } + + // --------------------------------------------------------- + // PHASE 3: EXECUTION (Actually move files) + // --------------------------------------------------------- + println!("\n{}", "--- MOVING FILES ---".bold().underline()); + for item in plan.files { let source = base_path.join(&item.filename); - let target_dir = base_path.join(&item.category); - let target = target_dir.join(&item.filename); - // 1. Create the category folder if it doesn't exist (e.g., "Downloads/Images") - if !target_dir.exists() { - fs::create_dir_all(&target_dir).expect("Failed to create folder"); - println!("Created folder: {:?}", item.category); + // Logic: Destination / Parent Category / Sub Category + let mut final_path = base_path.join(&item.category); + + if !item.sub_category.is_empty() { + final_path = final_path.join(&item.sub_category); + } + + let file_name = Path::new(&item.filename) + .file_name() + .unwrap_or_else(|| OsStr::new(&item.filename)) + .to_string_lossy() + .into_owned(); + + let target = final_path.join(&file_name); + + // 1. Create the category/sub-category folder + // (Only need to call this once per file path) + if let Err(e) = fs::create_dir_all(&final_path) { + println!( + "{} Failed to create dir {:?}: {}", + "ERROR:".red(), + final_path, + e + ); + continue; // Skip moving this file if we can't make the folder } // 2. Move the file if source.exists() { match fs::rename(&source, &target) { - Ok(_) => println!("Moved: {} -> {}/", item.filename, item.category), - Err(e) => println!("Failed to move {}: {}", item.filename, e), + Ok(_) => { + // Formatting the success message + if item.sub_category.is_empty() { + println!("Moved: {} -> {}/", item.filename, item.category.green()); + } else { + println!( + "Moved: {} -> {}/{}", + item.filename, + item.category.green(), + item.sub_category.blue() + ); + } + } + Err(e) => println!("{} Failed to move {}: {}", "ERROR:".red(), item.filename, e), } } else { - println!("Skipping: {} (File not found)", item.filename); + println!( + "{} Skipping {}: File not found", + "WARN:".yellow(), + item.filename + ); } } + + println!("\n{}", "Organization Complete!".bold().green()); +} // --- 1. Helper to check if file is likely text --- +pub fn is_text_file(path: &Path) -> bool { + let text_extensions = [ + "txt", "md", "rs", "py", "js", "html", "css", "json", "xml", "csv", + ]; + + if let Some(ext) = path.extension() { + if let Some(ext_str) = ext.to_str() { + return text_extensions.contains(&ext_str.to_lowercase().as_str()); + } + } + false +} + +// --- 2. Helper to safely read content (with limit) --- +pub fn read_file_sample(path: &Path, max_chars: usize) -> Option { + use std::io::Read; + // Attempt to open the file + let file = match fs::File::open(path) { + Ok(f) => f, + Err(_) => return None, + }; + + // Buffer to hold file contents + let mut buffer = Vec::new(); + + // Read the whole file (or you could use take() to limit bytes read for speed) + // For safety, let's limit the read to avoidance huge memory spikes on massive logs + let mut handle = file.take(max_chars as u64); + if handle.read_to_end(&mut buffer).is_err() { + return None; + } + + // Try to convert to UTF-8. If it fails (binary data), return None. + String::from_utf8(buffer).ok() } diff --git a/src/gemini.rs b/src/gemini.rs index 70c29fb..6df444c 100644 --- a/src/gemini.rs +++ b/src/gemini.rs @@ -1,18 +1,55 @@ -use crate::files::OrganizationPlan; +use crate::cache::Cache; +use crate::files::{FileCategory, OrganizationPlan}; +use crate::gemini_errors::GeminiError; use reqwest::Client; +use serde::Deserialize; use serde_json::json; +use std::path::Path; +use std::time::Duration; + +#[derive(Deserialize, Default)] +struct GeminiResponse { + candidates: Vec, +} + +#[derive(Deserialize)] +struct Candidate { + content: Content, +} + +#[derive(Deserialize)] +struct Content { + parts: Vec, +} + +#[derive(Deserialize)] +struct Part { + text: String, +} + +#[derive(Deserialize)] +struct FileCategoryResponse { + filename: String, + category: String, +} + +#[derive(Deserialize)] +struct OrganizationPlanResponse { + files: Vec, +} pub struct GeminiClient { api_key: String, client: Client, base_url: String, } + impl GeminiClient { pub fn new(api_key: String) -> Self { Self { api_key, client: Client::new(), - base_url: "https://generativelanguage.googleapis.com/v1beta/models/gemini-3-flash-preview:generateContent".to_string(), + base_url: "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent".to_string(), } } @@ -20,9 +57,26 @@ impl GeminiClient { pub async fn organize_files( &self, filenames: Vec, - ) -> Result> { + ) -> Result { + self.organize_files_with_cache(filenames, None, None).await + } + + /// Takes a list of filenames and asks Gemini to categorize them with caching support + pub async fn organize_files_with_cache( + &self, + filenames: Vec, + mut cache: Option<&mut Cache>, + base_path: Option<&Path>, + ) -> Result { let url = format!("{}?key={}", self.base_url, self.api_key); + // Check cache first if available + if let (Some(cache_ref), Some(base_path)) = (cache.as_ref(), base_path) { + if let Some(cached_response) = cache_ref.get_cached_response(&filenames, base_path) { + return Ok(cached_response); + } + } + // 1. Construct the Prompt let file_list_str = filenames.join(", "); let prompt = format!( @@ -42,24 +96,134 @@ impl GeminiClient { } }); - // 3. Send - let res = self.client.post(&url).json(&request_body).send().await?; + // 3. Send with retry logic + let res = self.send_request_with_retry(&url, &request_body).await?; // 4. Parse if res.status().is_success() { - let resp_json: serde_json::Value = res.json().await?; + let gemini_response: GeminiResponse = res.json().await.map_err(GeminiError::NetworkError)?; - // Extract the raw JSON string from Gemini - let raw_text = resp_json["candidates"][0]["content"]["parts"][0]["text"] - .as_str() - .ok_or("Failed to get text from Gemini")?; + // Extract raw JSON string from Gemini using proper structs + let raw_text = &gemini_response.candidates + .get(0) + .ok_or_else(|| GeminiError::InvalidResponse("No candidates in response".to_string()))? + .content.parts + .get(0) + .ok_or_else(|| GeminiError::InvalidResponse("No parts in content".to_string()))? + .text; + + // Deserialize into our temporary response struct + let plan_response: OrganizationPlanResponse = serde_json::from_str(raw_text)?; + + // Manually map to the final OrganizationPlan + let plan = OrganizationPlan { + files: plan_response + .files + .into_iter() + .map(|f| FileCategory { + filename: f.filename, + category: f.category, + sub_category: String::new(), // Initialize with empty sub_category + }) + .collect(), + }; + + // Cache the response if cache is available + if let (Some(cache), Some(base_path)) = (cache.as_mut(), base_path) { + cache.cache_response(&filenames, plan.clone(), base_path); + } - // Deserialize into our Rust Struct - let plan: OrganizationPlan = serde_json::from_str(raw_text)?; Ok(plan) } else { - let err = res.text().await?; - Err(format!("API Error: {}", err).into()) + Err(GeminiError::from_response(res).await) } } -} + + /// Send request with retry logic for retryable errors + async fn send_request_with_retry( + &self, + url: &str, + request_body: &serde_json::Value, + ) -> Result { + let mut attempts = 0; + let max_attempts = 3; + + loop { + attempts += 1; + + match self.client.post(url).json(request_body).send().await { + Ok(response) => { + if response.status().is_success() { + return Ok(response); + } + + let error = GeminiError::from_response(response).await; + + if error.is_retryable() && attempts < max_attempts { + if let Some(delay) = error.retry_delay() { + println!("API Error: {}. Retrying in {} seconds (attempt {}/{})", + error, delay.as_secs(), attempts, max_attempts); + tokio::time::sleep(delay).await; + continue; + } + } + + return Err(error); + } + Err(e) => { + if attempts < max_attempts { + println!("Network error: {}. Retrying in {} seconds (attempt {}/{})", + e, 5, attempts, max_attempts); + tokio::time::sleep(Duration::from_secs(5)).await; + continue; + } + return Err(GeminiError::NetworkError(e)); + } + } + } + } + + pub async fn get_ai_sub_category( + &self, + filename: &str, + parent_category: &str, + content: &str, + ) -> String { + let url = format!("{}?key={}", self.base_url, self.api_key); + + let prompt = format!( + "I have a file named '{}' inside the '{}' folder. Here is the first 1000 characters of the content:\n---\n{}\n---\nBased on this, suggest a single short sub-folder name (e.g., 'Invoices', 'Notes', 'Config'). Return ONLY the name of the sub-folder. Do not use markdown or explanations.", + filename, parent_category, content + ); + + let request_body = json!({ + "contents": [{ + "parts": [{ "text": prompt }] + }] + }); + + let res = self.client.post(&url).json(&request_body).send().await; + + if let Ok(res) = res { + if res.status().is_success() { + let gemini_response: GeminiResponse = res.json().await.unwrap_or_default(); + let sub_category = gemini_response.candidates + .get(0) + .and_then(|c| c.content.parts.get(0)) + .map(|p| p.text.trim()) + .unwrap_or("General") + .to_string(); + + if sub_category.is_empty() { + "General".to_string() + } else { + sub_category + } + } else { + "General".to_string() + } + } else { + "General".to_string() + } + } +} \ No newline at end of file diff --git a/src/gemini_errors.rs b/src/gemini_errors.rs new file mode 100644 index 0000000..1609b68 --- /dev/null +++ b/src/gemini_errors.rs @@ -0,0 +1,226 @@ +use reqwest::Response; +use serde::Deserialize; +use std::time::Duration; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum GeminiError { + #[error("API rate limit exceeded. Retry after {retry_after} seconds")] + RateLimitExceeded { retry_after: u32 }, + + #[error("Quota exceeded. Usage limit reached: {limit}")] + QuotaExceeded { limit: String }, + + #[error("Model '{model}' not found or unavailable")] + ModelNotFound { model: String }, + + #[error("Invalid API key. Please check your GEMINI_API_KEY")] + InvalidApiKey, + + #[error("Content policy violation: {reason}")] + ContentPolicyViolation { reason: String }, + + #[error("Invalid request: {details}")] + InvalidRequest { details: String }, + + #[error("Network error: {0}")] + NetworkError(#[from] reqwest::Error), + + #[error("Invalid response format: {0}")] + InvalidResponse(String), + + #[error("API error (HTTP {status}): {message}")] + ApiError { status: u16, message: String }, + + #[error("Service temporarily unavailable: {reason}")] + ServiceUnavailable { reason: String }, + + #[error("Request timeout after {seconds} seconds")] + Timeout { seconds: u64 }, + + #[error("JSON serialization/deserialization error: {0}")] + SerializationError(#[from] serde_json::Error), + + #[error("Internal server error: {details}")] + InternalError { details: String }, +} + +#[derive(Debug, Deserialize)] +struct GeminiErrorResponse { + error: GeminiErrorDetail, +} + +#[derive(Debug, Deserialize)] +struct GeminiErrorDetail { + code: i32, + message: String, + status: String, + #[serde(default)] + details: Vec, +} + +#[derive(Debug, Deserialize)] +struct GeminiErrorDetailInfo { + #[serde(rename = "@type")] + error_type: String, + #[serde(rename = "retryDelay")] + retry_delay: Option, + quota_limit: Option, + quota_metro: Option, +} + +impl GeminiError { + /// Parse HTTP response and convert to appropriate GeminiError + pub async fn from_response(response: Response) -> Self { + let status = response.status(); + + // Try to parse error response body + let error_text = match response.text().await { + Ok(text) => text, + Err(e) => { + return GeminiError::NetworkError(e); + } + }; + + // Try to parse structured error response + if let Ok(gemini_error) = serde_json::from_str::(&error_text) { + return Self::from_gemini_error(gemini_error.error, status.as_u16()); + } + + // Fallback to HTTP status code based errors + Self::from_status_code(status, &error_text) + } + + fn from_gemini_error(error_detail: GeminiErrorDetail, status: u16) -> Self { + let details = error_detail.details; + + match error_detail.status.as_str() { + "RESOURCE_EXHAUSTED" => { + if let Some(retry_info) = details.iter().find(|d| d.retry_delay.is_some()) { + if let Some(retry_delay) = &retry_info.retry_delay { + if let Ok(seconds) = retry_delay.parse::() { + return GeminiError::RateLimitExceeded { retry_after: seconds }; + } + } + } + + if let Some(quota_info) = details.iter().find(|d| d.quota_limit.is_some()) { + let limit = quota_info.quota_limit.as_deref().unwrap_or("unknown"); + return GeminiError::QuotaExceeded { + limit: limit.to_string() + }; + } + + GeminiError::QuotaExceeded { + limit: "usage limit".to_string() + } + } + "NOT_FOUND" => { + // Extract model name from message if possible + let model = extract_model_name(&error_detail.message); + GeminiError::ModelNotFound { model } + } + "UNAUTHENTICATED" => { + GeminiError::InvalidApiKey + } + "PERMISSION_DENIED" => { + if error_detail.message.to_lowercase().contains("policy") { + GeminiError::ContentPolicyViolation { + reason: error_detail.message + } + } else { + GeminiError::InvalidRequest { + details: error_detail.message + } + } + } + "INVALID_ARGUMENT" => { + GeminiError::InvalidRequest { + details: error_detail.message + } + } + "UNAVAILABLE" => { + GeminiError::ServiceUnavailable { + reason: error_detail.message + } + } + "DEADLINE_EXCEEDED" => { + GeminiError::Timeout { seconds: 60 } + } + "INTERNAL" => { + GeminiError::InternalError { + details: error_detail.message + } + } + _ => { + GeminiError::ApiError { + status, + message: error_detail.message + } + } + } + } + + fn from_status_code(status: reqwest::StatusCode, error_text: &str) -> Self { + match status.as_u16() { + 400 => GeminiError::InvalidRequest { + details: error_text.to_string() + }, + 401 => GeminiError::InvalidApiKey, + 403 => GeminiError::ContentPolicyViolation { + reason: error_text.to_string() + }, + 404 => GeminiError::ModelNotFound { + model: "unknown".to_string() + }, + 429 => GeminiError::RateLimitExceeded { retry_after: 60 }, + 500 => GeminiError::InternalError { + details: error_text.to_string() + }, + 502 | 503 | 504 => GeminiError::ServiceUnavailable { + reason: error_text.to_string() + }, + _ => GeminiError::ApiError { + status: status.as_u16(), + message: error_text.to_string() + }, + } + } + + /// Check if this error is retryable + pub fn is_retryable(&self) -> bool { + match self { + GeminiError::RateLimitExceeded { .. } => true, + GeminiError::ServiceUnavailable { .. } => true, + GeminiError::Timeout { .. } => true, + GeminiError::NetworkError(_) => true, + GeminiError::InternalError { .. } => true, + _ => false, + } + } + + /// Get retry delay for retryable errors + pub fn retry_delay(&self) -> Option { + match self { + GeminiError::RateLimitExceeded { retry_after } => { + Some(Duration::from_secs(*retry_after as u64)) + } + GeminiError::ServiceUnavailable { .. } => Some(Duration::from_secs(30)), + GeminiError::NetworkError(_) => Some(Duration::from_secs(5)), + GeminiError::Timeout { .. } => Some(Duration::from_secs(10)), + GeminiError::InternalError { .. } => Some(Duration::from_secs(15)), + _ => None, + } + } +} + +fn extract_model_name(message: &str) -> String { + // Try to extract model name from error message + // Example: "Model 'gemini-1.5-flash' not found" + if let Some(start) = message.find('\'') { + if let Some(end) = message[start + 1..].find('\'') { + return message[start + 1..start + 1 + end].to_string(); + } + } + "unknown".to_string() +} \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 9be8a09..7eb1697 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,2 +1,4 @@ +pub mod cache; pub mod files; pub mod gemini; +pub mod gemini_errors; diff --git a/src/main.rs b/src/main.rs index f15f59b..6bcc54f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,22 +1,32 @@ -use std::path::PathBuf; - -use noentropy::files::FileBatch; -use noentropy::files::OrganizationPlan; -use noentropy::files::execute_move; +use colored::*; +use futures::future::join_all; +use noentropy::cache::Cache; +use noentropy::files::{FileBatch, OrganizationPlan, execute_move}; use noentropy::gemini::GeminiClient; +use noentropy::gemini_errors::GeminiError; +use std::path::{Path, PathBuf}; +use std::sync::Arc; #[tokio::main] async fn main() -> Result<(), Box> { dotenv::dotenv().ok(); - let api_key = std::env::var("GEMINI_API_KEY").expect("KEY not set"); - let download_path_var = std::env::var("DOWNLOAD_FOLDER").expect("Set DOWNLOAD_FOLDER={path}"); + let api_key = std::env::var("GEMINI_API_KEY") + .map_err(|_| "GEMINI_API_KEY environment variable not set. Please set it in your .env file.")?; + let download_path_var = std::env::var("DOWNLOAD_FOLDER") + .map_err(|_| "DOWNLOAD_FOLDER environment variable not set. Please set it in your .env file.")?; // 1. Setup let download_path: PathBuf = PathBuf::from(download_path_var.to_string()); let client: GeminiClient = GeminiClient::new(api_key); + + // Initialize cache + let cache_path = Path::new(".noentropy_cache.json"); + let mut cache = Cache::load_or_create(cache_path); + + // Clean up old cache entries (older than 7 days) + cache.cleanup_old_entries(7 * 24 * 60 * 60); - // 2. Get Files (Using your previous FileBatch logic) - // Assuming FileBatch::from_path returns a struct with .filenames + // 2. Get Files let batch = FileBatch::from_path(download_path.clone()); if batch.filenames.is_empty() { @@ -26,17 +36,131 @@ async fn main() -> Result<(), Box> { println!( "Found {} files. Asking Gemini to organize...", - batch.filenames.len() + batch.count() ); - // 3. Call Gemini - let plan: OrganizationPlan = client.organize_files(batch.filenames).await?; + // 3. Call Gemini for Initial Categorization + let mut plan: OrganizationPlan = match client + .organize_files_with_cache(batch.filenames, Some(&mut cache), Some(&download_path)) + .await + { + Ok(plan) => plan, + Err(e) => { + handle_gemini_error(e); + return Ok(()); + } + }; - println!("Gemini Plan received! Moving files..."); + println!("Gemini Plan received! Performing deep inspection..."); - // 4. Execute + // 4. Deep Inspection - Process files concurrently + let client = Arc::new(client); + + let tasks: Vec<_> = plan.files.iter_mut() + .zip(batch.paths.iter()) + .map(|(file_category, path)| { + let client = Arc::clone(&client); + let filename = file_category.filename.clone(); + let category = file_category.category.clone(); + let path = path.clone(); + + async move { + if noentropy::files::is_text_file(&path) { + if let Some(content) = noentropy::files::read_file_sample(&path, 2000) { + println!("Reading content of {}...", filename.green()); + client.get_ai_sub_category(&filename, &category, &content).await + } else { + String::new() + } + } else { + String::new() + } + } + }) + .collect(); + + // Wait for all concurrent tasks to complete + let sub_categories = join_all(tasks).await; + + // Apply the results back to the plan + for (file_category, sub_category) in plan.files.iter_mut().zip(sub_categories) { + file_category.sub_category = sub_category; + } + + println!("Deep inspection complete! Moving Files....."); + // 5. Execute execute_move(&download_path, plan); - println!("Done!"); + + // Save cache before exiting + if let Err(e) = cache.save(cache_path) { + println!("Warning: Failed to save cache: {}", e); + } + Ok(()) } + +fn handle_gemini_error(error: GeminiError) { + use colored::*; + + match error { + GeminiError::RateLimitExceeded { retry_after } => { + println!("{} API rate limit exceeded. Please wait {} seconds before trying again.", + "ERROR:".red(), retry_after); + } + GeminiError::QuotaExceeded { limit } => { + println!("{} Quota exceeded: {}. Please check your Gemini API usage.", + "ERROR:".red(), limit); + } + GeminiError::ModelNotFound { model } => { + println!("{} Model '{}' not found. Please check the model name in the configuration.", + "ERROR:".red(), model); + } + GeminiError::InvalidApiKey => { + println!("{} Invalid API key. Please check your GEMINI_API_KEY environment variable.", + "ERROR:".red()); + } + GeminiError::ContentPolicyViolation { reason } => { + println!("{} Content policy violation: {}", + "ERROR:".red(), reason); + } + GeminiError::ServiceUnavailable { reason } => { + println!("{} Gemini service is temporarily unavailable: {}", + "ERROR:".red(), reason); + } + GeminiError::NetworkError(e) => { + println!("{} Network error: {}", + "ERROR:".red(), e); + } + GeminiError::Timeout { seconds } => { + println!("{} Request timed out after {} seconds.", + "ERROR:".red(), seconds); + } + GeminiError::InvalidRequest { details } => { + println!("{} Invalid request: {}", + "ERROR:".red(), details); + } + GeminiError::ApiError { status, message } => { + println!("{} API error (HTTP {}): {}", + "ERROR:".red(), status, message); + } + GeminiError::InvalidResponse(msg) => { + println!("{} Invalid response from Gemini: {}", + "ERROR:".red(), msg); + } + GeminiError::InternalError { details } => { + println!("{} Internal server error: {}", + "ERROR:".red(), details); + } + GeminiError::SerializationError(e) => { + println!("{} JSON serialization error: {}", + "ERROR:".red(), e); + } + } + + println!("\n{} Check the following:", "HINT:".yellow()); + println!(" • Your GEMINI_API_KEY is correctly set"); + println!(" • Your internet connection is working"); + println!(" • Gemini API service is available"); + println!(" • You haven't exceeded your API quota"); +}