Merge pull request #639 from mathieu-lacage/master
Verified: both paths return identical data (99,842 rows), and all splits under subset='all' load cleanly.
This commit is contained in:
+1
-1
@@ -166,7 +166,7 @@ train_tasks = [
|
|||||||
SmolTalk(split="train"), # 460K rows of general conversations
|
SmolTalk(split="train"), # 460K rows of general conversations
|
||||||
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
|
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
|
||||||
CustomJSON(filepath=identity_conversations_filepath), # 2 epochs of these
|
CustomJSON(filepath=identity_conversations_filepath), # 2 epochs of these
|
||||||
*[MMLU(subset="auxiliary_train", split="train") for _ in range(args.mmlu_epochs)], # 100K rows per epoch
|
*[MMLU(subset="all", split="auxiliary_train") for _ in range(args.mmlu_epochs)], # 100K rows per epoch
|
||||||
*[GSM8K(subset="main", split="train") for _ in range(args.gsm8k_epochs)], # 8K rows per epoch
|
*[GSM8K(subset="main", split="train") for _ in range(args.gsm8k_epochs)], # 8K rows per epoch
|
||||||
SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
|
SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
|
||||||
SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
|
SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
|
||||||
|
|||||||
+2
-2
@@ -135,12 +135,12 @@ if __name__ == "__main__":
|
|||||||
# very lightweight test of slicing
|
# very lightweight test of slicing
|
||||||
from tasks.mmlu import MMLU
|
from tasks.mmlu import MMLU
|
||||||
|
|
||||||
ds = MMLU(subset="auxiliary_train", split="train")
|
ds = MMLU(subset="all", split="auxiliary_train")
|
||||||
print("Length of MMLU: ", len(ds))
|
print("Length of MMLU: ", len(ds))
|
||||||
ex = ds[5]
|
ex = ds[5]
|
||||||
print("5th example: ", ex)
|
print("5th example: ", ex)
|
||||||
|
|
||||||
ds = MMLU(subset="auxiliary_train", split="train", start=5, stop=10)
|
ds = MMLU(subset="all", split="auxiliary_train", start=5, stop=10)
|
||||||
print("Length of sliced MMLU[5:10]: ", len(ds))
|
print("Length of sliced MMLU[5:10]: ", len(ds))
|
||||||
print("0th example of sliced MMLU: ", ds[0])
|
print("0th example of sliced MMLU: ", ds[0])
|
||||||
|
|
||||||
|
|||||||
+2
-7
@@ -13,16 +13,11 @@ class MMLU(Task):
|
|||||||
|
|
||||||
def __init__(self, subset, split, **kwargs):
|
def __init__(self, subset, split, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
assert subset in ["all", "auxiliary_train"], f"subset {subset} must be all|auxiliary_train"
|
assert subset in ["all"], f"subset {subset} must be all"
|
||||||
assert split in ["train", "validation", "dev", "test"], f"split {split} must be train|validation|dev|test"
|
assert split in ["auxiliary_train", "validation", "dev", "test"], f"split {split} must be auxiliary_train|validation|dev|test"
|
||||||
if subset == "auxiliary_train":
|
|
||||||
assert split == "train", "auxiliary_train must be split into train"
|
|
||||||
self.subset = subset
|
self.subset = subset
|
||||||
self.split = split
|
self.split = split
|
||||||
self.ds = load_dataset("cais/mmlu", subset, split=split).shuffle(seed=42)
|
self.ds = load_dataset("cais/mmlu", subset, split=split).shuffle(seed=42)
|
||||||
if subset == "auxiliary_train":
|
|
||||||
# I don't understand why but the auxiliary_train rows have some weird additional 'train' wrapper
|
|
||||||
self.ds = self.ds.map(lambda row: row['train'], remove_columns=['train'])
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def eval_type(self):
|
def eval_type(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user