loocorez commited on
Commit
3d31b21
·
verified ·
1 Parent(s): c29ec17

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +53 -18
app.py CHANGED
@@ -8,7 +8,13 @@ from tokenizers import Tokenizer as HFTokenizer
8
  from gpt_infer import GPT, GPTConfig
9
 
10
  DEFAULT_MODEL = os.environ.get('NANOCHAT_DEFAULT_MODEL', 'loocorez/nanochat-sft-d20-step650')
 
 
 
 
 
11
 
 
12
  def load_model(repo_id: str):
13
  cfg_path = hf_hub_download(repo_id, 'config.json')
14
  with open(cfg_path, 'r') as f:
@@ -43,10 +49,14 @@ def get_model(repo_id: str):
43
  return model_cache[repo_id]
44
 
45
  @torch.inference_mode()
46
- def generate(repo_id: str, prompt: str, max_tokens: int, temperature: float, top_k: int|None):
47
  model, tok = get_model(repo_id)
48
  bos_id = tok.token_to_id('<|bos|>')
49
- ids = tok.encode(prompt).ids
 
 
 
 
50
  if bos_id is not None:
51
  ids = [bos_id] + ids
52
  out_tokens = []
@@ -55,24 +65,49 @@ def generate(repo_id: str, prompt: str, max_tokens: int, temperature: float, top
55
  text = tok.decode(out_tokens, skip_special_tokens=False)
56
  return text
57
 
 
 
 
 
 
 
 
58
  with gr.Blocks() as demo:
59
  gr.Markdown('# nanochat (ZeroGPU)')
60
- gr.Markdown('Select a model and generate text.')
61
- repo = gr.Dropdown(choices=[
62
- 'loocorez/nanochat-sft-d20-step650',
63
- 'loocorez/nanochat-mid-d20-step765',
64
- 'loocorez/nanochat-base-d20-step21400',
65
- ], value=DEFAULT_MODEL, label='Model Repo')
66
- prompt = gr.Textbox(label='Prompt', lines=6)
67
- with gr.Row():
68
- max_tokens = gr.Slider(1, 256, value=128, step=1, label='Max tokens')
69
- temperature = gr.Slider(0.0, 1.5, value=0.8, step=0.05, label='Temperature')
70
- top_k = gr.Slider(0, 100, value=40, step=1, label='Top-k (0=disabled)')
71
- btn = gr.Button('Generate')
72
- output = gr.Textbox(label='Output', lines=10)
73
- btn.click(fn=lambda r,p,m,t,k: generate(r,p,int(m),float(t),int(k) if int(k)>0 else None),
74
- inputs=[repo, prompt, max_tokens, temperature, top_k],
75
- outputs=output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  if __name__ == '__main__':
78
  demo.launch()
 
8
  from gpt_infer import GPT, GPTConfig
9
 
10
  DEFAULT_MODEL = os.environ.get('NANOCHAT_DEFAULT_MODEL', 'loocorez/nanochat-sft-d20-step650')
11
+ ALL_MODELS = [
12
+ 'loocorez/nanochat-sft-d20-step650',
13
+ 'loocorez/nanochat-mid-d20-step765',
14
+ 'loocorez/nanochat-base-d20-step21400',
15
+ ]
16
 
17
+ @torch.inference_mode()
18
  def load_model(repo_id: str):
19
  cfg_path = hf_hub_download(repo_id, 'config.json')
20
  with open(cfg_path, 'r') as f:
 
49
  return model_cache[repo_id]
50
 
51
  @torch.inference_mode()
52
+ def generate(repo_id: str, system_prompt: str, prompt: str, max_tokens: int, temperature: float, top_k: int|None):
53
  model, tok = get_model(repo_id)
54
  bos_id = tok.token_to_id('<|bos|>')
55
+ # Combine system + user prompt
56
+ text = prompt if not system_prompt else f"{system_prompt.strip()}
57
+
58
+ {prompt}"
59
+ ids = tok.encode(text).ids
60
  if bos_id is not None:
61
  ids = [bos_id] + ids
62
  out_tokens = []
 
65
  text = tok.decode(out_tokens, skip_special_tokens=False)
66
  return text
67
 
68
+ @torch.inference_mode()
69
+ def compare_three(system_prompt: str, prompt: str, max_tokens: int, temperature: float, top_k: int|None):
70
+ outputs = []
71
+ for repo_id in ALL_MODELS:
72
+ outputs.append(generate(repo_id, system_prompt, prompt, max_tokens, temperature, top_k))
73
+ return tuple(outputs)
74
+
75
  with gr.Blocks() as demo:
76
  gr.Markdown('# nanochat (ZeroGPU)')
77
+ gr.Markdown('Run a single model or compare SFT/MID/BASE side by side.')
78
+ with gr.Tabs():
79
+ with gr.Tab('Single'):
80
+ repo = gr.Dropdown(choices=ALL_MODELS, value=DEFAULT_MODEL, label='Model Repo')
81
+ system = gr.Textbox(label='System prompt (optional)', value='You are a helpful assistant.', lines=2)
82
+ prompt = gr.Textbox(label='User prompt', lines=6)
83
+ with gr.Row():
84
+ max_tokens = gr.Slider(1, 256, value=128, step=1, label='Max tokens')
85
+ temperature = gr.Slider(0.0, 1.5, value=0.8, step=0.05, label='Temperature')
86
+ top_k = gr.Slider(0, 100, value=40, step=1, label='Top-k (0=disabled)')
87
+ btn = gr.Button('Generate')
88
+ output = gr.Textbox(label='Output', lines=10)
89
+ btn.click(
90
+ fn=lambda r,s,p,m,t,k: generate(r,s,p,int(m),float(t),int(k) if int(k)>0 else None),
91
+ inputs=[repo, system, prompt, max_tokens, temperature, top_k],
92
+ outputs=output
93
+ )
94
+ with gr.Tab('Compare 3'):
95
+ system_c = gr.Textbox(label='System prompt (optional)', value='You are a helpful assistant.', lines=2)
96
+ prompt_c = gr.Textbox(label='User prompt', lines=6)
97
+ with gr.Row():
98
+ max_tokens_c = gr.Slider(1, 256, value=128, step=1, label='Max tokens')
99
+ temperature_c = gr.Slider(0.0, 1.5, value=0.8, step=0.05, label='Temperature')
100
+ top_k_c = gr.Slider(0, 100, value=40, step=1, label='Top-k (0=disabled)')
101
+ btn_c = gr.Button('Run on all three')
102
+ with gr.Row():
103
+ out_sft = gr.Textbox(label='SFT', lines=10)
104
+ out_mid = gr.Textbox(label='MID', lines=10)
105
+ out_base = gr.Textbox(label='BASE', lines=10)
106
+ btn_c.click(
107
+ fn=lambda s,p,m,t,k: compare_three(s,p,int(m),float(t),int(k) if int(k)>0 else None),
108
+ inputs=[system_c, prompt_c, max_tokens_c, temperature_c, top_k_c],
109
+ outputs=[out_sft, out_mid, out_base]
110
+ )
111
 
112
  if __name__ == '__main__':
113
  demo.launch()