Tarive commited on
Commit
49219c5
·
verified ·
1 Parent(s): db72e03

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -92
app.py CHANGED
@@ -1,107 +1,109 @@
1
  import gradio as gr
2
  import torch
3
  import yaml
4
- import os
 
5
 
6
- def load_model():
7
- """Load the HRM model and config"""
8
- try:
9
- # Load config
10
- with open('all_config.yaml', 'r') as f:
11
- config = yaml.safe_load(f)
12
-
13
- # Load checkpoint
14
- checkpoint = torch.load('pytorch_model.bin', map_location='cpu')
15
-
16
- return config, checkpoint, "✅ Model loaded successfully!"
17
- except Exception as e:
18
- return None, None, f"❌ Error loading model: {str(e)}"
19
 
20
- def test_model_info(config, checkpoint):
21
- """Display model information"""
22
- if config is None or checkpoint is None:
23
- return "Model not loaded"
24
-
25
- info = f"""
26
- **Model Architecture**: {config['arch']['name']}
27
- **Hidden Size**: {config['arch']['hidden_size']}
28
- **H Layers**: {config['arch']['H_layers']}
29
- **L Layers**: {config['arch']['L_layers']}
30
- **Parameters in Checkpoint**: {len(checkpoint)}
31
- **Model Purpose**: Grant Abstract Optimization
32
-
33
- **Training Details**:
34
- - Steps: 492,500 (final checkpoint)
35
- - Batch Size: {config['global_batch_size']}
36
- - Learning Rate: {config['lr']}
37
- """
38
- return info
39
 
40
- def placeholder_inference(draft_abstract, grant_type):
41
- """Placeholder for actual inference (requires full training pipeline)"""
42
- return f"""
43
- **Input Abstract**: {draft_abstract[:100]}...
44
-
45
- **Grant Type**: {grant_type}
46
-
47
- **Status**: Model checkpoint loaded successfully!
48
-
49
- ⚠️ **Note**: Full inference requires the original training pipeline with tokenizer and preprocessing code.
50
- This demo shows that the model weights are accessible and the architecture is properly configured.
51
-
52
- **Next Steps**:
53
- 1. Integrate with original training codebase
54
- 2. Load tokenizer and preprocessing pipeline
55
- 3. Implement full inference function
 
 
 
 
56
  """
 
 
57
 
58
- # Load model on startup
59
- config, checkpoint, load_status = load_model()
 
 
 
 
 
60
 
61
- # Create Gradio interface
62
- with gr.Blocks(title="HRM Grant Abstract Optimizer") as demo:
63
- gr.Markdown("# 🎯 Hierarchical Reasoning Model for Grant Abstract Optimization")
64
- gr.Markdown("A specialized 27M-parameter model for transforming draft grant abstracts into funding-worthy versions.")
65
-
66
- with gr.Tab("Model Info"):
67
- gr.Markdown("## Model Status")
68
- gr.Markdown(load_status)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
- if config is not None:
71
- model_info = test_model_info(config, checkpoint)
72
- gr.Markdown(model_info)
73
-
74
- with gr.Tab("Test Interface"):
75
- gr.Markdown("## Abstract Optimization Demo")
76
- gr.Markdown("*Note: This is a demonstration interface. Full inference requires integration with the training pipeline.*")
77
 
78
- with gr.Row():
79
- with gr.Column():
80
- draft_input = gr.Textbox(
81
- label="Draft Abstract",
82
- placeholder="Enter your sub-optimal grant abstract here...",
83
- lines=8,
84
- value="Our study will investigate protein interactions in cancer cells. We believe this research could be important for understanding disease mechanisms."
85
- )
86
- grant_type = gr.Dropdown(
87
- choices=["R01", "F32", "K99", "R21", "R15"],
88
- label="Grant Type",
89
- value="R01"
90
- )
91
- optimize_btn = gr.Button("Optimize Abstract", variant="primary")
 
 
 
 
 
 
92
 
93
- with gr.Column():
94
- output = gr.Textbox(
95
- label="Optimized Abstract",
96
- lines=10,
97
- interactive=False
98
- )
99
-
100
- optimize_btn.click(
101
- fn=placeholder_inference,
102
- inputs=[draft_input, grant_type],
103
- outputs=output
104
- )
105
 
106
  if __name__ == "__main__":
107
  demo.launch()
 
1
  import gradio as gr
2
  import torch
3
  import yaml
4
+ import json
5
+ from tokenizers import Tokenizer
6
 
7
+ # --- 1. Load Custom Model Code ---
8
+ # This dynamically loads your corrected HRM source code
9
+ from models.hrm.hrm_act_v1 import HierarchicalReasoningModel_ACTV1
 
 
 
 
 
 
 
 
 
 
10
 
