MMLU main split is named auxiliary_train, not train

This commit is contained in:
Mathieu Lacage
2026-03-13 13:19:10 +01:00
parent f068604948
commit a641b6ca96
3 changed files with 5 additions and 10 deletions
+2 -2
View File
@@ -135,12 +135,12 @@ if __name__ == "__main__":
# very lightweight test of slicing
from tasks.mmlu import MMLU
ds = MMLU(subset="auxiliary_train", split="train")
ds = MMLU(subset="all", split="auxiliary_train")
print("Length of MMLU: ", len(ds))
ex = ds[5]
print("5th example: ", ex)
ds = MMLU(subset="auxiliary_train", split="train", start=5, stop=10)
ds = MMLU(subset="all", split="auxiliary_train", start=5, stop=10)
print("Length of sliced MMLU[5:10]: ", len(ds))
print("0th example of sliced MMLU: ", ds[0])
+2 -7
View File
@@ -13,16 +13,11 @@ class MMLU(Task):
def __init__(self, subset, split, **kwargs):
super().__init__(**kwargs)
assert subset in ["all", "auxiliary_train"], f"subset {subset} must be all|auxiliary_train"
assert split in ["train", "validation", "dev", "test"], f"split {split} must be train|validation|dev|test"
if subset == "auxiliary_train":
assert split == "train", "auxiliary_train must be split into train"
assert subset in ["all"], f"subset {subset} must be all"
assert split in ["auxiliary_train", "validation", "dev", "test"], f"split {split} must be auxiliary_train|validation|dev|test"
self.subset = subset
self.split = split
self.ds = load_dataset("cais/mmlu", subset, split=split).shuffle(seed=42)
if subset == "auxiliary_train":
# I don't understand why but the auxiliary_train rows have some weird additional 'train' wrapper
self.ds = self.ds.map(lambda row: row['train'], remove_columns=['train'])
@property
def eval_type(self):