ezpz.data.hf⚓︎
- See ezpz/data/
hf.py
ezpz/datasets/hf.py
HuggingFace Datasets loading and tokenization.
ToyTextDataset
⚓︎
Bases: Dataset
Pads or truncates sentences to a fixed length.
Source code in src/ezpz/data/hf.py
class ToyTextDataset(Dataset):
"""Pads or truncates sentences to a fixed length."""
def __init__(
self, texts: List[str], vocab: Dict[str, int], seq_len: int = 12
):
self.texts = texts
self.vocab = vocab
self.seq_len = seq_len
self.pad_id = vocab["<pad>"]
self.unk_id = vocab["<unk>"]
def __len__(self) -> int:
return len(self.texts)
def _encode(self, text: str) -> torch.Tensor:
tokens = [
self.vocab.get(tok, self.unk_id) for tok in text.lower().split()
]
tokens = tokens[: self.seq_len]
tokens += [self.pad_id] * (self.seq_len - len(tokens))
return torch.tensor(tokens, dtype=torch.long)
def __getitem__(self, idx: int) -> torch.Tensor: # type:ignore
return self._encode(self.texts[idx])
build_vocab(texts)
⚓︎
Create a tiny vocabulary from a list of strings.
Source code in src/ezpz/data/hf.py
def build_vocab(texts: Iterable[str]) -> Tuple[Dict[str, int], Dict[int, str]]:
"""Create a tiny vocabulary from a list of strings."""
specials = ["<pad>", "<unk>"]
words = sorted({word for text in texts for word in text.lower().split()})
vocab = {tok: idx for idx, tok in enumerate(specials + words)}
inv_vocab = {idx: tok for tok, idx in vocab.items()}
return vocab, inv_vocab
get_hf_text_dataset(*, dataset_name, split, text_column, tokenizer_name, seq_len, limit, seed)
⚓︎
Build a tokenized HF dataset with input_ids + attention_mask.
Returns:
| Type | Description |
|---|---|
tuple[Dataset, AutoTokenizer]
|
tokenized dataset (torch formatted) and tokenizer. |
Source code in src/ezpz/data/hf.py
def get_hf_text_dataset(
*,
dataset_name: str,
split: str,
text_column: str,
tokenizer_name: str,
seq_len: int,
limit: int,
seed: int,
) -> tuple[datasets.Dataset, AutoTokenizer]:
"""
Build a tokenized HF dataset with input_ids + attention_mask.
Returns:
tokenized dataset (torch formatted) and tokenizer.
"""
if seq_len <= 0:
raise ValueError("seq_len must be > 0 for HF dataset tokenization.")
logger.info(
"Tokenizing HF dataset %s split=%s column=%s limit=%s seq_len=%s",
dataset_name,
split,
text_column,
limit,
seq_len,
)
dataset = datasets.load_dataset(dataset_name, split=split)
if (
cnames := getattr(dataset, "column_names")
) and text_column not in list(cnames):
raise ValueError(
f"text_column '{text_column}' not in dataset columns {dataset.column_names}"
)
if limit > 0 and limit < len(dataset):
dataset = dataset.shuffle(seed=seed)
dataset = dataset.select(range(limit))
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
max_length = seq_len + 1
def tokenize_function(examples):
return tokenizer(
examples[text_column],
padding="max_length",
truncation=True,
max_length=max_length,
return_attention_mask=True,
)
tokenized = dataset.map(
tokenize_function,
batched=True,
remove_columns=dataset.column_names,
desc="Tokenizing HF dataset",
)
tokenized.set_format(
type="torch", columns=["input_ids", "attention_mask"]
)
tokenized.pad_id = tokenizer.pad_token_id # type: ignore[attr-defined]
tokenized.vocab_size = tokenizer.vocab_size # type: ignore[attr-defined]
return tokenized, tokenizer
load_hf_texts(dataset_name, split, text_column, limit)
⚓︎
Pull a small slice of text from a Hugging Face dataset for quick experiments.
This uses only a limited number of rows (limit) to keep the example light.
Source code in src/ezpz/data/hf.py
def load_hf_texts(
dataset_name: str,
split: str,
text_column: str,
limit: int,
) -> list[str]:
"""
Pull a small slice of text from a Hugging Face dataset for quick experiments.
This uses only a limited number of rows (`limit`) to keep the example light.
"""
try:
from datasets import load_dataset # type: ignore
except Exception as exc: # pragma: no cover - best-effort import
raise RuntimeError(
"datasets package is required for --hf-dataset usage"
) from exc
logger.info(
"Loading HF dataset %s split=%s column=%s limit=%s",
dataset_name,
split,
text_column,
limit,
)
dataset = load_dataset(dataset_name, split=split)
# assert isinstance(dataset, datasets.Data)
# if text_column not in list(dataset.column_names):
if (
cnames := getattr(dataset, "column_names")
) and text_column not in list(cnames):
raise ValueError(
f"text_column '{text_column}' not in dataset columns {dataset.column_names}"
)
else:
assert callable(getattr(dataset, "select"))
total = len(dataset)
if limit <= 0:
raise ValueError("limit must be > 0 for HF dataset sampling.")
if limit >= total:
indices = list(range(total))
else:
seed = int(os.environ.get("EZPZ_HF_SAMPLE_SEED", "1337"))
try:
dataset = dataset.shuffle(seed=seed)
indices = list(range(limit))
except Exception:
rng = torch.Generator().manual_seed(seed)
indices = torch.randperm(total, generator=rng)[:limit].tolist()
texts = [
str(row[text_column]) for row in dataset.select(indices)
if str(row.get(text_column, "")).strip()
]
if not texts:
raise ValueError("No text rows found from HF dataset.")
return texts
split_dataset(data_args, train_split_name='train', validation_split_name=None, cache_dir=None, token=None, trust_remote_code=False)
⚓︎
Splits the dataset into training and validation sets based on the provided split names.
Args:
Source code in src/ezpz/data/hf.py
def split_dataset(
data_args: HfDataTrainingArguments,
train_split_name: str = "train",
validation_split_name: Optional[str] = None,
cache_dir: Optional[str | os.PathLike | Path] = None,
token: Optional[str] = None,
trust_remote_code: bool = False,
# model_args: HfModelArguments,
) -> datasets.IterableDatasetDict | datasets.DatasetDict:
"""
Splits the dataset into training and validation sets based on the provided split names.
Args:
"""
dsets = {}
# if (
# validation_split_name not in raw_datasets.keys() and training_args.do_eval
# ): # type:ignore
# assert data_args.dataset_name is not None, (
# "dataset_name must be provided to split the dataset."
# )
dataset_name = data_args.dataset_name
assert dataset_name is not None, (
"dataset_name must be provided to split the dataset."
)
cache_dir = (
Path("./.cache/hf/datasets") if cache_dir is None else cache_dir
)
assert cache_dir is not None and isinstance(cache_dir, (str, os.PathLike))
cache_dir = Path(cache_dir).as_posix()
if validation_split_name is not None:
try:
dsets[validation_split_name] = datasets.load_dataset( # type:ignore
dataset_name,
data_args.dataset_config_name,
split=f"{train_split_name}[:{data_args.validation_split_percentage}%]",
cache_dir=cache_dir,
token=token,
streaming=data_args.streaming,
trust_remote_code=trust_remote_code,
)
dsets[train_split_name] = datasets.load_dataset( # type: ignore
dataset_name,
data_args.dataset_config_name,
split=f"{train_split_name}[{data_args.validation_split_percentage}%:]",
cache_dir=cache_dir,
token=token,
streaming=data_args.streaming,
trust_remote_code=trust_remote_code,
)
except ValueError:
# In some cases, the dataset doesn't support slicing.
# In this case, we just use the full training set as validation set.
dsets[validation_split_name] = datasets.load_dataset( # type:ignore
dataset_name,
data_args.dataset_config_name,
split=train_split_name,
cache_dir=cache_dir,
token=token,
streaming=data_args.streaming,
trust_remote_code=trust_remote_code,
)
try:
dsets[train_split_name] = datasets.load_dataset( # type:ignore
dataset_name,
data_args.dataset_config_name,
split=f"{train_split_name}[:{data_args.validation_split_percentage}%]",
cache_dir=cache_dir,
token=token,
streaming=data_args.streaming,
trust_remote_code=trust_remote_code,
)
except Exception:
# In some cases, the dataset doesn't support slicing.
# In this case, we just use the full training set as validation set.
dsets[train_split_name] = datasets.load_dataset( # type:ignore
dataset_name,
data_args.dataset_config_name,
split=train_split_name,
cache_dir=cache_dir,
token=token,
streaming=data_args.streaming,
trust_remote_code=trust_remote_code,
)
if data_args.streaming:
return datasets.IterableDatasetDict(dsets)
return datasets.DatasetDict(dsets)