NGen 2
While using with transformers you can only use the 15M variant for now.
NGen 2 is an advanced Transformer model training pipeline that supports multiple model variants. It ranges from a nano variant (approximately 120M parameters) to a foundational variant (approximately 1B parameters). The pipeline incorporates modern architectural improvements such as rotary positional embeddings, RMSNorm, and GEGLU activations to boost performance and training efficiency.
Note: Although NGen 2 is designed to train a 1B-parameter model, its advanced architecture pushes its performance closer to that of much larger models. Try using NGen3 for performance.
The NGen2 Series was only Produced till 170M and The other Variants were produced but were never made public
Model Variants
NGen 2 supports the following variants via the --variant flag:
- nano: ~120M parameters
- small: ~300M parameters
- medium: ~500M parameters
- large: ~700M parameters
- foundational: ~1B parameters
Each variant adjusts key hyperparameters such as the number of layers, model dimension (d_model), number of attention heads (n_heads), and the feed-forward dimension (d_ff).
Requirements
- Python 3.8+
- PyTorch
- Transformers
- Datasets
- DeepSpeed (optional, for efficient training)
- Azure ML SDK (for distributed training on Azure)
Install dependencies using pip (adjust as needed):
pip install torch transformers datasets deepspeed azureml-core
Usage
1. Data Preparation
First, download and preprocess the OpenWebText dataset:
python prepare.py --output_dir ./_data_ --max_length 4096
This script downloads, tokenizes, and saves the dataset in Arrow format to the ./data directory.
2. Local Training
The main training script is train.py. It loads the processed dataset (by default from ./data), instantiates the desired model variant, and starts training.
Example CLI Commands
- Train the nano (120M) variant:
python train.py --dataset_dir ./_data_ --output_dir ./checkpoints_nano --batch_size 4 --epochs 3 --variant nano
- Train the small (300M) variant:
python train.py --dataset_dir ./_data_ --output_dir ./checkpoints_small --batch_size 4 --epochs 3 --variant small
- Train the medium (500M) variant:
python train.py --dataset_dir ./_data_ --output_dir ./checkpoints_medium --batch_size 4 --epochs 3 --variant medium
- Train the large (700M) variant:
python train.py --dataset_dir ./_data_ --output_dir ./checkpoints_large --batch_size 4 --epochs 3 --variant large
- Train the foundational (1B) variant with rotary embeddings enabled:
python train.py --dataset_dir ./_data_ --output_dir ./checkpoints_foundational --batch_size 4 --epochs 3 --variant foundational --use_rotary
3. Training on Azure ML
- Step 1: Set Up Azure ML Resources
Use azure_setup.py to create or connect to your Azure ML workspace and set up a compute cluster:
python azure_setup.py \
--workspace_name MyWorkspace \
--resource_group MyResourceGroup \
--subscription_id YOUR_SUBSCRIPTION_ID \
--location eastus \
--compute_name gpu-cluster \
--vm_size Standard_NC6 \
--max_nodes 4 \
--min_nodes 0
- Step 2: Submit a Training Job to Azure ML
Use
submit_train.pyto submit your training script to Azure ML:
python submit_train.py \
--experiment_name ngen3-experiment \
--compute_target gpu-cluster \
--script train.py \
--dataset_dir ./_data_ \
--output_dir ./checkpoints_foundational \
--batch_size 4 \
--epochs 3 \
--variant foundational \
--use_rotary
4. DeepSpeed Integration
The deepspeed.json file configures mixed-precision training and ZeRO optimizations. To leverage DeepSpeed, ensure it is installed and adjust your training script or submission command to enable DeepSpeed support.
License
License The NGen 2 project is developed and maintained by TNSA AI. The licensing model is dual:
- The nano and small variants are open source and released under the MIT License.
- The medium, large, and foundational variants are proprietary and are not open source. Use of these proprietary components is subject to TNSA AI's proprietary licensing terms.
Copyright
© 2023 TNSA AI. All rights reserved. for Use read LICENCE in the LICENSE file
Model tree for TNSA/NGen2-170M
Unable to build the model tree, the base model loops to the model itself. Learn more.