-
环境搭建
-
客户端环境
-
录音模块
pip install pyaudio
-
-
服务器环境
- flask
-
-
客户端
- 录音模块
- 缓冲区
- 发送数据
- 缓冲区数据
- 接收数据
- 识别结果
- 录音模块
-
服务器端
- 接收缓冲区数据
- 调用识别接口
- 传入缓冲区数据
- 返回识别文字
- 发送识别文字给客户端
相关包的安装
pip install pygame playsound librosa wave
客户端
录音模块
获取麦克风数据以及保存音频
实际上这里不需要保存音频,这里只是为了测试录音是否可以正常运行
import pyaudio
import wave
CHUNK = 1024
FORMAT = pyaudio.paInt16
CHANNELS = 1
SAMPALE_RATE = 44100 # 默认是 44100 保真度最高,识别的时候使用16000
RECORD_SECONDS = 4
temp_save_path = "Audio/temp.wav"
p = pyaudio.PyAudio()
# 保存麦克风数据
def save_wav(frames, save_path):
wf = wave.open(save_path, 'wb')
wf.setnchannels(CHANNELS)
wf.setsampwidth(2)
wf.setframerate(SAMPALE_RATE)
wf.writeframes(b''.join(frames))
wf.close()
print('3[93m' + "已录入缓冲区" + '3[0m')
# 获取麦克风数据
def recording(save_path):
p = pyaudio.PyAudio()
stream = p.open(format=FORMAT,
channels=CHANNELS,
rate=SAMPALE_RATE,
input=True,
frames_per_buffer=CHUNK)
print('3[93m' + "recording" + '3[0m')
# 缓冲区大小
frames = []
max_size = 16*4
while 1:
data = stream.read(CHUNK)
#data = int.from_bytes(data, byteorder='big', signed=False)
frames.append(data)
if len(frames) == max_size:
# 保存缓冲区
save_wav(frames, save_path)
# 清空缓冲区
frames = []
break
# 发送到服务器
# result = requests(frames)
# break
# if result == "退出":
# break
recording(temp_save_path)
服务器端
需要配置wenet使用环境,请参考我写的这篇文章
服务端环境搭建文档结构
├── cache # 缓冲区
│ ├── temp.wav
│ ├── temp1.wav
│ └── temp2.wav
├── client # 测试客户端
│ ├── __init__.py
│ └── client.py
├── decoder # 识别模块
│ ├── create_data_list.sh
│ ├── datalist
│ │ ├── temp
│ │ ├── temp1
│ │ └── temp2
│ ├── online_data.list
│ ├── recognize.py
│ └── wenet -> /mnt/f/data/wenet/wenet/
├── model # 模型存放
│ ├── 20210618_u2pp_conformer_exp.tar.gz
│ ├── 20210815_unified_conformer_exp
│ │ ├── final.pt
│ │ ├── global_cmvn
│ │ ├── train.yaml
│ │ └── words.txt
│ └── 20210815_unified_conformer_exp.tar.gz
└── server # 测试服务端
├── __init__.py
└── server.py
pip install flask
from flask import Flask
app = Flask(__name__)
# 加载模型 传入参数
@app.route("/")
def getdata():
# 调用识别
save_wav(data,save_path)
if __name__ =="__main__":
app.run()
接收数据
# 保存麦克风数据
def save_wav(frames, save_path):
wf = wave.open(save_path, 'wb')
wf.setnchannels(CHANNELS)
wf.setsampwidth(2)
wf.setframerate(SAMPALE_RATE)
wf.writeframes(b''.join(frames))
wf.close()
print('3[93m' + "已录入缓冲区" + '3[0m')
生成data_list
wenet识别需要这个文件,内部读取这个文件的数据,格式必须是如下
{"key":"temp","wav":"/home/sunao/data/StreamAIzimu/cache/temp.wav","txt":""}
#!/usr/bin/bash
root=/home/sunao/data/StreamAIzimu
data=${root}/cache/temp.wav
echo "{\"key\":\"temp\",\"wav\":\"${data}\",\"txt\":\"\"}" > online_data.list
修改wenet的识别文件recognize.py改为只加载一次模型,并且取消默认的bash脚本传入参数的方式
recognize.py
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Xiaoyu Chen, Di Wu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import copy
import logging
import os
import sys
import time
import torch
import yaml
from torch.utils.data import DataLoader
from wenet.dataset.dataset import Dataset
from wenet.transformer.asr_model import init_asr_model
from wenet.utils.checkpoint import load_checkpoint
from wenet.utils.file_utils import read_symbol_table, read_non_lang_symbols
from wenet.utils.config import override_config
class recognize():
def __init__(self, ):
self.root_path = os.pardir
self.batch_size = 1
self.beam_size = 10
self.bpe_model = None
self.checkpoint = '../model/20210815_unified_conformer_exp/final.pt'
self.config = '../model/20210815_unified_conformer_exp/train.yaml'
self.ctc_weight = 0.5
self.data_type = 'raw'
self.decoding_chunk_size = -1
self.dict = '../model/20210815_unified_conformer_exp/words.txt'
self.gpu = -1
self.mode = 'attention_rescoring'
self.non_lang_syms = None
self.num_decoding_left_chunks = -1
self.override_config = []
self.penalty = 0.0
self.result_file = 'online_text'
self.reverse_weight = 0.0
self.simulate_streaming = False,
self.test_data = 'online_data.list'
self.use_cuda = self.gpu >= 0 and torch.cuda.is_available()
self.device = torch.device('cuda' if False else 'cpu')
self.load_configs() # 加载配置
self.test_data_conf()
self.loadmodel() # 加载模型
def load_configs(self):
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
os.environ['CUDA_VISIBLE_DEVICES'] = str(self.gpu)
if self.mode in ['ctc_prefix_beam_search', 'attention_rescoring'
] and self.batch_size > 1:
logging.fatal(
'decoding mode {} must be running with batch_size == 1'.format(
self.mode))
sys.exit(1)
with open(self.config, 'r') as fin:
self.configs = yaml.load(fin, Loader=yaml.FullLoader)
if len(self.override_config) > 0:
self.configs = override_config(self.configs, self.override_config)
# 加载词典
self.symbol_table = read_symbol_table(self.dict)
def loadmodel(self):
# Init asr model from configs
model = init_asr_model(self.configs)
# Load dict
self.char_dict = {v: k for k, v in self.symbol_table.items()}
self.eos = len(self.char_dict) - 1
load_checkpoint(model, self.checkpoint)
self.model = model.to(self.device)
self.model.eval()
def test_data_conf(self):
'''
测试数据配置
'''
self.test_conf = copy.deepcopy(self.configs['dataset_conf'])
self.test_conf['filter_conf']['max_length'] = 102400
self.test_conf['filter_conf']['min_length'] = 0
self.test_conf['filter_conf']['token_max_length'] = 102400
self.test_conf['filter_conf']['token_min_length'] = 0
self.test_conf['filter_conf']['max_output_input_ratio'] = 102400
self.test_conf['filter_conf']['min_output_input_ratio'] = 0
self.test_conf['speed_perturb'] = False
self.test_conf['spec_aug'] = False
self.test_conf['shuffle'] = False
self.test_conf['sort'] = False
if 'fbank_conf' in self.test_conf:
self.test_conf['fbank_conf']['dither'] = 0.0
elif 'mfcc_conf' in self.test_conf:
self.test_conf['mfcc_conf']['dither'] = 0.0
self.test_conf['batch_conf']['batch_type'] = "static"
self.test_conf['batch_conf']['batch_size'] = self.batch_size
self.non_lang_syms = read_non_lang_symbols(self.non_lang_syms)
def get_test_data_loader(self):
test_dataset = Dataset(self.data_type,
self.test_data,
self.symbol_table,
self.test_conf,
self.bpe_model,
self.non_lang_syms,
partition=False)
return DataLoader(test_dataset, batch_size=None, num_workers=0)
def get_recognize(self):
test_data_loader = self.get_test_data_loader()
with torch.no_grad():
for batch_idx, batch in enumerate(test_data_loader):
keys, feats, target, feats_lengths, target_lengths = batch
feats = feats.to(self.device)
feats_lengths = feats_lengths.to(self.device)
assert (feats.size(0) == 1)
if self.mode == 'attention':
hyps, _ = self.model.recognize(
feats,
feats_lengths,
beam_size=self.beam_size,
decoding_chunk_size=self.decoding_chunk_size,
num_decoding_left_chunks=self.num_decoding_left_chunks,
simulate_streaming=self.simulate_streaming)
hyps = [hyp.tolist() for hyp in hyps]
elif self.mode == 'ctc_greedy_search':
hyps, _ = self.model.ctc_greedy_search(
feats,
feats_lengths,
decoding_chunk_size=self.decoding_chunk_size,
num_decoding_left_chunks=self.num_decoding_left_chunks,
simulate_streaming=self.simulate_streaming)
# ctc_prefix_beam_search and attention_rescoring only return one
# result in List[int], change it to List[List[int]] for compatible
# with other batch decoding mode
elif self.mode == 'ctc_prefix_beam_search':
assert (feats.size(0) == 1)
hyp, _ = self.model.ctc_prefix_beam_search(
feats,
feats_lengths,
self.beam_size,
decoding_chunk_size=self.decoding_chunk_size,
num_decoding_left_chunks=self.num_decoding_left_chunks,
simulate_streaming=self.simulate_streaming)
hyps = [hyp]
elif self.mode == 'attention_rescoring':
assert (feats.size(0) == 1)
hyp, _ = self.model.attention_rescoring(
feats,
feats_lengths,
self.beam_size,
decoding_chunk_size=self.decoding_chunk_size,
num_decoding_left_chunks=self.num_decoding_left_chunks,
ctc_weight=self.ctc_weight,
simulate_streaming=self.simulate_streaming,
reverse_weight=self.reverse_weight)
hyps = [hyp]
content = ''
for w in hyps[0]:
if w == self.eos:
break
content += self.char_dict[w]
return content
if __name__ == '__main__':
# 加载模型
recog = recognize()
# 实时接收数据
result = recog.get_recognize()
print(result)
# 实时返回识别结果
修改支持指定音频识别
这样就可以支持多路翻译
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Xiaoyu Chen, Di Wu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import copy
import logging
import os
import sys
import time
import torch
import yaml
from torch.utils.data import DataLoader
from wenet.dataset.dataset import Dataset
from wenet.transformer.asr_model import init_asr_model
from wenet.utils.checkpoint import load_checkpoint
from wenet.utils.file_utils import read_symbol_table, read_non_lang_symbols
from wenet.utils.config import override_config
class recognize():
def __init__(self, ):
self.root_path = os.pardir
self.batch_size = 1
self.beam_size = 10
self.bpe_model = None
self.checkpoint = '../model/20210815_unified_conformer_exp/final.pt'
self.config = '../model/20210815_unified_conformer_exp/train.yaml'
self.ctc_weight = 0.5
self.data_type = 'raw'
self.decoding_chunk_size = -1
self.dict = '../model/20210815_unified_conformer_exp/words.txt'
self.gpu = -1
self.mode = 'attention_rescoring'
self.non_lang_syms = None
self.num_decoding_left_chunks = -1
self.override_config = []
self.penalty = 0.0
self.result_file = 'online_text'
self.reverse_weight = 0.0
self.simulate_streaming = False,
self.test_data = 'online_data.list'
self.use_cuda = self.gpu >= 0 and torch.cuda.is_available()
self.device = torch.device('cuda' if False else 'cpu')
self.load_configs() # 加载配置
self.test_data_conf()
self.loadmodel() # 加载模型
def load_configs(self):
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
os.environ['CUDA_VISIBLE_DEVICES'] = str(self.gpu)
if self.mode in ['ctc_prefix_beam_search', 'attention_rescoring'
] and self.batch_size > 1:
logging.fatal(
'decoding mode {} must be running with batch_size == 1'.format(
self.mode))
sys.exit(1)
with open(self.config, 'r') as fin:
self.configs = yaml.load(fin, Loader=yaml.FullLoader)
if len(self.override_config) > 0:
self.configs = override_config(self.configs, self.override_config)
# 加载词典
self.symbol_table = read_symbol_table(self.dict)
def loadmodel(self):
# Init asr model from configs
model = init_asr_model(self.configs)
# Load dict
self.char_dict = {v: k for k, v in self.symbol_table.items()}
self.eos = len(self.char_dict) - 1
load_checkpoint(model, self.checkpoint)
self.model = model.to(self.device)
self.model.eval()
def test_data_conf(self):
'''
测试数据配置
'''
self.test_conf = copy.deepcopy(self.configs['dataset_conf'])
self.test_conf['filter_conf']['max_length'] = 102400
self.test_conf['filter_conf']['min_length'] = 0
self.test_conf['filter_conf']['token_max_length'] = 102400
self.test_conf['filter_conf']['token_min_length'] = 0
self.test_conf['filter_conf']['max_output_input_ratio'] = 102400
self.test_conf['filter_conf']['min_output_input_ratio'] = 0
self.test_conf['speed_perturb'] = False
self.test_conf['spec_aug'] = False
self.test_conf['shuffle'] = False
self.test_conf['sort'] = False
if 'fbank_conf' in self.test_conf:
self.test_conf['fbank_conf']['dither'] = 0.0
elif 'mfcc_conf' in self.test_conf:
self.test_conf['mfcc_conf']['dither'] = 0.0
self.test_conf['batch_conf']['batch_type'] = "static"
self.test_conf['batch_conf']['batch_size'] = self.batch_size
self.non_lang_syms = read_non_lang_symbols(self.non_lang_syms)
def get_test_data_loader(self,path):
self.test_data=path
test_dataset = Dataset(self.data_type,
self.test_data,
self.symbol_table,
self.test_conf,
self.bpe_model,
self.non_lang_syms,
partition=False)
return DataLoader(test_dataset, batch_size=None, num_workers=0)
def create_data_list(self,path):
file_name = path.split("/")[-1].split(".")[0]
filepath = "./datalist/"+file_name
if not os.path.exists(filepath):
with open(filepath,'w',encoding="utf-8") as file:
file.write('{"key":"%s","wav":"/home/sunao/data/StreamAIzimu/cache/%s.wav","txt":""}'%(file_name,file_name))
return filepath
def get_recognize(self , path):
path = self.create_data_list(path)
test_data_loader = self.get_test_data_loader(path)
with torch.no_grad():
for batch_idx, batch in enumerate(test_data_loader):
keys, feats, target, feats_lengths, target_lengths = batch
feats = feats.to(self.device)
feats_lengths = feats_lengths.to(self.device)
assert (feats.size(0) == 1)
if self.mode == 'attention':
hyps, _ = self.model.recognize(
feats,
feats_lengths,
beam_size=self.beam_size,
decoding_chunk_size=self.decoding_chunk_size,
num_decoding_left_chunks=self.num_decoding_left_chunks,
simulate_streaming=self.simulate_streaming)
hyps = [hyp.tolist() for hyp in hyps]
elif self.mode == 'ctc_greedy_search':
hyps, _ = self.model.ctc_greedy_search(
feats,
feats_lengths,
decoding_chunk_size=self.decoding_chunk_size,
num_decoding_left_chunks=self.num_decoding_left_chunks,
simulate_streaming=self.simulate_streaming)
# ctc_prefix_beam_search and attention_rescoring only return one
# result in List[int], change it to List[List[int]] for compatible
# with other batch decoding mode
elif self.mode == 'ctc_prefix_beam_search':
assert (feats.size(0) == 1)
hyp, _ = self.model.ctc_prefix_beam_search(
feats,
feats_lengths,
self.beam_size,
decoding_chunk_size=self.decoding_chunk_size,
num_decoding_left_chunks=self.num_decoding_left_chunks,
simulate_streaming=self.simulate_streaming)
hyps = [hyp]
elif self.mode == 'attention_rescoring':
assert (feats.size(0) == 1)
hyp, _ = self.model.attention_rescoring(
feats,
feats_lengths,
self.beam_size,
decoding_chunk_size=self.decoding_chunk_size,
num_decoding_left_chunks=self.num_decoding_left_chunks,
ctc_weight=self.ctc_weight,
simulate_streaming=self.simulate_streaming,
reverse_weight=self.reverse_weight)
hyps = [hyp]
content = ''
for w in hyps[0]:
if w == self.eos:
break
content += self.char_dict[w]
return content
if __name__ == '__main__':
# 加载模型
recog = recognize()
# 实时接收数据
result1 = recog.get_recognize("../cache/temp.wav")
result2 = recog.get_recognize("../cache/temp1.wav")
result3 = recog.get_recognize("../cache/temp2.wav")
print(result1)
print(result2)
print(result3)
# 实时返回识别结果
返回客户识别结果
开启服务的时候加载模型,
当有数据传入的调用识别接口
from flask import Flask
from recognize import recognize
app = Flask(__name__)
# 加载模型 传入参数
model = recognize()
@app.route("/")
def run_recognize():
# 调用识别
result = model.get_recognize()
return result
if __name__ =="__main__":
app.run()
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)