File size: 4,619 Bytes
4068083 bebb3b0 4068083 bebb3b0 4068083 bebb3b0 4068083 bebb3b0 4068083 bebb3b0 4068083 bebb3b0 4068083 bebb3b0 4068083 bebb3b0 4068083 bebb3b0 4068083 bebb3b0 4068083 bebb3b0 4068083 bebb3b0 4068083 bebb3b0 4068083 bebb3b0 4068083 bebb3b0 4068083 bebb3b0 4068083 bebb3b0 4068083 bebb3b0 4068083 bebb3b0 4068083 bebb3b0 4068083 bebb3b0 4068083 bebb3b0 4068083 bebb3b0 4068083 bebb3b0 4068083 bebb3b0 4068083 bebb3b0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | ---
language:
- en
- fr
- es
- zh
- hi
- ja
- ru
license: apache-2.0
library_name: flax
tags:
- jax
- flax
- tpu
- text-generation
- base-model
- custom-architecture
datasets:
- HuggingFaceFW/fineweb-edu
- bigcode/starcoderdata
- HuggingFaceFW/fineweb-2
- open-web-math/open-web-math
metrics:
- loss
- perplexity
pipeline_tag: text-generation
inference: false
---
# Zenyx-Base-220M: High-Density Nano Foundation Model
<div align="center">




</div>
**Zenyx-Base-220M** is a 220 million parameter causal language model built from scratch using JAX/Flax on Kaggle TPU v5e-8.
Unlike typical small models trained on limited data, Zenyx-Base was trained on **~153 Billion tokens**—far exceeding the Chinchilla optimal point for this parameter count. This "over-training" strategy was employed to maximize the information density and logic capabilities of the weights, creating a robust foundation for reasoning tasks.
## 🧠 Model Description
* **Architecture:** Custom Llama-style Transformer (RoPE, SwiGLU, RMSNorm, Grouped Query Attention).
* **Tokenizer:** Qwen 2.5 Tokenizer (151,650 Vocab Size) for high compression efficiency.
* **Context Window:** 2048 Tokens.
* **Training Hardware:** TPU v5e-8.
* **Final Validation Loss:** **~2.38** (Exceptional convergence for 220M).
### Technical Specifications
| Hyperparameter | Value |
| :--- | :--- |
| **Layers** | 12 |
| **Hidden Dim** | 768 |
| **MLP Dim** | 3072 |
| **Attention Heads** | 12 |
| **KV Heads** | 4 (GQA) |
| **Vocab Size** | 151,646 |
## 📚 Training Curriculum (The "Omni-Mix")
The model was trained using a rigorous 4-stage curriculum designed to layer capabilities sequentially:
1. **Phase 1: Fundamentals (FineWeb-Edu)**
* Focus on high-quality educational English text to establish linguistic baselines.
2. **Phase 2: Logic & Structure (StarCoder - Python)**
* Introduction of code data to enforce logical indentation, syntax, and structured thinking.
3. **Phase 3: Multilingualism (FineWeb-2)**
* Exposure to 6 major languages (Hindi, Chinese, Russian, Japanese, French, Spanish) to expand the semantic embedding space.
4. **Phase 4: The Infinite Polish (Omni-Mix)**
* A weighted interleaving of all previous datasets plus **OpenWebMath** to converge the model's logic and language capabilities.
## 💻 Usage
This model is a raw **JAX/Flax** checkpoint saved in `.safetensors` format. It uses a custom architecture definition and requires `flax` and `jax` to run.
### Loading with JAX/Flax
```python
import jax
import jax.numpy as jnp
from flax.training import train_state
from flax import serialization
from safetensors.flax import load_file
from transformers import AutoTokenizer
import flax.linen as nn
# 1. Define Architecture (Must match training config)
class TransformerLM(nn.Module):
vocab_size: int
embed_dim: int = 768
num_layers: int = 12
num_heads: int = 12
num_kv_heads: int = 4
mlp_dim: int = 3072
max_length: int = 2048
dropout_rate: float = 0.0
# ... (Insert full model class definition here from the training script) ...
# 2. Load Resources
repo_id = "Arko007/Zenyx_Base_220M"
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", trust_remote_code=True)
# 3. Initialize & Load Weights
model = TransformerLM(vocab_size=len(tokenizer))
dummy_input = jnp.ones((1, 1), dtype=jnp.int32)
params = model.init(jax.random.PRNGKey(0), dummy_input)['params']
# Load Safetensors
# Ensure model.safetensors is downloaded locally
loaded_params = load_file("model.safetensors")
print("Weights loaded successfully!")
```
## ⚠️ Limitations
- Size: At 220M parameters, the model's knowledge retrieval capacity is limited compared to 7B+ models.
- Base Model: This is a pre-trained base. It has not been fine-tuned for chat or instruction following (see Zenyx-DeepSeek-220M for the instruct version).
- Hallucinations: While logically consistent, it may generate factually incorrect statements.
## 📜 Citation
```python
@misc{ZenyxBase220M,
title = {Zenyx-Base-220M: High-Density Foundation Model},
author = {Arko007},
year = {2025},
publisher = {HuggingFace},
url = {[https://huggingface.co/Arko007/Zenyx_Base_220M](https://huggingface.co/Arko007/Zenyx_Base_220M)}
}
``` |