Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 127 additions & 23 deletions src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,29 +148,16 @@ impl Client {
}

pub async fn get_text<U: IntoUrl>(&self, url: U) -> Result<String> {
self.get_text_with_headers(url, &HeaderMap::new()).await
self.get_text_request(url).send().await
}

pub async fn get_text_with_headers<U: IntoUrl>(
&self,
url: U,
extra_headers: &HeaderMap,
) -> Result<String> {
let mut url = url.into_url().unwrap();
// Merge GitHub headers with any extra headers provided
let mut headers = host_auth_headers(&url);
headers.extend(extra_headers.clone());
let resp = self.get_async_with_headers(url.clone(), &headers).await?;
let text = resp.text().await?;
if text.starts_with("<!DOCTYPE html>") {
if url.scheme() == "http" {
// try with https since http may be blocked
url.set_scheme("https").unwrap();
return Box::pin(self.get_text_with_headers(url, extra_headers)).await;
}
bail!("Got HTML instead of text from {}", url);
pub fn get_text_request<U: IntoUrl>(&self, url: U) -> TextRequest<'_> {
TextRequest {
client: self,
url: url.into_url().unwrap(),
extra_headers: HeaderMap::new(),
retries: Settings::get().http_retries,
}
Ok(text)
}

/// Like get_text but caches results in memory for the duration of the process.
Expand Down Expand Up @@ -376,7 +363,25 @@ impl Client {
headers: &HeaderMap,
verb_label: &str,
) -> Result<Response> {
retry_async(verb_label, &url, || async {
self.send_with_https_fallback_with_retries(
method,
url,
headers,
verb_label,
Settings::get().http_retries,
)
.await
}

async fn send_with_https_fallback_with_retries(
&self,
method: Method,
url: Url,
headers: &HeaderMap,
verb_label: &str,
retries: i64,
) -> Result<Response> {
retry_async_with_retries(verb_label, &url, retries, || async {
self.send_once_with_https_fallback(method.clone(), url.clone(), headers, verb_label)
.await
})
Expand Down Expand Up @@ -499,6 +504,52 @@ impl Client {
}
}

pub struct TextRequest<'a> {
client: &'a Client,
url: Url,
extra_headers: HeaderMap,
retries: i64,
}

impl TextRequest<'_> {
pub fn headers(mut self, headers: &HeaderMap) -> Self {
self.extra_headers.extend(headers.clone());
self
}

pub fn retries(mut self, retries: i64) -> Self {
self.retries = retries;
self
}

pub async fn send(mut self) -> Result<String> {
ensure!(!Settings::get().offline(), "offline mode is enabled");
// Merge GitHub headers with any extra headers provided
let mut headers = host_auth_headers(&self.url);
headers.extend(self.extra_headers.clone());
let resp = self
.client
.send_with_https_fallback_with_retries(
Method::GET,
self.url.clone(),
&headers,
"GET",
self.retries,
)
.await?;
let text = resp.text().await?;
if text.starts_with("<!DOCTYPE html>") {
if self.url.scheme() == "http" {
// try with https since http may be blocked
self.url.set_scheme("https").unwrap();
return Box::pin(self.send()).await;
}
bail!("Got HTML instead of text from {}", self.url);
}
Ok(text)
}
Comment thread
cursor[bot] marked this conversation as resolved.
Comment thread
greptile-apps[bot] marked this conversation as resolved.
}

fn is_authenticated_github_forbidden(url: &Url, headers: &HeaderMap, resp: &Response) -> bool {
resp.status() == StatusCode::FORBIDDEN
&& url.host_str() == Some("api.github.com")
Expand Down Expand Up @@ -705,12 +756,25 @@ pub(crate) fn is_transient(err: &Report) -> bool {
/// infrastructure as it's happening, instead of waiting through the backoff
/// schedule. Successful rescues and final exhaustion don't get extra warnings
/// — the caller surfaces the outcome.
pub(crate) async fn retry_async<F, Fut, T>(verb_label: &str, url: &Url, mut f: F) -> Result<T>
pub(crate) async fn retry_async<F, Fut, T>(verb_label: &str, url: &Url, f: F) -> Result<T>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = Result<T>>,
{
let mut backoff = default_backoff_strategy(Settings::get().http_retries);
retry_async_with_retries(verb_label, url, Settings::get().http_retries, f).await
}

pub(crate) async fn retry_async_with_retries<F, Fut, T>(
verb_label: &str,
url: &Url,
retries: i64,
mut f: F,
) -> Result<T>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = Result<T>>,
{
let mut backoff = default_backoff_strategy(retries);
let mut attempt: usize = 1;
loop {
match f().await {
Expand Down Expand Up @@ -830,6 +894,13 @@ mod tests {
crate::config::Settings::reset(Some(settings));
SettingsGuard { _lock: lock }
}
fn set_test_offline() -> SettingsGuard {
let lock = TEST_SETTINGS_LOCK.lock().unwrap();
let mut settings = crate::config::settings::SettingsPartial::empty();
settings.offline = Some(true);
crate::config::Settings::reset(Some(settings));
SettingsGuard { _lock: lock }
}

// A tiny in-process HTTP/1.1 responder. Each accepted connection consumes
// the next response from `responses` and writes it back. Returns the bound
Expand Down Expand Up @@ -933,6 +1004,39 @@ mod tests {
assert_eq!(count.load(std::sync::atomic::Ordering::SeqCst), 2);
}

#[tokio::test(flavor = "current_thread")]
async fn test_text_request_can_override_retry_count() {
let _guard = set_test_http_retries(3);
let (port, count) = spawn_canned_server(vec![
bad_gateway_response(),
bad_gateway_response(),
ok_response(),
])
.await;
let url: Url = format!("http://127.0.0.1:{}/", port).parse().unwrap();
let client = Client::new(Duration::from_secs(2), ClientKind::Http).unwrap();
let err = client
.get_text_request(url)
.retries(1)
.send()
.await
.unwrap_err();
assert!(format!("{err:?}").contains("502"));
// Should stop after the initial request plus the single overridden retry.
assert_eq!(count.load(std::sync::atomic::Ordering::SeqCst), 2);
}

#[tokio::test(flavor = "current_thread")]
async fn test_text_request_respects_offline_mode() {
let _guard = set_test_offline();
let (port, count) = spawn_canned_server(vec![ok_response()]).await;
let url: Url = format!("http://127.0.0.1:{}/", port).parse().unwrap();
let client = Client::new(Duration::from_secs(2), ClientKind::Http).unwrap();
let err = client.get_text_request(url).send().await.unwrap_err();
assert_eq!(err.to_string(), "offline mode is enabled");
assert_eq!(count.load(std::sync::atomic::Ordering::SeqCst), 0);
}

#[test]
fn test_backoff_strategy_yields_requested_count_beyond_schedule() {
// Regression: a fixed-length schedule used to silently cap retries at 4.
Expand Down
7 changes: 6 additions & 1 deletion src/versions_host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ struct VersionsResponse {
versions: indexmap::IndexMap<String, VersionEntry>,
}

const VERSION_LIST_RETRIES: i64 = 1;

#[derive(serde::Deserialize)]
struct VersionEntry {
created_at: toml::value::Datetime,
Expand Down Expand Up @@ -97,7 +99,10 @@ pub async fn list_versions(tool: &str) -> eyre::Result<Option<Vec<VersionInfo>>>
// Use TOML format which includes created_at timestamps
let url = format!("https://mise-versions.jdx.dev/tools/{}.toml", tool);
let versions: Vec<VersionInfo> = match HTTP_FETCH
.get_text_with_headers(&url, &VERSIONS_HOST_HEADERS)
.get_text_request(&url)
.headers(&VERSIONS_HOST_HEADERS)
.retries(VERSION_LIST_RETRIES)
.send()
.await
{
Ok(body) => {
Expand Down
Loading