Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: download and cache repodata.json #55

Merged
merged 14 commits into from
Feb 20, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions crates/rattler/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,18 @@ rustls-tls = ['reqwest/rustls-tls']
[dependencies]
anyhow = "1.0.44"
apple-codesign = "0.22.0"
async-compression = { version = "0.3.12", features = ["gzip", "futures-bufread", "tokio", "bzip2"] }
async-compression = { version = "0.3.12", features = ["gzip", "tokio", "bzip2", "zstd"] }
blake2 = "0.10.6"
bytes = "1.1.0"
cache_control = "0.2.0"
chrono = { version = "0.4.23", default-features = false, features = ["std", "serde", "alloc"] }
digest = "0.10.6"
dirs = "4.0.0"
extendhash = "1.0.9"
futures = "0.3.17"
fxhash = "0.2.1"
hex = "0.4.3"
humansize = "2.1.3"
indicatif = { version = "0.17.1", features = ["improved_unicode"] }
itertools = "0.10.3"
libc = "0.2"
Expand Down Expand Up @@ -48,16 +52,20 @@ tracing = "0.1.29"
url = { version = "2.2.2", features = ["serde"] }
uuid = { version = "1.3.0", features = ["v4", "fast-rng"] }

[target.'cfg(windows)'.dependencies]
windows-sys = { version = "0.45.0", features = ["Win32_Storage_FileSystem", "Win32_Foundation", "Win32_System_IO"] }

[dev-dependencies]
assert_matches = "1.5.0"
axum = "0.6.2"
hex-literal = "0.3.4"
insta = { version = "1.16.0", features = ["yaml"] }
proptest = "1.0.0"
rand = "0.8.4"
rstest = "0.16.0"
tokio-test = "0.4.2"
tower-http = { version = "0.3.5", features = ["fs", "compression-gzip"] }
tracing-test = "0.2.4"
tower-http = { version = "0.3.5", features = ["fs", "compression-gzip", "trace"] }
tracing-test = { version = "0.2.4" }

[build-dependencies]
cc = "1"
Expand Down
14 changes: 3 additions & 11 deletions crates/rattler/resources/channels/empty/noarch/repodata.json
Git LFS file not shown
80 changes: 80 additions & 0 deletions crates/rattler/src/repo_data/cache/cache_headers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
use reqwest::{
header,
header::{HeaderMap, HeaderValue},
Response,
};
use serde::{Deserialize, Serialize};

/// Extracted HTTP response headers that enable caching the repodata.json files.
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct CacheHeaders {
/// The ETag HTTP cache header
#[serde(default, skip_serializing_if = "Option::is_none")]
pub etag: Option<String>,

/// The Last-Modified HTTP cache header
#[serde(default, skip_serializing_if = "Option::is_none", rename = "mod")]
pub last_modified: Option<String>,

/// The cache control configuration
#[serde(default, skip_serializing_if = "Option::is_none")]
pub cache_control: Option<String>,
}

impl From<&Response> for CacheHeaders {
fn from(response: &Response) -> Self {
// Get the ETag from the response (if any). This can be used to cache the result during a
// next request.
let etag = response
.headers()
.get(header::ETAG)
.and_then(|header| header.to_str().ok())
.map(ToOwned::to_owned);

// Get the last modified time. This can also be used to cache the result during a next
// request.
let last_modified = response
.headers()
.get(header::LAST_MODIFIED)
.and_then(|header| header.to_str().ok())
.map(ToOwned::to_owned);

// Get the cache-control headers so we possibly perform local caching.
let cache_control = response
.headers()
.get(header::CACHE_CONTROL)
.and_then(|header| header.to_str().ok())
.map(ToOwned::to_owned);

Self {
etag,
last_modified,
cache_control,
}
}
}

impl CacheHeaders {
/// Adds the headers to the specified request to short-circuit if the content is still up to
/// date.
pub fn add_to_request(&self, headers: &mut HeaderMap) {
// If previously there was an etag header, add the If-None-Match header so the server only sends
// us new data if the etag is not longer valid.
if let Some(etag) = self
.etag
.as_deref()
.and_then(|etag| HeaderValue::from_str(etag).ok())
{
headers.insert(header::IF_NONE_MATCH, etag);
}
// If a previous request contains a Last-Modified header, add the If-Modified-Since header to let
// the server send us new data if the contents has been modified since that date.
if let Some(last_modified) = self
.last_modified
.as_deref()
.and_then(|last_modifed| HeaderValue::from_str(last_modifed).ok())
{
headers.insert(header::IF_MODIFIED_SINCE, last_modified);
}
}
}
170 changes: 170 additions & 0 deletions crates/rattler/src/repo_data/cache/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
mod cache_headers;

