This repository has been archived by the owner on Feb 15, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathzmq_server_w2v2.py
91 lines (71 loc) · 2.65 KB
/
zmq_server_w2v2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
import base64
import time
import zmq
import torch
import torchaudio
import logging
import warnings
from loguru import logger
from transformers import Wav2Vec2ProcessorWithLM, Wav2Vec2Processor, Wav2Vec2ForCTC
from utils import InterceptHandler, MemoryTempfile
# Configure logging
logging.basicConfig(handlers=[InterceptHandler()], level=logging.INFO, force=True)
# Turn off warning messages
warnings.simplefilter('ignore')
# Configure ZeroMQ
context = zmq.Context()
socket = context.socket(zmq.REP)
socket.bind("tcp://0.0.0.0:5555")
# Load the model
USE_LM = os.getenv('USE_LM', default='no') == 'yes'
WAV2VEC2_MODEL = os.getenv('WAV2VEC2_MODEL', default='Yehor/wav2vec2-xls-r-300m-uk-with-small-lm')
logger.info('Loading the model: wav2vec2')
ts = time.time()
if USE_LM:
processor = Wav2Vec2ProcessorWithLM.from_pretrained(WAV2VEC2_MODEL, cache_dir='./all-models')
else:
processor = Wav2Vec2Processor.from_pretrained(WAV2VEC2_MODEL, cache_dir='./all-models')
model = Wav2Vec2ForCTC.from_pretrained(WAV2VEC2_MODEL, cache_dir='./all-models')
model.to('cpu')
logger.info(f'Loaded the model: {time.time() - ts} seconds')
logger.info('---')
while True:
try:
# Wait for a next request from the client
message = socket.recv()
logger.info(f"Received a file to recognize with len={len(message)}")
# Convert to bytes
data_bytes = base64.b64decode(message)
text = ''
# Save in memory
tf = MemoryTempfile()
with tf.NamedTemporaryFile('wb') as f:
f.write(data_bytes)
filename = f.name
# Convert the data to the tensor
waveform, sample_rate = torchaudio.load(filename)
speech = waveform.squeeze().numpy()
# Inference
input_values = processor(speech, sampling_rate=16000, return_tensors='pt', padding='longest').input_values
with torch.no_grad():
logits = model(input_values).logits
if USE_LM:
prediction = processor.batch_decode(logits.numpy()).text
text = prediction[0]
else:
predicted_ids = torch.argmax(logits, dim=-1)
prediction = processor.batch_decode(predicted_ids)
text = prediction[0]
logger.info(f"Recognized transcript: {text}")
# Send a reply with the transcript back to the client
reply = text
socket.send(reply.encode('utf-8'))
except KeyboardInterrupt as e:
logger.info('Exiting...')
break
except Exception as e:
logger.error(e)
reply = 'error'
socket.send(reply.encode('utf-8'))
logger.info('---')