libobs_wrapper\bootstrap/
download.rs1use std::{env::temp_dir, path::PathBuf};
2
3use anyhow::Context;
4use async_stream::stream;
5use futures_core::Stream;
6use futures_util::StreamExt;
7use libobs::{LIBOBS_API_MAJOR_VER, LIBOBS_API_MINOR_VER};
8use semver::Version;
9use sha2::{Digest, Sha256};
10use tokio::{fs::File, io::AsyncWriteExt};
11use uuid::Uuid;
12
13use super::{github_types, LIBRARY_OBS_VERSION};
14
15pub enum DownloadStatus {
16 Error(anyhow::Error),
17 Progress(f32, String),
18 Done(PathBuf),
19}
20
21pub(crate) async fn download_obs(repo: &str) -> anyhow::Result<impl Stream<Item = DownloadStatus>> {
22 let client = reqwest::ClientBuilder::new()
24 .user_agent("clipture-rs")
25 .build()?;
26
27 let releases_url = format!("https://api.github.com/repos/{}/releases", repo);
28 let releases: github_types::Root = client.get(&releases_url).send().await?.json().await?;
29
30 let mut possible_versions = vec![];
31 for release in releases {
32 let tag = release.tag_name.replace("obs-build-", "");
33 let version = Version::parse(&tag).context("Parsing version")?;
34
35 if version.major == LIBOBS_API_MAJOR_VER as u64
37 && version.minor == LIBOBS_API_MINOR_VER as u64
38 {
39 possible_versions.push(release);
40 }
41 }
42
43 let latest_version = possible_versions
44 .iter()
45 .max_by_key(|r| &r.published_at)
46 .context(format!(
47 "Finding a matching obs version for {}",
48 *LIBRARY_OBS_VERSION
49 ))?;
50
51 let archive_url = latest_version
52 .assets
53 .iter()
54 .find(|a| a.name.ends_with(".7z"))
55 .context("Finding 7z asset")?
56 .browser_download_url
57 .clone();
58
59 let hash_url = latest_version
60 .assets
61 .iter()
62 .find(|a| a.name.ends_with(".sha256"))
63 .context("Finding sha256 asset")?
64 .browser_download_url
65 .clone();
66
67 let res = client.get(archive_url).send().await?;
68 let length = res.content_length().unwrap_or(0);
69
70 let mut bytes_stream = res.bytes_stream();
71
72 let path = PathBuf::new()
73 .join(temp_dir())
74 .join(format!("{}.7z", Uuid::new_v4()));
75 let mut tmp_file = File::create_new(&path)
76 .await
77 .context("Creating temporary file")?;
78
79 let mut curr_len = 0;
80 let mut hasher = Sha256::new();
81 Ok(stream! {
82 yield DownloadStatus::Progress(0.0, "Downloading OBS".to_string());
83 while let Some(chunk) = bytes_stream.next().await {
84 let chunk = chunk.context("Retrieving data from stream");
85 if let Err(e) = chunk {
86 yield DownloadStatus::Error(e);
87 return;
88 }
89
90 let chunk = chunk.unwrap();
91 hasher.update(&chunk);
92 let r = tmp_file.write_all(&chunk).await.context("Writing to temporary file");
93 if let Err(e) = r {
94 yield DownloadStatus::Error(e);
95 return;
96 }
97
98 curr_len = std::cmp::min(curr_len + chunk.len() as u64, length);
99 yield DownloadStatus::Progress(curr_len as f32 / length as f32, "Downloading OBS".to_string());
100 }
101
102 let remote_hash = client.get(hash_url).send().await.context("Fetching hash");
104 if let Err(e) = remote_hash {
105 yield DownloadStatus::Error(e);
106 return;
107 }
108
109 let remote_hash = remote_hash.unwrap().text().await.context("Reading hash");
110 if let Err(e) = remote_hash {
111 yield DownloadStatus::Error(e);
112 return;
113 }
114
115 let remote_hash = remote_hash.unwrap();
116 let remote_hash = hex::decode(remote_hash.trim()).context("Decoding hash");
117 if let Err(e) = remote_hash {
118 yield DownloadStatus::Error(e);
119 return;
120 }
121
122 let remote_hash = remote_hash.unwrap();
123
124 let local_hash = hasher.finalize();
126 if local_hash.as_slice() != remote_hash {
127 yield DownloadStatus::Error(anyhow::anyhow!("Hash mismatch"));
128 return;
129 }
130
131 log::info!("Hashes match");
132 yield DownloadStatus::Done(path);
133 })
134}