pub use cache_headers::CacheHeaders;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::{fs::File, io::Read, path::Path, str::FromStr, time::SystemTime};
use url::Url;

/// Representation of the `.state.json` file alongside a `repodata.json` file.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RepoDataState {
/// The URL from where the repodata was downloaded. This is the URL of the `repodata.json`,
/// `repodata.json.zst`, or another variant. This is different from the subdir url which does
/// NOT include the final filename.
pub url: Url,

/// The HTTP cache headers send along with the last response.
#[serde(flatten)]
pub cache_headers: CacheHeaders,

/// The timestamp of the repodata.json on disk
#[serde(
deserialize_with = "duration_from_nanos",
serialize_with = "duration_to_nanos",
rename = "mtime_ns"
)]
pub cache_last_modified: SystemTime,

/// The size of the repodata.json file on disk.
#[serde(rename = "size")]
pub cache_size: u64,

/// The blake2 hash of the file
#[serde(
default,
skip_serializing_if = "Option::is_none",
deserialize_with = "deserialize_blake2_hash",
serialize_with = "serialize_blake2_hash"
)]
pub blake2_hash: Option<blake2::digest::Output<blake2::Blake2s256>>,

/// Whether or not zst is available for the subdirectory
pub has_zst: Option<Expiring<bool>>,

/// Whether a bz2 compressed version is available for the subdirectory
pub has_bz2: Option<Expiring<bool>>,

/// Whether or not JLAP is available for the subdirectory
pub has_jlap: Option<Expiring<bool>>,
}

impl RepoDataState {
/// Reads and parses a file from disk.
pub fn from_path(path: &Path) -> Result<RepoDataState, std::io::Error> {
let content = {
let mut file = File::open(path)?;
let mut content = Default::default();
file.read_to_string(&mut content)?;
content
};
Ok(Self::from_str(&content)?)
}

/// Save the cache state to the specified file.
pub fn to_path(&self, path: &Path) -> Result<(), std::io::Error> {
let file = File::create(path)?;
Ok(serde_json::to_writer_pretty(file, self)?)
}
}

impl FromStr for RepoDataState {
type Err = serde_json::Error;

fn from_str(s: &str) -> Result<Self, Self::Err> {
serde_json::from_str(s)
}
}

/// Represents a value and when the value was last checked.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Expiring<T> {
pub value: T,

// #[serde(with = "chrono::serde::ts_seconds")]
pub last_checked: chrono::DateTime<chrono::Utc>,
}

impl<T> Expiring<T> {
pub fn value(&self, expiration: chrono::Duration) -> Option<&T> {
if chrono::Utc::now().signed_duration_since(self.last_checked) >= expiration {
None
} else {
Some(&self.value)
}
}
}

/// Deserializes a [`SystemTime`] by parsing an integer and converting that as a nanosecond based unix
/// epoch timestamp to a [`SystemTime`].
fn duration_from_nanos<'de, D>(deserializer: D) -> Result<SystemTime, D::Error>
where
D: Deserializer<'de>,
{
use serde::de::Error;
SystemTime::UNIX_EPOCH
.checked_add(std::time::Duration::from_nanos(Deserialize::deserialize(
deserializer,
)?))
.ok_or_else(|| D::Error::custom("the time cannot be represented internally"))
}

/// Serializes a [`SystemTime`] by converting it to a nanosecond based unix epoch timestamp.
fn duration_to_nanos<S: Serializer>(time: &SystemTime, s: S) -> Result<S::Ok, S::Error> {
use serde::ser::Error;
time.duration_since(SystemTime::UNIX_EPOCH)
.map_err(|_| S::Error::custom("duration cannot be computed for file time"))?
.as_nanos()
.serialize(s)
}

fn deserialize_blake2_hash<'de, D>(
deserializer: D,
) -> Result<Option<blake2::digest::Output<blake2::Blake2s256>>, D::Error>
where
D: Deserializer<'de>,
{
use serde::de::Error;
match Option::<&'de str>::deserialize(deserializer)? {
Some(str) => {
let mut hash = <blake2::digest::Output<blake2::Blake2s256>>::default();
hex::decode_to_slice(str, &mut hash).map_err(D::Error::custom)?;
Ok(Some(hash))
}
None => Ok(None),
}
}

