Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve automatic device map #1076

Merged
merged 4 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/DEVICE_MAPPING.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ components on the GPU.

To control the mapping across devices, you can set the following maximum parameters which the model should expect in a prompt.

- maximum sequence length
- maximum batch size
- (vision models) maximum image length (length refers to the edge length)
- (vision models) maximum number of images
- maximum sequence length (default: 4096)
- maximum batch size (default: 1)
- (vision models) maximum image length (length refers to the edge length) (default: 1024)
- (vision models) maximum number of images (default: 1)

These parameters do not translate to hard limits during runtime, they only control the mapping.

Expand Down
37 changes: 33 additions & 4 deletions mistralrs-core/src/pipeline/loaders/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use std::{
use anyhow::{Context, Result};
use as_any::AsAny;
use candle_core::{DType, Device};
use itertools::Itertools;
use mistralrs_quant::IsqType;
use tokio::sync::Mutex;

Expand All @@ -22,7 +23,7 @@ pub use normal_loaders::{
Phi2Loader, Phi3Loader, Phi3_5MoELoader, Qwen2Loader, Starcoder2Loader,
};

use tracing::warn;
use tracing::{info, warn};
pub use vision_loaders::{
Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader, Phi3VLoader, Qwen2VLLoader,
VLlamaLoader, VisionLoaderType, VisionModel, VisionModelLoader,
Expand Down Expand Up @@ -406,10 +407,10 @@ impl Display for AutoDeviceMapParams {
}

impl AutoDeviceMapParams {
pub const DEFAULT_MAX_SEQ_LEN: usize = 16 * 1024;
pub const DEFAULT_MAX_SEQ_LEN: usize = 4 * 1024;
pub const DEFAULT_MAX_BATCH_SIZE: usize = 1;
pub const DEFAULT_MAX_NUM_IMAGES: usize = 1;
pub const DEFAULT_MAX_IMAGE_LENGTH: usize = 2 * 1024;
pub const DEFAULT_MAX_IMAGE_LENGTH: usize = 1024;

pub fn default_text() -> Self {
Self::Text {
Expand All @@ -431,6 +432,19 @@ impl AutoDeviceMapParams {
}
}

#[derive(Clone, Debug)]
pub(crate) enum NonMappedSubModel {
Vision,
}

impl Display for NonMappedSubModel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Vision => write!(f, "vision"),
}
}
}

fn calculate_key_block_shape(
model_config: &dyn ModelConfigLike,
dtype: DType,
Expand Down Expand Up @@ -485,6 +499,9 @@ pub trait DeviceMappedModelLoader {
dtype: DType,
weight_pack_factor: usize,
) -> Result<Vec<usize>>;
fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
None
}
fn num_layers(&self, config: &str) -> Result<usize>;
fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>>;

Expand Down Expand Up @@ -585,6 +602,16 @@ pub trait DeviceMappedModelLoader {

let mut device_layers = Vec::new();

info!("Using automatic device mapping parameters: {params}.");
if let Some(sub_models) = self.non_mapped_sub_models() {
let (_, last) = per_layer_avail.last().unwrap();
info!(
"The following sub-models will not be device mapped and will be loaded on {}: {}",
last.device_pretty_repr(),
sub_models.iter().map(|x| x.to_string()).join(", ")
);
}

let mut current_ordinal = 0;
let mut current_layer = 0;
let per_layer_avail_cpy = per_layer_avail.clone();
Expand Down Expand Up @@ -627,7 +654,9 @@ pub trait DeviceMappedModelLoader {

// Device w/ ordinal 0 carries the non-mapped things
if current_ordinal == 0 {
used_capacity += non_mapped_size_in_bytes + non_mapped_max_act_size_in_bytes;
// Ensure the activations are properly handled
used_capacity = used_capacity.max(non_mapped_max_act_size_in_bytes);
used_capacity += non_mapped_size_in_bytes;
}

while let Some(&last) = layer_sizes_in_bytes.last() {
Expand Down
30 changes: 29 additions & 1 deletion mistralrs-core/src/pipeline/loaders/vision_loaders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use pyo3::pyclass;
use regex::Regex;
use serde::Deserialize;

use super::{DeviceMappedModelLoader, NormalLoadingMetadata};
use super::{DeviceMappedModelLoader, NonMappedSubModel, NormalLoadingMetadata};
use crate::amoe::AnyMoeBaseModelMixin;
use crate::device_map::DeviceMapper;
use crate::layers::Conv3dConfig;
Expand Down Expand Up @@ -462,6 +462,10 @@ impl DeviceMappedModelLoader for Phi3VLoader {

Ok(Box::new(cfg))
}

fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
Some(vec![NonMappedSubModel::Vision])
}
}

// ======================== Idefics 2 loader
Expand Down Expand Up @@ -791,6 +795,10 @@ impl DeviceMappedModelLoader for Idefics2Loader {

Ok(Box::new(cfg))
}

fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
Some(vec![NonMappedSubModel::Vision])
}
}