11
+ # --- 2. Load Artifacts ---
12
+ print("Loading artifacts...")
13
+ # Load the tokenizer
14
+ tokenizer = Tokenizer.from_file("tokenizer.json")
15
+ # Load the model configuration
16
+ with open('config.yaml', 'r') as f:
17
+ config_data = yaml.safe_load(f)
18
+ model_config = config_data['arch']
19
+ # Load the grant type mapping
20
+ with open('activity_code_map.json', 'r') as f:
21
+ activity_code_map = json.load(f)
 
 
 
 
 
 
 
 
22
 
23
+ # --- 3. Initialize the Model ---
24
+ print("Initializing model...")
25
+ # The model expects a dict, so we pass the Pydantic model's dict representation
26
+ # We also need to add other required keys from the root of the config
27
+ model_config.update({
28
+ 'batch_size': config_data['global_batch_size'],
29
+ 'seq_len': 512, # You may need to get this from your dataset metadata
30
+ 'num_puzzle_identifiers': len(activity_code_map) + 1,
31
+ 'vocab_size': tokenizer.get_vocab_size()
32
+ })
33
+ model = HierarchicalReasoningModel_ACTV1(config_dict=model_config)
34
+ # Load the fine-tuned weights
35
+ model.load_state_dict(torch.load('pytorch_model.bin', map_location='cpu'))
36
+ model.eval() # Set the model to evaluation mode
37
+ print("Model loaded successfully!")
38
+
39
+ # --- 4. Define the Inference Function ---
40
+ def optimize_abstract(draft_abstract, grant_type):
41
+ """
42
+ Takes a draft abstract and grant type, runs the model, and returns the optimized text.
43
  """
44
+ if not draft_abstract or not grant_type:
45
+ return "Please provide both a draft abstract and a grant type."
46
 
47
+ try:
48
+ # Prepare inputs
49
+ tokenizer.enable_padding(length=512)
50
+ tokenizer.enable_truncation(max_length=512)
51
+
52
+ input_ids = tokenizer.encode(draft_abstract).ids
53
+ grant_type_id = activity_code_map.get(grant_type, 0) # Default to 0 if unknown
54
 
55
+ # Convert to PyTorch tensors
56
+ input_tensor = torch.tensor([input_ids], dtype=torch.long)
57
+ grant_tensor = torch.tensor([grant_type_id], dtype=torch.long)
58
+
59
+ # Create the batch dictionary that the model expects
60
+ batch = {
61
+ "inputs": input_tensor,
62
+ "puzzle_identifiers": grant_tensor,
63
+ # The model requires a 'labels' field, even for inference, so we provide a dummy one
64
+ "labels": torch.zeros_like(input_tensor)
65
+ }
66
+
67
+ # Run inference
68
+ with torch.no_grad():
69
+ carry = model.initial_carry(batch)
70
+ # The model runs in a loop; for inference, we run it for the max steps
71
+ for _ in range(model_config['halt_max_steps']):
72
+ carry, _ = model(carry=carry, batch=batch)
73
+
74
+ # Get the final logits from the carry state
75
+ final_logits = model.inner.lm_head(carry.inner_carry.z_H)[:, model.inner.puzzle_emb_len:]
76
+ predicted_ids = torch.argmax(final_logits, dim=-1).squeeze().tolist()
77
 
78
+ # Decode the output
79
+ optimized_text = tokenizer.decode(predicted_ids, skip_special_tokens=True)
 
 
 
 
 
80
 
81
+ return optimized_text
82
+
83
+ except Exception as e:
84
+ print(f"An error occurred during inference: {e}")
85
+ return f"Error: Could not process the abstract. Details: {e}"
86
+
87
+ # --- 5. Create the Gradio Interface ---
88
+ grant_type_choices = list(activity_code_map.keys())
89
+
90
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
91
+ gr.Markdown("# 🚀 HRM Grant Abstract Optimizer")
92
+ gr.Markdown("Enter a draft abstract and select the grant type to get a version optimized by the fine-tuned Hierarchical Reasoning Model.")
93
+
94
+ with gr.Row():
95
+ with gr.Column():
96
+ draft_input = gr.Textbox(label="Draft Abstract", lines=15, placeholder="Paste your draft abstract here...")
97
+ grant_type = gr.Dropdown(label="Grant Type", choices=grant_type_choices, value=grant_type_choices[0] if grant_type_choices else None)
98
+ optimize_btn = gr.Button("Optimize Abstract", variant="primary")
99
+ with gr.Column():
100
+ output_text = gr.Textbox(label="Optimized Abstract", lines=17, interactive=False)
101
 
102
+ optimize_btn.click(
103
+ fn=optimize_abstract,
104
+ inputs=[draft_input, grant_type],
105
+ outputs=output_text
106
+ )
 
 
 
 
 
 
 
107
 
108
  if __name__ == "__main__":
109
  demo.launch()