Spaces:
Sleeping
Sleeping
fix model name
Browse files- app.py +1 -1
- model.py +1 -1
- pipline.py +2 -2
app.py
CHANGED
|
@@ -22,7 +22,7 @@ def __run_pipline():
|
|
| 22 |
def __run_model():
|
| 23 |
st.text(f"input_text: {state.input_text}")
|
| 24 |
st.markdown(":green[Running model]")
|
| 25 |
-
st.text(model.
|
| 26 |
|
| 27 |
|
| 28 |
st.text_area("input_text", key="input_text")
|
|
|
|
| 22 |
def __run_model():
|
| 23 |
st.text(f"input_text: {state.input_text}")
|
| 24 |
st.markdown(":green[Running model]")
|
| 25 |
+
st.text(model.run(state.input_text))
|
| 26 |
|
| 27 |
|
| 28 |
st.text_area("input_text", key="input_text")
|
model.py
CHANGED
|
@@ -12,7 +12,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
| 12 |
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto").to(device=device)
|
| 13 |
|
| 14 |
|
| 15 |
-
def
|
| 16 |
inputs = tokenizer.encode(text=text, return_tensors="pt").to(device=device)
|
| 17 |
outputs = model.generate(inputs,**kargs)
|
| 18 |
return tokenizer.decode(outputs[0])
|
|
|
|
| 12 |
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto").to(device=device)
|
| 13 |
|
| 14 |
|
| 15 |
+
def run(text,**kargs):
|
| 16 |
inputs = tokenizer.encode(text=text, return_tensors="pt").to(device=device)
|
| 17 |
outputs = model.generate(inputs,**kargs)
|
| 18 |
return tokenizer.decode(outputs[0])
|
pipline.py
CHANGED
|
@@ -2,7 +2,7 @@ import langchain as lc
|
|
| 2 |
from langchain import PromptTemplate
|
| 3 |
from langchain.prompts import load_prompt
|
| 4 |
import wikipedia
|
| 5 |
-
|
| 6 |
|
| 7 |
# save templates to a file
|
| 8 |
|
|
@@ -24,7 +24,7 @@ def pipeline(text, word):
|
|
| 24 |
model_output = ""
|
| 25 |
input_text = prompt.format(adjective="funny", content=text)
|
| 26 |
while word not in model_output:
|
| 27 |
-
model_output = model(input_text)
|
| 28 |
wikipedia_entry = wikipedia.search(word)[1]
|
| 29 |
wiki = wikipedia.summary(wikipedia_entry, auto_suggest=False, redirect=True)
|
| 30 |
input_text += model_output + wiki
|
|
|
|
| 2 |
from langchain import PromptTemplate
|
| 3 |
from langchain.prompts import load_prompt
|
| 4 |
import wikipedia
|
| 5 |
+
import model
|
| 6 |
|
| 7 |
# save templates to a file
|
| 8 |
|
|
|
|
| 24 |
model_output = ""
|
| 25 |
input_text = prompt.format(adjective="funny", content=text)
|
| 26 |
while word not in model_output:
|
| 27 |
+
model_output = model.run(input_text)
|
| 28 |
wikipedia_entry = wikipedia.search(word)[1]
|
| 29 |
wiki = wikipedia.summary(wikipedia_entry, auto_suggest=False, redirect=True)
|
| 30 |
input_text += model_output + wiki
|