// ======================== LLaVANext Loader
Expand Down Expand Up @@ -1039,6 +1047,10 @@ impl DeviceMappedModelLoader for LLaVANextLoader {

Ok(Box::new(cfg))
}

fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
Some(vec![NonMappedSubModel::Vision])
}
}

// ======================== LLaVA Loader
Expand Down Expand Up @@ -1279,6 +1291,10 @@ impl DeviceMappedModelLoader for LLaVALoader {

Ok(Box::new(cfg))
}

fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
Some(vec![NonMappedSubModel::Vision])
}
}

// ======================== MLlama Loader
Expand Down Expand Up @@ -1653,6 +1669,10 @@ impl DeviceMappedModelLoader for VLlamaLoader {

Ok(Box::new(cfg))
}

fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
Some(vec![NonMappedSubModel::Vision])
}
}

// ======================== Qwen2VL Loader
Expand Down Expand Up @@ -1936,6 +1956,10 @@ impl DeviceMappedModelLoader for Qwen2VLLoader {

Ok(Box::new(cfg))
}

fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
Some(vec![NonMappedSubModel::Vision])
}
}

// ======================== Idefics 3 loader
Expand Down Expand Up @@ -2213,4 +2237,8 @@ impl DeviceMappedModelLoader for Idefics3Loader {

Ok(Box::new(cfg))
}

fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
Some(vec![NonMappedSubModel::Vision])
}
}
11 changes: 6 additions & 5 deletions mistralrs-core/src/pipeline/vision.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,11 +292,6 @@ impl Loader for VisionLoader {
mapper = DeviceMapSetting::Map(new);
}

info!(
"Model config: {:?}",
self.inner
.get_config_repr(&config, self.config.use_flash_attn)?
);
let pipeline_mapper = mapper.into_mapper(
self.inner.get_total_device_mapping_num_layers(&config)?,
device,
Expand All @@ -322,6 +317,12 @@ impl Loader for VisionLoader {
paged_attn_config = None;
}

info!(
"Model config: {:?}",
self.inner
.get_config_repr(&config, self.config.use_flash_attn)?
);

let mut loading_isq = in_situ_quant.is_some() || self.config.from_uqff.is_some();
if let Some(ref topology) = self.config.topology {
loading_isq |= topology
Expand Down
6 changes: 3 additions & 3 deletions mistralrs-pyo3/mistralrs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class TextAutoMapParams:
These affects automatic device mapping but are not a hard limit.
"""

max_seq_len: int = 16 * 1024
max_seq_len: int = 4* 1024
max_batch_size: int = 1

@dataclass
Expand All @@ -130,10 +130,10 @@ class VisionAutoMapParams:
These affects automatic device mapping but are not a hard limit.
"""

max_seq_len: int = 16 * 1024
max_seq_len: int = 4* 1024
max_batch_size: int = 1
max_num_images: int = 1
max_image_length: int = 2 * 1024
max_image_length: int = 1024

class Which(Enum):
"""
Expand Down
Loading