Christina Theodoris commited on
Commit
18a2ca6
·
1 Parent(s): 7e4697d

update eval_strategy for new transformers version

Browse files
Files changed (1) hide show
  1. geneformer/classifier.py +11 -2
geneformer/classifier.py CHANGED
@@ -48,11 +48,13 @@ import logging
48
  import os
49
  import pickle
50
  import subprocess
 
51
  from pathlib import Path
52
 
53
  import numpy as np
54
  import pandas as pd
55
  import seaborn as sns
 
56
  from tqdm.auto import tqdm, trange
57
  from transformers import Trainer
58
  from transformers.training_args import TrainingArguments
@@ -71,6 +73,7 @@ sns.set()
71
 
72
  logger = logging.getLogger(__name__)
73
 
 
74
 
75
  class Classifier:
76
  valid_option_dict = {
@@ -1060,7 +1063,10 @@ class Classifier:
1060
  def_training_args["logging_steps"] = logging_steps
1061
  def_training_args["output_dir"] = output_directory
1062
  if eval_data is None:
1063
- def_training_args["evaluation_strategy"] = "no"
 
 
 
1064
  def_training_args["load_best_model_at_end"] = False
1065
  def_training_args.update(
1066
  {"save_strategy": "epoch", "save_total_limit": 1}
@@ -1231,7 +1237,10 @@ class Classifier:
1231
  def_training_args["logging_steps"] = logging_steps
1232
  def_training_args["output_dir"] = output_directory
1233
  if eval_data is None:
1234
- def_training_args["evaluation_strategy"] = "no"
 
 
 
1235
  def_training_args["load_best_model_at_end"] = False
1236
  training_args_init = TrainingArguments(**def_training_args)
1237
 
 
48
  import os
49
  import pickle
50
  import subprocess
51
+ from packaging.version import parse
52
  from pathlib import Path
53
 
54
  import numpy as np
55
  import pandas as pd
56
  import seaborn as sns
57
+ import transformers
58
  from tqdm.auto import tqdm, trange
59
  from transformers import Trainer
60
  from transformers.training_args import TrainingArguments
 
73
 
74
  logger = logging.getLogger(__name__)
75
 
76
+ transformers_version = parse(transformers.__version__)
77
 
78
  class Classifier:
79
  valid_option_dict = {
 
1063
  def_training_args["logging_steps"] = logging_steps
1064
  def_training_args["output_dir"] = output_directory
1065
  if eval_data is None:
1066
+ if transformers_version >= parse("4.46"):
1067
+ def_training_args["eval_strategy"] = "no"
1068
+ else:
1069
+ def_training_args["evaluation_strategy"] = "no"
1070
  def_training_args["load_best_model_at_end"] = False
1071
  def_training_args.update(
1072
  {"save_strategy": "epoch", "save_total_limit": 1}
 
1237
  def_training_args["logging_steps"] = logging_steps
1238
  def_training_args["output_dir"] = output_directory
1239
  if eval_data is None:
1240
+ if transformers_version >= parse("4.46"):
1241
+ def_training_args["eval_strategy"] = "no"
1242
+ else:
1243
+ def_training_args["evaluation_strategy"] = "no"
1244
  def_training_args["load_best_model_at_end"] = False
1245
  training_args_init = TrainingArguments(**def_training_args)
1246