Skip to content

Commit e2bd725

Browse files
authored
py : fix oai proxy (#3972)
* fix oai proxy fix generation not stoped while bot stop talking in chat mode fix possible `slot_id` not exist response for cors (and pre flight) * oai proxy: workaround for some client (such as Chatbox) * use stop as separator to replace hardcoded `\n`
1 parent 1f5cd83 commit e2bd725

File tree

1 file changed

+25
-21
lines changed

1 file changed

+25
-21
lines changed

examples/server/api_like_OAI.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
slot_id = -1
1212

1313
parser = argparse.ArgumentParser(description="An example of using server.cpp with a similar API to OAI. It must be used together with server.cpp.")
14-
parser.add_argument("--chat-prompt", type=str, help="the top prompt in chat completions(default: 'A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n')", default='A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n')
15-
parser.add_argument("--user-name", type=str, help="USER name in chat completions(default: '\\nUSER: ')", default="\\nUSER: ")
16-
parser.add_argument("--ai-name", type=str, help="ASSISTANT name in chat completions(default: '\\nASSISTANT: ')", default="\\nASSISTANT: ")
17-
parser.add_argument("--system-name", type=str, help="SYSTEM name in chat completions(default: '\\nASSISTANT's RULE: ')", default="\\nASSISTANT's RULE: ")
14+
parser.add_argument("--chat-prompt", type=str, help="the top prompt in chat completions(default: 'A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.')", default='A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.')
15+
parser.add_argument("--user-name", type=str, help="USER name in chat completions(default: 'USER: ')", default="USER: ")
16+
parser.add_argument("--ai-name", type=str, help="ASSISTANT name in chat completions(default: 'ASSISTANT: ')", default="ASSISTANT: ")
17+
parser.add_argument("--system-name", type=str, help="SYSTEM name in chat completions(default: 'ASSISTANT's RULE: ')", default="ASSISTANT's RULE: ")
1818
parser.add_argument("--stop", type=str, help="the end of response in chat completions(default: '</s>')", default="</s>")
1919
parser.add_argument("--llama-api", type=str, help="Set the address of server.cpp in llama.cpp(default: http://127.0.0.1:8080)", default='http://127.0.0.1:8080')
2020
parser.add_argument("--api-key", type=str, help="Set the api key to allow only few user(default: NULL)", default="")
@@ -34,19 +34,19 @@ def is_present(json, key):
3434

3535
#convert chat to prompt
3636
def convert_chat(messages):
37-
prompt = "" + args.chat_prompt.replace("\\n", "\n")
3837

39-
system_n = args.system_name.replace("\\n", "\n")
40-
user_n = args.user_name.replace("\\n", "\n")
41-
ai_n = args.ai_name.replace("\\n", "\n")
42-
stop = args.stop.replace("\\n", "\n")
38+
system_n = args.system_name
39+
user_n = args.user_name
40+
ai_n = args.ai_name
41+
stop = args.stop
4342

43+
prompt = "" + args.chat_prompt + stop
4444

4545
for line in messages:
4646
if (line["role"] == "system"):
47-
prompt += f"{system_n}{line['content']}"
47+
prompt += f"{system_n}{line['content']}{stop}"
4848
if (line["role"] == "user"):
49-
prompt += f"{user_n}{line['content']}"
49+
prompt += f"{user_n}{line['content']}{stop}"
5050
if (line["role"] == "assistant"):
5151
prompt += f"{ai_n}{line['content']}{stop}"
5252
prompt += ai_n.rstrip()
@@ -130,7 +130,7 @@ def make_resData_stream(data, chat=False, time_now = 0, start=False):
130130
}
131131
]
132132
}
133-
slot_id = data["slot_id"]
133+
slot_id = data.get("slot_id")
134134
if (chat):
135135
if (start):
136136
resData["choices"][0]["delta"] = {
@@ -150,11 +150,13 @@ def make_resData_stream(data, chat=False, time_now = 0, start=False):
150150
return resData
151151

152152

153-
@app.route('/chat/completions', methods=['POST'])
154-
@app.route('/v1/chat/completions', methods=['POST'])
153+
@app.route('/chat/completions', methods=['POST', 'OPTIONS'])
154+
@app.route('/v1/chat/completions', methods=['POST', 'OPTIONS'])
155155
def chat_completions():
156156
if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key):
157157
return Response(status=403)
158+
if request.method == 'OPTIONS':
159+
return Response(headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"})
158160
body = request.get_json()
159161
stream = False
160162
tokenize = False
@@ -177,20 +179,22 @@ def generate():
177179
data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True)
178180
time_now = int(time.time())
179181
resData = make_resData_stream({}, chat=True, time_now=time_now, start=True)
180-
yield 'data: {}\n'.format(json.dumps(resData))
182+
yield 'data: {}\n\n'.format(json.dumps(resData))
181183
for line in data.iter_lines():
182184
if line:
183185
decoded_line = line.decode('utf-8')
184186
resData = make_resData_stream(json.loads(decoded_line[6:]), chat=True, time_now=time_now)
185-
yield 'data: {}\n'.format(json.dumps(resData))
186-
return Response(generate(), mimetype='text/event-stream')
187+
yield 'data: {}\n\n'.format(json.dumps(resData))
188+
return Response(generate(), mimetype='text/event-stream', headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"})
187189

188190

189-
@app.route('/completions', methods=['POST'])
190-
@app.route('/v1/completions', methods=['POST'])
191+
@app.route('/completions', methods=['POST', 'OPTIONS'])
192+
@app.route('/v1/completions', methods=['POST', 'OPTIONS'])
191193
def completion():
192194
if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key):
193195
return Response(status=403)
196+
if request.method == 'OPTIONS':
197+
return Response(headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"})
194198
body = request.get_json()
195199
stream = False
196200
tokenize = False
@@ -216,8 +220,8 @@ def generate():
216220
if line:
217221
decoded_line = line.decode('utf-8')
218222
resData = make_resData_stream(json.loads(decoded_line[6:]), chat=False, time_now=time_now)
219-
yield 'data: {}\n'.format(json.dumps(resData))
220-
return Response(generate(), mimetype='text/event-stream')
223+
yield 'data: {}\n\n'.format(json.dumps(resData))
224+
return Response(generate(), mimetype='text/event-stream', headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*"})
221225

222226
if __name__ == '__main__':
223227
app.run(args.host, port=args.port)

0 commit comments

Comments
 (0)