MMLU main split is named auxiliary_train, not train
This commit is contained in:
+2
-7
@@ -13,16 +13,11 @@ class MMLU(Task):
|
||||
|
||||
def __init__(self, subset, split, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
assert subset in ["all", "auxiliary_train"], f"subset {subset} must be all|auxiliary_train"
|
||||
assert split in ["train", "validation", "dev", "test"], f"split {split} must be train|validation|dev|test"
|
||||
if subset == "auxiliary_train":
|
||||
assert split == "train", "auxiliary_train must be split into train"
|
||||
assert subset in ["all"], f"subset {subset} must be all"
|
||||
assert split in ["auxiliary_train", "validation", "dev", "test"], f"split {split} must be auxiliary_train|validation|dev|test"
|
||||
self.subset = subset
|
||||
self.split = split
|
||||
self.ds = load_dataset("cais/mmlu", subset, split=split).shuffle(seed=42)
|
||||
if subset == "auxiliary_train":
|
||||
# I don't understand why but the auxiliary_train rows have some weird additional 'train' wrapper
|
||||
self.ds = self.ds.map(lambda row: row['train'], remove_columns=['train'])
|
||||
|
||||
@property
|
||||
def eval_type(self):
|
||||
|
||||
Reference in New Issue
Block a user