mirror of
				https://github.com/huggingface/text-generation-inference.git
				synced 2023-08-15 01:09:35 +03:00 
			
		
		
		
	@@ -26,7 +26,9 @@ try:
 | 
			
		||||
 | 
			
		||||
    FLASH_ATTENTION = torch.cuda.is_available()
 | 
			
		||||
except ImportError:
 | 
			
		||||
    logger.opt(exception=True).warning("Could not import Flash Attention enabled models")
 | 
			
		||||
    logger.opt(exception=True).warning(
 | 
			
		||||
        "Could not import Flash Attention enabled models"
 | 
			
		||||
    )
 | 
			
		||||
    FLASH_ATTENTION = False
 | 
			
		||||
 | 
			
		||||
__all__ = [
 | 
			
		||||
@@ -88,10 +90,10 @@ def get_model(
 | 
			
		||||
                raise NotImplementedError(
 | 
			
		||||
                    FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
 | 
			
		||||
                )
 | 
			
		||||
            return FlashSantacoderSharded(model_id, revision=revision)
 | 
			
		||||
            return FlashSantacoderSharded(model_id, revision, quantize=quantize)
 | 
			
		||||
        else:
 | 
			
		||||
            santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
 | 
			
		||||
            return santacoder_cls(model_id, revision, quantize)
 | 
			
		||||
            return santacoder_cls(model_id, revision, quantize=quantize)
 | 
			
		||||
 | 
			
		||||
    config = AutoConfig.from_pretrained(model_id, revision=revision)
 | 
			
		||||
    model_type = config.model_type
 | 
			
		||||
 
 | 
			
		||||
@@ -33,6 +33,12 @@ import dropout_layer_norm
 | 
			
		||||
 | 
			
		||||
from flash_attn.layers.rotary import RotaryEmbedding
 | 
			
		||||
 | 
			
		||||
HAS_BITS_AND_BYTES = True
 | 
			
		||||
try:
 | 
			
		||||
    from bitsandbytes.nn import Linear8bitLt
 | 
			
		||||
except ImportError as e:
 | 
			
		||||
    HAS_BITS_AND_BYTES = False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LlamaRMSNorm(nn.Module):
 | 
			
		||||
    def __init__(self, hidden_size, eps=1e-6):
 | 
			
		||||
@@ -94,14 +100,44 @@ class FastLinear(nn.Linear):
 | 
			
		||||
        dtype=None,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
 | 
			
		||||
        self.quantized = False
 | 
			
		||||
        self.bnb_linear = None
 | 
			
		||||
 | 
			
		||||
    def transpose_weight(self):
 | 
			
		||||
        self.weight = nn.Parameter(self.weight.T)
 | 
			
		||||
    def prepare_weights(self, quantize: bool = False):
 | 
			
		||||
        if quantize:
 | 
			
		||||
            if not HAS_BITS_AND_BYTES:
 | 
			
		||||
                raise ImportError(
 | 
			
		||||
                    "bitsandbytes is not available on your machine either because it is not installed "
 | 
			
		||||
                    "or you don't have a GPU.\n"
 | 
			
		||||
                    "You can install it with `pip install bitsandbytes`."
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            self.quantized = True
 | 
			
		||||
            self.bnb_linear = Linear8bitLt(
 | 
			
		||||
                self.in_features,
 | 
			
		||||
                self.out_features,
 | 
			
		||||
                has_fp16_weights=False,
 | 
			
		||||
                threshold=6.0,
 | 
			
		||||
                bias=False,
 | 
			
		||||
            )
 | 
			
		||||
            # Copy data to bnb_linear
 | 
			
		||||
            self.bnb_linear.weight.data = self.weight.data
 | 
			
		||||
            if self.bias is not None:
 | 
			
		||||
                self.bnb_linear.bias = nn.Parameter(self.bias)
 | 
			
		||||
 | 
			
		||||
            # Delete reference to data
 | 
			
		||||
            self.weight = None
 | 
			
		||||
            self.bias = None
 | 
			
		||||
        else:
 | 
			
		||||
            self.weight = nn.Parameter(self.weight.T)
 | 
			
		||||
 | 
			
		||||
    def forward(self, input: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
        if self.bias is not None:
 | 
			
		||||
            return torch.addmm(self.bias, input, self.weight)
 | 
			
		||||
        return torch.matmul(input, self.weight)
 | 
			
		||||
        if self.quantized:
 | 
			
		||||
            return self.bnb_linear(input)
 | 
			
		||||
        else:
 | 
			
		||||
            if self.bias is not None:
 | 
			
		||||
                return torch.addmm(self.bias, input, self.weight)
 | 
			
		||||
            return torch.matmul(input, self.weight)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TensorParallelColumnLinear(FastLinear):
 | 
			
		||||
@@ -502,15 +538,15 @@ class FlashLlamaModel(torch.nn.Module):
 | 
			
		||||
        self.head_size = self.layers[0].self_attn.head_size
 | 
			
		||||
        self.num_heads = self.layers[0].self_attn.num_heads
 | 
			
		||||
 | 
			
		||||
    def post_load_weights(self):
 | 
			
		||||
    def post_load_weights(self, load_in_8bit: bool = False):
 | 
			
		||||
        if isinstance(self.embed_tokens, TensorParallelEmbedding):
 | 
			
		||||
            self.embed_tokens.add_null_idx()
 | 
			
		||||
        for layer in self.layers:
 | 
			
		||||
            layer: FlashLlamaLayer
 | 
			
		||||
            layer.self_attn.query_key_value.transpose_weight()
 | 
			
		||||
            layer.self_attn.o_proj.transpose_weight()
 | 
			
		||||
            layer.mlp.gate_up_proj.transpose_weight()
 | 
			
		||||
            layer.mlp.down_proj.transpose_weight()
 | 
			
		||||
            layer.self_attn.query_key_value.prepare_weights(load_in_8bit)
 | 
			
		||||
            layer.self_attn.o_proj.prepare_weights(load_in_8bit)
 | 
			
		||||
            layer.mlp.gate_up_proj.prepare_weights(load_in_8bit)
 | 
			
		||||
            layer.mlp.down_proj.prepare_weights(load_in_8bit)
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
@@ -592,9 +628,9 @@ class FlashLlamaForCausalLM(torch.nn.Module):
 | 
			
		||||
        else:
 | 
			
		||||
            self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
 | 
			
		||||
 | 
			
		||||
    def post_load_weights(self):
 | 
			
		||||
        self.model.post_load_weights()
 | 
			
		||||
        self.lm_head.transpose_weight()
 | 
			
		||||
    def post_load_weights(self, load_in_8bit: bool = False):
 | 
			
		||||
        self.model.post_load_weights(load_in_8bit)
 | 
			
		||||
        self.lm_head.prepare_weights()
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
 
 | 
			
		||||
@@ -35,6 +35,12 @@ import dropout_layer_norm
 | 
			
		||||
 | 
			
		||||
from flash_attn.layers.rotary import RotaryEmbedding
 | 
			
		||||
 | 
			
		||||
HAS_BITS_AND_BYTES = True
 | 
			
		||||
try:
 | 
			
		||||
    from bitsandbytes.nn import Linear8bitLt
 | 
			
		||||
except ImportError as e:
 | 
			
		||||
    HAS_BITS_AND_BYTES = False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FastLayerNorm(nn.LayerNorm):
 | 
			
		||||
    def forward(self, hidden_states, residual=None):
 | 
			
		||||
@@ -82,14 +88,44 @@ class FastLinear(nn.Linear):
 | 
			
		||||
        dtype=None,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
 | 
			
		||||
        self.quantized = False
 | 
			
		||||
        self.bnb_linear = None
 | 
			
		||||
 | 
			
		||||
    def transpose_weight(self):
 | 
			
		||||
        self.weight = nn.Parameter(self.weight.T)
 | 
			
		||||
    def prepare_weights(self, quantize: bool = False):
 | 
			
		||||
        if quantize:
 | 
			
		||||
            if not HAS_BITS_AND_BYTES:
 | 
			
		||||
                raise ImportError(
 | 
			
		||||
                    "bitsandbytes is not available on your machine either because it is not installed "
 | 
			
		||||
                    "or you don't have a GPU.\n"
 | 
			
		||||
                    "You can install it with `pip install bitsandbytes`."
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            self.quantized = True
 | 
			
		||||
            self.bnb_linear = Linear8bitLt(
 | 
			
		||||
                self.in_features,
 | 
			
		||||
                self.out_features,
 | 
			
		||||
                has_fp16_weights=False,
 | 
			
		||||
                threshold=6.0,
 | 
			
		||||
                bias=False,
 | 
			
		||||
            )
 | 
			
		||||
            # Copy data to bnb_linear
 | 
			
		||||
            self.bnb_linear.weight.data = self.weight.data
 | 
			
		||||
            if self.bias is not None:
 | 
			
		||||
                self.bnb_linear.bias = nn.Parameter(self.bias)
 | 
			
		||||
 | 
			
		||||
            # Delete reference to data
 | 
			
		||||
            self.weight = None
 | 
			
		||||
            self.bias = None
 | 
			
		||||
        else:
 | 
			
		||||
            self.weight = nn.Parameter(self.weight.T)
 | 
			
		||||
 | 
			
		||||
    def forward(self, input: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
        if self.bias is not None:
 | 
			
		||||
            return torch.addmm(self.bias, input, self.weight)
 | 
			
		||||
        return torch.matmul(input, self.weight)
 | 
			
		||||
        if self.quantized:
 | 
			
		||||
            return self.bnb_linear(input)
 | 
			
		||||
        else:
 | 
			
		||||
            if self.bias is not None:
 | 
			
		||||
                return torch.addmm(self.bias, input, self.weight)
 | 
			
		||||
            return torch.matmul(input, self.weight)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TensorParallelColumnLinear(FastLinear):
 | 
			
		||||
@@ -552,23 +588,27 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
 | 
			
		||||
        self.head_size = self.layers[0].attention.head_size
 | 
			
		||||
        self.num_heads = self.layers[0].attention.num_heads
 | 
			
		||||
 | 
			
		||||
    def post_load_weights(self):
 | 
			
		||||
    def post_load_weights(self, load_in_8bit=False):
 | 
			
		||||
        if isinstance(self.embed_in, TensorParallelEmbedding):
 | 
			
		||||
            self.embed_in.add_null_idx()
 | 
			
		||||
        for layer in self.layers:
 | 
			
		||||
            layer: FlashNeoXLayer
 | 
			
		||||
            layer.attention.shuffle_qkv_dims()
 | 
			
		||||
            layer.attention.query_key_value.transpose_weight()
 | 
			
		||||
            layer.attention.dense.transpose_weight()
 | 
			
		||||
            layer.mlp.dense_h_to_4h.transpose_weight()
 | 
			
		||||
            layer.mlp.dense_4h_to_h.transpose_weight()
 | 
			
		||||
            layer.attention.query_key_value.prepare_weights(load_in_8bit)
 | 
			
		||||
            layer.attention.dense.prepare_weights(load_in_8bit)
 | 
			
		||||
            layer.mlp.dense_h_to_4h.prepare_weights(load_in_8bit)
 | 
			
		||||
            layer.mlp.dense_4h_to_h.prepare_weights(load_in_8bit)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
 | 
			
		||||
        # Pop here as we will replace the layer in our own logic and don't want from_pretrained
 | 
			
		||||
        # to do it for us
 | 
			
		||||
        load_in_8bit = kwargs.pop("load_in_8bit", False)
 | 
			
		||||
        model = super(FlashGPTNeoXModel, cls).from_pretrained(
 | 
			
		||||
            pretrained_model_name_or_path, *model_args, **kwargs
 | 
			
		||||
            pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
 | 
			
		||||
        )
 | 
			
		||||
        model.post_load_weights()
 | 
			
		||||
 | 
			
		||||
        model.post_load_weights(load_in_8bit)
 | 
			
		||||
        return model
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
@@ -653,16 +693,19 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
 | 
			
		||||
                config.hidden_size, config.vocab_size, bias=False
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def post_load_weights(self):
 | 
			
		||||
        self.gpt_neox.post_load_weights()
 | 
			
		||||
        self.embed_out.transpose_weight()
 | 
			
		||||
    def post_load_weights(self, load_in_8bit=False):
 | 
			
		||||
        self.gpt_neox.post_load_weights(load_in_8bit)
 | 
			
		||||
        self.embed_out.prepare_weights()
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
 | 
			
		||||
        # Pop here as we will replace the layer in our own logic and don't want from_pretrained
 | 
			
		||||
        # to do it for us
 | 
			
		||||
        load_in_8bit = kwargs.pop("load_in_8bit", False)
 | 
			
		||||
        model = super(FlashGPTNeoXForCausalLM, cls).from_pretrained(
 | 
			
		||||
            pretrained_model_name_or_path, *model_args, **kwargs
 | 
			
		||||
            pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
 | 
			
		||||
        )
 | 
			
		||||
        model.post_load_weights()
 | 
			
		||||
        model.post_load_weights(load_in_8bit)
 | 
			
		||||
        return model
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
 
 | 
			
		||||
@@ -10,6 +10,12 @@ from transformers.activations import ACT2FN
 | 
			
		||||
import flash_attn_cuda
 | 
			
		||||
import dropout_layer_norm
 | 
			
		||||
 | 
			
		||||
HAS_BITS_AND_BYTES = True
 | 
			
		||||
try:
 | 
			
		||||
    from bitsandbytes.nn import Linear8bitLt
 | 
			
		||||
except ImportError as e:
 | 
			
		||||
    HAS_BITS_AND_BYTES = False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FastLayerNorm(nn.LayerNorm):
 | 
			
		||||
    def forward(self, hidden_states, residual=None):
 | 
			
		||||
@@ -57,14 +63,44 @@ class FastLinear(nn.Linear):
 | 
			
		||||
        dtype=None,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
 | 
			
		||||
        self.quantized = False
 | 
			
		||||
        self.bnb_linear = None
 | 
			
		||||
 | 
			
		||||
    def transpose_weight(self):
 | 
			
		||||
        self.weight = nn.Parameter(self.weight.T)
 | 
			
		||||
    def prepare_weights(self, quantize: bool = False):
 | 
			
		||||
        if quantize:
 | 
			
		||||
            if not HAS_BITS_AND_BYTES:
 | 
			
		||||
                raise ImportError(
 | 
			
		||||
                    "bitsandbytes is not available on your machine either because it is not installed "
 | 
			
		||||
                    "or you don't have a GPU.\n"
 | 
			
		||||
                    "You can install it with `pip install bitsandbytes`."
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            self.quantized = True
 | 
			
		||||
            self.bnb_linear = Linear8bitLt(
 | 
			
		||||
                self.in_features,
 | 
			
		||||
                self.out_features,
 | 
			
		||||
                has_fp16_weights=False,
 | 
			
		||||
                threshold=6.0,
 | 
			
		||||
                bias=False,
 | 
			
		||||
            )
 | 
			
		||||
            # Copy data to bnb_linear
 | 
			
		||||
            self.bnb_linear.weight.data = self.weight.data
 | 
			
		||||
            if self.bias is not None:
 | 
			
		||||
                self.bnb_linear.bias = nn.Parameter(self.bias)
 | 
			
		||||
 | 
			
		||||
            # Delete reference to data
 | 
			
		||||
            self.weight = None
 | 
			
		||||
            self.bias = None
 | 
			
		||||
        else:
 | 
			
		||||
            self.weight = nn.Parameter(self.weight.T)
 | 
			
		||||
 | 
			
		||||
    def forward(self, input: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
        if self.bias is not None:
 | 
			
		||||
            return torch.addmm(self.bias, input, self.weight)
 | 
			
		||||
        return torch.matmul(input, self.weight)
 | 
			
		||||
        if self.quantized:
 | 
			
		||||
            return self.bnb_linear(input)
 | 
			
		||||
        else:
 | 
			
		||||
            if self.bias is not None:
 | 
			
		||||
                return torch.addmm(self.bias, input, self.weight)
 | 
			
		||||
            return torch.matmul(input, self.weight)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TensorParallelColumnLinear(FastLinear):
 | 
			
		||||
@@ -431,16 +467,16 @@ class FlashSantacoderModel(nn.Module):
 | 
			
		||||
        self.head_size = self.h[0].attn.head_size
 | 
			
		||||
        self.num_heads = self.h[0].attn.num_heads
 | 
			
		||||
 | 
			
		||||
    def post_load_weights(self):
 | 
			
		||||
    def post_load_weights(self, load_in_8bit: bool = False):
 | 
			
		||||
        if self.tp_embeddings:
 | 
			
		||||
            self.wte.add_null_idx()
 | 
			
		||||
            self.wpe.add_null_idx()
 | 
			
		||||
        for layer in self.h:
 | 
			
		||||
            layer: Block
 | 
			
		||||
            layer.attn.c_attn.transpose_weight()
 | 
			
		||||
            layer.attn.c_proj.transpose_weight()
 | 
			
		||||
            layer.mlp.c_fc.transpose_weight()
 | 
			
		||||
            layer.mlp.c_proj.transpose_weight()
 | 
			
		||||
            layer.attn.c_attn.prepare_weights(load_in_8bit)
 | 
			
		||||
            layer.attn.c_proj.prepare_weights(load_in_8bit)
 | 
			
		||||
            layer.mlp.c_fc.prepare_weights(load_in_8bit)
 | 
			
		||||
            layer.mlp.c_proj.prepare_weights(load_in_8bit)
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
@@ -508,9 +544,9 @@ class FlashSantacoderForCausalLM(nn.Module):
 | 
			
		||||
        else:
 | 
			
		||||
            self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
 | 
			
		||||
 | 
			
		||||
    def post_load_weights(self):
 | 
			
		||||
        self.transformer.post_load_weights()
 | 
			
		||||
        self.lm_head.transpose_weight()
 | 
			
		||||
    def post_load_weights(self, load_in_8bit: bool = False):
 | 
			
		||||
        self.transformer.post_load_weights(load_in_8bit)
 | 
			
		||||
        self.lm_head.prepare_weights()
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
 
 | 
			
		||||
@@ -221,9 +221,6 @@ class FlashCausalLM(Model):
 | 
			
		||||
        else:
 | 
			
		||||
            raise NotImplementedError("FlashCausalLM is only available on GPU")
 | 
			
		||||
 | 
			
		||||
        if quantize:
 | 
			
		||||
            raise NotImplementedError("FlashCausalLM does not support quantization")
 | 
			
		||||
 | 
			
		||||
        tokenizer = AutoTokenizer.from_pretrained(
 | 
			
		||||
            model_id, revision=revision, padding_side="left", truncation_side="left"
 | 
			
		||||
        )
 | 
			
		||||
@@ -232,9 +229,10 @@ class FlashCausalLM(Model):
 | 
			
		||||
                model_id,
 | 
			
		||||
                revision=revision,
 | 
			
		||||
                torch_dtype=dtype,
 | 
			
		||||
                load_in_8bit=quantize,
 | 
			
		||||
            )
 | 
			
		||||
            .eval()
 | 
			
		||||
            .cuda()
 | 
			
		||||
            .to(device)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        super(FlashCausalLM, self).__init__(
 | 
			
		||||
 
 | 
			
		||||
@@ -35,9 +35,6 @@ class FlashLlama(FlashCausalLM):
 | 
			
		||||
        else:
 | 
			
		||||
            raise NotImplementedError("FlashLlama is only available on GPU")
 | 
			
		||||
 | 
			
		||||
        if quantize:
 | 
			
		||||
            raise NotImplementedError("FlashLlama does not support quantization")
 | 
			
		||||
 | 
			
		||||
        tokenizer = LlamaTokenizer.from_pretrained(
 | 
			
		||||
            model_id,
 | 
			
		||||
            revision=revision,
 | 
			
		||||
@@ -61,8 +58,8 @@ class FlashLlama(FlashCausalLM):
 | 
			
		||||
        with init_empty_weights():
 | 
			
		||||
            model = FlashLlamaForCausalLM(config)
 | 
			
		||||
 | 
			
		||||
        self.load_weights(model, filenames, device, dtype)
 | 
			
		||||
        self.model = model.eval()
 | 
			
		||||
        self.load_weights(model, filenames, quantize, device, dtype)
 | 
			
		||||
        self.model = model.eval().to(device)
 | 
			
		||||
 | 
			
		||||
        super(FlashCausalLM, self).__init__(
 | 
			
		||||
            tokenizer=tokenizer,
 | 
			
		||||
@@ -73,13 +70,14 @@ class FlashLlama(FlashCausalLM):
 | 
			
		||||
    def load_weights(
 | 
			
		||||
        model,
 | 
			
		||||
        filenames: List[Path],
 | 
			
		||||
        quantize: bool,
 | 
			
		||||
        device: torch.device,
 | 
			
		||||
        dtype: torch.dtype,
 | 
			
		||||
    ):
 | 
			
		||||
        for filename in filenames:
 | 
			
		||||
            state_dict = torch.load(filename, map_location="cpu")
 | 
			
		||||
            for key, value in state_dict.items():
 | 
			
		||||
                value = value.to(device).to(dtype)
 | 
			
		||||
                value = value.to(device if not quantize else "cpu").to(dtype)
 | 
			
		||||
 | 
			
		||||
                layer_name = ".".join(key.split(".")[:4])
 | 
			
		||||
 | 
			
		||||
@@ -139,7 +137,7 @@ class FlashLlama(FlashCausalLM):
 | 
			
		||||
                del value
 | 
			
		||||
 | 
			
		||||
        torch.cuda.empty_cache()
 | 
			
		||||
        model.post_load_weights()
 | 
			
		||||
        model.post_load_weights(quantize)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FlashLlamaSharded(FlashLlama):
 | 
			
		||||
@@ -154,9 +152,6 @@ class FlashLlamaSharded(FlashLlama):
 | 
			
		||||
        else:
 | 
			
		||||
            raise NotImplementedError("FlashLlama is only available on GPU")
 | 
			
		||||
 | 
			
		||||
        if quantize:
 | 
			
		||||
            raise NotImplementedError("FlashLlama does not support quantization")
 | 
			
		||||
 | 
			
		||||
        tokenizer = LlamaTokenizer.from_pretrained(
 | 
			
		||||
            model_id,
 | 
			
		||||
            revision=revision,
 | 
			
		||||
@@ -185,7 +180,7 @@ class FlashLlamaSharded(FlashLlama):
 | 
			
		||||
            rank=self.rank,
 | 
			
		||||
            world_size=self.world_size,
 | 
			
		||||
        )
 | 
			
		||||
        self.model = model.eval()
 | 
			
		||||
        self.model = model.eval().to(device)
 | 
			
		||||
        torch.distributed.barrier(group=self.process_group)
 | 
			
		||||
        super(FlashCausalLM, self).__init__(
 | 
			
		||||
            tokenizer=tokenizer,
 | 
			
		||||
@@ -300,4 +295,4 @@ class FlashLlamaSharded(FlashLlama):
 | 
			
		||||
                    else:
 | 
			
		||||
                        module._buffers[param_name] = tensor
 | 
			
		||||
        torch.cuda.empty_cache()
 | 
			
		||||
        model.post_load_weights()
 | 
			
		||||
        model.post_load_weights(quantize)
 | 
			
		||||
 
 | 
			
		||||
@@ -41,9 +41,6 @@ class FlashNeoXSharded(FlashNeoX):
 | 
			
		||||
        else:
 | 
			
		||||
            raise NotImplementedError("FlashNeoX is only available on GPU")
 | 
			
		||||
 | 
			
		||||
        if quantize:
 | 
			
		||||
            raise NotImplementedError("FlashNeoX does not support quantization")
 | 
			
		||||
 | 
			
		||||
        tokenizer = AutoTokenizer.from_pretrained(
 | 
			
		||||
            model_id, revision=revision, padding_side="left", truncation_side="left"
 | 
			
		||||
        )
 | 
			
		||||
@@ -63,13 +60,13 @@ class FlashNeoXSharded(FlashNeoX):
 | 
			
		||||
        self.load_weights(
 | 
			
		||||
            model,
 | 
			
		||||
            filenames,
 | 
			
		||||
            quantize=quantize,
 | 
			
		||||
            device=device,
 | 
			
		||||
            dtype=dtype,
 | 
			
		||||
            rank=self.rank,
 | 
			
		||||
            world_size=self.world_size,
 | 
			
		||||
        )
 | 
			
		||||
        model.post_load_weights()
 | 
			
		||||
        self.model = model.eval()
 | 
			
		||||
        self.model = model.eval().to(device)
 | 
			
		||||
        torch.distributed.barrier(group=self.process_group)
 | 
			
		||||
        super(FlashCausalLM, self).__init__(
 | 
			
		||||
            tokenizer=tokenizer,
 | 
			
		||||
@@ -80,6 +77,7 @@ class FlashNeoXSharded(FlashNeoX):
 | 
			
		||||
    def load_weights(
 | 
			
		||||
        model,
 | 
			
		||||
        filenames: List[str],
 | 
			
		||||
        quantize: bool,
 | 
			
		||||
        device: torch.device,
 | 
			
		||||
        dtype: torch.dtype,
 | 
			
		||||
        rank: int,
 | 
			
		||||
@@ -87,7 +85,9 @@ class FlashNeoXSharded(FlashNeoX):
 | 
			
		||||
    ):
 | 
			
		||||
        parameters = dict(model.named_parameters())
 | 
			
		||||
        for file in filenames:
 | 
			
		||||
            with safe_open(file, framework="pt", device=str(device)) as f:
 | 
			
		||||
            with safe_open(
 | 
			
		||||
                file, framework="pt", device=str(device) if not quantize else "cpu"
 | 
			
		||||
            ) as f:
 | 
			
		||||
                for name in f.keys():
 | 
			
		||||
                    module_name, param_name = name.rsplit(".", 1)
 | 
			
		||||
                    module = model.get_submodule(module_name)
 | 
			
		||||
@@ -146,3 +146,4 @@ class FlashNeoXSharded(FlashNeoX):
 | 
			
		||||
                        module._parameters[param_name] = tensor
 | 
			
		||||
                    else:
 | 
			
		||||
                        module._buffers[param_name] = tensor
 | 
			
		||||
        model.post_load_weights(quantize)
 | 
			
		||||
 
 | 
			
		||||
@@ -34,9 +34,6 @@ class FlashSantacoder(FlashCausalLM):
 | 
			
		||||
        else:
 | 
			
		||||
            raise NotImplementedError("FlashSantacoder is only available on GPU")
 | 
			
		||||
 | 
			
		||||
        if quantize:
 | 
			
		||||
            raise NotImplementedError("FlashSantacoder does not support quantization")
 | 
			
		||||
 | 
			
		||||
        tokenizer = AutoTokenizer.from_pretrained(
 | 
			
		||||
            model_id, revision=revision, padding_side="left", truncation_side="left"
 | 
			
		||||
        )
 | 
			
		||||
@@ -58,9 +55,14 @@ class FlashSantacoder(FlashCausalLM):
 | 
			
		||||
            model = FlashSantacoderForCausalLM(config)
 | 
			
		||||
 | 
			
		||||
        self.load_weights(
 | 
			
		||||
            model, filenames, device, dtype, config.architectures[0].startswith("GPT2")
 | 
			
		||||
            model,
 | 
			
		||||
            filenames,
 | 
			
		||||
            quantize,
 | 
			
		||||
            device,
 | 
			
		||||
            dtype,
 | 
			
		||||
            config.architectures[0].startswith("GPT2"),
 | 
			
		||||
        )
 | 
			
		||||
        self.model = model.eval()
 | 
			
		||||
        self.model = model.eval().to(device)
 | 
			
		||||
 | 
			
		||||
        super(FlashCausalLM, self).__init__(
 | 
			
		||||
            tokenizer=tokenizer, device=device, decode_buffer=1
 | 
			
		||||
@@ -70,6 +72,7 @@ class FlashSantacoder(FlashCausalLM):
 | 
			
		||||
    def load_weights(
 | 
			
		||||
        model: FlashSantacoderForCausalLM,
 | 
			
		||||
        filenames: List[Path],
 | 
			
		||||
        quantize: bool,
 | 
			
		||||
        device: torch.device,
 | 
			
		||||
        dtype: torch.dtype,
 | 
			
		||||
        transpose: bool,
 | 
			
		||||
@@ -77,7 +80,7 @@ class FlashSantacoder(FlashCausalLM):
 | 
			
		||||
        for filename in filenames:
 | 
			
		||||
            state_dict = torch.load(filename, map_location="cpu")
 | 
			
		||||
            for key, value in state_dict.items():
 | 
			
		||||
                value = value.to(device).to(dtype)
 | 
			
		||||
                value = value.to(device if not quantize else "cpu").to(dtype)
 | 
			
		||||
 | 
			
		||||
                layer_name = ".".join(key.split(".")[:4])
 | 
			
		||||
 | 
			
		||||
@@ -152,7 +155,7 @@ class FlashSantacoder(FlashCausalLM):
 | 
			
		||||
                del value
 | 
			
		||||
 | 
			
		||||
        torch.cuda.empty_cache()
 | 
			
		||||
        model.post_load_weights()
 | 
			
		||||
        model.post_load_weights(quantize)
 | 
			
		||||
 | 
			
		||||
    def decode(self, generated_ids: List[int]) -> str:
 | 
			
		||||
        # Do not skip special tokens as they are used for custom parsing rules of the generated text
 | 
			
		||||
@@ -173,11 +176,6 @@ class FlashSantacoderSharded(FlashSantacoder):
 | 
			
		||||
        else:
 | 
			
		||||
            raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
 | 
			
		||||
 | 
			
		||||
        if quantize:
 | 
			
		||||
            raise NotImplementedError(
 | 
			
		||||
                "FlashSantacoderSharded does not support quantization"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        tokenizer = AutoTokenizer.from_pretrained(
 | 
			
		||||
            model_id, revision=revision, padding_side="left", truncation_side="left"
 | 
			
		||||
        )
 | 
			
		||||
@@ -197,13 +195,14 @@ class FlashSantacoderSharded(FlashSantacoder):
 | 
			
		||||
        self.load_weights(
 | 
			
		||||
            model,
 | 
			
		||||
            filenames,
 | 
			
		||||
            quantize=quantize,
 | 
			
		||||
            device=device,
 | 
			
		||||
            dtype=dtype,
 | 
			
		||||
            rank=self.rank,
 | 
			
		||||
            world_size=self.world_size,
 | 
			
		||||
            transpose=config.architectures[0].startswith("GPT2"),
 | 
			
		||||
        )
 | 
			
		||||
        self.model = model.eval()
 | 
			
		||||
        self.model = model.eval().to(device)
 | 
			
		||||
        torch.distributed.barrier(group=self.process_group)
 | 
			
		||||
        super(FlashCausalLM, self).__init__(
 | 
			
		||||
            tokenizer=tokenizer,
 | 
			
		||||
@@ -214,6 +213,7 @@ class FlashSantacoderSharded(FlashSantacoder):
 | 
			
		||||
    def load_weights(
 | 
			
		||||
        model,
 | 
			
		||||
        filenames: List[str],
 | 
			
		||||
        quantize: bool,
 | 
			
		||||
        device: torch.device,
 | 
			
		||||
        dtype: torch.dtype,
 | 
			
		||||
        rank: int,
 | 
			
		||||
@@ -221,7 +221,9 @@ class FlashSantacoderSharded(FlashSantacoder):
 | 
			
		||||
        transpose: bool,
 | 
			
		||||
    ):
 | 
			
		||||
        for file in filenames:
 | 
			
		||||
            with safe_open(file, framework="pt", device=str(device)) as f:
 | 
			
		||||
            with safe_open(
 | 
			
		||||
                file, framework="pt", device=str(device) if not quantize else "cpu"
 | 
			
		||||
            ) as f:
 | 
			
		||||
                for key in f.keys():
 | 
			
		||||
                    slice_ = f.get_slice(key)
 | 
			
		||||
 | 
			
		||||
@@ -363,4 +365,4 @@ class FlashSantacoderSharded(FlashSantacoder):
 | 
			
		||||
                    else:
 | 
			
		||||
                        module._buffers[param_name] = tensor
 | 
			
		||||
        torch.cuda.empty_cache()
 | 
			
		||||
        model.post_load_weights()
 | 
			
		||||
        model.post_load_weights(quantize)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user