diff --git a/src/http.rs b/src/http.rs index 74c58e727e..2c96c77f1d 100644 --- a/src/http.rs +++ b/src/http.rs @@ -148,29 +148,16 @@ impl Client { } pub async fn get_text(&self, url: U) -> Result { - self.get_text_with_headers(url, &HeaderMap::new()).await + self.get_text_request(url).send().await } - pub async fn get_text_with_headers( - &self, - url: U, - extra_headers: &HeaderMap, - ) -> Result { - let mut url = url.into_url().unwrap(); - // Merge GitHub headers with any extra headers provided - let mut headers = host_auth_headers(&url); - headers.extend(extra_headers.clone()); - let resp = self.get_async_with_headers(url.clone(), &headers).await?; - let text = resp.text().await?; - if text.starts_with("") { - if url.scheme() == "http" { - // try with https since http may be blocked - url.set_scheme("https").unwrap(); - return Box::pin(self.get_text_with_headers(url, extra_headers)).await; - } - bail!("Got HTML instead of text from {}", url); + pub fn get_text_request(&self, url: U) -> TextRequest<'_> { + TextRequest { + client: self, + url: url.into_url().unwrap(), + extra_headers: HeaderMap::new(), + retries: Settings::get().http_retries, } - Ok(text) } /// Like get_text but caches results in memory for the duration of the process. @@ -376,7 +363,25 @@ impl Client { headers: &HeaderMap, verb_label: &str, ) -> Result { - retry_async(verb_label, &url, || async { + self.send_with_https_fallback_with_retries( + method, + url, + headers, + verb_label, + Settings::get().http_retries, + ) + .await + } + + async fn send_with_https_fallback_with_retries( + &self, + method: Method, + url: Url, + headers: &HeaderMap, + verb_label: &str, + retries: i64, + ) -> Result { + retry_async_with_retries(verb_label, &url, retries, || async { self.send_once_with_https_fallback(method.clone(), url.clone(), headers, verb_label) .await }) @@ -499,6 +504,52 @@ impl Client { } } +pub struct TextRequest<'a> { + client: &'a Client, + url: Url, + extra_headers: HeaderMap, + retries: i64, +} + +impl TextRequest<'_> { + pub fn headers(mut self, headers: &HeaderMap) -> Self { + self.extra_headers.extend(headers.clone()); + self + } + + pub fn retries(mut self, retries: i64) -> Self { + self.retries = retries; + self + } + + pub async fn send(mut self) -> Result { + ensure!(!Settings::get().offline(), "offline mode is enabled"); + // Merge GitHub headers with any extra headers provided + let mut headers = host_auth_headers(&self.url); + headers.extend(self.extra_headers.clone()); + let resp = self + .client + .send_with_https_fallback_with_retries( + Method::GET, + self.url.clone(), + &headers, + "GET", + self.retries, + ) + .await?; + let text = resp.text().await?; + if text.starts_with("") { + if self.url.scheme() == "http" { + // try with https since http may be blocked + self.url.set_scheme("https").unwrap(); + return Box::pin(self.send()).await; + } + bail!("Got HTML instead of text from {}", self.url); + } + Ok(text) + } +} + fn is_authenticated_github_forbidden(url: &Url, headers: &HeaderMap, resp: &Response) -> bool { resp.status() == StatusCode::FORBIDDEN && url.host_str() == Some("api.github.com") @@ -705,12 +756,25 @@ pub(crate) fn is_transient(err: &Report) -> bool { /// infrastructure as it's happening, instead of waiting through the backoff /// schedule. Successful rescues and final exhaustion don't get extra warnings /// — the caller surfaces the outcome. -pub(crate) async fn retry_async(verb_label: &str, url: &Url, mut f: F) -> Result +pub(crate) async fn retry_async(verb_label: &str, url: &Url, f: F) -> Result where F: FnMut() -> Fut, Fut: std::future::Future>, { - let mut backoff = default_backoff_strategy(Settings::get().http_retries); + retry_async_with_retries(verb_label, url, Settings::get().http_retries, f).await +} + +pub(crate) async fn retry_async_with_retries( + verb_label: &str, + url: &Url, + retries: i64, + mut f: F, +) -> Result +where + F: FnMut() -> Fut, + Fut: std::future::Future>, +{ + let mut backoff = default_backoff_strategy(retries); let mut attempt: usize = 1; loop { match f().await { @@ -830,6 +894,13 @@ mod tests { crate::config::Settings::reset(Some(settings)); SettingsGuard { _lock: lock } } + fn set_test_offline() -> SettingsGuard { + let lock = TEST_SETTINGS_LOCK.lock().unwrap(); + let mut settings = crate::config::settings::SettingsPartial::empty(); + settings.offline = Some(true); + crate::config::Settings::reset(Some(settings)); + SettingsGuard { _lock: lock } + } // A tiny in-process HTTP/1.1 responder. Each accepted connection consumes // the next response from `responses` and writes it back. Returns the bound @@ -933,6 +1004,39 @@ mod tests { assert_eq!(count.load(std::sync::atomic::Ordering::SeqCst), 2); } + #[tokio::test(flavor = "current_thread")] + async fn test_text_request_can_override_retry_count() { + let _guard = set_test_http_retries(3); + let (port, count) = spawn_canned_server(vec![ + bad_gateway_response(), + bad_gateway_response(), + ok_response(), + ]) + .await; + let url: Url = format!("http://127.0.0.1:{}/", port).parse().unwrap(); + let client = Client::new(Duration::from_secs(2), ClientKind::Http).unwrap(); + let err = client + .get_text_request(url) + .retries(1) + .send() + .await + .unwrap_err(); + assert!(format!("{err:?}").contains("502")); + // Should stop after the initial request plus the single overridden retry. + assert_eq!(count.load(std::sync::atomic::Ordering::SeqCst), 2); + } + + #[tokio::test(flavor = "current_thread")] + async fn test_text_request_respects_offline_mode() { + let _guard = set_test_offline(); + let (port, count) = spawn_canned_server(vec![ok_response()]).await; + let url: Url = format!("http://127.0.0.1:{}/", port).parse().unwrap(); + let client = Client::new(Duration::from_secs(2), ClientKind::Http).unwrap(); + let err = client.get_text_request(url).send().await.unwrap_err(); + assert_eq!(err.to_string(), "offline mode is enabled"); + assert_eq!(count.load(std::sync::atomic::Ordering::SeqCst), 0); + } + #[test] fn test_backoff_strategy_yields_requested_count_beyond_schedule() { // Regression: a fixed-length schedule used to silently cap retries at 4. diff --git a/src/versions_host.rs b/src/versions_host.rs index 05ab604987..377a9ebe9b 100644 --- a/src/versions_host.rs +++ b/src/versions_host.rs @@ -52,6 +52,8 @@ struct VersionsResponse { versions: indexmap::IndexMap, } +const VERSION_LIST_RETRIES: i64 = 1; + #[derive(serde::Deserialize)] struct VersionEntry { created_at: toml::value::Datetime, @@ -97,7 +99,10 @@ pub async fn list_versions(tool: &str) -> eyre::Result>> // Use TOML format which includes created_at timestamps let url = format!("https://mise-versions.jdx.dev/tools/{}.toml", tool); let versions: Vec = match HTTP_FETCH - .get_text_with_headers(&url, &VERSIONS_HOST_HEADERS) + .get_text_request(&url) + .headers(&VERSIONS_HOST_HEADERS) + .retries(VERSION_LIST_RETRIES) + .send() .await { Ok(body) => {