From 7a7472162568f6a7b844744a47a77c6f8de740ed Mon Sep 17 00:00:00 2001 From: Yasyf Mohamedali Date: Sun, 24 Mar 2024 01:54:58 -0700 Subject: [PATCH] Prevent duplicate `torch_dtype` kwargs (#115) Throws an error currently if you specify a `torch_dtype` in the `model_config`. ```pycon TypeError: transformers.models.auto.auto_factory._BaseAutoModelClass.from_pretrained() got multiple values for keyword argument 'torch_dtype' ``` --- llmlingua/prompt_compressor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llmlingua/prompt_compressor.py b/llmlingua/prompt_compressor.py index a26a569..9625065 100644 --- a/llmlingua/prompt_compressor.py +++ b/llmlingua/prompt_compressor.py @@ -138,7 +138,7 @@ def load_model( if "cuda" in device_map or "cpu" in device_map: model = MODEL_CLASS.from_pretrained( model_name, - torch_dtype=model_config.get( + torch_dtype=model_config.pop( "torch_dtype", "auto" if device_map == "cuda" else torch.float32 ), device_map=device_map, @@ -150,7 +150,7 @@ def load_model( model = MODEL_CLASS.from_pretrained( model_name, device_map=device_map, - torch_dtype=model_config.get("torch_dtype", "auto"), + torch_dtype=model_config.pop("torch_dtype", "auto"), pad_token_id=tokenizer.pad_token_id, **model_config, )