File size: 4,619 Bytes
4068083
bebb3b0
 
 
 
 
 
 
 
4068083
bebb3b0
 
 
 
 
 
 
 
4068083
 
bebb3b0
4068083
bebb3b0
4068083
bebb3b0
4068083
 
bebb3b0
4068083
 
bebb3b0
4068083
bebb3b0
4068083
bebb3b0
 
 
 
4068083
bebb3b0
4068083
bebb3b0
4068083
bebb3b0
4068083
bebb3b0
4068083
bebb3b0
 
 
 
 
4068083
bebb3b0
 
 
 
 
 
 
 
 
4068083
bebb3b0
4068083
bebb3b0
4068083
bebb3b0
 
 
 
 
 
 
 
4068083
bebb3b0
4068083
bebb3b0
4068083
bebb3b0
4068083
 
bebb3b0
 
 
 
 
4068083
bebb3b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4068083
 
bebb3b0
 
 
 
4068083
bebb3b0
4068083
bebb3b0
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
---
language:
- en
- fr
- es
- zh
- hi
- ja
- ru
license: apache-2.0
library_name: flax
tags:
- jax
- flax
- tpu
- text-generation
- base-model
- custom-architecture
datasets:
- HuggingFaceFW/fineweb-edu
- bigcode/starcoderdata
- HuggingFaceFW/fineweb-2
- open-web-math/open-web-math
metrics:
- loss
- perplexity
pipeline_tag: text-generation
inference: false
---

# Zenyx-Base-220M: High-Density Nano Foundation Model

<div align="center">

![Model Architecture](https://img.shields.io/badge/Model-Zenyx_Base-blue?style=for-the-badge)
![Parameter Count](https://img.shields.io/badge/Params-220M-orange?style=for-the-badge)
![Training Tokens](https://img.shields.io/badge/Tokens-153B-green?style=for-the-badge)
![Format](https://img.shields.io/badge/Weights-Safetensors-yellow?style=for-the-badge)

</div>

**Zenyx-Base-220M** is a 220 million parameter causal language model built from scratch using JAX/Flax on Kaggle TPU v5e-8. 

Unlike typical small models trained on limited data, Zenyx-Base was trained on **~153 Billion tokens**—far exceeding the Chinchilla optimal point for this parameter count. This "over-training" strategy was employed to maximize the information density and logic capabilities of the weights, creating a robust foundation for reasoning tasks.

## 🧠 Model Description

* **Architecture:** Custom Llama-style Transformer (RoPE, SwiGLU, RMSNorm, Grouped Query Attention).
* **Tokenizer:** Qwen 2.5 Tokenizer (151,650 Vocab Size) for high compression efficiency.
* **Context Window:** 2048 Tokens.
* **Training Hardware:** TPU v5e-8.
* **Final Validation Loss:** **~2.38** (Exceptional convergence for 220M).

### Technical Specifications
| Hyperparameter | Value |
| :--- | :--- |
| **Layers** | 12 |
| **Hidden Dim** | 768 |
| **MLP Dim** | 3072 |
| **Attention Heads** | 12 |
| **KV Heads** | 4 (GQA) |
| **Vocab Size** | 151,646 |

## 📚 Training Curriculum (The "Omni-Mix")

The model was trained using a rigorous 4-stage curriculum designed to layer capabilities sequentially:

1.  **Phase 1: Fundamentals (FineWeb-Edu)**
    * Focus on high-quality educational English text to establish linguistic baselines.
2.  **Phase 2: Logic & Structure (StarCoder - Python)**
    * Introduction of code data to enforce logical indentation, syntax, and structured thinking.
3.  **Phase 3: Multilingualism (FineWeb-2)**
    * Exposure to 6 major languages (Hindi, Chinese, Russian, Japanese, French, Spanish) to expand the semantic embedding space.
4.  **Phase 4: The Infinite Polish (Omni-Mix)**
    * A weighted interleaving of all previous datasets plus **OpenWebMath** to converge the model's logic and language capabilities.

## 💻 Usage

This model is a raw **JAX/Flax** checkpoint saved in `.safetensors` format. It uses a custom architecture definition and requires `flax` and `jax` to run.

### Loading with JAX/Flax

```python
import jax
import jax.numpy as jnp
from flax.training import train_state
from flax import serialization
from safetensors.flax import load_file
from transformers import AutoTokenizer
import flax.linen as nn

# 1. Define Architecture (Must match training config)
class TransformerLM(nn.Module):
    vocab_size: int
    embed_dim: int = 768
    num_layers: int = 12
    num_heads: int = 12
    num_kv_heads: int = 4
    mlp_dim: int = 3072
    max_length: int = 2048
    dropout_rate: float = 0.0
    
    # ... (Insert full model class definition here from the training script) ...

# 2. Load Resources
repo_id = "Arko007/Zenyx_Base_220M"
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", trust_remote_code=True)

# 3. Initialize & Load Weights
model = TransformerLM(vocab_size=len(tokenizer))
dummy_input = jnp.ones((1, 1), dtype=jnp.int32)
params = model.init(jax.random.PRNGKey(0), dummy_input)['params']

# Load Safetensors
# Ensure model.safetensors is downloaded locally
loaded_params = load_file("model.safetensors") 
print("Weights loaded successfully!")
```

## ⚠️ Limitations
- Size: At 220M parameters, the model's knowledge retrieval capacity is limited compared to 7B+ models.
- Base Model: This is a pre-trained base. It has not been fine-tuned for chat or instruction following (see Zenyx-DeepSeek-220M for the instruct version).
- Hallucinations: While logically consistent, it may generate factually incorrect statements.

## 📜 Citation

```python
@misc{ZenyxBase220M,
  title = {Zenyx-Base-220M: High-Density Foundation Model},
  author = {Arko007},
  year = {2025},
  publisher = {HuggingFace},
  url = {[https://huggingface.co/Arko007/Zenyx_Base_220M](https://huggingface.co/Arko007/Zenyx_Base_220M)}
}
```