Commit 4e2e022a authored by nanahira's avatar nanahira

first

parents
/ChatYuan-large-v1
__pycache__
.git*
Dockerfile
.dockerignore
/docker-compose.yml
/README.md
/LICENSE
/ChatYuan-large-v1
__pycache__
FROM nvidia/cuda:11.8.0-runtime-ubuntu22.04
RUN apt update && apt -y install python3-pip python-is-python3 && \
pip3 install -U pip && \
rm -rf /var/lib/apt/lists/* /var/log* /tmp/* /var/tmp/*
WORKDIR /app
COPY ./requirements.txt ./
RUN pip install --no-cache -r requirements.txt
COPY . ./
CMD ["./gunicorn_starter.sh"]
This diff is collapsed.
# chatyuan-api
An API for ChatYuan.
import os
from model import answer
from flask import Flask, request
app = Flask(__name__)
accessToken = os.environ.get("ACCESS_TOKEN")
def makeResponse(statusCode, message, data):
return { 'success': statusCode < 400, 'statusCode': statusCode, 'message': message, 'data': data }, statusCode
@app.route('/', methods = ['POST'])
def conversation():
if accessToken and request.headers.get("Authorization") != "Bearer " + accessToken:
return makeResponse(403, "Access denied", None)
prompt = request.json
if not prompt or 'text' not in prompt:
return makeResponse(400, 'Empty text')
result = answer(prompt['text'])
return makeResponse(200, 'success', { 'text': result })
version: '2.4'
services:
yuan:
build: '.'
runtime: nvidia
volumes:
- ./ChatYuan-large-v1:/app/ChatYuan-large-v1:ro
ports:
- '127.0.0.1:4000:80'
#!/bin/bash
if [ -z "$HOST" ]; then
HOST=0.0.0.0
fi
if [ -z "$PORT" ]; then
PORT=80
fi
gunicorn app:app -w 1 --threads $[$(nproc) * 2 + 1] -b $HOST:$PORT
from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoTokenizer
from test_gpu import deviceName
import torch
print('Loading model.')
tokenizer = T5Tokenizer.from_pretrained("./ChatYuan-large-v1")
model = T5ForConditionalGeneration.from_pretrained("./ChatYuan-large-v1")
device = torch.device(deviceName())
model.to(device)
print('Loaded model.')
def preprocess(text):
text = text.replace("\n", "\\n").replace("\t", "\\t")
return text
def postprocess(text):
return text.replace("\\n", "\n").replace("\\t", "\t")
def answer(text, sample=True, top_p=1, temperature=0.7):
'''sample:是否抽样。生成任务,可以设置为True;
top_p:0-1之间,生成的内容越多样'''
text = preprocess(text)
encoding = tokenizer(text=[text], truncation=True, padding=True, max_length=768, return_tensors="pt").to(device)
if not sample:
out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=512, num_beams=1, length_penalty=0.6)
else:
out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=512, do_sample=True, top_p=top_p, temperature=temperature, no_repeat_ngram_size=3)
out_text = tokenizer.batch_decode(out["sequences"], skip_special_tokens=True)
return postprocess(out_text[0])
--extra-index-url https://download.pytorch.org/whl/cu118
torch
transformers
flask==2.1.2
gunicorn
sentencepiece
import torch
import os
def deviceName():
if torch.cuda.is_available():
return os.getenv('DEVICE', 'cuda')
else:
return 'cpu'
if __name__ == '__main__':
print(deviceName())
from model import answer
print(answer('用户:您好!今天中午吃点什么比较好呢?\n花糖:'))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment