From 76958f6ef901f8df985d37cc478d6d094dea67ae Mon Sep 17 00:00:00 2001 From: TimVan Date: Mon, 4 Sep 2023 10:05:43 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8F=AF=E4=BB=A5=E8=87=AA=E5=8A=A8=E6=9B=B4?= =?UTF-8?q?=E6=96=B0=E5=88=B0=E6=9C=80=E6=96=B0=E7=9A=84=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 1 + app/app.py | 58 ++++++++++++++++++++++++++++++++++++++++++++++---- app/app_win.py | 54 +++++++++++++++++++++++++++++++++++++++++++--- 3 files changed, 106 insertions(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index cd315bc..15ffce2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ # 自定义 test_*.py app/test/output +sha.txt # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/app/app.py b/app/app.py index 36c9b9a..8d216b9 100644 --- a/app/app.py +++ b/app/app.py @@ -4,6 +4,7 @@ import torch from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM import deepspeed +import requests app = Flask(__name__) app.config.from_pyfile('config.py') @@ -28,10 +29,58 @@ torch.set_num_threads(app.config["NUM_THREADS"]) -tokenizer = AutoTokenizer.from_pretrained(model_name) -config = AutoConfig.from_pretrained(model_name) -orgin_model = AutoModelForCausalLM.from_pretrained(model_name) +# 读取或创建本地SHA文件 +def read_or_create_sha_file(): + if os.path.exists('sha.txt'): + with open('sha.txt', 'r') as f: + return f.read().strip() + else: + return None + +# 写入新的SHA到本地文件 + + +def write_sha_to_file(new_sha): + with open('sha.txt', 'w') as f: + f.write(new_sha) + +# 获得最新的SHA + + +def get_latest_sha(model_name): + response = requests.get(f"https://huggingface.co/api/models/{model_name}") + if response.status_code == 200: + remote_sha = response.json().get("sha") + return remote_sha.strip() + else: + print(f"Failed to check for updates: {response.content}") + return None + + +# 读取本地SHA +local_sha = read_or_create_sha_file() +# 获取远程SHA +remote_sha = get_latest_sha(model_name) + +# 是否需要更新 +should_update = local_sha is None or local_sha != remote_sha + +# 如果需要更新或者是第一次运行,则下载模型 +if should_update: + print("Downloading model..., this might take a while") + print(f"- Model name: {model_name}") + +orgin_model = AutoModelForCausalLM.from_pretrained( + model_name, force_download=should_update) +tokenizer = AutoTokenizer.from_pretrained( + model_name, force_download=should_update) + +# 更新本地SHA文件 +if remote_sha: + write_sha_to_file(remote_sha) + + model = deepspeed.init_inference( model=orgin_model, # Transformers模型 mp_size=1, # GPU数量 @@ -153,4 +202,5 @@ def immersive_translation(): # 启动命令:deepspeed --num_gpus 1 app.py port = app.config["DEFAULT_PORT"] - app.run(host="0.0.0.0", port=port, debug=False, use_reloader=False, threaded=False) + app.run(host="0.0.0.0", port=port, debug=False, + use_reloader=False, threaded=False) diff --git a/app/app_win.py b/app/app_win.py index 114a265..02d73b6 100644 --- a/app/app_win.py +++ b/app/app_win.py @@ -3,6 +3,7 @@ import os import torch from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM +import requests app = Flask(__name__) app.config.from_pyfile('config.py') @@ -15,9 +16,55 @@ os.environ["CUDA_VISIBLE_DEVICES"] = app.config["CUDA_VISIBLE_DEVICES"] torch.set_num_threads(app.config["NUM_THREADS"]) -tokenizer = AutoTokenizer.from_pretrained(model_name) + +# 读取或创建本地SHA文件 +def read_or_create_sha_file(): + if os.path.exists('sha.txt'): + with open('sha.txt', 'r') as f: + return f.read().strip() + else: + return None + +# 写入新的SHA到本地文件 + + +def write_sha_to_file(new_sha): + with open('sha.txt', 'w') as f: + f.write(new_sha) + +# 获得最新的SHA +def get_latest_sha(model_name): + response = requests.get(f"https://huggingface.co/api/models/{model_name}") + if response.status_code == 200: + remote_sha = response.json().get("sha") + return remote_sha.strip() + else: + print(f"Failed to check for updates: {response.content}") + return None + + +# 读取本地SHA +local_sha = read_or_create_sha_file() +# 获取远程SHA +remote_sha = get_latest_sha(model_name) + +# 是否需要更新 +should_update = local_sha is None or local_sha != remote_sha + +# 如果需要更新或者是第一次运行,则下载模型 +if should_update: + print("Downloading model..., this might take a while") + print(f"- Model name: {model_name}") + model = AutoModelForCausalLM.from_pretrained( - model_name, torch_dtype=torch.bfloat16) + model_name, force_download=should_update) +tokenizer = AutoTokenizer.from_pretrained( + model_name, force_download=should_update) + +# 更新本地SHA文件 +if remote_sha: + write_sha_to_file(remote_sha) + model.eval() model.cuda() # Ensure the model uses GPU @@ -135,4 +182,5 @@ def immersive_translation(): # 启动命令:python app_win.py port = app.config["DEFAULT_PORT"] - app.run(host="0.0.0.0", port=port, debug=False, use_reloader=False, threaded=False) + app.run(host="0.0.0.0", port=port, debug=False, + use_reloader=False, threaded=False)