{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"trusted": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/workspace/working\n",
"Cloning into 'mistral-finetune'...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/working/venv/lib/python3.10/site-packages/IPython/core/magics/osm.py:417: UserWarning: This is now an optional IPython functionality, setting dhist requires you to install the `pickleshare` library.\n",
" self.shell.db['dhist'] = compress_dhist(dhist)[-100:]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"remote: Enumerating objects: 472, done.\u001b[K\n",
"remote: Counting objects: 100% (249/249), done.\u001b[K\n",
"remote: Compressing objects: 100% (90/90), done.\u001b[K\n",
"remote: Total 472 (delta 211), reused 159 (delta 159), pack-reused 223 (from 2)\u001b[K\n",
"Receiving objects: 100% (472/472), 243.32 KiB | 5.29 MiB/s, done.\n",
"Resolving deltas: 100% (251/251), done.\n"
]
}
],
"source": [
"%cd /workspace/working\n",
"!git clone https://github.com/mistralai/mistral-finetune.git"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"trusted": true
},
"outputs": [],
"source": [
"# %pip install -r /workspace/working/mistral-finetune/requirements.txt"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"trusted": true
},
"outputs": [],
"source": [
"# %pip install huggingface_hub"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"trusted": true
},
"outputs": [],
"source": [
"# !pip install -q llama-cpp-python --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cu121"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"trusted": true
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "704817ec3d4c4b92b77002db54fe485a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Fetching 3 files: 0%| | 0/3 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0f336ef4b9dc48a2885292f467ceda14",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"consolidated.safetensors: 0%| | 0.00/14.5G [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a9fdfaf041084286b750b9ed2f5c24be",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"tokenizer.model.v3: 0%| | 0.00/587k [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cc61fc51fce543febb4075f2f64cd25e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"params.json: 0%| | 0.00/202 [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'/workspace/working/mistral_models'"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from huggingface_hub import snapshot_download\n",
"from pathlib import Path\n",
"\n",
"# Set the path to the /workspace/working directory\n",
"mistral_models_path = Path('/workspace/working/', 'mistral_models')\n",
"mistral_models_path.mkdir(parents=True, exist_ok=True)\n",
"\n",
"# Download model snapshot to the desired path\n",
"snapshot_download(repo_id=\"mistralai/Mistral-7B-Instruct-v0.3\", allow_patterns=[\"params.json\", \"consolidated.safetensors\", \"tokenizer.model.v3\"], local_dir=mistral_models_path)\n",
"# snapshot_download(repo_id=\"unsloth/mistral-7b-instruct-v0.3-bnb-4bit\", allow_patterns=[\"config.json\", \"model.safetensors\", \"tokenizer.model.v3\"], local_dir=mistral_models_path)\n",
"\n",
"# import json\n",
"\n",
"# # Create a dictionary (or any other data structure)\n",
"# data = {\n",
"# \"dim\": 4096,\n",
"# \"n_layers\": 32,\n",
"# \"head_dim\": 128,\n",
"# \"hidden_dim\": 14336,\n",
"# \"n_heads\": 32,\n",
"# \"n_kv_heads\": 8,\n",
"# \"norm_eps\": 1e-05,\n",
"# \"vocab_size\": 32768,\n",
"# \"rope_theta\": 1000000.0\n",
"# }\n",
"\n",
"# Specify the path where the file will be saved\n",
"# file_path = '/workspace/working/mistral_models/params.json'\n",
"\n",
"# # Write the dictionary to a JSON file\n",
"# with open(file_path, 'w') as json_file:\n",
"# json.dump(data, json_file, indent=4)\n",
"\n",
"# print(f\"JSON file created at {file_path}\")\n",
"\n",
"# os.rename('/workspace/working/mistral_models/model.safetensors', '/workspace/working/mistral_models/consolidated.safetensors')\n",
"\n",
"# Copy the model files to the /workspace/working directory (if necessary)\n",
"# !mv /workspace/working/mistral_models/7B-instruct-v0.3 /workspace/working/mistral_models\n",
"# !rm -r /workspace/working/mistral_models/7B-instruct-v0.3\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"trusted": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"consolidated.safetensors params.json tokenizer.model.v3\n"
]
}
],
"source": [
"%ls /workspace/working/mistral_models/"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"trusted": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/workspace/working\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/working/venv/lib/python3.10/site-packages/IPython/core/magics/osm.py:417: UserWarning: This is now an optional IPython functionality, setting dhist requires you to install the `pickleshare` library.\n",
" self.shell.db['dhist'] = compress_dhist(dhist)[-100:]\n"
]
}
],
"source": [
"%cd /workspace/working/"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"trusted": true
},
"outputs": [],
"source": [
"%mkdir -p data"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"trusted": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/workspace/working/data\n"
]
}
],
"source": [
"%cd /workspace/working/data/"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"trusted": true
},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" text_query | \n",
" language | \n",
" sparql_query | \n",
" knowledge_graphs | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" Welche Bücher von Kerouac wurden von Viking Pr... | \n",
" de | \n",
" PREFIX dbo: <http://dbpedia.org/ontology/>\\nPR... | \n",
" DBpedia | \n",
"
\n",
" \n",
" | 1 | \n",
" ¿Qué libros de Kerouac han sido publicados por... | \n",
" es | \n",
" PREFIX dbo: <http://dbpedia.org/ontology/>\\nPR... | \n",
" DBpedia | \n",
"
\n",
" \n",
" | 2 | \n",
" Quali libri di Kerouac furono pubblicati da Vi... | \n",
" it | \n",
" PREFIX dbo: <http://dbpedia.org/ontology/>\\nPR... | \n",
" DBpedia | \n",
"
\n",
" \n",
" | 3 | \n",
" Quels lives de Kerouac ont été publiés par Vik... | \n",
" fr | \n",
" PREFIX dbo: <http://dbpedia.org/ontology/>\\nPR... | \n",
" DBpedia | \n",
"
\n",
" \n",
" | 4 | \n",
" Welke boeken van Jack Kerouac werden uitgegeve... | \n",
" nl | \n",
" PREFIX dbo: <http://dbpedia.org/ontology/>\\nPR... | \n",
" DBpedia | \n",
"
\n",
" \n",
" | ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" | 893476 | \n",
" Was sind die Arten von Speisen oder Gerichten,... | \n",
" de | \n",
" SELECT DISTINCT ?sbj ?sbj_label { ?statement1 ... | \n",
" DBpedia | \n",
"
\n",
" \n",
" | 893477 | \n",
" Welche Art von Essen oder Gericht enthält das ... | \n",
" de | \n",
" SELECT DISTINCT ?sbj ?sbj_label { ?statement1 ... | \n",
" DBpedia | \n",
"
\n",
" \n",
" | 893478 | \n",
" Welche Art von Quantenteilchen hat das höchste... | \n",
" de | \n",
" select ?ent where {\\n?ent <http://www.wikidata... | \n",
" DBpedia | \n",
"
\n",
" \n",
" | 893479 | \n",
" Welches Quantenteilchen hat das höchste Gyroma... | \n",
" de | \n",
" select ?ent where {\\n?ent <http://www.wikidata... | \n",
" DBpedia | \n",
"
\n",
" \n",
" | 893480 | \n",
" Sage mir die literarische Gattung, deren Name ... | \n",
" de | \n",
" SELECT DISTINCT ?sbj ?sbj_label { ?statement1 ... | \n",
" DBpedia | \n",
"
\n",
" \n",
"
\n",
"
893481 rows × 4 columns
\n",
"
"
],
"text/plain": [
" text_query language \\\n",
"0 Welche Bücher von Kerouac wurden von Viking Pr... de \n",
"1 ¿Qué libros de Kerouac han sido publicados por... es \n",
"2 Quali libri di Kerouac furono pubblicati da Vi... it \n",
"3 Quels lives de Kerouac ont été publiés par Vik... fr \n",
"4 Welke boeken van Jack Kerouac werden uitgegeve... nl \n",
"... ... ... \n",
"893476 Was sind die Arten von Speisen oder Gerichten,... de \n",
"893477 Welche Art von Essen oder Gericht enthält das ... de \n",
"893478 Welche Art von Quantenteilchen hat das höchste... de \n",
"893479 Welches Quantenteilchen hat das höchste Gyroma... de \n",
"893480 Sage mir die literarische Gattung, deren Name ... de \n",
"\n",
" sparql_query knowledge_graphs \n",
"0 PREFIX dbo: \\nPR... DBpedia \n",
"1 PREFIX dbo: \\nPR... DBpedia \n",
"2 PREFIX dbo: \\nPR... DBpedia \n",
"3 PREFIX dbo: \\nPR... DBpedia \n",
"4 PREFIX dbo: \\nPR... DBpedia \n",
"... ... ... \n",
"893476 SELECT DISTINCT ?sbj ?sbj_label { ?statement1 ... DBpedia \n",
"893477 SELECT DISTINCT ?sbj ?sbj_label { ?statement1 ... DBpedia \n",
"893478 select ?ent where {\\n?ent \n",
"\n",
"\n",
" \n",
" \n",
" | \n",
" text_query | \n",
" language | \n",
" sparql_query | \n",
" knowledge_graphs | \n",
"
\n",
" \n",
" \n",
" \n",
"
\n",
""
],
"text/plain": [
"Empty DataFrame\n",
"Columns: [text_query, language, sparql_query, knowledge_graphs]\n",
"Index: []"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dftemp = dftemp[dftemp[\"language\"] == \"en\"]\n",
"dftemp = dftemp[dftemp[\"knowledge_graphs\"] == \"DBpedia\"]\n",
"dftemp[dftemp[\"language\"] != \"en\"]\n",
"# dftemp"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"trusted": true
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" text_query | \n",
" language | \n",
" sparql_query | \n",
" knowledge_graphs | \n",
"
\n",
" \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
"Empty DataFrame\n",
"Columns: [text_query, language, sparql_query, knowledge_graphs]\n",
"Index: []"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dftemp[dftemp[\"knowledge_graphs\"] != \"DBpedia\"]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"trusted": true
},
"outputs": [],
"source": [
"dftemp.reset_index(drop = True, inplace = True)\n",
"dftemp.drop(columns=[\"language\", \"knowledge_graphs\"], axis = 1, inplace = True)\n",
"tbremovedrows = dftemp[\"sparql_query\"].str.contains(\"\\n\")\n",
"dftemp = dftemp[~tbremovedrows]\n",
"dftemp.reset_index(drop = True, inplace = True)\n",
"dftemp[\"messages\"] = dftemp.apply(lambda row: [{\"role\": \"user\", \"content\": row[\"text_query\"]}, {\"role\": \"assistant\", \"content\": row[\"sparql_query\"]}], axis=1)\n",
"dftemp.rename(columns = {\"text_query\" : \"NL_Query\"}, inplace = True)\n",
"dftemp.drop(columns = \"sparql_query\", inplace = True)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"trusted": true
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" NL_Query | \n",
" messages | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" How high is the lighthouse in Colombo? | \n",
" [{'role': 'user', 'content': 'How high is the ... | \n",
"
\n",
" \n",
" | 1 | \n",
" Who is the host of the BBC Wildlife Specials? | \n",
" [{'role': 'user', 'content': 'Who is the host ... | \n",
"
\n",
" \n",
" | 2 | \n",
" How much did Pulp Fiction cost? | \n",
" [{'role': 'user', 'content': 'How much did Pul... | \n",
"
\n",
" \n",
" | 3 | \n",
" In what city is the Heineken brewery? | \n",
" [{'role': 'user', 'content': 'In what city is ... | \n",
"
\n",
" \n",
" | 4 | \n",
" When did Operation Overlord commence? | \n",
" [{'role': 'user', 'content': 'When did Operati... | \n",
"
\n",
" \n",
" | ... | \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" | 42233 | \n",
" Name a gene that begins with a T. | \n",
" [{'role': 'user', 'content': 'Name a gene that... | \n",
"
\n",
" \n",
" | 42234 | \n",
" What is antonym of of spore print color of Ple... | \n",
" [{'role': 'user', 'content': 'What is antonym ... | \n",
"
\n",
" \n",
" | 42235 | \n",
" Tell me mixture whose name has the word spirom... | \n",
" [{'role': 'user', 'content': 'Tell me mixture ... | \n",
"
\n",
" \n",
" | 42236 | \n",
" Let me know blend whose title has the word spi... | \n",
" [{'role': 'user', 'content': 'Let me know blen... | \n",
"
\n",
" \n",
" | 42237 | \n",
" What Theoi Project ID does Manticore has? | \n",
" [{'role': 'user', 'content': 'What Theoi Proje... | \n",
"
\n",
" \n",
"
\n",
"
42238 rows × 2 columns
\n",
"
"
],
"text/plain": [
" NL_Query \\\n",
"0 How high is the lighthouse in Colombo? \n",
"1 Who is the host of the BBC Wildlife Specials? \n",
"2 How much did Pulp Fiction cost? \n",
"3 In what city is the Heineken brewery? \n",
"4 When did Operation Overlord commence? \n",
"... ... \n",
"42233 Name a gene that begins with a T. \n",
"42234 What is antonym of of spore print color of Ple... \n",
"42235 Tell me mixture whose name has the word spirom... \n",
"42236 Let me know blend whose title has the word spi... \n",
"42237 What Theoi Project ID does Manticore has? \n",
"\n",
" messages \n",
"0 [{'role': 'user', 'content': 'How high is the ... \n",
"1 [{'role': 'user', 'content': 'Who is the host ... \n",
"2 [{'role': 'user', 'content': 'How much did Pul... \n",
"3 [{'role': 'user', 'content': 'In what city is ... \n",
"4 [{'role': 'user', 'content': 'When did Operati... \n",
"... ... \n",
"42233 [{'role': 'user', 'content': 'Name a gene that... \n",
"42234 [{'role': 'user', 'content': 'What is antonym ... \n",
"42235 [{'role': 'user', 'content': 'Tell me mixture ... \n",
"42236 [{'role': 'user', 'content': 'Let me know blen... \n",
"42237 [{'role': 'user', 'content': 'What Theoi Proje... \n",
"\n",
"[42238 rows x 2 columns]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dftemp"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"trusted": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"434\n",
"2\n",
"826\n"
]
}
],
"source": [
"print(dftemp[\"NL_Query\"].str.len().max())\n",
"print(dftemp[\"messages\"].str.len().max())\n",
"print(\n",
" dftemp[\"messages\"]\n",
" .apply(lambda x: \" \".join([msg.get(\"content\", \"\") for msg in x if msg.get(\"role\") == \"assistant\"]) if isinstance(x, list) else \"\")\n",
" .str.len()\n",
" .max()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"trusted": true
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" NL_Query | \n",
" messages | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" What is the sister city of the birthplace of N... | \n",
" [{'role': 'user', 'content': 'What is the sist... | \n",
"
\n",
" \n",
" | 1 | \n",
" Is the oxidation state of phosphorus equal to 3? | \n",
" [{'role': 'user', 'content': 'Is the oxidation... | \n",
"
\n",
" \n",
" | 2 | \n",
" Where is the official residence and the office... | \n",
" [{'role': 'user', 'content': 'Where is the off... | \n",
"
\n",
" \n",
" | 3 | \n",
" WHICH IS THE SUBSIDIARY COMPANY OF SHAREHOLDER... | \n",
" [{'role': 'user', 'content': 'WHICH IS THE SUB... | \n",
"
\n",
" \n",
" | 4 | \n",
" What is in the performer of Vivien Leigh ? | \n",
" [{'role': 'user', 'content': 'What is in the ... | \n",
"
\n",
" \n",
" | ... | \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" | 33785 | \n",
" What are the national symbol which start with ... | \n",
" [{'role': 'user', 'content': 'What are the nat... | \n",
"
\n",
" \n",
" | 33786 | \n",
" What is the Xbox One Connector? | \n",
" [{'role': 'user', 'content': 'What is the Xbox... | \n",
"
\n",
" \n",
" | 33787 | \n",
" Is 0 the the number of houses of the Schuleroch? | \n",
" [{'role': 'user', 'content': 'Is 0 the the num... | \n",
"
\n",
" \n",
" | 33788 | \n",
" Whichs {house} of {characters} of {Benvenuto C... | \n",
" [{'role': 'user', 'content': 'Whichs {house} o... | \n",
"
\n",
" \n",
" | 33789 | \n",
" List all the locations of the companies whose ... | \n",
" [{'role': 'user', 'content': 'List all the loc... | \n",
"
\n",
" \n",
"
\n",
"
33790 rows × 2 columns
\n",
"
"
],
"text/plain": [
" NL_Query \\\n",
"0 What is the sister city of the birthplace of N... \n",
"1 Is the oxidation state of phosphorus equal to 3? \n",
"2 Where is the official residence and the office... \n",
"3 WHICH IS THE SUBSIDIARY COMPANY OF SHAREHOLDER... \n",
"4 What is in the performer of Vivien Leigh ? \n",
"... ... \n",
"33785 What are the national symbol which start with ... \n",
"33786 What is the Xbox One Connector? \n",
"33787 Is 0 the the number of houses of the Schuleroch? \n",
"33788 Whichs {house} of {characters} of {Benvenuto C... \n",
"33789 List all the locations of the companies whose ... \n",
"\n",
" messages \n",
"0 [{'role': 'user', 'content': 'What is the sist... \n",
"1 [{'role': 'user', 'content': 'Is the oxidation... \n",
"2 [{'role': 'user', 'content': 'Where is the off... \n",
"3 [{'role': 'user', 'content': 'WHICH IS THE SUB... \n",
"4 [{'role': 'user', 'content': 'What is in the ... \n",
"... ... \n",
"33785 [{'role': 'user', 'content': 'What are the nat... \n",
"33786 [{'role': 'user', 'content': 'What is the Xbox... \n",
"33787 [{'role': 'user', 'content': 'Is 0 the the num... \n",
"33788 [{'role': 'user', 'content': 'Whichs {house} o... \n",
"33789 [{'role': 'user', 'content': 'List all the loc... \n",
"\n",
"[33790 rows x 2 columns]"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_train=dftemp.sample(frac=0.80,random_state=200)\n",
"df_eval=dftemp.drop(df_train.index)\n",
"\n",
"df_train.reset_index(drop = True, inplace = True)\n",
"df_eval.reset_index(drop = True, inplace = True)\n",
"df_train"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"trusted": true
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" NL_Query | \n",
" messages | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" How much did Pulp Fiction cost? | \n",
" [{'role': 'user', 'content': 'How much did Pul... | \n",
"
\n",
" \n",
" | 1 | \n",
" In what city is the Heineken brewery? | \n",
" [{'role': 'user', 'content': 'In what city is ... | \n",
"
\n",
" \n",
" | 2 | \n",
" When did Operation Overlord commence? | \n",
" [{'role': 'user', 'content': 'When did Operati... | \n",
"
\n",
" \n",
" | 3 | \n",
" In which city does the Chile Route 68 end? | \n",
" [{'role': 'user', 'content': 'In which city do... | \n",
"
\n",
" \n",
" | 4 | \n",
" Where does Piccadilly start? | \n",
" [{'role': 'user', 'content': 'Where does Picca... | \n",
"
\n",
" \n",
" | ... | \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" | 8443 | \n",
" True or False: Kenya's Human Development Index... | \n",
" [{'role': 'user', 'content': 'True or False: K... | \n",
"
\n",
" \n",
" | 8444 | \n",
" Who wrote the bibiiography of Natja Brunckhorst? | \n",
" [{'role': 'user', 'content': 'Who wrote the bi... | \n",
"
\n",
" \n",
" | 8445 | \n",
" What are the jurisdiction which start with the... | \n",
" [{'role': 'user', 'content': 'What are the jur... | \n",
"
\n",
" \n",
" | 8446 | \n",
" what are the emirate of the united arab emirat... | \n",
" [{'role': 'user', 'content': 'what are the emi... | \n",
"
\n",
" \n",
" | 8447 | \n",
" Name a gene that begins with a T. | \n",
" [{'role': 'user', 'content': 'Name a gene that... | \n",
"
\n",
" \n",
"
\n",
"
8448 rows × 2 columns
\n",
"
"
],
"text/plain": [
" NL_Query \\\n",
"0 How much did Pulp Fiction cost? \n",
"1 In what city is the Heineken brewery? \n",
"2 When did Operation Overlord commence? \n",
"3 In which city does the Chile Route 68 end? \n",
"4 Where does Piccadilly start? \n",
"... ... \n",
"8443 True or False: Kenya's Human Development Index... \n",
"8444 Who wrote the bibiiography of Natja Brunckhorst? \n",
"8445 What are the jurisdiction which start with the... \n",
"8446 what are the emirate of the united arab emirat... \n",
"8447 Name a gene that begins with a T. \n",
"\n",
" messages \n",
"0 [{'role': 'user', 'content': 'How much did Pul... \n",
"1 [{'role': 'user', 'content': 'In what city is ... \n",
"2 [{'role': 'user', 'content': 'When did Operati... \n",
"3 [{'role': 'user', 'content': 'In which city do... \n",
"4 [{'role': 'user', 'content': 'Where does Picca... \n",
"... ... \n",
"8443 [{'role': 'user', 'content': 'True or False: K... \n",
"8444 [{'role': 'user', 'content': 'Who wrote the bi... \n",
"8445 [{'role': 'user', 'content': 'What are the jur... \n",
"8446 [{'role': 'user', 'content': 'what are the emi... \n",
"8447 [{'role': 'user', 'content': 'Name a gene that... \n",
"\n",
"[8448 rows x 2 columns]"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_eval"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"trusted": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"What is the sister city of the birthplace of Nikita Telenkov?\n"
]
},
{
"data": {
"text/plain": [
"[{'role': 'user',\n",
" 'content': 'What is the sister city of the birthplace of Nikita Telenkov?'},\n",
" {'role': 'assistant',\n",
" 'content': 'SELECT ?answer WHERE { ?statement1 . ?statement1 . ?statement1 ?X . ?statement2 ?X. ?statement2 . ?statement2 ?answer . }'}]"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"print(df_train.iloc[0][\"NL_Query\"])\n",
"df_train.iloc[0][\"messages\"]"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"trusted": true
},
"outputs": [],
"source": [
"# save data into .jsonl files\n",
"df_train.to_json(\"NL_to_SPARQL_train.jsonl\", orient=\"records\", lines=True)\n",
"df_eval.to_json(\"NL_to_SPARQL_eval.jsonl\", orient=\"records\", lines=True)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"trusted": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"NL_to_SPARQL_eval.jsonl NL_to_SPARQL_train.jsonl\n"
]
}
],
"source": [
"%ls /workspace/working/data/"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"trusted": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/workspace/working/mistral-finetune\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/working/venv/lib/python3.10/site-packages/IPython/core/magics/osm.py:417: UserWarning: This is now an optional IPython functionality, setting dhist requires you to install the `pickleshare` library.\n",
" self.shell.db['dhist'] = compress_dhist(dhist)[-100:]\n"
]
}
],
"source": [
"%cd /workspace/working/mistral-finetune/"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"trusted": true
},
"outputs": [],
"source": [
"import yaml\n",
"\n",
"# # Read the existing YAML file\n",
"# with open('example/7B.yaml', 'r') as file:\n",
"# data = yaml.safe_load(file)\n",
"\n",
"# # Print the loaded data to see its contents\n",
"# print(data)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"trusted": true
},
"outputs": [],
"source": [
"# # Modify the data\n",
"# data['data']['instruct_data'] = '/workspace/working/data/NL_to_SPARQL_train.jsonl'\n",
"# data['data']['eval_instruct_data'] = '/workspace/working/data/NL_to_SPARQL_eval.jsonl'\n",
"# data['model_id_or_path'] = '/workspace/working/mistral_models'\n",
"# data['run_dir'] = '/workspace/working/outputs'\n",
"# data['seq_len'] = 8192\n",
"# data['lora']['rank'] = 16\n",
"# data['log_freq'] = 50\n",
"# data['num_microbatches'] = 8\n",
"# data['optim']['lr'] = 1.e-5\n",
"# data['lora']['scaling'] = 16\n",
"# if 'wandb' in data:\n",
"# del data['wandb']\n",
"# # data['wandb']['offline'] = True\n",
"\n",
"\n",
"# # Print the modified data\n",
"# print(data)\n",
"config = \"\"\"\n",
"# data\n",
"data:\n",
" instruct_data: \"/workspace/working/data/NL_to_SPARQL_train.jsonl\" # Fill\n",
" data: \"\" # Optionally fill with pretraining data\n",
" eval_instruct_data: \"/workspace/working/data/NL_to_SPARQL_eval.jsonl\" # Optionally fill\n",
"\n",
"# model\n",
"model_id_or_path: \"/workspace/working/mistral_models\" # Change to downloaded path\n",
"lora:\n",
" rank: 16\n",
" scaling: 16\n",
"# optim\n",
"# tokens per training steps = batch_size x num_GPUs x seq_len\n",
"# we recommend sequence length of 32768\n",
"# If you run into memory error, you can try reduce the sequence length\n",
"seq_len: 2048\n",
"batch_size: 1\n",
"num_microbatches: 8\n",
"max_steps: 300\n",
"optim:\n",
" lr: 1.e-5\n",
" weight_decay: 0.1\n",
" pct_start: 0.05\n",
"\n",
"# other\n",
"seed: 0\n",
"log_freq: 10\n",
"eval_freq: 100\n",
"no_eval: False\n",
"ckpt_freq: 100\n",
"\n",
"save_adapters: True # save only trained LoRA adapters. Set to `False` to merge LoRA adapter into the base model and save full fine-tuned model\n",
"\n",
"run_dir: \"outputs\" # Fill\n",
"\"\"\"\n",
"\n",
"# save the same file locally into the example.yaml file\n",
"import yaml\n",
"with open('example/7B.yaml', 'w') as file:\n",
" yaml.dump(yaml.safe_load(config), file)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"trusted": true
},
"outputs": [],
"source": [
"# config = \"\"\"\n",
"# wandb:\n",
"# offline: True\n",
"# \"\"\"\n",
"# # Write the modified data back to the file\n",
"# with open('example/7B.yaml', 'w') as file:\n",
"# yaml.dump(data, file, default_flow_style=False)\n",
"# file.write(config)\n",
"\n",
"# print(\"YAML file updated successfully.\")"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"trusted": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"batch_size: 1\n",
"ckpt_freq: 100\n",
"data:\n",
" data: ''\n",
" eval_instruct_data: /workspace/working/data/NL_to_SPARQL_eval.jsonl\n",
" instruct_data: /workspace/working/data/NL_to_SPARQL_train.jsonl\n",
"eval_freq: 100\n",
"log_freq: 10\n",
"lora:\n",
" rank: 16\n",
" scaling: 16\n",
"max_steps: 300\n",
"model_id_or_path: /workspace/working/mistral_models\n",
"no_eval: false\n",
"num_microbatches: 8\n",
"optim:\n",
" lr: 1.0e-05\n",
" pct_start: 0.05\n",
" weight_decay: 0.1\n",
"run_dir: outputs\n",
"save_adapters: true\n",
"seed: 0\n",
"seq_len: 2048\n"
]
}
],
"source": [
"! cat example/7B.yaml"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"trusted": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True\n"
]
}
],
"source": [
"import os\n",
"\n",
"print(os.path.exists(\"example/7B.yaml\"))\n"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"trusted": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/workspace/working/mistral-finetune\n"
]
}
],
"source": [
"%cd /workspace/working/mistral-finetune/"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"trusted": true
},
"outputs": [],
"source": [
"!python -m utils.reformat_data /workspace/working/data/NL_to_SPARQL_train.jsonl"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"trusted": true
},
"outputs": [],
"source": [
"!python -m utils.reformat_data /workspace/working/data/NL_to_SPARQL_eval.jsonl"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"trusted": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0it [00:00, ?it/s]Validating /workspace/working/data/NL_to_SPARQL_train.jsonl ...\n",
"\n",
" 0%| | 0/33790 [00:00, ?it/s]\u001b[A\n",
" 1%|▍ | 398/33790 [00:00<00:08, 3974.46it/s]\u001b[A\n",
" 2%|▉ | 818/33790 [00:00<00:08, 4105.68it/s]\u001b[A\n",
" 4%|█▎ | 1230/33790 [00:00<00:07, 4108.17it/s]\u001b[A\n",
" 5%|█▊ | 1646/33790 [00:00<00:07, 4128.18it/s]\u001b[A\n",
" 6%|██▏ | 2071/33790 [00:00<00:07, 4170.27it/s]\u001b[A\n",
" 7%|██▋ | 2494/33790 [00:00<00:07, 4189.40it/s]\u001b[A\n",
" 9%|███ | 2913/33790 [00:00<00:07, 4178.51it/s]\u001b[A\n",
" 10%|███▌ | 3331/33790 [00:00<00:07, 4167.07it/s]\u001b[A\n",
" 11%|███▉ | 3749/33790 [00:00<00:07, 4170.99it/s]\u001b[A\n",
" 12%|████▍ | 4177/33790 [00:01<00:07, 4201.57it/s]\u001b[A\n",
" 14%|████▉ | 4601/33790 [00:01<00:06, 4212.62it/s]\u001b[A\n",
" 15%|█████▎ | 5023/33790 [00:01<00:06, 4195.22it/s]\u001b[A\n",
" 16%|█████▊ | 5443/33790 [00:01<00:06, 4152.50it/s]\u001b[A\n",
" 17%|██████▏ | 5859/33790 [00:01<00:06, 4148.96it/s]\u001b[A\n",
" 19%|██████▋ | 6274/33790 [00:01<00:06, 4139.69it/s]\u001b[A\n",
" 20%|███████▏ | 6693/33790 [00:01<00:06, 4154.68it/s]\u001b[A\n",
" 21%|███████▌ | 7109/33790 [00:01<00:06, 4142.47it/s]\u001b[A\n",
" 22%|████████ | 7530/33790 [00:01<00:06, 4161.44it/s]\u001b[A\n",
" 24%|████████▍ | 7953/33790 [00:01<00:06, 4181.75it/s]\u001b[A\n",
" 25%|████████▉ | 8372/33790 [00:02<00:06, 4157.47it/s]\u001b[A\n",
" 26%|█████████▎ | 8788/33790 [00:02<00:06, 4155.90it/s]\u001b[A\n",
" 27%|█████████▊ | 9210/33790 [00:02<00:05, 4173.02it/s]\u001b[A\n",
" 28%|██████████▎ | 9628/33790 [00:02<00:05, 4146.54it/s]\u001b[A\n",
" 30%|██████████▍ | 10047/33790 [00:02<00:05, 4159.22it/s]\u001b[A\n",
" 31%|██████████▊ | 10463/33790 [00:02<00:05, 4159.04it/s]\u001b[A\n",
" 32%|███████████▎ | 10879/33790 [00:02<00:05, 4157.94it/s]\u001b[A\n",
" 33%|███████████▋ | 11295/33790 [00:02<00:05, 4141.22it/s]\u001b[A\n",
" 35%|████████████▏ | 11725/33790 [00:02<00:05, 4186.44it/s]\u001b[A\n",
" 36%|████████████▌ | 12144/33790 [00:02<00:05, 4178.55it/s]\u001b[A\n",
" 37%|█████████████ | 12562/33790 [00:03<00:05, 4167.60it/s]\u001b[A\n",
" 38%|█████████████▍ | 12979/33790 [00:03<00:05, 4158.30it/s]\u001b[A\n",
" 40%|█████████████▉ | 13399/33790 [00:03<00:04, 4168.21it/s]\u001b[A\n",
" 41%|██████████████▎ | 13816/33790 [00:03<00:04, 4164.57it/s]\u001b[A\n",
" 42%|██████████████▋ | 14233/33790 [00:03<00:04, 4162.97it/s]\u001b[A\n",
" 43%|███████████████▏ | 14650/33790 [00:03<00:04, 4154.27it/s]\u001b[A\n",
" 45%|███████████████▌ | 15069/33790 [00:03<00:04, 4164.54it/s]\u001b[A\n",
" 46%|████████████████ | 15486/33790 [00:03<00:04, 4164.95it/s]\u001b[A\n",
" 47%|████████████████▍ | 15903/33790 [00:03<00:04, 4151.58it/s]\u001b[A\n",
" 48%|████████████████▉ | 16319/33790 [00:03<00:04, 4137.88it/s]\u001b[A\n",
" 50%|█████████████████▎ | 16744/33790 [00:04<00:04, 4168.93it/s]\u001b[A\n",
" 51%|█████████████████▊ | 17172/33790 [00:04<00:03, 4201.18it/s]\u001b[A\n",
" 52%|██████████████████▏ | 17593/33790 [00:04<00:03, 4188.07it/s]\u001b[A\n",
" 53%|██████████████████▋ | 18017/33790 [00:04<00:03, 4201.49it/s]\u001b[A\n",
" 55%|███████████████████ | 18438/33790 [00:04<00:03, 4161.22it/s]\u001b[A\n",
" 56%|███████████████████▌ | 18855/33790 [00:04<00:03, 4159.03it/s]\u001b[A\n",
" 57%|███████████████████▉ | 19271/33790 [00:04<00:03, 4140.24it/s]\u001b[A\n",
" 58%|████████████████████▍ | 19691/33790 [00:04<00:03, 4156.64it/s]\u001b[A\n",
" 60%|████████████████████▊ | 20107/33790 [00:04<00:03, 4135.66it/s]\u001b[A\n",
" 61%|█████████████████████▎ | 20526/33790 [00:04<00:03, 4149.47it/s]\u001b[A\n",
" 62%|█████████████████████▋ | 20943/33790 [00:05<00:03, 4154.18it/s]\u001b[A\n",
" 63%|██████████████████████▏ | 21366/33790 [00:05<00:02, 4173.67it/s]\u001b[A\n",
" 64%|██████████████████████▌ | 21788/33790 [00:05<00:02, 4185.03it/s]\u001b[A\n",
" 66%|███████████████████████ | 22211/33790 [00:05<00:02, 4197.80it/s]\u001b[A\n",
" 67%|███████████████████████▍ | 22631/33790 [00:05<00:02, 4178.59it/s]\u001b[A\n",
" 68%|███████████████████████▊ | 23049/33790 [00:05<00:02, 4176.12it/s]\u001b[A\n",
" 69%|████████████████████████▎ | 23467/33790 [00:05<00:02, 4169.36it/s]\u001b[A\n",
" 71%|████████████████████████▋ | 23884/33790 [00:05<00:02, 4131.98it/s]\u001b[A\n",
" 72%|█████████████████████████▏ | 24305/33790 [00:05<00:02, 4152.65it/s]\u001b[A\n",
" 73%|█████████████████████████▌ | 24726/33790 [00:05<00:02, 4167.09it/s]\u001b[A\n",
" 74%|██████████████████████████ | 25143/33790 [00:06<00:02, 4154.75it/s]\u001b[A\n",
" 76%|██████████████████████████▍ | 25561/33790 [00:06<00:01, 4160.52it/s]\u001b[A\n",
" 77%|██████████████████████████▉ | 25978/33790 [00:06<00:01, 4143.77it/s]\u001b[A\n",
" 78%|███████████████████████████▎ | 26393/33790 [00:06<00:01, 4141.87it/s]\u001b[A\n",
" 79%|███████████████████████████▊ | 26813/33790 [00:06<00:01, 4157.12it/s]\u001b[A\n",
" 81%|████████████████████████████▏ | 27231/33790 [00:06<00:01, 4163.25it/s]\u001b[A\n",
" 82%|████████████████████████████▋ | 27648/33790 [00:06<00:01, 4162.48it/s]\u001b[A\n",
" 83%|█████████████████████████████ | 28065/33790 [00:06<00:01, 4156.25it/s]\u001b[A\n",
" 84%|█████████████████████████████▌ | 28481/33790 [00:06<00:01, 4152.43it/s]\u001b[A\n",
" 86%|█████████████████████████████▉ | 28898/33790 [00:06<00:01, 4156.01it/s]\u001b[A\n",
" 87%|██████████████████████████████▎ | 29323/33790 [00:07<00:01, 4181.91it/s]\u001b[A\n",
" 88%|██████████████████████████████▊ | 29742/33790 [00:07<00:00, 4174.55it/s]\u001b[A\n",
" 89%|███████████████████████████████▏ | 30160/33790 [00:07<00:00, 4161.90it/s]\u001b[A\n",
" 90%|███████████████████████████████▋ | 30577/33790 [00:07<00:00, 4158.03it/s]\u001b[A\n",
" 92%|████████████████████████████████ | 30993/33790 [00:07<00:00, 4128.54it/s]\u001b[A\n",
" 93%|████████████████████████████████▌ | 31406/33790 [00:07<00:00, 4088.34it/s]\u001b[A\n",
" 94%|████████████████████████████████▉ | 31815/33790 [00:07<00:00, 4075.49it/s]\u001b[A\n",
" 95%|█████████████████████████████████▍ | 32229/33790 [00:07<00:00, 4093.59it/s]\u001b[A\n",
" 97%|█████████████████████████████████▊ | 32650/33790 [00:07<00:00, 4127.49it/s]\u001b[A\n",
" 98%|██████████████████████████████████▎| 33066/33790 [00:07<00:00, 4135.18it/s]\u001b[A\n",
"100%|███████████████████████████████████| 33790/33790 [00:08<00:00, 4158.60it/s]\u001b[A\n",
"1it [00:08, 8.15s/it]\n",
"No errors! Data is correctly formatted!\n",
"Stats for /workspace/working/data/NL_to_SPARQL_train.jsonl \n",
" -------------------- \n",
" {\n",
" \"expected\": {\n",
" \"eta\": \"00:06:47\",\n",
" \"data_tokens\": 7990855,\n",
" \"train_tokens\": 4915200,\n",
" \"epochs\": \"0.62\",\n",
" \"max_steps\": 300,\n",
" \"data_tokens_per_dataset\": {\n",
" \"/workspace/working/data/NL_to_SPARQL_train.jsonl\": \"7990855.0\"\n",
" },\n",
" \"train_tokens_per_dataset\": {\n",
" \"/workspace/working/data/NL_to_SPARQL_train.jsonl\": \"4915200.0\"\n",
" },\n",
" \"epochs_per_dataset\": {\n",
" \"/workspace/working/data/NL_to_SPARQL_train.jsonl\": \"0.6\"\n",
" }\n",
" }\n",
"}\n",
"0it [00:00, ?it/s]Validating /workspace/working/data/NL_to_SPARQL_eval.jsonl ...\n",
"\n",
" 0%| | 0/8448 [00:00, ?it/s]\u001b[A\n",
" 8%|██▉ | 644/8448 [00:00<00:01, 6439.32it/s]\u001b[A\n",
" 15%|█████▋ | 1288/8448 [00:00<00:01, 5663.12it/s]\u001b[A\n",
" 22%|████████▏ | 1861/8448 [00:00<00:01, 4764.56it/s]\u001b[A\n",
" 28%|██████████▎ | 2352/8448 [00:00<00:01, 4449.87it/s]\u001b[A\n",
" 33%|████████████▎ | 2805/8448 [00:00<00:01, 4281.36it/s]\u001b[A\n",
" 38%|██████████████▏ | 3238/8448 [00:00<00:01, 4183.43it/s]\u001b[A\n",
" 43%|████████████████ | 3659/8448 [00:00<00:01, 4112.39it/s]\u001b[A\n",
" 48%|█████████████████▊ | 4071/8448 [00:00<00:01, 4063.24it/s]\u001b[A\n",
" 53%|███████████████████▌ | 4478/8448 [00:01<00:00, 4028.90it/s]\u001b[A\n",
" 58%|█████████████████████▍ | 4881/8448 [00:01<00:00, 4002.25it/s]\u001b[A\n",
" 63%|███████████████████████▏ | 5282/8448 [00:01<00:00, 3987.83it/s]\u001b[A\n",
" 67%|████████████████████████▉ | 5681/8448 [00:01<00:00, 3955.84it/s]\u001b[A\n",
" 72%|██████████████████████████▋ | 6082/8448 [00:01<00:00, 3968.80it/s]\u001b[A\n",
" 77%|████████████████████████████▍ | 6486/8448 [00:01<00:00, 3987.55it/s]\u001b[A\n",
" 81%|██████████████████████████████▏ | 6885/8448 [00:01<00:00, 3976.96it/s]\u001b[A\n",
" 86%|███████████████████████████████▉ | 7283/8448 [00:01<00:00, 3972.89it/s]\u001b[A\n",
" 91%|█████████████████████████████████▋ | 7681/8448 [00:01<00:00, 3973.57it/s]\u001b[A\n",
"100%|█████████████████████████████████████| 8448/8448 [00:02<00:00, 4153.81it/s]\u001b[A\n",
"1it [00:02, 2.04s/it]\n",
"No errors! Data is correctly formatted!\n"
]
}
],
"source": [
"!python -m utils.validate_data --train_yaml example/7B.yaml"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"trusted": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"batch_size: 1\n",
"ckpt_freq: 100\n",
"data:\n",
" data: ''\n",
" eval_instruct_data: /workspace/working/data/NL_to_SPARQL_eval.jsonl\n",
" instruct_data: /workspace/working/data/NL_to_SPARQL_train.jsonl\n",
"eval_freq: 100\n",
"log_freq: 10\n",
"lora:\n",
" rank: 16\n",
" scaling: 16\n",
"max_steps: 300\n",
"model_id_or_path: /workspace/working/mistral_models\n",
"no_eval: false\n",
"num_microbatches: 8\n",
"optim:\n",
" lr: 1.0e-05\n",
" pct_start: 0.05\n",
" weight_decay: 0.1\n",
"run_dir: outputs\n",
"save_adapters: true\n",
"seed: 0\n",
"seq_len: 2048\n"
]
}
],
"source": [
"! cat example/7B.yaml"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"trusted": true
},
"outputs": [],
"source": [
"!python -m utils.reformat_data /workspace/working/data/NL_to_SPARQL_train.jsonl"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"trusted": true
},
"outputs": [],
"source": [
"!python -m utils.reformat_data /workspace/working/data/NL_to_SPARQL_eval.jsonl"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"trusted": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of GPUs Available: 1\n",
"GPU 0: NVIDIA A800 80GB PCIe\n",
"NVIDIA A800 80GB PCIe\n",
"True\n",
"True\n",
"Thu Mar 27 23:57:52 2025 \n",
"+-----------------------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 550.90.07 Driver Version: 550.90.07 CUDA Version: 12.4 |\n",
"|-----------------------------------------+------------------------+----------------------+\n",
"| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|=========================================+========================+======================|\n",
"| 0 NVIDIA A800 80GB PCIe On | 00000000:61:00.0 Off | 0 |\n",
"| N/A 34C P0 42W / 300W | 4MiB / 81920MiB | 0% Default |\n",
"| | | Disabled |\n",
"+-----------------------------------------+------------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=========================================================================================|\n",
"| No running processes found |\n",
"+-----------------------------------------------------------------------------------------+\n"
]
}
],
"source": [
"import torch\n",
"print(\"Number of GPUs Available:\", torch.cuda.device_count())\n",
"for i in range(torch.cuda.device_count()):\n",
" print(f\"GPU {i}: {torch.cuda.get_device_name(i)}\")\n",
"\n",
"import torch\n",
"print(torch.cuda.get_device_name()) # Should show your GPU (e.g., Tesla T4)\n",
"print(torch.cuda.is_bf16_supported()) # Should be False (T4 doesn't support BF16)\n",
"print(torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction) # Should be True (T4 supports FP16)\n",
"\n",
"!nvidia-smi"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"trusted": true
},
"outputs": [],
"source": [
"# these info is needed for training\n",
"import os\n",
"os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n",
"\n",
"torch.cuda.empty_cache()"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"trusted": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"import dataclasses\n",
"import logging\n",
"import os\n",
"import pprint\n",
"from contextlib import ExitStack\n",
"from pathlib import Path\n",
"from typing import TYPE_CHECKING\n",
"\n",
"import fire\n",
"import torch.cuda\n",
"import torch.distributed as dist\n",
"from mistral_common.tokens.tokenizers.mistral import MistralTokenizer\n",
"from torch.optim import AdamW, lr_scheduler\n",
"\n",
"from finetune.args import TrainArgs\n",
"from finetune.checkpointing import Checkpointer\n",
"from finetune.data.data_loader import build_data_loader\n",
"from finetune.distributed import (\n",
" BACKEND,\n",
" avg_aggregate,\n",
" get_rank,\n",
" get_world_size,\n",
" is_torchrun,\n",
" set_device,\n",
")\n",
"from finetune.eval import evaluate\n",
"from finetune.loss import compute_loss_with_mask\n",
"from finetune.mixed_precision import (\n",
" downcast_mixed_precision,\n",
" prepare_mixed_precision,\n",
" upcast_mixed_precision,\n",
")\n",
"from finetune.monitoring.metrics_logger import (\n",
" MetricsLogger,\n",
" eval_log_msg,\n",
" get_eval_logs,\n",
" get_train_logs,\n",
" train_log_msg,\n",
")\n",
"from finetune.monitoring.utils import set_logger\n",
"from finetune.utils import (\n",
" TrainState,\n",
" logged_closing,\n",
" set_random_seed,\n",
")\n",
"from finetune.wrapped_model import load_model, load_args\n",
"\n",
"if TYPE_CHECKING:\n",
" from mistral_common.tokens.tokenizers.sentencepiece import InstructTokenizerBase\n",
"\n",
"logger = logging.getLogger(\"train\")\n",
"\n",
"\n",
"def main_logger_info(message: str) -> None:\n",
" if get_rank() == 0:\n",
" logger.info(message)\n",
"\n",
"\n",
"def train(config: str):\n",
" args: TrainArgs = TrainArgs.load(config, drop_extra_fields=False)\n",
" print(f\"args: {args}\")\n",
" set_logger(logging.INFO)\n",
"\n",
" with ExitStack() as exit_stack:\n",
" _train(args, exit_stack)\n",
" logger.info(\"Closed everything!\")\n",
"\n",
"\n",
"def _train(\n",
" args: TrainArgs,\n",
" exit_stack: ExitStack,\n",
"):\n",
" # 1. Initial setup and checks\n",
" set_random_seed(args.seed)\n",
"\n",
" # Init NCCL\n",
" if \"LOCAL_RANK\" in os.environ:\n",
" set_device()\n",
" logger.info(\"Going to init comms...\")\n",
"\n",
" dist.init_process_group(backend=BACKEND)\n",
" else:\n",
" logger.error(\n",
" \"PyTorch environment is not correctly initialized. This message should only be displayed when testing.\"\n",
" )\n",
"\n",
" # 2. Init run dir\n",
" main_logger_info(f\"Run dir: {args.run_dir}\")\n",
" run_dir = Path(args.run_dir)\n",
"\n",
" if is_torchrun():\n",
" if run_dir.exists():\n",
" raise RuntimeError(\n",
" f\"Run dir {run_dir} already exists. Make sure to either rename `run_dir` or remove {run_dir}.\"\n",
" )\n",
"\n",
" dist.barrier()\n",
" run_dir.mkdir(exist_ok=True, parents=True)\n",
"\n",
" args_path = run_dir / \"args.yaml\"\n",
" if not args_path.exists():\n",
" args.save(args_path)\n",
"\n",
" main_logger_info(f\"TrainArgs: {pprint.pformat(dataclasses.asdict(args))}\")\n",
"\n",
" # 3. Get loggers\n",
" metrics_logger: MetricsLogger = MetricsLogger(\n",
" run_dir,\n",
" tag=\"train\",\n",
" is_master=get_rank() == 0,\n",
" wandb_args=args.wandb,\n",
" mlflow_args=args.mlflow,\n",
" config=dataclasses.asdict(args),\n",
" )\n",
" exit_stack.enter_context(logged_closing(metrics_logger, \"metrics_logger\"))\n",
"\n",
" eval_logger: MetricsLogger = MetricsLogger(\n",
" run_dir,\n",
" tag=\"eval\",\n",
" is_master=get_rank() == 0,\n",
" wandb_args=args.wandb,\n",
" mlflow_args=args.mlflow,\n",
" config=dataclasses.asdict(args),\n",
" )\n",
" exit_stack.enter_context(logged_closing(eval_logger, \"eval_logger\"))\n",
"\n",
" # 5. Potentially download model\n",
" if Path(args.model_id_or_path).is_dir():\n",
" model_folder = Path(args.model_id_or_path)\n",
" else:\n",
" raise ValueError(\n",
" \"Invalid folder path. Please set `args.initial_model` to a valid folder path.\"\n",
" )\n",
"\n",
" # 6. Load function calling instruct tokenizer\n",
" vocab_size = load_args(model_folder, args.lora).vocab_size\n",
" is_tekken = vocab_size > 32768\n",
"\n",
" instruct_tokenizer: InstructTokenizerBase = MistralTokenizer.v3(\n",
" is_tekken=is_tekken\n",
" ).instruct_tokenizer # type: ignore\n",
"\n",
" # 7. Load data loaders\n",
" data_loader = build_data_loader(\n",
" instruct_tokenizer=instruct_tokenizer,\n",
" args=args.data,\n",
" seq_len=args.seq_len,\n",
" batch_size=args.batch_size,\n",
" seed=args.seed,\n",
" rank=get_rank(), # DDP rank\n",
" world_size=get_world_size(), # DDP world_size\n",
" is_eval=False,\n",
" )\n",
"\n",
" if not args.no_eval:\n",
" assert (\n",
" args.data.eval_instruct_data != \"\"\n",
" ), \"Either set `no_eval` to True or provide evaluation samples under `data.eval_instruct_data`\"\n",
"\n",
" eval_data_loader = build_data_loader(\n",
" instruct_tokenizer=instruct_tokenizer,\n",
" args=args.data,\n",
" seq_len=args.seq_len,\n",
" batch_size=args.batch_size,\n",
" seed=None,\n",
" rank=get_rank(), # DDP rank\n",
" world_size=get_world_size(), # DDP world_size\n",
" is_eval=True,\n",
" )\n",
" # pre-load all eval tokens\n",
" eval_batches = list(eval_data_loader)\n",
"\n",
" # 8. Load model\n",
" # Define mixed precision\n",
" param_dtype = torch.bfloat16\n",
" optim_dtype = torch.float32\n",
"\n",
" assert args.lora is not None, \"`args.lora` should be set to a valid value.\"\n",
"\n",
" model = load_model(\n",
" folder=model_folder,\n",
" lora=args.lora,\n",
" checkpoint=args.checkpoint,\n",
" param_dtype=param_dtype,\n",
" )\n",
"\n",
" # 9. Load optimizer\n",
" optimizer = AdamW(\n",
" model.parameters(),\n",
" lr=args.optim.lr,\n",
" betas=(0.9, 0.95),\n",
" eps=1e-08,\n",
" weight_decay=args.optim.weight_decay,\n",
" )\n",
"\n",
" scheduler = lr_scheduler.OneCycleLR(\n",
" optimizer,\n",
" max_lr=args.optim.lr,\n",
" total_steps=args.max_steps,\n",
" pct_start=args.optim.pct_start,\n",
" )\n",
"\n",
" state = TrainState(args.max_steps)\n",
"\n",
" # 10. Initialize checkpointer\n",
" checkpointer = Checkpointer(\n",
" model=model,\n",
" state=state,\n",
" run_dir=run_dir,\n",
" optimizer=optimizer,\n",
" num_ckpt_keep=args.num_ckpt_keep,\n",
" )\n",
" # 11. Prepare mixed precision\n",
" prepare_mixed_precision(\n",
" model.parameters(), param_dtype=param_dtype, optim_dtype=optim_dtype\n",
" )\n",
"\n",
" # 12. train!\n",
" model.train()\n",
" torch.cuda.empty_cache()\n",
"\n",
" while state.step < args.max_steps:\n",
" state.start_step()\n",
" is_last_step = state.step == args.max_steps\n",
"\n",
" optimizer.zero_grad()\n",
"\n",
" loss = torch.tensor([0.0], device=\"cuda\")\n",
" n_batch_tokens: int = 0\n",
"\n",
" for i in range(args.num_microbatches):\n",
" # batch\n",
" batch = next(data_loader)\n",
"\n",
" x = torch.from_numpy(batch.x).cuda(non_blocking=True)\n",
" y = torch.from_numpy(batch.y).cuda(non_blocking=True)\n",
" y_mask = (\n",
" torch.from_numpy(batch.y_mask).cuda(non_blocking=True)\n",
" if batch.y_mask is not None\n",
" else None\n",
" )\n",
"\n",
" # forward / backward\n",
" output = model(\n",
" input_ids=x,\n",
" seqlens=batch.sizes,\n",
" )\n",
" mb_loss = compute_loss_with_mask(output, y, y_mask)\n",
"\n",
" mb_loss.backward()\n",
"\n",
" loss += mb_loss.detach()\n",
" n_batch_tokens += x.numel()\n",
"\n",
" if i < args.num_microbatches - 1:\n",
" # synchronize CUDA to re-run backward\n",
" assert args.num_microbatches > 1 # should not happen\n",
" torch.cuda.synchronize()\n",
"\n",
" if args.num_microbatches > 1:\n",
" loss /= args.num_microbatches\n",
" for p in model.parameters():\n",
" if p.requires_grad:\n",
" assert p.grad is not None\n",
" p.grad.div_(args.num_microbatches)\n",
"\n",
" # upcast params for optimizer update\n",
" upcast_mixed_precision(model.parameters(), optim_dtype=optim_dtype)\n",
"\n",
" # clip grad norm\n",
" model.clip_grad_norm_(max_norm=args.max_norm)\n",
"\n",
" # optimizer step\n",
" optimizer.step()\n",
"\n",
" # downcast params for forward & backward\n",
" downcast_mixed_precision(model.parameters(), param_dtype=param_dtype)\n",
"\n",
" last_lr = scheduler.get_last_lr()[0]\n",
" scheduler.step()\n",
"\n",
" # Host sync\n",
" loss_item = loss.item()\n",
" avg_loss = avg_aggregate(loss_item)\n",
"\n",
" if not args.no_eval and (\n",
" (args.eval_freq > 0 and state.step % args.eval_freq == 0) or is_last_step\n",
" ):\n",
" # write perplexity to state\n",
" evaluate(model, eval_batches, state)\n",
"\n",
" eval_logs = get_eval_logs(\n",
" state.step, avg_loss, state.this_eval_perplexity, state.this_eval_loss\n",
" )\n",
"\n",
" main_logger_info(eval_log_msg(eval_logs))\n",
" eval_logger.log(eval_logs, step=state.step)\n",
"\n",
" # Timing\n",
" state.end_step(n_batch_tokens)\n",
"\n",
" if state.step % args.log_freq == 0:\n",
" train_logs = get_train_logs(\n",
" state,\n",
" avg_loss,\n",
" last_lr,\n",
" torch.cuda.max_memory_allocated(),\n",
" torch.cuda.memory_allocated(),\n",
" args,\n",
" )\n",
" main_logger_info(train_log_msg(state, logs=train_logs, loss=avg_loss))\n",
" metrics_logger.log(train_logs, step=state.step)\n",
"\n",
" if not args.no_ckpt and (\n",
" (args.ckpt_freq > 0 and state.step % args.ckpt_freq == 0) or is_last_step\n",
" ):\n",
" checkpointer.save_checkpoint(\n",
" save_only_lora=args.save_adapters,\n",
" dtype=param_dtype,\n",
" instruct_tokenizer=instruct_tokenizer,\n",
" )\n",
"\n",
" main_logger_info(\"done!\")\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
" \"\"\"See README.md for usage.\"\"\"\n",
" fire.Fire(train)\n"
]
}
],
"source": [
"! cat train.py"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"trusted": true
},
"outputs": [],
"source": [
"# with open(\"train.py\", \"r\") as f:\n",
"# content = f.read()\n",
"\n",
"# # Modify content (e.g., replace dtype settings)\n",
"# content = content.replace(\"torch.bfloat16\", \"torch.float16\")\n",
"\n",
"# # Write back\n",
"# with open(\"train.py\", \"w\") as f:\n",
"# f.write(content)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"trusted": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"import dataclasses\n",
"import logging\n",
"import os\n",
"import pprint\n",
"from contextlib import ExitStack\n",
"from pathlib import Path\n",
"from typing import TYPE_CHECKING\n",
"\n",
"import fire\n",
"import torch.cuda\n",
"import torch.distributed as dist\n",
"from mistral_common.tokens.tokenizers.mistral import MistralTokenizer\n",
"from torch.optim import AdamW, lr_scheduler\n",
"\n",
"from finetune.args import TrainArgs\n",
"from finetune.checkpointing import Checkpointer\n",
"from finetune.data.data_loader import build_data_loader\n",
"from finetune.distributed import (\n",
" BACKEND,\n",
" avg_aggregate,\n",
" get_rank,\n",
" get_world_size,\n",
" is_torchrun,\n",
" set_device,\n",
")\n",
"from finetune.eval import evaluate\n",
"from finetune.loss import compute_loss_with_mask\n",
"from finetune.mixed_precision import (\n",
" downcast_mixed_precision,\n",
" prepare_mixed_precision,\n",
" upcast_mixed_precision,\n",
")\n",
"from finetune.monitoring.metrics_logger import (\n",
" MetricsLogger,\n",
" eval_log_msg,\n",
" get_eval_logs,\n",
" get_train_logs,\n",
" train_log_msg,\n",
")\n",
"from finetune.monitoring.utils import set_logger\n",
"from finetune.utils import (\n",
" TrainState,\n",
" logged_closing,\n",
" set_random_seed,\n",
")\n",
"from finetune.wrapped_model import load_model, load_args\n",
"\n",
"if TYPE_CHECKING:\n",
" from mistral_common.tokens.tokenizers.sentencepiece import InstructTokenizerBase\n",
"\n",
"logger = logging.getLogger(\"train\")\n",
"\n",
"\n",
"def main_logger_info(message: str) -> None:\n",
" if get_rank() == 0:\n",
" logger.info(message)\n",
"\n",
"\n",
"def train(config: str):\n",
" args: TrainArgs = TrainArgs.load(config, drop_extra_fields=False)\n",
" print(f\"args: {args}\")\n",
" set_logger(logging.INFO)\n",
"\n",
" with ExitStack() as exit_stack:\n",
" _train(args, exit_stack)\n",
" logger.info(\"Closed everything!\")\n",
"\n",
"\n",
"def _train(\n",
" args: TrainArgs,\n",
" exit_stack: ExitStack,\n",
"):\n",
" # 1. Initial setup and checks\n",
" set_random_seed(args.seed)\n",
"\n",
" # Init NCCL\n",
" if \"LOCAL_RANK\" in os.environ:\n",
" set_device()\n",
" logger.info(\"Going to init comms...\")\n",
"\n",
" dist.init_process_group(backend=BACKEND)\n",
" else:\n",
" logger.error(\n",
" \"PyTorch environment is not correctly initialized. This message should only be displayed when testing.\"\n",
" )\n",
"\n",
" # 2. Init run dir\n",
" main_logger_info(f\"Run dir: {args.run_dir}\")\n",
" run_dir = Path(args.run_dir)\n",
"\n",
" if is_torchrun():\n",
" if run_dir.exists():\n",
" raise RuntimeError(\n",
" f\"Run dir {run_dir} already exists. Make sure to either rename `run_dir` or remove {run_dir}.\"\n",
" )\n",
"\n",
" dist.barrier()\n",
" run_dir.mkdir(exist_ok=True, parents=True)\n",
"\n",
" args_path = run_dir / \"args.yaml\"\n",
" if not args_path.exists():\n",
" args.save(args_path)\n",
"\n",
" main_logger_info(f\"TrainArgs: {pprint.pformat(dataclasses.asdict(args))}\")\n",
"\n",
" # 3. Get loggers\n",
" metrics_logger: MetricsLogger = MetricsLogger(\n",
" run_dir,\n",
" tag=\"train\",\n",
" is_master=get_rank() == 0,\n",
" wandb_args=args.wandb,\n",
" mlflow_args=args.mlflow,\n",
" config=dataclasses.asdict(args),\n",
" )\n",
" exit_stack.enter_context(logged_closing(metrics_logger, \"metrics_logger\"))\n",
"\n",
" eval_logger: MetricsLogger = MetricsLogger(\n",
" run_dir,\n",
" tag=\"eval\",\n",
" is_master=get_rank() == 0,\n",
" wandb_args=args.wandb,\n",
" mlflow_args=args.mlflow,\n",
" config=dataclasses.asdict(args),\n",
" )\n",
" exit_stack.enter_context(logged_closing(eval_logger, \"eval_logger\"))\n",
"\n",
" # 5. Potentially download model\n",
" if Path(args.model_id_or_path).is_dir():\n",
" model_folder = Path(args.model_id_or_path)\n",
" else:\n",
" raise ValueError(\n",
" \"Invalid folder path. Please set `args.initial_model` to a valid folder path.\"\n",
" )\n",
"\n",
" # 6. Load function calling instruct tokenizer\n",
" vocab_size = load_args(model_folder, args.lora).vocab_size\n",
" is_tekken = vocab_size > 32768\n",
"\n",
" instruct_tokenizer: InstructTokenizerBase = MistralTokenizer.v3(\n",
" is_tekken=is_tekken\n",
" ).instruct_tokenizer # type: ignore\n",
"\n",
" # 7. Load data loaders\n",
" data_loader = build_data_loader(\n",
" instruct_tokenizer=instruct_tokenizer,\n",
" args=args.data,\n",
" seq_len=args.seq_len,\n",
" batch_size=args.batch_size,\n",
" seed=args.seed,\n",
" rank=get_rank(), # DDP rank\n",
" world_size=get_world_size(), # DDP world_size\n",
" is_eval=False,\n",
" )\n",
"\n",
" if not args.no_eval:\n",
" assert (\n",
" args.data.eval_instruct_data != \"\"\n",
" ), \"Either set `no_eval` to True or provide evaluation samples under `data.eval_instruct_data`\"\n",
"\n",
" eval_data_loader = build_data_loader(\n",
" instruct_tokenizer=instruct_tokenizer,\n",
" args=args.data,\n",
" seq_len=args.seq_len,\n",
" batch_size=args.batch_size,\n",
" seed=None,\n",
" rank=get_rank(), # DDP rank\n",
" world_size=get_world_size(), # DDP world_size\n",
" is_eval=True,\n",
" )\n",
" # pre-load all eval tokens\n",
" eval_batches = list(eval_data_loader)\n",
"\n",
" # 8. Load model\n",
" # Define mixed precision\n",
" param_dtype = torch.bfloat16\n",
" optim_dtype = torch.float32\n",
"\n",
" assert args.lora is not None, \"`args.lora` should be set to a valid value.\"\n",
"\n",
" model = load_model(\n",
" folder=model_folder,\n",
" lora=args.lora,\n",
" checkpoint=args.checkpoint,\n",
" param_dtype=param_dtype,\n",
" )\n",
"\n",
" # 9. Load optimizer\n",
" optimizer = AdamW(\n",
" model.parameters(),\n",
" lr=args.optim.lr,\n",
" betas=(0.9, 0.95),\n",
" eps=1e-08,\n",
" weight_decay=args.optim.weight_decay,\n",
" )\n",
"\n",
" scheduler = lr_scheduler.OneCycleLR(\n",
" optimizer,\n",
" max_lr=args.optim.lr,\n",
" total_steps=args.max_steps,\n",
" pct_start=args.optim.pct_start,\n",
" )\n",
"\n",
" state = TrainState(args.max_steps)\n",
"\n",
" # 10. Initialize checkpointer\n",
" checkpointer = Checkpointer(\n",
" model=model,\n",
" state=state,\n",
" run_dir=run_dir,\n",
" optimizer=optimizer,\n",
" num_ckpt_keep=args.num_ckpt_keep,\n",
" )\n",
" # 11. Prepare mixed precision\n",
" prepare_mixed_precision(\n",
" model.parameters(), param_dtype=param_dtype, optim_dtype=optim_dtype\n",
" )\n",
"\n",
" # 12. train!\n",
" model.train()\n",
" torch.cuda.empty_cache()\n",
"\n",
" while state.step < args.max_steps:\n",
" state.start_step()\n",
" is_last_step = state.step == args.max_steps\n",
"\n",
" optimizer.zero_grad()\n",
"\n",
" loss = torch.tensor([0.0], device=\"cuda\")\n",
" n_batch_tokens: int = 0\n",
"\n",
" for i in range(args.num_microbatches):\n",
" # batch\n",
" batch = next(data_loader)\n",
"\n",
" x = torch.from_numpy(batch.x).cuda(non_blocking=True)\n",
" y = torch.from_numpy(batch.y).cuda(non_blocking=True)\n",
" y_mask = (\n",
" torch.from_numpy(batch.y_mask).cuda(non_blocking=True)\n",
" if batch.y_mask is not None\n",
" else None\n",
" )\n",
"\n",
" # forward / backward\n",
" output = model(\n",
" input_ids=x,\n",
" seqlens=batch.sizes,\n",
" )\n",
" mb_loss = compute_loss_with_mask(output, y, y_mask)\n",
"\n",
" mb_loss.backward()\n",
"\n",
" loss += mb_loss.detach()\n",
" n_batch_tokens += x.numel()\n",
"\n",
" if i < args.num_microbatches - 1:\n",
" # synchronize CUDA to re-run backward\n",
" assert args.num_microbatches > 1 # should not happen\n",
" torch.cuda.synchronize()\n",
"\n",
" if args.num_microbatches > 1:\n",
" loss /= args.num_microbatches\n",
" for p in model.parameters():\n",
" if p.requires_grad:\n",
" assert p.grad is not None\n",
" p.grad.div_(args.num_microbatches)\n",
"\n",
" # upcast params for optimizer update\n",
" upcast_mixed_precision(model.parameters(), optim_dtype=optim_dtype)\n",
"\n",
" # clip grad norm\n",
" model.clip_grad_norm_(max_norm=args.max_norm)\n",
"\n",
" # optimizer step\n",
" optimizer.step()\n",
"\n",
" # downcast params for forward & backward\n",
" downcast_mixed_precision(model.parameters(), param_dtype=param_dtype)\n",
"\n",
" last_lr = scheduler.get_last_lr()[0]\n",
" scheduler.step()\n",
"\n",
" # Host sync\n",
" loss_item = loss.item()\n",
" avg_loss = avg_aggregate(loss_item)\n",
"\n",
" if not args.no_eval and (\n",
" (args.eval_freq > 0 and state.step % args.eval_freq == 0) or is_last_step\n",
" ):\n",
" # write perplexity to state\n",
" evaluate(model, eval_batches, state)\n",
"\n",
" eval_logs = get_eval_logs(\n",
" state.step, avg_loss, state.this_eval_perplexity, state.this_eval_loss\n",
" )\n",
"\n",
" main_logger_info(eval_log_msg(eval_logs))\n",
" eval_logger.log(eval_logs, step=state.step)\n",
"\n",
" # Timing\n",
" state.end_step(n_batch_tokens)\n",
"\n",
" if state.step % args.log_freq == 0:\n",
" train_logs = get_train_logs(\n",
" state,\n",
" avg_loss,\n",
" last_lr,\n",
" torch.cuda.max_memory_allocated(),\n",
" torch.cuda.memory_allocated(),\n",
" args,\n",
" )\n",
" main_logger_info(train_log_msg(state, logs=train_logs, loss=avg_loss))\n",
" metrics_logger.log(train_logs, step=state.step)\n",
"\n",
" if not args.no_ckpt and (\n",
" (args.ckpt_freq > 0 and state.step % args.ckpt_freq == 0) or is_last_step\n",
" ):\n",
" checkpointer.save_checkpoint(\n",
" save_only_lora=args.save_adapters,\n",
" dtype=param_dtype,\n",
" instruct_tokenizer=instruct_tokenizer,\n",
" )\n",
"\n",
" main_logger_info(\"done!\")\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
" \"\"\"See README.md for usage.\"\"\"\n",
" fire.Fire(train)\n"
]
}
],
"source": [
"! cat train.py"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {
"trusted": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/workspace/working/mistral-finetune\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/working/venv/lib/python3.10/site-packages/IPython/core/magics/osm.py:417: UserWarning: This is now an optional IPython functionality, setting dhist requires you to install the `pickleshare` library.\n",
" self.shell.db['dhist'] = compress_dhist(dhist)[-100:]\n"
]
}
],
"source": [
"%cd /workspace/working/mistral-finetune/"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"trusted": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"args: TrainArgs(data=DataArgs(data='', shuffle=False, instruct_data='/workspace/working/data/NL_to_SPARQL_train.jsonl', eval_instruct_data='/workspace/working/data/NL_to_SPARQL_eval.jsonl', instruct=InstructArgs(shuffle=True, dynamic_chunk_fn_call=True)), model_id_or_path='/workspace/working/mistral_models', run_dir='outputs', optim=OptimArgs(lr=1e-05, weight_decay=0.1, pct_start=0.05), seed=0, num_microbatches=8, seq_len=2048, batch_size=1, max_norm=1.0, max_steps=300, log_freq=10, ckpt_freq=100, save_adapters=True, no_ckpt=False, num_ckpt_keep=3, eval_freq=100, no_eval=False, checkpoint=True, world_size=1, wandb=WandbArgs(project=None, offline=False, key=None, run_name=None), mlflow=MLFlowArgs(tracking_uri=None, experiment_name=None), lora=LoraArgs(enable=True, rank=16, dropout=0.0, scaling=16.0))\n",
"2025-03-27 23:59:21 (UTC) - 0:00:02 - distributed - INFO - torch.cuda.device_count: 1\n",
"2025-03-27 23:59:21 (UTC) - 0:00:02 - distributed - INFO - CUDA_VISIBLE_DEVICES: 0\n",
"2025-03-27 23:59:21 (UTC) - 0:00:02 - distributed - INFO - local rank: 0\n",
"2025-03-27 23:59:21 (UTC) - 0:00:02 - train - INFO - Going to init comms...\n",
"2025-03-27 23:59:21 (UTC) - 0:00:02 - train - INFO - Run dir: outputs\n",
"2025-03-27 23:59:21 (UTC) - 0:00:02 - train - INFO - TrainArgs: {'batch_size': 1,\n",
" 'checkpoint': True,\n",
" 'ckpt_freq': 100,\n",
" 'data': {'data': '',\n",
" 'eval_instruct_data': '/workspace/working/data/NL_to_SPARQL_eval.jsonl',\n",
" 'instruct': {'dynamic_chunk_fn_call': True, 'shuffle': True},\n",
" 'instruct_data': '/workspace/working/data/NL_to_SPARQL_train.jsonl',\n",
" 'shuffle': False},\n",
" 'eval_freq': 100,\n",
" 'log_freq': 10,\n",
" 'lora': {'dropout': 0.0, 'enable': True, 'rank': 16, 'scaling': 16.0},\n",
" 'max_norm': 1.0,\n",
" 'max_steps': 300,\n",
" 'mlflow': {'experiment_name': None, 'tracking_uri': None},\n",
" 'model_id_or_path': '/workspace/working/mistral_models',\n",
" 'no_ckpt': False,\n",
" 'no_eval': False,\n",
" 'num_ckpt_keep': 3,\n",
" 'num_microbatches': 8,\n",
" 'optim': {'lr': 1e-05, 'pct_start': 0.05, 'weight_decay': 0.1},\n",
" 'run_dir': 'outputs',\n",
" 'save_adapters': True,\n",
" 'seed': 0,\n",
" 'seq_len': 2048,\n",
" 'wandb': {'key': None, 'offline': False, 'project': None, 'run_name': None},\n",
" 'world_size': 1}\n",
"2025-03-27 23:59:24 (UTC) - 0:00:05 - finetune.wrapped_model - INFO - Reloading model from /workspace/working/mistral_models/consolidated.safetensors ...\n",
"2025-03-27 23:59:24 (UTC) - 0:00:05 - finetune.wrapped_model - INFO - Converting model to dtype torch.bfloat16 ...\n",
"2025-03-27 23:59:24 (UTC) - 0:00:05 - finetune.wrapped_model - INFO - Loaded model on cpu!\n",
"2025-03-27 23:59:24 (UTC) - 0:00:05 - finetune.wrapped_model - INFO - Initializing lora layers ...\n",
"2025-03-27 23:59:24 (UTC) - 0:00:05 - finetune.wrapped_model - INFO - Finished initialization!\n",
"2025-03-27 23:59:24 (UTC) - 0:00:05 - finetune.wrapped_model - INFO - Sharding model over 1 GPUs ...\n",
"2025-03-27 23:59:28 (UTC) - 0:00:09 - finetune.wrapped_model - INFO - Model sharded!\n",
"2025-03-27 23:59:28 (UTC) - 0:00:09 - finetune.wrapped_model - INFO - 41,943,040 out of 7,289,966,592 parameters are finetuned (0.58%).\n",
"2025-03-27 23:59:29 (UTC) - 0:00:10 - dataset - INFO - Loading /workspace/working/data/NL_to_SPARQL_train.jsonl ...\n",
"2025-03-27 23:59:39 (UTC) - 0:00:20 - dataset - INFO - /workspace/working/data/NL_to_SPARQL_train.jsonl loaded and tokenized.\n",
"2025-03-27 23:59:39 (UTC) - 0:00:20 - dataset - INFO - Shuffling /workspace/working/data/NL_to_SPARQL_train.jsonl ...\n",
"2025-03-28 00:00:34 (UTC) - 0:01:15 - train - INFO - step: 000010 - done (%): 3.3 - loss: 0.464 - lr: 7.3e-06 - peak_alloc_mem (GB): 15.7 - alloc_mem (GB): 14.5 - words_per_second: 3065.9 - avg_words_per_second: 2534.8 - ETA: >2025-03-28 00:31:48\n",
"2025-03-28 00:01:28 (UTC) - 0:02:09 - train - INFO - step: 000020 - done (%): 6.7 - loss: 0.180 - lr: 1.0e-05 - peak_alloc_mem (GB): 15.7 - alloc_mem (GB): 14.5 - words_per_second: 2903.7 - avg_words_per_second: 2757.1 - ETA: >2025-03-28 00:29:12\n",
"2025-03-28 00:02:22 (UTC) - 0:03:03 - train - INFO - step: 000030 - done (%): 10.0 - loss: 0.152 - lr: 9.9e-06 - peak_alloc_mem (GB): 15.7 - alloc_mem (GB): 14.5 - words_per_second: 3008.0 - avg_words_per_second: 2836.0 - ETA: >2025-03-28 00:28:22\n",
"2025-03-28 00:03:16 (UTC) - 0:03:57 - train - INFO - step: 000040 - done (%): 13.3 - loss: 0.139 - lr: 9.8e-06 - peak_alloc_mem (GB): 15.7 - alloc_mem (GB): 14.5 - words_per_second: 3025.4 - avg_words_per_second: 2885.0 - ETA: >2025-03-28 00:27:53\n",
"2025-03-28 00:04:11 (UTC) - 0:04:52 - train - INFO - step: 000050 - done (%): 16.7 - loss: 0.146 - lr: 9.6e-06 - peak_alloc_mem (GB): 15.7 - alloc_mem (GB): 14.5 - words_per_second: 2972.8 - avg_words_per_second: 2909.4 - ETA: >2025-03-28 00:27:39\n",
"2025-03-28 00:05:05 (UTC) - 0:05:46 - train - INFO - step: 000060 - done (%): 20.0 - loss: 0.143 - lr: 9.4e-06 - peak_alloc_mem (GB): 15.7 - alloc_mem (GB): 14.5 - words_per_second: 3018.1 - avg_words_per_second: 2925.7 - ETA: >2025-03-28 00:27:29\n",
"2025-03-28 00:05:59 (UTC) - 0:06:40 - train - INFO - step: 000070 - done (%): 23.3 - loss: 0.132 - lr: 9.1e-06 - peak_alloc_mem (GB): 15.7 - alloc_mem (GB): 14.5 - words_per_second: 3035.8 - avg_words_per_second: 2938.7 - ETA: >2025-03-28 00:27:22\n",
"2025-03-28 00:06:54 (UTC) - 0:07:35 - train - INFO - step: 000080 - done (%): 26.7 - loss: 0.132 - lr: 8.8e-06 - peak_alloc_mem (GB): 15.7 - alloc_mem (GB): 14.5 - words_per_second: 3002.2 - avg_words_per_second: 2946.8 - ETA: >2025-03-28 00:27:17\n",
"2025-03-28 00:07:49 (UTC) - 0:08:30 - train - INFO - step: 000090 - done (%): 30.0 - loss: 0.129 - lr: 8.4e-06 - peak_alloc_mem (GB): 15.7 - alloc_mem (GB): 14.5 - words_per_second: 2988.9 - avg_words_per_second: 2952.6 - ETA: >2025-03-28 00:27:14\n",
"2025-03-28 00:08:43 (UTC) - 0:09:24 - eval - INFO - Start eval...\n",
"2025-03-28 00:12:01 (UTC) - 0:12:42 - eval - INFO - Eval finished!\n",
"2025-03-28 00:12:01 (UTC) - 0:12:42 - train - INFO - step: 000100 - eval_perplexity: 1.091 - eval_loss: 0.126 - train_loss: 0.127\n",
"2025-03-28 00:12:01 (UTC) - 0:12:42 - train - INFO - step: 000100 - done (%): 33.3 - loss: 0.127 - lr: 8.0e-06 - peak_alloc_mem (GB): 15.7 - alloc_mem (GB): 14.5 - words_per_second: 80.6 - avg_words_per_second: 2178.9 - ETA: >2025-03-28 00:37:05\n",
"2025-03-28 00:12:01 (UTC) - 0:12:42 - checkpointing - INFO - Dumping checkpoint in outputs/checkpoints/checkpoint_000100/consolidated using tmp name: tmp.consolidated\n",
"2025-03-28 00:12:01 (UTC) - 0:12:42 - checkpointing - INFO - Done dumping checkpoint in outputs/checkpoints/checkpoint_000100/consolidated for step: 100\n",
"2025-03-28 00:12:01 (UTC) - 0:12:42 - checkpointing - INFO - Done deleting checkpoints \n",
"2025-03-28 00:12:01 (UTC) - 0:12:42 - checkpointing - INFO - Done!\n",
"2025-03-28 00:12:56 (UTC) - 0:13:37 - train - INFO - step: 000110 - done (%): 36.7 - loss: 0.128 - lr: 7.5e-06 - peak_alloc_mem (GB): 15.7 - alloc_mem (GB): 14.5 - words_per_second: 3033.9 - avg_words_per_second: 2234.7 - ETA: >2025-03-28 00:36:09\n",
"2025-03-28 00:13:50 (UTC) - 0:14:31 - train - INFO - step: 000120 - done (%): 40.0 - loss: 0.126 - lr: 7.0e-06 - peak_alloc_mem (GB): 15.7 - alloc_mem (GB): 14.5 - words_per_second: 2993.1 - avg_words_per_second: 2283.4 - ETA: >2025-03-28 00:35:22\n",
"2025-03-28 00:14:45 (UTC) - 0:15:26 - train - INFO - step: 000130 - done (%): 43.3 - loss: 0.133 - lr: 6.5e-06 - peak_alloc_mem (GB): 15.7 - alloc_mem (GB): 14.5 - words_per_second: 3025.0 - avg_words_per_second: 2326.4 - ETA: >2025-03-28 00:34:42\n",
"2025-03-28 00:15:40 (UTC) - 0:16:21 - train - INFO - step: 000140 - done (%): 46.7 - loss: 0.127 - lr: 6.0e-06 - peak_alloc_mem (GB): 15.7 - alloc_mem (GB): 14.5 - words_per_second: 3057.7 - avg_words_per_second: 2364.4 - ETA: >2025-03-28 00:34:08\n",
"2025-03-28 00:16:34 (UTC) - 0:17:15 - train - INFO - step: 000150 - done (%): 50.0 - loss: 0.122 - lr: 5.4e-06 - peak_alloc_mem (GB): 15.7 - alloc_mem (GB): 14.5 - words_per_second: 2998.8 - avg_words_per_second: 2398.3 - ETA: >2025-03-28 00:33:39\n",
"2025-03-28 00:17:29 (UTC) - 0:18:10 - train - INFO - step: 000160 - done (%): 53.3 - loss: 0.120 - lr: 4.9e-06 - peak_alloc_mem (GB): 15.7 - alloc_mem (GB): 14.5 - words_per_second: 2955.3 - avg_words_per_second: 2427.6 - ETA: >2025-03-28 00:33:14\n",
"2025-03-28 00:18:24 (UTC) - 0:19:05 - train - INFO - step: 000170 - done (%): 56.7 - loss: 0.117 - lr: 4.3e-06 - peak_alloc_mem (GB): 15.7 - alloc_mem (GB): 14.5 - words_per_second: 2971.4 - avg_words_per_second: 2455.3 - ETA: >2025-03-28 00:32:51\n",
"2025-03-28 00:19:19 (UTC) - 0:20:00 - train - INFO - step: 000180 - done (%): 60.0 - loss: 0.111 - lr: 3.8e-06 - peak_alloc_mem (GB): 15.7 - alloc_mem (GB): 14.5 - words_per_second: 2995.5 - avg_words_per_second: 2479.3 - ETA: >2025-03-28 00:32:32\n",
"2025-03-28 00:20:14 (UTC) - 0:20:54 - train - INFO - step: 000190 - done (%): 63.3 - loss: 0.117 - lr: 3.2e-06 - peak_alloc_mem (GB): 15.7 - alloc_mem (GB): 14.5 - words_per_second: 2991.3 - avg_words_per_second: 2502.3 - ETA: >2025-03-28 00:32:14\n",
"2025-03-28 00:21:08 (UTC) - 0:21:49 - eval - INFO - Start eval...\n",
"2025-03-28 00:24:26 (UTC) - 0:25:07 - eval - INFO - Eval finished!\n",
"2025-03-28 00:24:26 (UTC) - 0:25:07 - train - INFO - step: 000200 - eval_perplexity: 1.082 - eval_loss: 0.113 - train_loss: 0.110\n",
"2025-03-28 00:24:26 (UTC) - 0:25:07 - train - INFO - step: 000200 - done (%): 66.7 - loss: 0.110 - lr: 2.7e-06 - peak_alloc_mem (GB): 15.7 - alloc_mem (GB): 14.5 - words_per_second: 80.4 - avg_words_per_second: 2188.9 - ETA: >2025-03-28 00:36:55\n",
"2025-03-28 00:24:26 (UTC) - 0:25:07 - checkpointing - INFO - Dumping checkpoint in outputs/checkpoints/checkpoint_000200/consolidated using tmp name: tmp.consolidated\n",
"2025-03-28 00:24:27 (UTC) - 0:25:08 - checkpointing - INFO - Done dumping checkpoint in outputs/checkpoints/checkpoint_000200/consolidated for step: 200\n",
"2025-03-28 00:24:27 (UTC) - 0:25:08 - checkpointing - INFO - Done deleting checkpoints \n",
"2025-03-28 00:24:27 (UTC) - 0:25:08 - checkpointing - INFO - Done!\n",
"2025-03-28 00:25:22 (UTC) - 0:26:03 - train - INFO - step: 000210 - done (%): 70.0 - loss: 0.111 - lr: 2.3e-06 - peak_alloc_mem (GB): 15.7 - alloc_mem (GB): 14.5 - words_per_second: 3005.4 - avg_words_per_second: 2217.1 - ETA: >2025-03-28 00:36:27\n",
"2025-03-28 00:26:16 (UTC) - 0:26:57 - train - INFO - step: 000220 - done (%): 73.3 - loss: 0.112 - lr: 1.8e-06 - peak_alloc_mem (GB): 15.7 - alloc_mem (GB): 14.5 - words_per_second: 2962.5 - avg_words_per_second: 2243.7 - ETA: >2025-03-28 00:36:00\n",
"2025-03-28 00:27:11 (UTC) - 0:27:51 - train - INFO - step: 000230 - done (%): 76.7 - loss: 0.112 - lr: 1.4e-06 - peak_alloc_mem (GB): 15.7 - alloc_mem (GB): 14.5 - words_per_second: 3047.9 - avg_words_per_second: 2269.0 - ETA: >2025-03-28 00:35:36\n",
"2025-03-28 00:28:06 (UTC) - 0:28:47 - train - INFO - step: 000240 - done (%): 80.0 - loss: 0.113 - lr: 1.1e-06 - peak_alloc_mem (GB): 15.7 - alloc_mem (GB): 14.5 - words_per_second: 2951.1 - avg_words_per_second: 2291.0 - ETA: >2025-03-28 00:35:15\n",
"2025-03-28 00:29:01 (UTC) - 0:29:42 - train - INFO - step: 000250 - done (%): 83.3 - loss: 0.103 - lr: 7.4e-07 - peak_alloc_mem (GB): 15.7 - alloc_mem (GB): 14.5 - words_per_second: 3009.4 - avg_words_per_second: 2312.8 - ETA: >2025-03-28 00:34:55\n",
"2025-03-28 00:29:55 (UTC) - 0:30:36 - train - INFO - step: 000260 - done (%): 86.7 - loss: 0.119 - lr: 4.8e-07 - peak_alloc_mem (GB): 15.7 - alloc_mem (GB): 14.5 - words_per_second: 2896.6 - avg_words_per_second: 2334.0 - ETA: >2025-03-28 00:34:36\n",
"2025-03-28 00:30:49 (UTC) - 0:31:30 - train - INFO - step: 000270 - done (%): 90.0 - loss: 0.116 - lr: 2.7e-07 - peak_alloc_mem (GB): 15.7 - alloc_mem (GB): 14.5 - words_per_second: 3006.5 - avg_words_per_second: 2353.9 - ETA: >2025-03-28 00:34:18\n",
"2025-03-28 00:31:44 (UTC) - 0:32:25 - train - INFO - step: 000280 - done (%): 93.3 - loss: 0.109 - lr: 1.2e-07 - peak_alloc_mem (GB): 15.7 - alloc_mem (GB): 14.5 - words_per_second: 3010.0 - avg_words_per_second: 2372.2 - ETA: >2025-03-28 00:34:02\n",
"2025-03-28 00:32:38 (UTC) - 0:33:19 - train - INFO - step: 000290 - done (%): 96.7 - loss: 0.101 - lr: 3.0e-08 - peak_alloc_mem (GB): 15.7 - alloc_mem (GB): 14.5 - words_per_second: 3023.6 - avg_words_per_second: 2390.2 - ETA: >2025-03-28 00:33:46\n",
"2025-03-28 00:33:32 (UTC) - 0:34:13 - eval - INFO - Start eval...\n",
"2025-03-28 00:36:50 (UTC) - 0:37:31 - eval - INFO - Eval finished!\n",
"2025-03-28 00:36:50 (UTC) - 0:37:31 - train - INFO - step: 000300 - eval_perplexity: 1.080 - eval_loss: 0.110 - train_loss: 0.114\n",
"2025-03-28 00:36:50 (UTC) - 0:37:31 - train - INFO - step: 000300 - done (%): 100.0 - loss: 0.114 - lr: 4.0e-11 - peak_alloc_mem (GB): 15.7 - alloc_mem (GB): 14.5 - words_per_second: 80.3 - avg_words_per_second: 2193.7 - ETA: >2025-03-28 00:36:50\n",
"2025-03-28 00:36:50 (UTC) - 0:37:31 - checkpointing - INFO - Dumping checkpoint in outputs/checkpoints/checkpoint_000300/consolidated using tmp name: tmp.consolidated\n",
"2025-03-28 00:36:51 (UTC) - 0:37:32 - checkpointing - INFO - Done dumping checkpoint in outputs/checkpoints/checkpoint_000300/consolidated for step: 300\n",
"2025-03-28 00:36:51 (UTC) - 0:37:32 - checkpointing - INFO - Done deleting checkpoints \n",
"2025-03-28 00:36:51 (UTC) - 0:37:32 - checkpointing - INFO - Done!\n",
"2025-03-28 00:36:51 (UTC) - 0:37:32 - train - INFO - done!\n",
"2025-03-28 00:36:51 (UTC) - 0:37:32 - utils - INFO - Closing: eval_logger\n",
"2025-03-28 00:36:51 (UTC) - 0:37:32 - utils - INFO - Closed: eval_logger\n",
"2025-03-28 00:36:51 (UTC) - 0:37:32 - utils - INFO - Closing: metrics_logger\n",
"2025-03-28 00:36:51 (UTC) - 0:37:32 - utils - INFO - Closed: metrics_logger\n",
"2025-03-28 00:36:51 (UTC) - 0:37:32 - train - INFO - Closed everything!\n"
]
}
],
"source": [
"!torchrun --nproc-per-node 1 -m train example/7B.yaml"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"trusted": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"LICENSE \u001b[0m\u001b[01;34mexample\u001b[0m/ pyproject.toml \u001b[01;34mtests\u001b[0m/ \u001b[01;34mutils\u001b[0m/\n",
"README.md \u001b[01;34mfinetune\u001b[0m/ requirements.dev.txt train.py\n",
"\u001b[01;34m__pycache__\u001b[0m/ \u001b[01;34mmodel\u001b[0m/ requirements.txt \u001b[01;34mtutorials\u001b[0m/\n"
]
}
],
"source": [
"%mv /workspace/working/mistral-finetune/outputs/ /workspace/working/\n",
"%ls /workspace/working/mistral-finetune/"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {
"trusted": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/workspace/working\n",
"\u001b[0m\u001b[01;34mdata\u001b[0m/ mistral-finetune-fyp.ipynb \u001b[01;34moutputs\u001b[0m/\n",
"\u001b[01;34mmistral-finetune\u001b[0m/ \u001b[01;34mmistral_models\u001b[0m/ \u001b[01;34mvenv\u001b[0m/\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/working/venv/lib/python3.10/site-packages/IPython/core/magics/osm.py:417: UserWarning: This is now an optional IPython functionality, setting dhist requires you to install the `pickleshare` library.\n",
" self.shell.db['dhist'] = compress_dhist(dhist)[-100:]\n"
]
}
],
"source": [
"%cd /workspace/working/\n",
"%ls /workspace/working/"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {
"trusted": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/workspace/working/outputs\n",
"args.yaml \u001b[0m\u001b[01;34mcheckpoints\u001b[0m/ metrics.eval.jsonl metrics.train.jsonl \u001b[01;34mtb\u001b[0m/\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/working/venv/lib/python3.10/site-packages/IPython/core/magics/osm.py:417: UserWarning: This is now an optional IPython functionality, setting dhist requires you to install the `pickleshare` library.\n",
" self.shell.db['dhist'] = compress_dhist(dhist)[-100:]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"/workspace/working/outputs/checkpoints\n",
"\u001b[0m\u001b[01;34mcheckpoint_000100\u001b[0m/ \u001b[01;34mcheckpoint_000200\u001b[0m/ \u001b[01;34mcheckpoint_000300\u001b[0m/\n",
"/workspace/working/outputs/checkpoints/checkpoint_000300\n",
"\u001b[0m\u001b[01;34mconsolidated\u001b[0m/\n",
"/workspace/working/outputs/checkpoints/checkpoint_000300/consolidated\n",
"lora.safetensors params.json tokenizer.model.v3\n"
]
},
{
"data": {
"text/plain": [
"'/workspace/working/outputs/checkpoints/checkpoint_000300/consolidated'"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%cd /workspace/working/outputs/\n",
"%ls /workspace/working/outputs/\n",
"%cd /workspace/working/outputs/checkpoints/\n",
"%ls /workspace/working/outputs/checkpoints/\n",
"%cd /workspace/working/outputs/checkpoints/checkpoint_000300/\n",
"%ls /workspace/working/outputs/checkpoints/checkpoint_000300/\n",
"%cd /workspace/working/outputs/checkpoints/checkpoint_000300/consolidated/\n",
"%ls /workspace/working/outputs/checkpoints/checkpoint_000300/consolidated/\n",
"%pwd"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {
"trusted": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/workspace/working/mistral-finetune\n"
]
}
],
"source": [
"%cd /workspace/working/mistral-finetune/"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {
"trusted": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting mistral_inference\n",
" Downloading mistral_inference-1.6.0-py3-none-any.whl.metadata (17 kB)\n",
"Requirement already satisfied: fire>=0.6.0 in /workspace/working/venv/lib/python3.10/site-packages (from mistral_inference) (0.7.0)\n",
"Requirement already satisfied: mistral_common>=1.5.4 in /workspace/working/venv/lib/python3.10/site-packages (from mistral_inference) (1.5.4)\n",
"Requirement already satisfied: pillow>=10.3.0 in /workspace/working/venv/lib/python3.10/site-packages (from mistral_inference) (11.1.0)\n",
"Requirement already satisfied: safetensors>=0.4.0 in /workspace/working/venv/lib/python3.10/site-packages (from mistral_inference) (0.5.3)\n",
"Requirement already satisfied: simple-parsing>=0.1.5 in /workspace/working/venv/lib/python3.10/site-packages (from mistral_inference) (0.1.7)\n",
"Requirement already satisfied: xformers>=0.0.24 in /workspace/working/venv/lib/python3.10/site-packages (from mistral_inference) (0.0.24)\n",
"Requirement already satisfied: termcolor in /workspace/working/venv/lib/python3.10/site-packages (from fire>=0.6.0->mistral_inference) (2.5.0)\n",
"Requirement already satisfied: jsonschema>=4.21.1 in /workspace/working/venv/lib/python3.10/site-packages (from mistral_common>=1.5.4->mistral_inference) (4.23.0)\n",
"Requirement already satisfied: numpy>=1.25 in /workspace/working/venv/lib/python3.10/site-packages (from mistral_common>=1.5.4->mistral_inference) (1.26.4)\n",
"Requirement already satisfied: pydantic<3.0,>=2.7 in /workspace/working/venv/lib/python3.10/site-packages (from mistral_common>=1.5.4->mistral_inference) (2.11.0)\n",
"Requirement already satisfied: requests>=2.0.0 in /workspace/working/venv/lib/python3.10/site-packages (from mistral_common>=1.5.4->mistral_inference) (2.32.3)\n",
"Requirement already satisfied: sentencepiece>=0.2.0 in /workspace/working/venv/lib/python3.10/site-packages (from mistral_common>=1.5.4->mistral_inference) (0.2.0)\n",
"Requirement already satisfied: tiktoken>=0.7.0 in /workspace/working/venv/lib/python3.10/site-packages (from mistral_common>=1.5.4->mistral_inference) (0.9.0)\n",
"Requirement already satisfied: typing-extensions>=4.11.0 in /workspace/working/venv/lib/python3.10/site-packages (from mistral_common>=1.5.4->mistral_inference) (4.13.0)\n",
"Requirement already satisfied: docstring-parser<1.0,>=0.15 in /workspace/working/venv/lib/python3.10/site-packages (from simple-parsing>=0.1.5->mistral_inference) (0.16)\n",
"Requirement already satisfied: torch==2.2.0 in /workspace/working/venv/lib/python3.10/site-packages (from xformers>=0.0.24->mistral_inference) (2.2.0)\n",
"Requirement already satisfied: filelock in /workspace/working/venv/lib/python3.10/site-packages (from torch==2.2.0->xformers>=0.0.24->mistral_inference) (3.18.0)\n",
"Requirement already satisfied: sympy in /workspace/working/venv/lib/python3.10/site-packages (from torch==2.2.0->xformers>=0.0.24->mistral_inference) (1.13.3)\n",
"Requirement already satisfied: networkx in /workspace/working/venv/lib/python3.10/site-packages (from torch==2.2.0->xformers>=0.0.24->mistral_inference) (3.4.2)\n",
"Requirement already satisfied: jinja2 in /workspace/working/venv/lib/python3.10/site-packages (from torch==2.2.0->xformers>=0.0.24->mistral_inference) (3.1.6)\n",
"Requirement already satisfied: fsspec in /workspace/working/venv/lib/python3.10/site-packages (from torch==2.2.0->xformers>=0.0.24->mistral_inference) (2025.3.0)\n",
"Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /workspace/working/venv/lib/python3.10/site-packages (from torch==2.2.0->xformers>=0.0.24->mistral_inference) (12.1.105)\n",
"Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /workspace/working/venv/lib/python3.10/site-packages (from torch==2.2.0->xformers>=0.0.24->mistral_inference) (12.1.105)\n",
"Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /workspace/working/venv/lib/python3.10/site-packages (from torch==2.2.0->xformers>=0.0.24->mistral_inference) (12.1.105)\n",
"Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /workspace/working/venv/lib/python3.10/site-packages (from torch==2.2.0->xformers>=0.0.24->mistral_inference) (8.9.2.26)\n",
"Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /workspace/working/venv/lib/python3.10/site-packages (from torch==2.2.0->xformers>=0.0.24->mistral_inference) (12.1.3.1)\n",
"Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /workspace/working/venv/lib/python3.10/site-packages (from torch==2.2.0->xformers>=0.0.24->mistral_inference) (11.0.2.54)\n",
"Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /workspace/working/venv/lib/python3.10/site-packages (from torch==2.2.0->xformers>=0.0.24->mistral_inference) (10.3.2.106)\n",
"Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /workspace/working/venv/lib/python3.10/site-packages (from torch==2.2.0->xformers>=0.0.24->mistral_inference) (11.4.5.107)\n",
"Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /workspace/working/venv/lib/python3.10/site-packages (from torch==2.2.0->xformers>=0.0.24->mistral_inference) (12.1.0.106)\n",
"Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /workspace/working/venv/lib/python3.10/site-packages (from torch==2.2.0->xformers>=0.0.24->mistral_inference) (2.19.3)\n",
"Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /workspace/working/venv/lib/python3.10/site-packages (from torch==2.2.0->xformers>=0.0.24->mistral_inference) (12.1.105)\n",
"Requirement already satisfied: triton==2.2.0 in /workspace/working/venv/lib/python3.10/site-packages (from torch==2.2.0->xformers>=0.0.24->mistral_inference) (2.2.0)\n",
"Requirement already satisfied: nvidia-nvjitlink-cu12 in /workspace/working/venv/lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch==2.2.0->xformers>=0.0.24->mistral_inference) (12.8.93)\n",
"Requirement already satisfied: attrs>=22.2.0 in /workspace/working/venv/lib/python3.10/site-packages (from jsonschema>=4.21.1->mistral_common>=1.5.4->mistral_inference) (25.3.0)\n",
"Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /workspace/working/venv/lib/python3.10/site-packages (from jsonschema>=4.21.1->mistral_common>=1.5.4->mistral_inference) (2024.10.1)\n",
"Requirement already satisfied: referencing>=0.28.4 in /workspace/working/venv/lib/python3.10/site-packages (from jsonschema>=4.21.1->mistral_common>=1.5.4->mistral_inference) (0.36.2)\n",
"Requirement already satisfied: rpds-py>=0.7.1 in /workspace/working/venv/lib/python3.10/site-packages (from jsonschema>=4.21.1->mistral_common>=1.5.4->mistral_inference) (0.24.0)\n",
"Requirement already satisfied: annotated-types>=0.6.0 in /workspace/working/venv/lib/python3.10/site-packages (from pydantic<3.0,>=2.7->mistral_common>=1.5.4->mistral_inference) (0.7.0)\n",
"Requirement already satisfied: pydantic-core==2.33.0 in /workspace/working/venv/lib/python3.10/site-packages (from pydantic<3.0,>=2.7->mistral_common>=1.5.4->mistral_inference) (2.33.0)\n",
"Requirement already satisfied: typing-inspection>=0.4.0 in /workspace/working/venv/lib/python3.10/site-packages (from pydantic<3.0,>=2.7->mistral_common>=1.5.4->mistral_inference) (0.4.0)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /workspace/working/venv/lib/python3.10/site-packages (from requests>=2.0.0->mistral_common>=1.5.4->mistral_inference) (3.4.1)\n",
"Requirement already satisfied: idna<4,>=2.5 in /workspace/working/venv/lib/python3.10/site-packages (from requests>=2.0.0->mistral_common>=1.5.4->mistral_inference) (3.10)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /workspace/working/venv/lib/python3.10/site-packages (from requests>=2.0.0->mistral_common>=1.5.4->mistral_inference) (2.3.0)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /workspace/working/venv/lib/python3.10/site-packages (from requests>=2.0.0->mistral_common>=1.5.4->mistral_inference) (2025.1.31)\n",
"Requirement already satisfied: regex>=2022.1.18 in /workspace/working/venv/lib/python3.10/site-packages (from tiktoken>=0.7.0->mistral_common>=1.5.4->mistral_inference) (2024.11.6)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /workspace/working/venv/lib/python3.10/site-packages (from jinja2->torch==2.2.0->xformers>=0.0.24->mistral_inference) (3.0.2)\n",
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in /workspace/working/venv/lib/python3.10/site-packages (from sympy->torch==2.2.0->xformers>=0.0.24->mistral_inference) (1.3.0)\n",
"Downloading mistral_inference-1.6.0-py3-none-any.whl (32 kB)\n",
"Installing collected packages: mistral_inference\n",
"Successfully installed mistral_inference-1.6.0\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%pip install mistral_inference"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"trusted": true
},
"outputs": [],
"source": [
"# with open(\"infer.py\", \"w\") as file:\n",
"# file.write('''import os\n",
"# import torch\n",
"# import torch.distributed as dist\n",
"# from mistral_inference.transformer import Transformer\n",
"# from mistral_inference.generate import generate\n",
"# from mistral_common.tokens.tokenizers.mistral import MistralTokenizer\n",
"# from mistral_common.protocol.instruct.messages import UserMessage\n",
"# from mistral_common.protocol.instruct.request import ChatCompletionRequest\n",
"\n",
"# def main():\n",
"# try:\n",
"# # Initialize distributed process group\n",
"# dist.init_process_group(backend='nccl')\n",
"# local_rank = int(os.environ['LOCAL_RANK'])\n",
"# torch.cuda.set_device(local_rank)\n",
" \n",
"# # Load tokenizer (each process does this)\n",
"# if dist.get_rank() == 0:\n",
"# print(\"Loading tokenizer...\")\n",
"# tokenizer_path = \"/workspace/working/outputs/checkpoints/checkpoint_000300/consolidated/tokenizer.model.v3\"\n",
"# mistral_tokenizer = MistralTokenizer.from_file(tokenizer_path)\n",
" \n",
"# # Load base model with error handling\n",
"# if dist.get_rank() == 0:\n",
"# print(\"Loading base model...\")\n",
"# model_path = \"/workspace/working/mistral_models\"\n",
"# try:\n",
"# model = Transformer.from_folder(\n",
"# model_path, \n",
"# dtype=torch.float16,\n",
"# device=torch.device(f\"cuda:{local_rank}\"),\n",
"# max_batch_size=3,\n",
"# num_pipeline_ranks=2\n",
"# )\n",
"# except Exception as e:\n",
"# print(f\"[Rank {dist.get_rank()}] Model loading failed: {e}\")\n",
"# raise\n",
" \n",
"# # Load LoRA adapter\n",
"# if dist.get_rank() == 0:\n",
"# print(\"Loading LoRA adapter...\")\n",
"# lora_path = \"/workspace/working/outputs/checkpoints/checkpoint_000300/consolidated/lora.safetensors\"\n",
"# try:\n",
"# model.load_lora(lora_path)\n",
"# except Exception as e:\n",
"# print(f\"[Rank {dist.get_rank()}] LoRA loading failed: {e}\")\n",
"# raise\n",
" \n",
"# if dist.get_rank() == 0:\n",
"# print(f\"Model loaded on cuda:{local_rank}\")\n",
"# print(\"Running inference...\")\n",
"# prompt = \"Who is the president of France?\"\n",
"# messages = [{\"role\": \"user\", \"content\": prompt}]\n",
"# #tokens = mistral_tokenizer.encode_chat_completion(\n",
"# # ChatCompletionRequest(messages=[UserMessage(content=prompt)])\n",
"# #).tokens\n",
"# tokens = mistral_tokenizer.encode_chat_completion(\n",
"# ChatCompletionRequest(messages=messages)\n",
"# ).tokens\n",
" \n",
"# out_tokens, _ = generate(\n",
"# [tokens], \n",
"# model, \n",
"# max_tokens=512, # Reduced for testing\n",
"# temperature=0.3, \n",
"# eos_id=mistral_tokenizer.instruct_tokenizer.tokenizer.eos_id\n",
"# )\n",
"# print(mistral_tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens[0]))\n",
" \n",
"# dist.barrier()\n",
" \n",
"# except Exception as e:\n",
"# print(f\"[Rank {dist.get_rank()}] Error: {e}\")\n",
"# raise\n",
"# finally:\n",
"# if dist.is_initialized():\n",
"# dist.destroy_process_group()\n",
"\n",
"# if __name__ == \"__main__\":\n",
"# # Set OMP_NUM_THREADS to avoid warning\n",
"# os.environ['OMP_NUM_THREADS'] = '1'\n",
"# main()\n",
"# ''')"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"trusted": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"To retrieve the current president of India from the DBpedia Knowledge Graph, you can use the following SPARQL query:\n",
"\n",
"```sparql\n",
"PREFIX dbo: \n",
"PREFIX rdf: \n",
"PREFIX rdfs: \n",
"PREFIX xsd: \n",
"\n",
"SELECT DISTINCT ?president ?presidentLabel\n",
"WHERE {\n",
" ?president a dbo:Person ;\n",
" dbo:countryOfCitizenship dbo:India ;\n",
" dbo:officeHolder dbo:PresidentOfIndia .\n",
" SERVICE wikibus-labs:pageLabels { bd:serviceParam wikibuslabs:language \"en\" . }\n",
" ?president rdfs:label ?presidentLabel .\n",
"}\n",
"ORDER BY DESC(STR(?presidentLabel))\n",
"LIMIT 1\n",
"```\n",
"\n",
"This query searches for individuals who are a `dbo:Person`, have `dbo:countryOfCitizenship` as India, and hold the office of `dbo:PresidentOfIndia`. It also uses the Wikibus-labs service to retrieve the English labels for the results. The results are ordered in descending order by the label and limited to the top 1 result.\n"
]
}
],
"source": [
"import torch\n",
"import safetensors\n",
"from mistral_inference.transformer import Transformer\n",
"from mistral_inference.generate import generate\n",
"from mistral_common.tokens.tokenizers.mistral import MistralTokenizer\n",
"from mistral_common.protocol.instruct.messages import UserMessage\n",
"from mistral_common.protocol.instruct.request import ChatCompletionRequest\n",
"\n",
"# load tokenizer\n",
"mistral_tokenizer = MistralTokenizer.from_file(\"/workspace/working/outputs/checkpoints/checkpoint_000300/consolidated/tokenizer.model.v3\")\n",
"# load model\n",
"model = Transformer.from_folder(\"/workspace/working/mistral_models\")\n",
"model.load_lora(\"/workspace/working/outputs/checkpoints/checkpoint_000300/consolidated/lora.safetensors\")\n",
"safetensors.torch.save_model(model, \"/workspace/working/outputs/finetunedmodel\")\n",
"\n",
"query = input(\"Any questions?\\n\")\n",
"prompt = f\"Give just the SPARQL query over DBpedia Knowledge Graph for the given natural language query: {query}\"\n",
"messages = [{\"role\" : \"user\", \"content\" : prompt}]\n",
"# chat completion request\n",
"completion_request = ChatCompletionRequest(messages=messages)\n",
"\n",
"# encode message\n",
"tokens = mistral_tokenizer.encode_chat_completion(completion_request).tokens\n",
"\n",
"# generate results\n",
"out_tokens, _ = generate([tokens], model, max_tokens=2048, temperature=0.50, eos_id=mistral_tokenizer.instruct_tokenizer.tokenizer.eos_id)\n",
"\n",
"# decode generated tokens\n",
"result = mistral_tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens[0])\n",
"print(result)\n",
"# ! cat infer.py"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"ename": "QueryBadFormed",
"evalue": "QueryBadFormed: A bad request has been sent to the endpoint: probably the SPARQL query is badly formed. \n\nResponse:\nb'Virtuoso 37000 Error SP030: SPARQL compiler, line 12: Undefined namespace prefix in prefix:localpart notation at \\'wikibus-labs:pageLabels\\' before \\'{\\'\\n\\nSPARQL query:\\n#output-format:application/sparql-results+json\\nPREFIX dbo: \\nPREFIX rdf: \\nPREFIX rdfs: \\nPREFIX xsd: \\n\\nSELECT DISTINCT ?president ?presidentLabel\\nWHERE {\\n ?president a dbo:Person ;\\n dbo:countryOfCitizenship dbo:India ;\\n dbo:officeHolder dbo:PresidentOfIndia .\\n SERVICE wikibus-labs:pageLabels { bd:serviceParam wikibuslabs:language \"en\" . }\\n ?president rdfs:label ?presidentLabel .\\n}\\nORDER BY DESC(STR(?presidentLabel))\\nLIMIT 1\\n'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mHTTPError\u001b[0m Traceback (most recent call last)",
"File \u001b[0;32m/workspace/working/venv/lib/python3.10/site-packages/SPARQLWrapper/Wrapper.py:926\u001b[0m, in \u001b[0;36mSPARQLWrapper._query\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 925\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 926\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[43murlopener\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrequest\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 927\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m response, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mreturnFormat\n",
"File \u001b[0;32m/usr/lib/python3.10/urllib/request.py:216\u001b[0m, in \u001b[0;36murlopen\u001b[0;34m(url, data, timeout, cafile, capath, cadefault, context)\u001b[0m\n\u001b[1;32m 215\u001b[0m opener \u001b[38;5;241m=\u001b[39m _opener\n\u001b[0;32m--> 216\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mopener\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mopen\u001b[49m\u001b[43m(\u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m/usr/lib/python3.10/urllib/request.py:525\u001b[0m, in \u001b[0;36mOpenerDirector.open\u001b[0;34m(self, fullurl, data, timeout)\u001b[0m\n\u001b[1;32m 524\u001b[0m meth \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mgetattr\u001b[39m(processor, meth_name)\n\u001b[0;32m--> 525\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[43mmeth\u001b[49m\u001b[43m(\u001b[49m\u001b[43mreq\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mresponse\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m response\n",
"File \u001b[0;32m/usr/lib/python3.10/urllib/request.py:634\u001b[0m, in \u001b[0;36mHTTPErrorProcessor.http_response\u001b[0;34m(self, request, response)\u001b[0m\n\u001b[1;32m 633\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;241m200\u001b[39m \u001b[38;5;241m<\u001b[39m\u001b[38;5;241m=\u001b[39m code \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m300\u001b[39m):\n\u001b[0;32m--> 634\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparent\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43merror\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 635\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mhttp\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrequest\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mresponse\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcode\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmsg\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhdrs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 637\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m response\n",
"File \u001b[0;32m/usr/lib/python3.10/urllib/request.py:557\u001b[0m, in \u001b[0;36mOpenerDirector.error\u001b[0;34m(self, proto, *args)\u001b[0m\n\u001b[1;32m 556\u001b[0m args \u001b[38;5;241m=\u001b[39m (\u001b[38;5;28mdict\u001b[39m, proto, meth_name) \u001b[38;5;241m+\u001b[39m args\n\u001b[0;32m--> 557\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_chain\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 558\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m result:\n",
"File \u001b[0;32m/usr/lib/python3.10/urllib/request.py:496\u001b[0m, in \u001b[0;36mOpenerDirector._call_chain\u001b[0;34m(self, chain, kind, meth_name, *args)\u001b[0m\n\u001b[1;32m 495\u001b[0m func \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mgetattr\u001b[39m(handler, meth_name)\n\u001b[0;32m--> 496\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 497\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m result \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
"File \u001b[0;32m/usr/lib/python3.10/urllib/request.py:749\u001b[0m, in \u001b[0;36mHTTPRedirectHandler.http_error_302\u001b[0;34m(self, req, fp, code, msg, headers)\u001b[0m\n\u001b[1;32m 747\u001b[0m fp\u001b[38;5;241m.\u001b[39mclose()\n\u001b[0;32m--> 749\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparent\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mopen\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnew\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreq\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m/usr/lib/python3.10/urllib/request.py:525\u001b[0m, in \u001b[0;36mOpenerDirector.open\u001b[0;34m(self, fullurl, data, timeout)\u001b[0m\n\u001b[1;32m 524\u001b[0m meth \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mgetattr\u001b[39m(processor, meth_name)\n\u001b[0;32m--> 525\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[43mmeth\u001b[49m\u001b[43m(\u001b[49m\u001b[43mreq\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mresponse\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m response\n",
"File \u001b[0;32m/usr/lib/python3.10/urllib/request.py:634\u001b[0m, in \u001b[0;36mHTTPErrorProcessor.http_response\u001b[0;34m(self, request, response)\u001b[0m\n\u001b[1;32m 633\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;241m200\u001b[39m \u001b[38;5;241m<\u001b[39m\u001b[38;5;241m=\u001b[39m code \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m300\u001b[39m):\n\u001b[0;32m--> 634\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparent\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43merror\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 635\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mhttp\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrequest\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mresponse\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcode\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmsg\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhdrs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 637\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m response\n",
"File \u001b[0;32m/usr/lib/python3.10/urllib/request.py:563\u001b[0m, in \u001b[0;36mOpenerDirector.error\u001b[0;34m(self, proto, *args)\u001b[0m\n\u001b[1;32m 562\u001b[0m args \u001b[38;5;241m=\u001b[39m (\u001b[38;5;28mdict\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdefault\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mhttp_error_default\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;241m+\u001b[39m orig_args\n\u001b[0;32m--> 563\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_chain\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m/usr/lib/python3.10/urllib/request.py:496\u001b[0m, in \u001b[0;36mOpenerDirector._call_chain\u001b[0;34m(self, chain, kind, meth_name, *args)\u001b[0m\n\u001b[1;32m 495\u001b[0m func \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mgetattr\u001b[39m(handler, meth_name)\n\u001b[0;32m--> 496\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 497\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m result \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
"File \u001b[0;32m/usr/lib/python3.10/urllib/request.py:643\u001b[0m, in \u001b[0;36mHTTPDefaultErrorHandler.http_error_default\u001b[0;34m(self, req, fp, code, msg, hdrs)\u001b[0m\n\u001b[1;32m 642\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mhttp_error_default\u001b[39m(\u001b[38;5;28mself\u001b[39m, req, fp, code, msg, hdrs):\n\u001b[0;32m--> 643\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m HTTPError(req\u001b[38;5;241m.\u001b[39mfull_url, code, msg, hdrs, fp)\n",
"\u001b[0;31mHTTPError\u001b[0m: HTTP Error 400: Bad Request",
"\nDuring handling of the above exception, another exception occurred:\n",
"\u001b[0;31mQueryBadFormed\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[22], line 19\u001b[0m\n\u001b[1;32m 3\u001b[0m SPARQLendpt\u001b[38;5;241m.\u001b[39msetReturnFormat(JSON)\n\u001b[1;32m 4\u001b[0m SPARQLendpt\u001b[38;5;241m.\u001b[39msetQuery(\u001b[38;5;124m'''\u001b[39m\u001b[38;5;124mPREFIX dbo: \u001b[39m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;124mPREFIX rdf: \u001b[39m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;124mPREFIX rdfs: \u001b[39m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[38;5;124mORDER BY DESC(STR(?presidentLabel))\u001b[39m\n\u001b[1;32m 18\u001b[0m \u001b[38;5;124mLIMIT 1\u001b[39m\u001b[38;5;124m'''\u001b[39m)\n\u001b[0;32m---> 19\u001b[0m LLMAnswer \u001b[38;5;241m=\u001b[39m \u001b[43mSPARQLendpt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mquery\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mconvert()\n\u001b[1;32m 20\u001b[0m LLMAnswer\n",
"File \u001b[0;32m/workspace/working/venv/lib/python3.10/site-packages/SPARQLWrapper/Wrapper.py:960\u001b[0m, in \u001b[0;36mSPARQLWrapper.query\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 942\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mquery\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mQueryResult\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 943\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 944\u001b[0m \u001b[38;5;124;03m Execute the query.\u001b[39;00m\n\u001b[1;32m 945\u001b[0m \u001b[38;5;124;03m Exceptions can be raised if either the URI is wrong or the HTTP sends back an error (this is also the\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 958\u001b[0m \u001b[38;5;124;03m :rtype: :class:`QueryResult` instance\u001b[39;00m\n\u001b[1;32m 959\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 960\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m QueryResult(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_query\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m)\n",
"File \u001b[0;32m/workspace/working/venv/lib/python3.10/site-packages/SPARQLWrapper/Wrapper.py:930\u001b[0m, in \u001b[0;36mSPARQLWrapper._query\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 928\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m urllib\u001b[38;5;241m.\u001b[39merror\u001b[38;5;241m.\u001b[39mHTTPError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 929\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m e\u001b[38;5;241m.\u001b[39mcode \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m400\u001b[39m:\n\u001b[0;32m--> 930\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m QueryBadFormed(e\u001b[38;5;241m.\u001b[39mread())\n\u001b[1;32m 931\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m e\u001b[38;5;241m.\u001b[39mcode \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m404\u001b[39m:\n\u001b[1;32m 932\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m EndPointNotFound(e\u001b[38;5;241m.\u001b[39mread())\n",
"\u001b[0;31mQueryBadFormed\u001b[0m: QueryBadFormed: A bad request has been sent to the endpoint: probably the SPARQL query is badly formed. \n\nResponse:\nb'Virtuoso 37000 Error SP030: SPARQL compiler, line 12: Undefined namespace prefix in prefix:localpart notation at \\'wikibus-labs:pageLabels\\' before \\'{\\'\\n\\nSPARQL query:\\n#output-format:application/sparql-results+json\\nPREFIX dbo: \\nPREFIX rdf: \\nPREFIX rdfs: \\nPREFIX xsd: \\n\\nSELECT DISTINCT ?president ?presidentLabel\\nWHERE {\\n ?president a dbo:Person ;\\n dbo:countryOfCitizenship dbo:India ;\\n dbo:officeHolder dbo:PresidentOfIndia .\\n SERVICE wikibus-labs:pageLabels { bd:serviceParam wikibuslabs:language \"en\" . }\\n ?president rdfs:label ?presidentLabel .\\n}\\nORDER BY DESC(STR(?presidentLabel))\\nLIMIT 1\\n'"
]
}
],
"source": [
"from SPARQLWrapper import JSON, SPARQLWrapper\n",
"SPARQLendpt = SPARQLWrapper(\"http://dbpedia.org/sparql\")\n",
"SPARQLendpt.setReturnFormat(JSON)\n",
"SPARQLendpt.setQuery('''PREFIX dbo: \n",
"PREFIX rdf: \n",
"PREFIX rdfs: \n",
"PREFIX xsd: \n",
"\n",
"SELECT DISTINCT ?president ?presidentLabel\n",
"WHERE {\n",
" ?president a dbo:Person ;\n",
" dbo:countryOfCitizenship dbo:India ;\n",
" dbo:officeHolder dbo:PresidentOfIndia .\n",
" SERVICE wikibus-labs:pageLabels { bd:serviceParam wikibuslabs:language \"en\" . }\n",
" ?president rdfs:label ?presidentLabel .\n",
"}\n",
"ORDER BY DESC(STR(?presidentLabel))\n",
"LIMIT 1''')\n",
"LLMAnswer = SPARQLendpt.query().convert()\n",
"LLMAnswer"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Transformer(\n",
" (tok_embeddings): Embedding(32768, 4096)\n",
" (norm): RMSNorm()\n",
" (output): Linear(in_features=4096, out_features=32768, bias=False)\n",
" (layers): ModuleDict(\n",
" (0): TransformerBlock(\n",
" (attention): Attention(\n",
" (wq): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (wk): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wv): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wo): Linear(in_features=4096, out_features=4096, bias=False)\n",
" )\n",
" (attention_norm): RMSNorm()\n",
" (ffn_norm): RMSNorm()\n",
" (feed_forward): FeedForward(\n",
" (w1): Linear(in_features=4096, out_features=14336, bias=False)\n",
" (w2): Linear(in_features=14336, out_features=4096, bias=False)\n",
" (w3): Linear(in_features=4096, out_features=14336, bias=False)\n",
" )\n",
" )\n",
" (1): TransformerBlock(\n",
" (attention): Attention(\n",
" (wq): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (wk): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wv): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wo): Linear(in_features=4096, out_features=4096, bias=False)\n",
" )\n",
" (attention_norm): RMSNorm()\n",
" (ffn_norm): RMSNorm()\n",
" (feed_forward): FeedForward(\n",
" (w1): Linear(in_features=4096, out_features=14336, bias=False)\n",
" (w2): Linear(in_features=14336, out_features=4096, bias=False)\n",
" (w3): Linear(in_features=4096, out_features=14336, bias=False)\n",
" )\n",
" )\n",
" (2): TransformerBlock(\n",
" (attention): Attention(\n",
" (wq): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (wk): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wv): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wo): Linear(in_features=4096, out_features=4096, bias=False)\n",
" )\n",
" (attention_norm): RMSNorm()\n",
" (ffn_norm): RMSNorm()\n",
" (feed_forward): FeedForward(\n",
" (w1): Linear(in_features=4096, out_features=14336, bias=False)\n",
" (w2): Linear(in_features=14336, out_features=4096, bias=False)\n",
" (w3): Linear(in_features=4096, out_features=14336, bias=False)\n",
" )\n",
" )\n",
" (3): TransformerBlock(\n",
" (attention): Attention(\n",
" (wq): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (wk): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wv): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wo): Linear(in_features=4096, out_features=4096, bias=False)\n",
" )\n",
" (attention_norm): RMSNorm()\n",
" (ffn_norm): RMSNorm()\n",
" (feed_forward): FeedForward(\n",
" (w1): Linear(in_features=4096, out_features=14336, bias=False)\n",
" (w2): Linear(in_features=14336, out_features=4096, bias=False)\n",
" (w3): Linear(in_features=4096, out_features=14336, bias=False)\n",
" )\n",
" )\n",
" (4): TransformerBlock(\n",
" (attention): Attention(\n",
" (wq): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (wk): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wv): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wo): Linear(in_features=4096, out_features=4096, bias=False)\n",
" )\n",
" (attention_norm): RMSNorm()\n",
" (ffn_norm): RMSNorm()\n",
" (feed_forward): FeedForward(\n",
" (w1): Linear(in_features=4096, out_features=14336, bias=False)\n",
" (w2): Linear(in_features=14336, out_features=4096, bias=False)\n",
" (w3): Linear(in_features=4096, out_features=14336, bias=False)\n",
" )\n",
" )\n",
" (5): TransformerBlock(\n",
" (attention): Attention(\n",
" (wq): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (wk): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wv): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wo): Linear(in_features=4096, out_features=4096, bias=False)\n",
" )\n",
" (attention_norm): RMSNorm()\n",
" (ffn_norm): RMSNorm()\n",
" (feed_forward): FeedForward(\n",
" (w1): Linear(in_features=4096, out_features=14336, bias=False)\n",
" (w2): Linear(in_features=14336, out_features=4096, bias=False)\n",
" (w3): Linear(in_features=4096, out_features=14336, bias=False)\n",
" )\n",
" )\n",
" (6): TransformerBlock(\n",
" (attention): Attention(\n",
" (wq): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (wk): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wv): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wo): Linear(in_features=4096, out_features=4096, bias=False)\n",
" )\n",
" (attention_norm): RMSNorm()\n",
" (ffn_norm): RMSNorm()\n",
" (feed_forward): FeedForward(\n",
" (w1): Linear(in_features=4096, out_features=14336, bias=False)\n",
" (w2): Linear(in_features=14336, out_features=4096, bias=False)\n",
" (w3): Linear(in_features=4096, out_features=14336, bias=False)\n",
" )\n",
" )\n",
" (7): TransformerBlock(\n",
" (attention): Attention(\n",
" (wq): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (wk): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wv): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wo): Linear(in_features=4096, out_features=4096, bias=False)\n",
" )\n",
" (attention_norm): RMSNorm()\n",
" (ffn_norm): RMSNorm()\n",
" (feed_forward): FeedForward(\n",
" (w1): Linear(in_features=4096, out_features=14336, bias=False)\n",
" (w2): Linear(in_features=14336, out_features=4096, bias=False)\n",
" (w3): Linear(in_features=4096, out_features=14336, bias=False)\n",
" )\n",
" )\n",
" (8): TransformerBlock(\n",
" (attention): Attention(\n",
" (wq): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (wk): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wv): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wo): Linear(in_features=4096, out_features=4096, bias=False)\n",
" )\n",
" (attention_norm): RMSNorm()\n",
" (ffn_norm): RMSNorm()\n",
" (feed_forward): FeedForward(\n",
" (w1): Linear(in_features=4096, out_features=14336, bias=False)\n",
" (w2): Linear(in_features=14336, out_features=4096, bias=False)\n",
" (w3): Linear(in_features=4096, out_features=14336, bias=False)\n",
" )\n",
" )\n",
" (9): TransformerBlock(\n",
" (attention): Attention(\n",
" (wq): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (wk): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wv): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wo): Linear(in_features=4096, out_features=4096, bias=False)\n",
" )\n",
" (attention_norm): RMSNorm()\n",
" (ffn_norm): RMSNorm()\n",
" (feed_forward): FeedForward(\n",
" (w1): Linear(in_features=4096, out_features=14336, bias=False)\n",
" (w2): Linear(in_features=14336, out_features=4096, bias=False)\n",
" (w3): Linear(in_features=4096, out_features=14336, bias=False)\n",
" )\n",
" )\n",
" (10): TransformerBlock(\n",
" (attention): Attention(\n",
" (wq): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (wk): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wv): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wo): Linear(in_features=4096, out_features=4096, bias=False)\n",
" )\n",
" (attention_norm): RMSNorm()\n",
" (ffn_norm): RMSNorm()\n",
" (feed_forward): FeedForward(\n",
" (w1): Linear(in_features=4096, out_features=14336, bias=False)\n",
" (w2): Linear(in_features=14336, out_features=4096, bias=False)\n",
" (w3): Linear(in_features=4096, out_features=14336, bias=False)\n",
" )\n",
" )\n",
" (11): TransformerBlock(\n",
" (attention): Attention(\n",
" (wq): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (wk): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wv): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wo): Linear(in_features=4096, out_features=4096, bias=False)\n",
" )\n",
" (attention_norm): RMSNorm()\n",
" (ffn_norm): RMSNorm()\n",
" (feed_forward): FeedForward(\n",
" (w1): Linear(in_features=4096, out_features=14336, bias=False)\n",
" (w2): Linear(in_features=14336, out_features=4096, bias=False)\n",
" (w3): Linear(in_features=4096, out_features=14336, bias=False)\n",
" )\n",
" )\n",
" (12): TransformerBlock(\n",
" (attention): Attention(\n",
" (wq): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (wk): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wv): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wo): Linear(in_features=4096, out_features=4096, bias=False)\n",
" )\n",
" (attention_norm): RMSNorm()\n",
" (ffn_norm): RMSNorm()\n",
" (feed_forward): FeedForward(\n",
" (w1): Linear(in_features=4096, out_features=14336, bias=False)\n",
" (w2): Linear(in_features=14336, out_features=4096, bias=False)\n",
" (w3): Linear(in_features=4096, out_features=14336, bias=False)\n",
" )\n",
" )\n",
" (13): TransformerBlock(\n",
" (attention): Attention(\n",
" (wq): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (wk): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wv): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wo): Linear(in_features=4096, out_features=4096, bias=False)\n",
" )\n",
" (attention_norm): RMSNorm()\n",
" (ffn_norm): RMSNorm()\n",
" (feed_forward): FeedForward(\n",
" (w1): Linear(in_features=4096, out_features=14336, bias=False)\n",
" (w2): Linear(in_features=14336, out_features=4096, bias=False)\n",
" (w3): Linear(in_features=4096, out_features=14336, bias=False)\n",
" )\n",
" )\n",
" (14): TransformerBlock(\n",
" (attention): Attention(\n",
" (wq): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (wk): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wv): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wo): Linear(in_features=4096, out_features=4096, bias=False)\n",
" )\n",
" (attention_norm): RMSNorm()\n",
" (ffn_norm): RMSNorm()\n",
" (feed_forward): FeedForward(\n",
" (w1): Linear(in_features=4096, out_features=14336, bias=False)\n",
" (w2): Linear(in_features=14336, out_features=4096, bias=False)\n",
" (w3): Linear(in_features=4096, out_features=14336, bias=False)\n",
" )\n",
" )\n",
" (15): TransformerBlock(\n",
" (attention): Attention(\n",
" (wq): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (wk): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wv): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wo): Linear(in_features=4096, out_features=4096, bias=False)\n",
" )\n",
" (attention_norm): RMSNorm()\n",
" (ffn_norm): RMSNorm()\n",
" (feed_forward): FeedForward(\n",
" (w1): Linear(in_features=4096, out_features=14336, bias=False)\n",
" (w2): Linear(in_features=14336, out_features=4096, bias=False)\n",
" (w3): Linear(in_features=4096, out_features=14336, bias=False)\n",
" )\n",
" )\n",
" (16): TransformerBlock(\n",
" (attention): Attention(\n",
" (wq): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (wk): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wv): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wo): Linear(in_features=4096, out_features=4096, bias=False)\n",
" )\n",
" (attention_norm): RMSNorm()\n",
" (ffn_norm): RMSNorm()\n",
" (feed_forward): FeedForward(\n",
" (w1): Linear(in_features=4096, out_features=14336, bias=False)\n",
" (w2): Linear(in_features=14336, out_features=4096, bias=False)\n",
" (w3): Linear(in_features=4096, out_features=14336, bias=False)\n",
" )\n",
" )\n",
" (17): TransformerBlock(\n",
" (attention): Attention(\n",
" (wq): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (wk): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wv): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wo): Linear(in_features=4096, out_features=4096, bias=False)\n",
" )\n",
" (attention_norm): RMSNorm()\n",
" (ffn_norm): RMSNorm()\n",
" (feed_forward): FeedForward(\n",
" (w1): Linear(in_features=4096, out_features=14336, bias=False)\n",
" (w2): Linear(in_features=14336, out_features=4096, bias=False)\n",
" (w3): Linear(in_features=4096, out_features=14336, bias=False)\n",
" )\n",
" )\n",
" (18): TransformerBlock(\n",
" (attention): Attention(\n",
" (wq): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (wk): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wv): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wo): Linear(in_features=4096, out_features=4096, bias=False)\n",
" )\n",
" (attention_norm): RMSNorm()\n",
" (ffn_norm): RMSNorm()\n",
" (feed_forward): FeedForward(\n",
" (w1): Linear(in_features=4096, out_features=14336, bias=False)\n",
" (w2): Linear(in_features=14336, out_features=4096, bias=False)\n",
" (w3): Linear(in_features=4096, out_features=14336, bias=False)\n",
" )\n",
" )\n",
" (19): TransformerBlock(\n",
" (attention): Attention(\n",
" (wq): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (wk): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wv): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wo): Linear(in_features=4096, out_features=4096, bias=False)\n",
" )\n",
" (attention_norm): RMSNorm()\n",
" (ffn_norm): RMSNorm()\n",
" (feed_forward): FeedForward(\n",
" (w1): Linear(in_features=4096, out_features=14336, bias=False)\n",
" (w2): Linear(in_features=14336, out_features=4096, bias=False)\n",
" (w3): Linear(in_features=4096, out_features=14336, bias=False)\n",
" )\n",
" )\n",
" (20): TransformerBlock(\n",
" (attention): Attention(\n",
" (wq): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (wk): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wv): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wo): Linear(in_features=4096, out_features=4096, bias=False)\n",
" )\n",
" (attention_norm): RMSNorm()\n",
" (ffn_norm): RMSNorm()\n",
" (feed_forward): FeedForward(\n",
" (w1): Linear(in_features=4096, out_features=14336, bias=False)\n",
" (w2): Linear(in_features=14336, out_features=4096, bias=False)\n",
" (w3): Linear(in_features=4096, out_features=14336, bias=False)\n",
" )\n",
" )\n",
" (21): TransformerBlock(\n",
" (attention): Attention(\n",
" (wq): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (wk): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wv): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wo): Linear(in_features=4096, out_features=4096, bias=False)\n",
" )\n",
" (attention_norm): RMSNorm()\n",
" (ffn_norm): RMSNorm()\n",
" (feed_forward): FeedForward(\n",
" (w1): Linear(in_features=4096, out_features=14336, bias=False)\n",
" (w2): Linear(in_features=14336, out_features=4096, bias=False)\n",
" (w3): Linear(in_features=4096, out_features=14336, bias=False)\n",
" )\n",
" )\n",
" (22): TransformerBlock(\n",
" (attention): Attention(\n",
" (wq): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (wk): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wv): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wo): Linear(in_features=4096, out_features=4096, bias=False)\n",
" )\n",
" (attention_norm): RMSNorm()\n",
" (ffn_norm): RMSNorm()\n",
" (feed_forward): FeedForward(\n",
" (w1): Linear(in_features=4096, out_features=14336, bias=False)\n",
" (w2): Linear(in_features=14336, out_features=4096, bias=False)\n",
" (w3): Linear(in_features=4096, out_features=14336, bias=False)\n",
" )\n",
" )\n",
" (23): TransformerBlock(\n",
" (attention): Attention(\n",
" (wq): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (wk): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wv): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wo): Linear(in_features=4096, out_features=4096, bias=False)\n",
" )\n",
" (attention_norm): RMSNorm()\n",
" (ffn_norm): RMSNorm()\n",
" (feed_forward): FeedForward(\n",
" (w1): Linear(in_features=4096, out_features=14336, bias=False)\n",
" (w2): Linear(in_features=14336, out_features=4096, bias=False)\n",
" (w3): Linear(in_features=4096, out_features=14336, bias=False)\n",
" )\n",
" )\n",
" (24): TransformerBlock(\n",
" (attention): Attention(\n",
" (wq): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (wk): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wv): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wo): Linear(in_features=4096, out_features=4096, bias=False)\n",
" )\n",
" (attention_norm): RMSNorm()\n",
" (ffn_norm): RMSNorm()\n",
" (feed_forward): FeedForward(\n",
" (w1): Linear(in_features=4096, out_features=14336, bias=False)\n",
" (w2): Linear(in_features=14336, out_features=4096, bias=False)\n",
" (w3): Linear(in_features=4096, out_features=14336, bias=False)\n",
" )\n",
" )\n",
" (25): TransformerBlock(\n",
" (attention): Attention(\n",
" (wq): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (wk): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wv): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wo): Linear(in_features=4096, out_features=4096, bias=False)\n",
" )\n",
" (attention_norm): RMSNorm()\n",
" (ffn_norm): RMSNorm()\n",
" (feed_forward): FeedForward(\n",
" (w1): Linear(in_features=4096, out_features=14336, bias=False)\n",
" (w2): Linear(in_features=14336, out_features=4096, bias=False)\n",
" (w3): Linear(in_features=4096, out_features=14336, bias=False)\n",
" )\n",
" )\n",
" (26): TransformerBlock(\n",
" (attention): Attention(\n",
" (wq): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (wk): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wv): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wo): Linear(in_features=4096, out_features=4096, bias=False)\n",
" )\n",
" (attention_norm): RMSNorm()\n",
" (ffn_norm): RMSNorm()\n",
" (feed_forward): FeedForward(\n",
" (w1): Linear(in_features=4096, out_features=14336, bias=False)\n",
" (w2): Linear(in_features=14336, out_features=4096, bias=False)\n",
" (w3): Linear(in_features=4096, out_features=14336, bias=False)\n",
" )\n",
" )\n",
" (27): TransformerBlock(\n",
" (attention): Attention(\n",
" (wq): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (wk): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wv): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wo): Linear(in_features=4096, out_features=4096, bias=False)\n",
" )\n",
" (attention_norm): RMSNorm()\n",
" (ffn_norm): RMSNorm()\n",
" (feed_forward): FeedForward(\n",
" (w1): Linear(in_features=4096, out_features=14336, bias=False)\n",
" (w2): Linear(in_features=14336, out_features=4096, bias=False)\n",
" (w3): Linear(in_features=4096, out_features=14336, bias=False)\n",
" )\n",
" )\n",
" (28): TransformerBlock(\n",
" (attention): Attention(\n",
" (wq): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (wk): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wv): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wo): Linear(in_features=4096, out_features=4096, bias=False)\n",
" )\n",
" (attention_norm): RMSNorm()\n",
" (ffn_norm): RMSNorm()\n",
" (feed_forward): FeedForward(\n",
" (w1): Linear(in_features=4096, out_features=14336, bias=False)\n",
" (w2): Linear(in_features=14336, out_features=4096, bias=False)\n",
" (w3): Linear(in_features=4096, out_features=14336, bias=False)\n",
" )\n",
" )\n",
" (29): TransformerBlock(\n",
" (attention): Attention(\n",
" (wq): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (wk): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wv): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wo): Linear(in_features=4096, out_features=4096, bias=False)\n",
" )\n",
" (attention_norm): RMSNorm()\n",
" (ffn_norm): RMSNorm()\n",
" (feed_forward): FeedForward(\n",
" (w1): Linear(in_features=4096, out_features=14336, bias=False)\n",
" (w2): Linear(in_features=14336, out_features=4096, bias=False)\n",
" (w3): Linear(in_features=4096, out_features=14336, bias=False)\n",
" )\n",
" )\n",
" (30): TransformerBlock(\n",
" (attention): Attention(\n",
" (wq): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (wk): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wv): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wo): Linear(in_features=4096, out_features=4096, bias=False)\n",
" )\n",
" (attention_norm): RMSNorm()\n",
" (ffn_norm): RMSNorm()\n",
" (feed_forward): FeedForward(\n",
" (w1): Linear(in_features=4096, out_features=14336, bias=False)\n",
" (w2): Linear(in_features=14336, out_features=4096, bias=False)\n",
" (w3): Linear(in_features=4096, out_features=14336, bias=False)\n",
" )\n",
" )\n",
" (31): TransformerBlock(\n",
" (attention): Attention(\n",
" (wq): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (wk): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wv): Linear(in_features=4096, out_features=1024, bias=False)\n",
" (wo): Linear(in_features=4096, out_features=4096, bias=False)\n",
" )\n",
" (attention_norm): RMSNorm()\n",
" (ffn_norm): RMSNorm()\n",
" (feed_forward): FeedForward(\n",
" (w1): Linear(in_features=4096, out_features=14336, bias=False)\n",
" (w2): Linear(in_features=14336, out_features=4096, bias=False)\n",
" (w3): Linear(in_features=4096, out_features=14336, bias=False)\n",
" )\n",
" )\n",
" )\n",
")"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": []
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"ename": "AttributeError",
"evalue": "'Transformer' object has no attribute 'save'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[25], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave\u001b[49m()\n",
"File \u001b[0;32m/workspace/working/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1688\u001b[0m, in \u001b[0;36mModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1686\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m modules:\n\u001b[1;32m 1687\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m modules[name]\n\u001b[0;32m-> 1688\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mself\u001b[39m)\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m object has no attribute \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
"\u001b[0;31mAttributeError\u001b[0m: 'Transformer' object has no attribute 'save'"
]
}
],
"source": [
"model.save()"
]
}
],
"metadata": {
"kaggle": {
"accelerator": "nvidiaTeslaT4",
"dataSources": [],
"isGpuEnabled": true,
"isInternetEnabled": true,
"language": "python",
"sourceType": "notebook"
},
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 4
}