Spaces:
Runtime error
Runtime error
Update vqa_accuracy.py
Browse files- vqa_accuracy.py +7 -18
vqa_accuracy.py
CHANGED
|
@@ -5,7 +5,7 @@ import re
|
|
| 5 |
_DESCRIPTION = """
|
| 6 |
VQA accuracy is a evaluation metric which is robust to inter-human variability in phrasing the answers:
|
| 7 |
$$
|
| 8 |
-
\\text{Acc}(
|
| 9 |
$$
|
| 10 |
Where `ans` is answered by machine. In order to be consistent with 'human accuracies', machine accuracies are averaged over all 10 choose 9 sets of human annotators.
|
| 11 |
"""
|
|
@@ -17,9 +17,9 @@ Args:
|
|
| 17 |
references (`list` of `str` lists): Ground truth answers.
|
| 18 |
answer_types (`list` of `str`, *optional*): Answer types corresponding to each questions.
|
| 19 |
questions_type (`list` of `str`, *optional*): Question types corresponding to each questions.
|
| 20 |
-
|
| 21 |
Returns:
|
| 22 |
-
visual question answering accuracy (`float` or `int`): Accuracy accuracy. Minimum possible value is 0. Maximum possible value is
|
| 23 |
|
| 24 |
"""
|
| 25 |
|
|
@@ -250,14 +250,7 @@ class VQAAccuracy(evaluate.Metric):
|
|
| 250 |
],
|
| 251 |
)
|
| 252 |
|
| 253 |
-
def _compute(
|
| 254 |
-
self,
|
| 255 |
-
predictions,
|
| 256 |
-
references,
|
| 257 |
-
answer_types=None,
|
| 258 |
-
question_types=None,
|
| 259 |
-
precision=2,
|
| 260 |
-
):
|
| 261 |
if answer_types is None:
|
| 262 |
answer_types = [None] * len(predictions)
|
| 263 |
|
|
@@ -300,21 +293,17 @@ class VQAAccuracy(evaluate.Metric):
|
|
| 300 |
ques_type_dict[ques_type].append(vqa_acc)
|
| 301 |
|
| 302 |
# the following key names follow the naming of the official evaluation results
|
| 303 |
-
result = {"overall":
|
| 304 |
|
| 305 |
if len(ans_type_dict) > 0:
|
| 306 |
result["perAnswerType"] = {
|
| 307 |
-
ans_type:
|
| 308 |
-
100 * sum(accuracy_list) / len(accuracy_list), precision
|
| 309 |
-
)
|
| 310 |
for ans_type, accuracy_list in ans_type_dict.items()
|
| 311 |
}
|
| 312 |
|
| 313 |
if len(ques_type_dict) > 0:
|
| 314 |
result["perQuestionType"] = {
|
| 315 |
-
ques_type:
|
| 316 |
-
100 * sum(accuracy_list) / len(accuracy_list), precision
|
| 317 |
-
)
|
| 318 |
for ques_type, accuracy_list in ques_type_dict.items()
|
| 319 |
}
|
| 320 |
|
|
|
|
| 5 |
_DESCRIPTION = """
|
| 6 |
VQA accuracy is a evaluation metric which is robust to inter-human variability in phrasing the answers:
|
| 7 |
$$
|
| 8 |
+
\\text{Acc}(ans) = \\min \\left( \\frac{\\text{# humans that said }ans}{3}, 1 \\right)
|
| 9 |
$$
|
| 10 |
Where `ans` is answered by machine. In order to be consistent with 'human accuracies', machine accuracies are averaged over all 10 choose 9 sets of human annotators.
|
| 11 |
"""
|
|
|
|
| 17 |
references (`list` of `str` lists): Ground truth answers.
|
| 18 |
answer_types (`list` of `str`, *optional*): Answer types corresponding to each questions.
|
| 19 |
questions_type (`list` of `str`, *optional*): Question types corresponding to each questions.
|
| 20 |
+
|
| 21 |
Returns:
|
| 22 |
+
visual question answering accuracy (`float` or `int`): Accuracy accuracy. Minimum possible value is 0. Maximum possible value is 100.
|
| 23 |
|
| 24 |
"""
|
| 25 |
|
|
|
|
| 250 |
],
|
| 251 |
)
|
| 252 |
|
| 253 |
+
def _compute(self, predictions, references, answer_types=None, question_types=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
if answer_types is None:
|
| 255 |
answer_types = [None] * len(predictions)
|
| 256 |
|
|
|
|
| 293 |
ques_type_dict[ques_type].append(vqa_acc)
|
| 294 |
|
| 295 |
# the following key names follow the naming of the official evaluation results
|
| 296 |
+
result = {"overall": 100 * sum(total) / len(total)}
|
| 297 |
|
| 298 |
if len(ans_type_dict) > 0:
|
| 299 |
result["perAnswerType"] = {
|
| 300 |
+
ans_type: 100 * sum(accuracy_list) / len(accuracy_list)
|
|
|
|
|
|
|
| 301 |
for ans_type, accuracy_list in ans_type_dict.items()
|
| 302 |
}
|
| 303 |
|
| 304 |
if len(ques_type_dict) > 0:
|
| 305 |
result["perQuestionType"] = {
|
| 306 |
+
ques_type: 100 * sum(accuracy_list) / len(accuracy_list)
|
|
|
|
|
|
|
| 307 |
for ques_type, accuracy_list in ques_type_dict.items()
|
| 308 |
}
|
| 309 |
|