diff --git a/src/main.rs b/src/main.rs index 4a4f219..4ce0b30 100644 --- a/src/main.rs +++ b/src/main.rs @@ -204,10 +204,7 @@ fn main() -> Result<()> { } Subcommands::Lsp => { - let mut project = RustAnalyzerProject::new(); - project - .get_sysroot_src() - .expect("Couldn't find toolchain path, do you have `rustc` installed?"); + let mut project = RustAnalyzerProject::build()?; project .exercises_to_json() .expect("Couldn't parse rustlings exercises files"); diff --git a/src/project.rs b/src/project.rs index 93f941d..a7414d1 100644 --- a/src/project.rs +++ b/src/project.rs @@ -1,3 +1,4 @@ +use anyhow::{bail, Context, Result}; use glob::glob; use serde::{Deserialize, Serialize}; use std::env; @@ -22,11 +23,44 @@ pub struct Crate { } impl RustAnalyzerProject { - pub fn new() -> RustAnalyzerProject { - RustAnalyzerProject { - sysroot_src: String::new(), - crates: Vec::new(), + pub fn build() -> Result { + // check if RUST_SRC_PATH is set + if let Ok(sysroot_src) = env::var("RUST_SRC_PATH") { + return Ok(Self { + sysroot_src, + crates: Vec::new(), + }); } + + let toolchain = Command::new("rustc") + .arg("--print") + .arg("sysroot") + .output() + .context("Failed to get the sysroot from `rustc`. Do you have `rustc` installed?")? + .stdout; + + let toolchain = + String::from_utf8(toolchain).context("The toolchain path is invalid UTF8")?; + let toolchain = toolchain.trim_end(); + + println!("Determined toolchain: {toolchain}\n"); + + let Ok(sysroot_src) = Path::new(toolchain) + .join("lib") + .join("rustlib") + .join("src") + .join("rust") + .join("library") + .into_os_string() + .into_string() + else { + bail!("The sysroot path is invalid UTF8"); + }; + + Ok(Self { + sysroot_src, + crates: Vec::new(), + }) } /// Write rust-project.json to disk @@ -66,39 +100,4 @@ impl RustAnalyzerProject { } Ok(()) } - - /// Use `rustc` to determine the default toolchain - pub fn get_sysroot_src(&mut self) -> Result<(), Box> { - // check if RUST_SRC_PATH is set - if let Ok(path) = env::var("RUST_SRC_PATH") { - self.sysroot_src = path; - return Ok(()); - } - - let toolchain = Command::new("rustc") - .arg("--print") - .arg("sysroot") - .output()? - .stdout; - - let toolchain = String::from_utf8(toolchain)?; - let toolchain = toolchain.trim_end(); - - println!("Determined toolchain: {toolchain}\n"); - - let Ok(path) = Path::new(toolchain) - .join("lib") - .join("rustlib") - .join("src") - .join("rust") - .join("library") - .into_os_string() - .into_string() - else { - return Err("The sysroot path is invalid UTF8".into()); - }; - self.sysroot_src = path; - - Ok(()) - } }