fn serialize_blake2_hash<S: Serializer>(
time: &Option<blake2::digest::Output<blake2::Blake2s256>>,
s: S,
) -> Result<S::Ok, S::Error> {
match time.as_ref() {
None => s.serialize_none(),
Some(hash) => format!("{:x}", hash).serialize(s),
}
}

#[cfg(test)]
mod test {
use super::RepoDataState;
use std::str::FromStr;

#[test]
pub fn test_parse_repo_data_state() {
insta::assert_yaml_snapshot!(RepoDataState::from_str(
r#"{
"cache_control": "public, max-age=1200",
"etag": "\"bec332621e00fc4ad87ba185171bcf46\"",
"has_zst": {
"last_checked": "2023-02-13T14:08:50Z",
"value": true
},
"mod": "Mon, 13 Feb 2023 13:49:56 GMT",
"mtime_ns": 1676297333020928000,
"size": 156627374,
baszalmstra marked this conversation as resolved.
Show resolved Hide resolved
"url": "https://conda.anaconda.org/conda-forge/win-64/repodata.json.zst"
}"#,
)
.unwrap());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
---
source: crates/rattler/src/repo_data/cache/mod.rs
expression: "RepoDataState::from_str(r#\"{\n \"cache_control\": \"public, max-age=1200\",\n \"etag\": \"\\\"bec332621e00fc4ad87ba185171bcf46\\\"\",\n \"has_zst\": {\n \"last_checked\": \"2023-02-13T14:08:50Z\",\n \"value\": true\n },\n \"mod\": \"Mon, 13 Feb 2023 13:49:56 GMT\",\n \"mtime_ns\": 1676297333020928000,\n \"size\": 156627374,\n \"url\": \"https://conda.anaconda.org/conda-forge/win-64/repodata.json.zst\"\n }\"#).unwrap()"
---
url: "https://conda.anaconda.org/conda-forge/win-64/repodata.json.zst"
etag: "\"bec332621e00fc4ad87ba185171bcf46\""
mod: "Mon, 13 Feb 2023 13:49:56 GMT"
cache_control: "public, max-age=1200"
mtime_ns: 1676297333020928000
size: 156627374
has_zst:
value: true
last_checked: "2023-02-13T14:08:50Z"
has_bz2: ~
has_jlap: ~

5 changes: 2 additions & 3 deletions crates/rattler/src/repo_data/fetch/request/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,11 @@ pub async fn fetch_repodata(
#[cfg(test)]
mod test {
use super::fetch_repodata;
use std::path::PathBuf;
use crate::get_test_data_dir;

#[tokio::test]
async fn test_fetch_file() {
let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
let subdir_path = manifest_dir.join("resources/channels/empty/noarch/repodata.json");
let subdir_path = get_test_data_dir().join("channels/empty/noarch/repodata.json");
let _ = fetch_repodata(&subdir_path, &mut |_| {}).await.unwrap();
}
}
8 changes: 3 additions & 5 deletions crates/rattler/src/repo_data/fetch/request/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,20 +326,19 @@ fn add_download_progress_listener<'s, E>(
mod test {
use std::fs::File;
use std::io::BufReader;
use std::path::PathBuf;
use std::str::FromStr;
use tempfile::TempDir;
use url::Url;

use super::{create_cache_file, fetch_repodata, read_cache_file, RepoDataMetadata};
use crate::get_test_data_dir;
use crate::repo_data::fetch::request::REPODATA_CHANNEL_PATH;
use crate::utils::simple_channel_server::SimpleChannelServer;
use rattler_conda_types::{Channel, ChannelConfig, Platform};

#[tokio::test]
async fn test_fetch_http() {
let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
let channel_path = manifest_dir.join("resources/channels/empty");
let channel_path = get_test_data_dir().join("channels/empty");

let server = SimpleChannelServer::new(channel_path);
let url = server.url().to_string();
Expand All @@ -360,8 +359,7 @@ mod test {

#[tokio::test]
async fn test_http_fetch_cache() {
let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
let channel_path = manifest_dir.join("resources/channels/empty");
let channel_path = get_test_data_dir().join("channels/empty");

let server = SimpleChannelServer::new(channel_path);
let url = server.url().to_string();
Expand Down
Loading