Module mimir.custom_datasets
Helper functions for processing of data (ultimately used for membership inference evaluation)
Functions
def dump_to_cache(data: List,
cache_dir,
path,
filename: str,
min_length: int,
max_length: int,
n_samples: int,
max_tokens: int)-
Expand source code
def dump_to_cache(data: List, cache_dir, path, filename: str, min_length: int, max_length: int, n_samples: int, max_tokens: int): """ Cache a file (one sample per line) """ # Make sure path directory exists subdir = os.path.join(cache_dir, f"cache_{min_length}_{max_length}_{n_samples}_{max_tokens}", path) os.makedirs(subdir, exist_ok=True) # Dump to file # Since each datum has newlines in it potentially, use jsonl format save_data(os.path.join(subdir, filename + ".jsonl"), data)
Cache a file (one sample per line)
def load(name, cache_dir, **kwargs)
-
Expand source code
def load(name, cache_dir, **kwargs): if name in DATASETS: load_fn = globals()[f'load_{name}'] return load_fn(cache_dir=cache_dir, **kwargs) else: raise ValueError(f'Unknown dataset {name}')
def load_cached(cache_dir,
data_split: str,
filename: str,
min_length: int,
max_length: int,
n_samples: int,
max_tokens: int,
load_from_hf: bool = False)-
Expand source code
def load_cached(cache_dir, data_split: str, filename: str, min_length: int, max_length: int, n_samples: int, max_tokens: int, load_from_hf: bool = False): """" Read from cache if available. Used for certain pile sources and xsum to ensure fairness in comparison across attacks.runs. """ if load_from_hf: print("Loading from HuggingFace!") data_split = data_split.replace("train", "member") data_split = data_split.replace("test", "nonmember") if not filename.startswith("the_pile"): raise ValueError(f"HuggingFace data only available for The Pile.") for source in SOURCES_UPLOADED: # Got a match if source in filename and filename.startswith(f"the_pile_{source}"): split = filename.split(f"the_pile_{source}")[1] if split == "": # The way HF data is uploaded, no split is recorded as "none" split = "none" else: # remove the first underscore split = split[1:] # remove '<' , '>' split = split.replace("<", "").replace(">", "") # Remove "_truncated" from the end, if present split = split.rsplit("_truncated", 1)[0] # Load corresponding dataset ds = datasets.load_dataset("iamgroot42/mimir", name=source, split=split, trust_remote_code=True) data = ds[data_split] # Check if the number of samples is correct if len(data) != n_samples: raise ValueError(f"Requested {n_samples} samples, but only {len(data)} samples available. Potential mismatch in HuggingFace data and requested data.") return data # If got here, matching source was not found raise ValueError(f"Requested source {filename} not found in HuggingFace data.") else: file_path = os.path.join(cache_dir, f"cache_{min_length}_{max_length}_{n_samples}_{max_tokens}", data_split, filename + ".jsonl") if not os.path.exists(file_path): raise ValueError(f"Requested cache file {file_path} does not exist") data = load_data(file_path) return data
" Read from cache if available. Used for certain pile sources and xsum to ensure fairness in comparison across attacks.runs.
def load_data(file_path)
-
Expand source code
def load_data(file_path): """ Load data from a given filepath (.jsonl) """ with open(file_path, 'r') as f: data = [json.loads(line) for line in f.readlines()] return data
Load data from a given filepath (.jsonl)
def load_english(cache_dir)
-
Expand source code
def load_english(cache_dir): return load_language('en', cache_dir)
def load_german(cache_dir)
-
Expand source code
def load_german(cache_dir): return load_language('de', cache_dir)
def load_language(language, cache_dir)
-
Expand source code
def load_language(language, cache_dir): # load either the english or german portion of the wmt16 dataset assert language in ['en', 'de'] d = datasets.load_dataset('wmt16', 'de-en', split='train', cache_dir=cache_dir) docs = d['translation'] desired_language_docs = [d[language] for d in docs] lens = [len(d.split()) for d in desired_language_docs] sub = [d for d, l in zip(desired_language_docs, lens) if l > 100 and l < 150] return sub
def load_pubmed(cache_dir)
-
Expand source code
def load_pubmed(cache_dir): data = datasets.load_dataset('pubmed_qa', 'pqa_labeled', split='train', cache_dir=cache_dir) # combine question and long_answer data = [f'Question: {q} Answer:{SEPARATOR}{a}' for q, a in zip(data['question'], data['long_answer'])] return data
def load_writing(cache_dir=None)
-
Expand source code
def load_writing(cache_dir=None): writing_path = 'data/writingPrompts' with open(f'{writing_path}/valid.wp_source', 'r') as f: prompts = f.readlines() with open(f'{writing_path}/valid.wp_target', 'r') as f: stories = f.readlines() prompts = [process_prompt(prompt) for prompt in prompts] joined = [process_spaces(prompt + " " + story) for prompt, story in zip(prompts, stories)] filtered = [story for story in joined if 'nsfw' not in story and 'NSFW' not in story] random.seed(0) random.shuffle(filtered) return filtered
def process_prompt(prompt)
-
Expand source code
def process_prompt(prompt): return prompt.replace('[ WP ]', '').replace('[ OT ]', '')
def process_spaces(story)
-
Expand source code
def process_spaces(story): return story.replace( ' ,', ',').replace( ' .', '.').replace( ' ?', '?').replace( ' !', '!').replace( ' ;', ';').replace( ' \'', '\'').replace( ' ’ ', '\'').replace( ' :', ':').replace( '<newline>', '\n').replace( '`` ', '"').replace( ' \'\'', '"').replace( '\'\'', '"').replace( '.. ', '... ').replace( ' )', ')').replace( '( ', '(').replace( ' n\'t', 'n\'t').replace( ' i ', ' I ').replace( ' i\'', ' I\'').replace( '\\\'', '\'').replace( '\n ', '\n').strip()
def save_data(file_path, data)
-
Expand source code
def save_data(file_path, data): # Since each datum has newlines in it potentially, use jsonl format with open(file_path, 'w') as f: for datum in data: f.write(json.dumps(datum) + "\n")