Spaces:
Running
Running
change complition model interface
Browse files- app.py +14 -1
- backend/query_llm.py +5 -6
app.py
CHANGED
|
@@ -39,6 +39,14 @@ def bot(history, chunk_table, embedding_model, llm_model, cross_encoder, top_k_p
|
|
| 39 |
top_k_param = int(top_k_param)
|
| 40 |
query = history[-1][0]
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
if not query:
|
| 43 |
raise gr.Warning("Please submit a non-empty string as a prompt")
|
| 44 |
|
|
@@ -48,9 +56,13 @@ def bot(history, chunk_table, embedding_model, llm_model, cross_encoder, top_k_p
|
|
| 48 |
|
| 49 |
#documents = retrieve(query, TOP_K)
|
| 50 |
documents = retrieve(query, top_k_param, chunk_table, embedding_model)
|
|
|
|
|
|
|
| 51 |
if cross_encoder != "None" and len(documents) > 1:
|
| 52 |
documents = rerank_documents(cross_encoder, documents, query, top_k_rerank=rerank_topk)
|
| 53 |
#"cross-encoder/ms-marco-MiniLM-L-6-v2"
|
|
|
|
|
|
|
| 54 |
|
| 55 |
|
| 56 |
|
|
@@ -79,7 +91,8 @@ def bot(history, chunk_table, embedding_model, llm_model, cross_encoder, top_k_p
|
|
| 79 |
# generate_fn = generate_openai
|
| 80 |
#else:
|
| 81 |
# raise gr.Error(f"API {api_kind} is not supported")
|
| 82 |
-
|
|
|
|
| 83 |
history[-1][1] = ""
|
| 84 |
for character in generate_fn(prompt, history[:-1], llm_model):
|
| 85 |
history[-1][1] = character
|
|
|
|
| 39 |
top_k_param = int(top_k_param)
|
| 40 |
query = history[-1][0]
|
| 41 |
|
| 42 |
+
logger.info("bot launched ...")
|
| 43 |
+
logger.info(f"embedding model: {embedding_model}")
|
| 44 |
+
logger.info(f"LLM model: {llm_model}")
|
| 45 |
+
logger.info(f"Cross encoder model: {cross_encoder}")
|
| 46 |
+
logger.info(f"TopK: {top_k_param}")
|
| 47 |
+
logger.info(f"ReRank TopK: {rerank_topk}")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
if not query:
|
| 51 |
raise gr.Warning("Please submit a non-empty string as a prompt")
|
| 52 |
|
|
|
|
| 56 |
|
| 57 |
#documents = retrieve(query, TOP_K)
|
| 58 |
documents = retrieve(query, top_k_param, chunk_table, embedding_model)
|
| 59 |
+
logger.info('Retrived document count:', len(documents))
|
| 60 |
+
|
| 61 |
if cross_encoder != "None" and len(documents) > 1:
|
| 62 |
documents = rerank_documents(cross_encoder, documents, query, top_k_rerank=rerank_topk)
|
| 63 |
#"cross-encoder/ms-marco-MiniLM-L-6-v2"
|
| 64 |
+
logger.info('ReRank done, document count:', len(documents))
|
| 65 |
+
|
| 66 |
|
| 67 |
|
| 68 |
|
|
|
|
| 91 |
# generate_fn = generate_openai
|
| 92 |
#else:
|
| 93 |
# raise gr.Error(f"API {api_kind} is not supported")
|
| 94 |
+
|
| 95 |
+
logger.info(f'Complition started. llm_model: {llm_model}, prompt: {prompt}')
|
| 96 |
history[-1][1] = ""
|
| 97 |
for character in generate_fn(prompt, history[:-1], llm_model):
|
| 98 |
history[-1][1] = character
|
backend/query_llm.py
CHANGED
|
@@ -10,12 +10,12 @@ from transformers import AutoTokenizer
|
|
| 10 |
|
| 11 |
OPENAI_KEY = os.getenv("OPENAI_API_KEY")
|
| 12 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 13 |
-
TOKENIZER = AutoTokenizer.from_pretrained(os.getenv("HF_MODEL"))
|
| 14 |
|
| 15 |
-
HF_CLIENT = InferenceClient(
|
| 16 |
-
os.getenv("HF_MODEL"),
|
| 17 |
-
token=HF_TOKEN
|
| 18 |
-
)
|
| 19 |
OAI_CLIENT = openai.Client(api_key=OPENAI_KEY)
|
| 20 |
|
| 21 |
HF_GENERATE_KWARGS = {
|
|
@@ -115,7 +115,6 @@ def generate_openai(prompt: str, history: str, model_name: str) -> Generator[str
|
|
| 115 |
|
| 116 |
try:
|
| 117 |
stream = OAI_CLIENT.chat.completions.create(
|
| 118 |
-
#model=os.getenv("OPENAI_MODEL"),
|
| 119 |
model = model_name,
|
| 120 |
messages=formatted_prompt,
|
| 121 |
**OAI_GENERATE_KWARGS,
|
|
|
|
| 10 |
|
| 11 |
OPENAI_KEY = os.getenv("OPENAI_API_KEY")
|
| 12 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 13 |
+
#TOKENIZER = AutoTokenizer.from_pretrained(os.getenv("HF_MODEL"))
|
| 14 |
|
| 15 |
+
#HF_CLIENT = InferenceClient(
|
| 16 |
+
# os.getenv("HF_MODEL"),
|
| 17 |
+
# token=HF_TOKEN
|
| 18 |
+
#)
|
| 19 |
OAI_CLIENT = openai.Client(api_key=OPENAI_KEY)
|
| 20 |
|
| 21 |
HF_GENERATE_KWARGS = {
|
|
|
|
| 115 |
|
| 116 |
try:
|
| 117 |
stream = OAI_CLIENT.chat.completions.create(
|
|
|
|
| 118 |
model = model_name,
|
| 119 |
messages=formatted_prompt,
|
| 120 |
**OAI_GENERATE_KWARGS,
|