mirror of
				https://github.com/huggingface/text-generation-inference.git
				synced 2023-08-15 01:09:35 +03:00 
			
		
		
		
	feat(launcher): default num_shard to CUDA_VISIBLE_DEVICES if possible (#108)
This commit is contained in:
		@@ -115,13 +115,11 @@ fn main() -> ExitCode {
 | 
			
		||||
                    None => {
 | 
			
		||||
                        // try to default to the number of available GPUs
 | 
			
		||||
                        tracing::info!("Parsing num_shard from CUDA_VISIBLE_DEVICES");
 | 
			
		||||
                        let cuda_visible_devices = env::var("CUDA_VISIBLE_DEVICES")
 | 
			
		||||
                        let n_devices = num_cuda_devices()
 | 
			
		||||
                            .expect("--num-shard and CUDA_VISIBLE_DEVICES are not set");
 | 
			
		||||
                        let n_devices = cuda_visible_devices.split(",").count();
 | 
			
		||||
                        if n_devices <= 1 {
 | 
			
		||||
                            panic!("`sharded` is true but only found {n_devices} CUDA devices");
 | 
			
		||||
                        }
 | 
			
		||||
                        tracing::info!("Sharding on {n_devices} found CUDA devices");
 | 
			
		||||
                        n_devices
 | 
			
		||||
                    }
 | 
			
		||||
                    Some(num_shard) => {
 | 
			
		||||
@@ -144,9 +142,19 @@ fn main() -> ExitCode {
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    } else {
 | 
			
		||||
        // default to a single shard
 | 
			
		||||
        num_shard.unwrap_or(1)
 | 
			
		||||
        match num_shard {
 | 
			
		||||
            // get num_shard from CUDA_VISIBLE_DEVICES or default to a single shard
 | 
			
		||||
            None => num_cuda_devices().unwrap_or(1),
 | 
			
		||||
            Some(num_shard) => num_shard,
 | 
			
		||||
        }
 | 
			
		||||
    };
 | 
			
		||||
    if num_shard < 1 {
 | 
			
		||||
        panic!("`num_shard` cannot be < 1");
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if num_shard > 1 {
 | 
			
		||||
        tracing::info!("Sharding model on {num_shard} processes");
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Signal handler
 | 
			
		||||
    let running = Arc::new(AtomicBool::new(true));
 | 
			
		||||
@@ -669,3 +677,11 @@ fn shutdown_shards(shutdown: Arc<Mutex<bool>>, shutdown_receiver: &mpsc::Receive
 | 
			
		||||
    // This will block till all shutdown_sender are dropped
 | 
			
		||||
    let _ = shutdown_receiver.recv();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn num_cuda_devices() -> Option<usize> {
 | 
			
		||||
    if let Ok(cuda_visible_devices) = env::var("CUDA_VISIBLE_DEVICES") {
 | 
			
		||||
        let n_devices = cuda_visible_devices.split(',').count();
 | 
			
		||||
        return Some(n_devices);
 | 
			
		||||
    }
 | 
			
		||||
    None
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user