| | from dataclasses import dataclass |
| | from typing import Any, ClassVar, List |
| |
|
| | from llmdataparser.base_parser import ( |
| | DatasetDescription, |
| | EvaluationMetric, |
| | HuggingFaceDatasetParser, |
| | HuggingFaceParseEntry, |
| | ) |
| |
|
| |
|
| | @dataclass(frozen=True, kw_only=True, slots=True) |
| | class BBHParseEntry(HuggingFaceParseEntry): |
| | """Custom entry class for BBH (Big Bench Hard), with fields specific to this dataset.""" |
| |
|
| | @classmethod |
| | def create( |
| | cls, |
| | question: str, |
| | answer: str, |
| | raw_question: str, |
| | raw_answer: str, |
| | task_name: str, |
| | ) -> "BBHParseEntry": |
| | return cls( |
| | question=question, |
| | answer=answer, |
| | raw_question=raw_question, |
| | raw_answer=raw_answer, |
| | task_name=task_name, |
| | ) |
| |
|
| |
|
| | class BBHDatasetParser(HuggingFaceDatasetParser[BBHParseEntry]): |
| | """Parser for the Big Bench Hard dataset.""" |
| |
|
| | _data_source: ClassVar[str] = "lukaemon/bbh" |
| | _task_names: ClassVar[list[str]] = [ |
| | "boolean_expressions", |
| | "causal_judgement", |
| | "date_understanding", |
| | "disambiguation_qa", |
| | "dyck_languages", |
| | "formal_fallacies", |
| | "geometric_shapes", |
| | "hyperbaton", |
| | "logical_deduction_five_objects", |
| | "logical_deduction_seven_objects", |
| | "logical_deduction_three_objects", |
| | "movie_recommendation", |
| | "multistep_arithmetic_two", |
| | "navigate", |
| | "object_counting", |
| | "penguins_in_a_table", |
| | "reasoning_about_colored_objects", |
| | "ruin_names", |
| | "salient_translation_error_detection", |
| | "snarks", |
| | "sports_understanding", |
| | "temporal_sequences", |
| | "tracking_shuffled_objects_five_objects", |
| | "tracking_shuffled_objects_seven_objects", |
| | "tracking_shuffled_objects_three_objects", |
| | "web_of_lies", |
| | "word_sorting", |
| | ] |
| | _default_task: ClassVar[str] = "reasoning_about_colored_objects" |
| |
|
| | def process_entry( |
| | self, row: dict[str, Any], task_name: str | None = None, **kwargs: Any |
| | ) -> BBHParseEntry: |
| | """Process a single BBH entry.""" |
| | raw_question = row["input"] |
| | raw_answer = row["target"] |
| |
|
| | |
| | clean_answer = raw_answer.strip("()") |
| |
|
| | question = str(raw_question) |
| |
|
| | |
| | task = task_name or self._get_current_task(row) |
| |
|
| | return BBHParseEntry.create( |
| | question=question, |
| | answer=clean_answer, |
| | raw_question=raw_question, |
| | raw_answer=raw_answer, |
| | task_name=task, |
| | ) |
| |
|
| | def get_dataset_description(self) -> DatasetDescription: |
| | """Returns a description of the Big Bench Hard dataset.""" |
| | return DatasetDescription.create( |
| | name="Big Bench Hard (BBH)", |
| | purpose="A curated subset of 23 challenging BIG-Bench tasks where language models initially performed below average human-rater performance", |
| | source="https://github.com/suzgunmirac/BIG-Bench-Hard", |
| | language="English", |
| | format="Multiple choice questions with single correct answers", |
| | characteristics=( |
| | "Tasks require complex multi-step reasoning and were selected based on " |
| | "initial model performance below human baseline. Performance can be " |
| | "significantly improved through chain-of-thought prompting. The dataset " |
| | "includes 23 core tasks plus additional related tasks." |
| | ), |
| | category=["Advanced Reasoning"], |
| | citation=( |
| | "@article{suzgun2022challenging,\n" |
| | " title={Challenging BIG-Bench Tasks and Whether Chain-of-Thought Can Solve Them},\n" |
| | ' author={Suzgun, Mirac and Scales, Nathan and Sch{"a}rli, Nathanael and Gehrmann, Sebastian and Tay, Yi and Chung, Hyung Won and Chowdhery, Aakanksha and Le, Quoc V and Chi, Ed H and Zhou, Denny and Wei, Jason},\n' |
| | " journal={arXiv preprint arXiv:2210.09261},\n" |
| | " year={2022}\n" |
| | "}" |
| | ), |
| | additional_info={ |
| | "model_performance": ( |
| | "With chain-of-thought prompting, PaLM surpassed human performance on " |
| | "10/23 tasks, while Codex surpassed human performance on 17/23 tasks" |
| | ), |
| | "size": "6.5k examples across 27 tasks (23 core + 4 related)", |
| | }, |
| | ) |
| |
|
| | def get_evaluation_metrics(self) -> List[EvaluationMetric]: |
| | """Returns the recommended evaluation metrics for BBH dataset.""" |
| | return [ |
| | EvaluationMetric.create( |
| | name="accuracy", |
| | type="classification", |
| | description="Proportion of exactly correct answers (after stripping parentheses)", |
| | implementation="evaluate.load('accuracy')", |
| | primary=True, |
| | ), |
| | EvaluationMetric.create( |
| | name="human_eval_delta", |
| | type="comparison", |
| | description="Difference between model accuracy and average human-rater performance baseline", |
| | implementation="custom_human_baseline_comparison", |
| | primary=True, |
| | ), |
| | EvaluationMetric.create( |
| | name="per_task_accuracy", |
| | type="classification", |
| | description="Accuracy broken down by individual reasoning tasks", |
| | implementation="custom_task_accuracy", |
| | primary=False, |
| | ), |
| | EvaluationMetric.create( |
| | name="exact_match", |
| | type="string_match", |
| | description="Strict exact match between predicted and target answers", |
| | implementation="evaluate.load('exact_match')", |
| | primary=False, |
| | ), |
| | ] |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | parser = BBHDatasetParser() |
| |
|
| | |
| | parser.load(task_name="reasoning_about_colored_objects") |
| |
|
| | |
| | parser.parse() |
| |
|
| | |
| | parsed_data = parser.get_parsed_data |
| |
|
| | |
| | if parsed_data: |
| | example = parsed_data[0] |
| | print("\nExample parsed entry:") |
| | print(f"Task: {example.task_name}") |
| | print(f"Question: {example.question}") |
| | print(f"Answer: {example.answer}") |
| |
|