Skip to content

Commit

Permalink
feat(candle): better cuda error
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene committed Jun 21, 2024
1 parent 901a0d4 commit 4601286
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use crate::models::{
use crate::models::{
FlashBertModel, FlashDistilBertModel, FlashJinaBertModel, FlashNomicBertModel,
};
use anyhow::Context;
use candle::{DType, Device};
use candle_nn::VarBuilder;
use models::BertConfig;
Expand Down Expand Up @@ -55,9 +56,11 @@ impl CandleBackend {
) -> Result<Self, BackendError> {
// Load config
let config: String = std::fs::read_to_string(model_path.join("config.json"))
.map_err(|err| BackendError::Start(err.to_string()))?;
.context("Unable to read config file")
.map_err(|err| BackendError::Start(format!("{err:?}")))?;
let config: Config = serde_json::from_str(&config)
.map_err(|err| BackendError::Start(format!("Model is not supported: {}", err)))?;
.context("Model is not supported")
.map_err(|err| BackendError::Start(format!("{err:?}")))?;

// Get candle device
let device = if candle::utils::cuda_is_available() {
Expand All @@ -72,7 +75,7 @@ impl CandleBackend {
)))
}
Err(err) => {
tracing::warn!("Could not find a compatible CUDA device on host: {err}");
tracing::warn!("Could not find a compatible CUDA device on host: {err:?}");
tracing::warn!("Using CPU instead");
Ok(Device::Cpu)
}
Expand Down

0 comments on commit 4601286

Please sign in to comment.