|  | --- | 
					
						
						|  | library_name: transformers | 
					
						
						|  | language: | 
					
						
						|  | - en | 
					
						
						|  | license: mit | 
					
						
						|  | base_model: FacebookAI/roberta-base | 
					
						
						|  | tags: | 
					
						
						|  | - generated_from_trainer | 
					
						
						|  | datasets: | 
					
						
						|  | - swag | 
					
						
						|  | metrics: | 
					
						
						|  | - accuracy | 
					
						
						|  | model-index: | 
					
						
						|  | - name: swag_base | 
					
						
						|  | results: | 
					
						
						|  | - task: | 
					
						
						|  | name: Multiple Choice | 
					
						
						|  | type: multiple-choice | 
					
						
						|  | dataset: | 
					
						
						|  | name: SWAG | 
					
						
						|  | type: swag | 
					
						
						|  | args: regular | 
					
						
						|  | metrics: | 
					
						
						|  | - name: Accuracy | 
					
						
						|  | type: accuracy | 
					
						
						|  | value: 0.7521243691444397 | 
					
						
						|  | --- | 
					
						
						|  |  | 
					
						
						|  | # swag_base | 
					
						
						|  |  | 
					
						
						|  | This model is a fine-tuned version of [FacebookAI/roberta-base](https://huggingface.co/FacebookAI/roberta-base) on the SWAG (Situations With Adversarial Generations) dataset. | 
					
						
						|  |  | 
					
						
						|  | ## Model description | 
					
						
						|  |  | 
					
						
						|  | The model is designed to perform multiple-choice reasoning about real-world situations. Given a context and four possible continuations, it predicts the most plausible ending based on common sense understanding. | 
					
						
						|  |  | 
					
						
						|  | Key Features: | 
					
						
						|  | - Base model: RoBERTa-base | 
					
						
						|  | - Task: Multiple Choice Prediction | 
					
						
						|  | - Training dataset: SWAG | 
					
						
						|  | - Performance: 75.21% accuracy on evaluation set | 
					
						
						|  |  | 
					
						
						|  | ## Training Procedure | 
					
						
						|  |  | 
					
						
						|  | ### Training hyperparameters | 
					
						
						|  | - Learning rate: 5e-05 | 
					
						
						|  | - Batch size: 16 | 
					
						
						|  | - Number of epochs: 3 | 
					
						
						|  | - Optimizer: AdamW | 
					
						
						|  | - Learning rate scheduler: Linear | 
					
						
						|  | - Training samples: 73,546 | 
					
						
						|  | - Training time: 17m 53s | 
					
						
						|  |  | 
					
						
						|  | ### Training Results | 
					
						
						|  | - Training loss: 0.73 | 
					
						
						|  | - Evaluation loss: 0.7362 | 
					
						
						|  | - Evaluation accuracy: 0.7521 | 
					
						
						|  | - Training samples/second: 205.623 | 
					
						
						|  | - Training steps/second: 12.852 | 
					
						
						|  |  | 
					
						
						|  | ## Usage Example | 
					
						
						|  |  | 
					
						
						|  | Here's how to use the model: | 
					
						
						|  |  | 
					
						
						|  | ```python | 
					
						
						|  | from transformers import AutoTokenizer, AutoModelForMultipleChoice | 
					
						
						|  | import torch | 
					
						
						|  |  | 
					
						
						|  | # Load model and tokenizer | 
					
						
						|  | model_path = "real-jiakai/roberta-base-uncased-finetuned-swag" | 
					
						
						|  | tokenizer = AutoTokenizer.from_pretrained(model_path) | 
					
						
						|  | model = AutoModelForMultipleChoice.from_pretrained(model_path) | 
					
						
						|  |  | 
					
						
						|  | def predict_swag(context, endings, model, tokenizer): | 
					
						
						|  | encoding = tokenizer( | 
					
						
						|  | [context] * 4, | 
					
						
						|  | endings, | 
					
						
						|  | truncation=True, | 
					
						
						|  | max_length=128, | 
					
						
						|  | padding="max_length", | 
					
						
						|  | return_tensors="pt" | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | input_ids = encoding['input_ids'].unsqueeze(0) | 
					
						
						|  | attention_mask = encoding['attention_mask'].unsqueeze(0) | 
					
						
						|  |  | 
					
						
						|  | outputs = model(input_ids=input_ids, attention_mask=attention_mask) | 
					
						
						|  | logits = outputs.logits | 
					
						
						|  |  | 
					
						
						|  | predicted_idx = torch.argmax(logits).item() | 
					
						
						|  |  | 
					
						
						|  | return { | 
					
						
						|  | 'context': context, | 
					
						
						|  | 'predicted_ending': endings[predicted_idx], | 
					
						
						|  | 'probabilities': torch.softmax(logits, dim=1)[0].tolist() | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | # Example scenarios | 
					
						
						|  | test_examples = [ | 
					
						
						|  | { | 
					
						
						|  | 'context': "Stephen Curry dribbles the ball at the three-point line", | 
					
						
						|  | 'endings': [ | 
					
						
						|  | "He quickly releases a perfect shot that swishes through the net",  # Most plausible | 
					
						
						|  | "He suddenly starts dancing ballet on the court", | 
					
						
						|  | "He transforms the basketball into a pizza", | 
					
						
						|  | "He flies to the moon with the basketball" | 
					
						
						|  | ] | 
					
						
						|  | }, | 
					
						
						|  | { | 
					
						
						|  | 'context': "Elon Musk walks into a SpaceX facility and looks at a rocket", | 
					
						
						|  | 'endings': [ | 
					
						
						|  | "He discusses technical details with the engineering team",  # Most plausible | 
					
						
						|  | "He turns the rocket into a giant chocolate bar", | 
					
						
						|  | "He starts playing basketball with the rocket", | 
					
						
						|  | "He teaches the rocket to speak French" | 
					
						
						|  | ] | 
					
						
						|  | } | 
					
						
						|  | ] | 
					
						
						|  |  | 
					
						
						|  | for i, example in enumerate(test_examples, 1): | 
					
						
						|  | result = predict_swag( | 
					
						
						|  | example['context'], | 
					
						
						|  | example['endings'], | 
					
						
						|  | model, | 
					
						
						|  | tokenizer | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | print(f"\n=== Test Scenario {i} ===") | 
					
						
						|  | print(f"Initial Context: {result['context']}") | 
					
						
						|  | print(f"\nPredicted Most Likely Ending: {result['predicted_ending']}") | 
					
						
						|  | print("\nProbabilities for All Options:") | 
					
						
						|  | for idx, (ending, prob) in enumerate(zip(result['all_endings'], result['probabilities'])): | 
					
						
						|  | print(f"Option {idx}: {ending}") | 
					
						
						|  | print(f"Probability: {prob:.3f}") | 
					
						
						|  | print("\n" + "="*50) | 
					
						
						|  | ``` | 
					
						
						|  |  | 
					
						
						|  | ## Limitations and Biases | 
					
						
						|  |  | 
					
						
						|  | The model's performance is limited by its training data and may not generalize well to all domains | 
					
						
						|  | Performance might vary depending on the complexity and domain of the input scenarios | 
					
						
						|  | The model may exhibit biases present in the training data | 
					
						
						|  |  | 
					
						
						|  | ## Framework versions | 
					
						
						|  |  | 
					
						
						|  | Transformers 4.47.0.dev0 | 
					
						
						|  | PyTorch 2.5.1+cu124 | 
					
						
						|  | Datasets 3.1.0 | 
					
						
						|  | Tokenizers 0.20.3 | 
					
						
						|  |  | 
					
						
						|  | ## Citation | 
					
						
						|  |  | 
					
						
						|  | If you use this model, please cite: | 
					
						
						|  |  | 
					
						
						|  | ``` | 
					
						
						|  | @inproceedings{zellers2018swagaf, | 
					
						
						|  | title={SWAG: A Large-Scale Adversarial Dataset for Grounded Commonsense Inference}, | 
					
						
						|  | author={Zellers, Rowan and Bisk, Yonatan and Schwartz, Roy and Choi, Yejin}, | 
					
						
						|  | booktitle = "Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing (EMNLP)", | 
					
						
						|  | year={2018} | 
					
						
						|  | } | 
					
						
						|  | ``` |