在run_classifier.py文件中添加函数serving_input_fn def serving_input_fn(): # 保存模型为SaveModel格式 # 采用最原始的feature方式,输入是feature Tensors。 # 如果采用build_parsing_serving_input_receiver_fn,则输入是tf.Examples label_ids = tf.placeholder(tf.int32, [None], name='label_ids') # 要素识别任务有2个类别 input_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_ids') input_mask = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_mask') segment_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='segment_ids') input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({ 'label_ids': label_ids, 'input_ids': input_ids, 'input_mask': input_mask, 'segment_ids': segment_ids, })() return input_fn 在main函数中修改验证代码 if FLAGS.do_eval: estimator._export_to_tpu = False estimator.export_savedmodel(FLAGS.trans_model_dir, serving_input_fn)2. docker部署bert模型
- 查看centos的版本
cat /etc/redhat-release
- 安装docker
yum install docker
- 启动docker
service docker start
- 查看docker安装的镜像
docker images
- 安装tensorflow-serving镜像
docker pull tensorflow/serving
- 在docker上启动镜像生成容器
docker run -t --rm -p 5000:8501 -v /opt/fuzzy_model/output:/models/bert-model -e MODEL_NAME=bert-model tensorflow/serving &
目录结构如下:
# -*- coding: utf-8 -*- """ Created on Fri Mar 6 14:41:13 2020 @author: """ import json import os import requests from fuzzy_query import tokenization import numpy as np def get_work_dir(): return os.path.dirname( os.path.dirname( os.path.dirname( os.path.dirname(__file__) ) ) ) vocal_file_path = 'vocab.txt' # vocal_file_path = os.path.join(get_work_dir(), 'var', 'vocab.txt') tokenizer = tokenization.FullTokenizer(vocab_file=vocal_file_path, do_lower_case=True) max_seq_length = 64 def text2ids(text_list): input_ids_list = [] input_mask_list = [] label_ids_list = [] segment_ids_list = [] for text in text_list: label = 0 tokens_a = tokenizer.tokenize(text) if len(tokens_a) > max_seq_length - 2: tokens_a = tokens_a[0:(max_seq_length - 2)] tokens = [] segment_ids = [] tokens.append("[CLS]") segment_ids.append(0) for token in tokens_a: tokens.append(token) segment_ids.append(0) tokens.append("[SEP]") segment_ids.append(0) input_ids = tokenizer.convert_tokens_to_ids(tokens) input_mask = [1] * len(input_ids) while len(input_ids) < max_seq_length: input_ids.append(0) input_mask.append(0) segment_ids.append(0) input_ids_list.append(input_ids) input_mask_list.append(input_mask) label_ids_list.append(label) segment_ids_list.append(segment_ids) return input_ids_list, input_mask_list, label_ids_list, segment_ids_list def bert_class(textList): """ 调用tfserving服务的接口,对外提供服务 :parma textList: 输入的文本列表 :return result: 结果 """ input_ids_list, input_mask_list, label_ids_list, segment_ids_list = text2ids(textList) data = json.dumps( { "name": 'bert', "signature_name": 'serving_default', "inputs": { 'input_ids': input_ids_list, 'input_mask': input_mask_list, 'label_ids': label_ids_list, 'segment_ids': segment_ids_list }}) headers = {"content-type": "application/json"} url = 'http://10.30.239.205:5000/v1/models/bert-model:predict' json_response = requests.post(url, data=data, headers=headers) predictions = json.loads(json_response.text)['outputs'] pre_list = [np.argmax(result) for result in predictions] print(pre_list) return pre_list if __name__ == "__main__": name_list = ['饰品'] bert_class(name_list)
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)