| from dataclasses import dataclass, field | |
| from typing import Optional | |
| from simple_parsing.helpers import Serializable | |
| class LoraArgs(Serializable): | |
| enable: bool = True | |
| rank: int = 16 | |
| dropout: float = 0.0 | |
| scaling: float = 2.0 | |
| def __post_init__(self): | |
| if self.enable: | |
| assert self.rank > 0 | |
| assert self.scaling > 0.0 | |
| class MoeArgs(Serializable): | |
| num_experts: int = 8 | |
| num_experts_per_tok: int = 2 | |
| class ModelArgs(Serializable): | |
| dim: int | |
| n_layers: int | |
| head_dim: int | |
| hidden_dim: int | |
| n_heads: int | |
| n_kv_heads: int | |
| norm_eps: float | |
| vocab_size: int | |
| rope_theta: float = 10000.0 | |
| lora: LoraArgs = field(default_factory=LoraArgs) | |
| moe: Optional[MoeArgs] = None | |