Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
6839941
refactor: fixed wrong number of bytes per token calculation
le1nux Dec 12, 2024
38e7582
feat: aded test for _get_required_num_of_bytes_to_repr
le1nux Dec 12, 2024
a40069f
chore: updated CHANGELOG_DEV
le1nux Dec 12, 2024
ea23f0e
refactor: fixed mismatch between char and byte index
le1nux Dec 12, 2024
cd54ae2
refactor: improved the test for index creation
le1nux Dec 12, 2024
e3e9d67
chore: updated CHANGELOG_DEV.md
le1nux Dec 12, 2024
0483362
refactor: the num bytes per token is now power of two
le1nux Dec 13, 2024
88b1201
chore: updated changelog
le1nux Dec 13, 2024
26f6b51
fix: only appending the eod token when tokenizer has not already adde…
le1nux Dec 13, 2024
1c1ccdc
feat: added verification script for indexation and tokenization
le1nux Dec 13, 2024
631ebcb
chore: updated changelog
le1nux Dec 13, 2024
957a1da
feat: added is_special_token_id to tokenizer and add_special_tokens f…
le1nux Dec 14, 2024
f41fcd3
feat: added check to PackedDataGenerator enforcing the eod token to b…
le1nux Dec 14, 2024
a3f911e
refactor: improved consistency between HF and SP tokenizers
le1nux Dec 15, 2024
8c31405
refactor: improved consistency between HF and SP tokenizers
le1nux Dec 15, 2024
6c53d8a
refactor: polished the indexation and tokenization including extensiv…
le1nux Dec 15, 2024
85f13ac
refactor: included requested review changes
le1nux Dec 16, 2024
14718cc
chore: added further checks making sure that tokenizer.get_token_id r…
le1nux Jan 13, 2025
4648f1b
chore: fixed data paths in tokenizer tests
le1nux Jan 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion CHANGELOG_DEV.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,49 @@ We can also now configure the encoding used for reading the documents. If encodi


**Breaking Changes**
* None
* None


## PR #280 Bug fix: the number of bytes per token were wrongly calculated

This PR fixes the bytes per token calculation.
Generally, we estimate how many bytes are needed to encode the full range of the vocabulary.
E.g., for a vocab size > 65536, we need 3 bytes for each token in the pbin file.

The calculation was wrong but coincidentally correct for the GPT2 tokenizer.



## PR #281: Bug fix: The char-based index is not always consistent with the byte-based index.

The first character of the string "ø This is..." is written on disc as two bytes, namely \xc3\xb8, when encoded as utf-8.
Therefore, the byte-based index has one more byte/char than the char-based index.

For consistency, we don't consider any char-based indexes anymore and always refer to byte-based indexes.


## PR #282: Bug fix: Enforce power of 2 number of bytes per token


Previously, the number of bytes per token was calculated by `math.ceil(log_2(vocab_size)/8)`, leading to ranges between 1 and 4 bytes.
However, the dataset implementation only support 1, 2 and 4 bytes per token, as defined here

https://github.com/Modalities/modalities/blob/0483362abac93e45850e56adaea7921e96836d59/src/modalities/dataloader/dataset.py#L202-L206

and

https://github.com/Modalities/modalities/blob/0483362abac93e45850e56adaea7921e96836d59/src/modalities/dataloader/dataset.py#L233-L234

I added a switch case that maps to the respective byte sizes, when packing the data.

This adds some inefficiencies as a vobabulary size > 65536 already requires 4 bytes per token, effectively doubling the storage requirements.


## PR #283: Bug fix: Only append eod token once when packing / tokenizing

Some HF tokenisers such as `xlm-roberta-large` add special tokens (e.g., eod token) automatically when encoding text, whereas others, such as `gpt2`, do not add special tokens.

This side-effect in the transformers library has lead to the eod token being appended twice when tokenizing / packing our data. We added a check for this and only append the eod token once now:
https://github.com/Modalities/modalities/blob/1c1ccdc973283c45bc8c9fadf4d20f03e435cd04/src/modalities/dataloader/create_packed_data.py#L327-L330

Additionally, I added a script that verifies the consistency of the indexation and tokenization of a given JSONL file. We run the indexation and tokenization routines in modalities and compare it to tokenized JSONL file to which we applied the HF tokenizer directly.
Binary file not shown.
3 changes: 2 additions & 1 deletion src/modalities/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,9 @@ def CMD_entry_point_pack_encoded_data(config_path: FilePath):
Args:
config_path (FilePath): Path to the config file describing the tokenization setup.
"""
config_dict = load_app_config_dict(config_path)

pack_encoded_data(config_path=config_path)
pack_encoded_data(config_dict=config_dict)


@data.command(name="merge_packed_data")
Expand Down
8 changes: 3 additions & 5 deletions src/modalities/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import modalities.inference.inference as inference
from modalities.checkpointing.checkpoint_conversion import CheckpointConversion
from modalities.config.component_factory import ComponentFactory
from modalities.config.config import load_app_config_dict
from modalities.config.instantiation_models import PackedDatasetComponentsInstantiationModel
from modalities.dataloader.create_index import IndexGenerator
from modalities.dataloader.create_packed_data import EmbeddedStreamData, PackedDataGenerator, join_embedded_stream_data
Expand Down Expand Up @@ -71,14 +70,14 @@ def convert_pytorch_to_hf_checkpoint(
return hf_model


def pack_encoded_data(config_path: FilePath):
def pack_encoded_data(config_dict: dict):
"""Packs and encodes an indexed, large jsonl-file.
(see also `create_index` for more information)
Returns .pbin-file, which can be inserted into a training process directly
and does not require its original jsonl-file or the respective index file anymore.

Args:
config_path (FilePath): Path to the config file describing the tokenization setup.
config_dict (dict): Dictionary containing the configuration for the packed data generation.
"""

# TODO: if we want to use alternative entrypoints together with the ResolverRegistry,
Expand All @@ -87,11 +86,10 @@ def pack_encoded_data(config_path: FilePath):
# One would requires an object of it to instantiate the ResolverRegistry.
# This could get resolved by implementing on own ResolverRegistry for each entrypoint or adapting the existing
# ResolverRegistry to work dynamically with any type-hinted config object from config.py.
config = load_app_config_dict(config_path)
registry = Registry(COMPONENTS)
component_factory = ComponentFactory(registry=registry)
components: PackedDatasetComponentsInstantiationModel = component_factory.build_components(
config_dict=config, components_model_type=PackedDatasetComponentsInstantiationModel
config_dict=config_dict, components_model_type=PackedDatasetComponentsInstantiationModel
)

generator = PackedDataGenerator(
Expand Down
12 changes: 6 additions & 6 deletions src/modalities/dataloader/create_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ def __init__(self, src_file: Path, drop_faulty_entries: bool = False):
"""
self.src_file = src_file
self.drop_faulty_entries = drop_faulty_entries
with self.src_file.open(mode="r") as fin:
with self.src_file.open(mode="rb") as fin:
# Move the cursor to the end of the file
fin.seek(0, os.SEEK_END)
# Get number of characters in the file
self._total_num_chars = fin.tell()
# Get number of bytes in the file
self._total_num_bytes = fin.tell()
self._queue_of_raw_lines = queue.Queue()
self._index_map = []
self._exception_buffer = []
Expand Down Expand Up @@ -105,14 +105,14 @@ def _reader_thread(self):
# the end of the file is reached. Each line is put into a queue along with its cursor position. If any
# errors are detected, the method returns immediately.

with open(self.src_file, "r") as fin:
with open(self.src_file, "rb") as fin:
while True:
cursor = fin.tell()
line = fin.readline()
if self._check_for_parallel_errors():
return
if fin.tell() == self._total_num_chars:
if line[-1] == "\n":
if fin.tell() == self._total_num_bytes:
if line.endswith(b"\n"):
line = line[:-1]
self._queue_of_raw_lines.put((cursor, line))
break
Expand Down
28 changes: 21 additions & 7 deletions src/modalities/dataloader/create_packed_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import multiprocessing
import os
import pickle
import traceback
import warnings
from io import BufferedWriter
from pathlib import Path
Expand Down Expand Up @@ -62,11 +63,11 @@ def __init__(
self.tokenizer = tokenizer
self.eod_token = eod_token
self._token_size_in_bytes = self._get_required_num_of_bytes_to_repr(self.tokenizer.vocab_size)
encoded_eod_token = self.tokenizer.get_token_id(self.eod_token)
self._encoded_eos_token_as_bytes = self._encoded_token_to_bytes(encoded_eod_token)
eod_token_id = self.tokenizer.get_token_id(self.eod_token)
self._encoded_eod_token_as_bytes = self._encoded_token_to_bytes(eod_token_id)
self.jq_filter = jq.compile(jq_pattern)
self._number_of_processes = number_of_processes
self._reader = LargeFileLinesReader(src_path, index_path=index_path)
self._reader = LargeFileLinesReader(src_path, index_path=index_path) # reads string with utf-8 encoding
self._total_num_of_tokens = 0
self._raw_samples_queue = multiprocessing.Queue(maxsize=raw_samples_queue_size)
self.processed_samples_queue = multiprocessing.Queue(maxsize=processed_samples_queue_size)
Expand All @@ -84,7 +85,17 @@ def _get_required_num_of_bytes_to_repr(int_to_get_repr: int) -> int:
Returns:
int: The number of bytes required to represent the integer.
"""
return math.ceil(math.log(math.log2(int_to_get_repr), 8))
# we currently only supoprt token sizes of 1, 2 and 4 bytes, as implemented here:
# https://github.com/Modalities/modalities/blob/fix_char_bytes_indexation_mismatch/src/modalities/dataloader/dataset.py#L202
num_bytes = math.ceil(math.log2(int_to_get_repr) / 8)
if num_bytes == 1:
return 1
elif num_bytes == 2:
return 2
elif num_bytes <= 4:
return 4
else:
raise ValueError("Currently only support token byte sizes of 1, 2, and 4.")

def _encoded_token_to_bytes(self, encoded_token: int) -> bytes:
"""
Expand Down Expand Up @@ -250,7 +261,6 @@ def _reader_thread(self) -> Callable:
def reader():
batch = []
for line_id, line in tqdm(enumerate(self._reader), desc="Reading jsonl", disable=True):
# line = self._reader[line_id]
batch.append((line_id, line))
if len(batch) % self.processing_batch_size == 0:
self._raw_samples_queue.put(batch)
Expand Down Expand Up @@ -289,9 +299,10 @@ def _process_thread(self, process_id: int):
)
except Exception as exception:
warnings.warn(
f"Could not process line of number {line_id} within process {process_id}. "
f"Could not process line {line_id} in {self.src_path} within process {process_id}. "
f"Raised the following error: {exception=}"
)
traceback.print_exc()

def _update_data_length_in_pre_allocated_header(self, dst_path: Path, index_list: list[tuple[int, int]]):
# Update the length of the data section in the pre-allocated header of the destination file.
Expand All @@ -312,7 +323,10 @@ def _process_line(self, line: str, process_id: int) -> bytes:
tokens = self.tokenizer.tokenize(jq_retrieved_text)
if len(tokens) == 0:
raise EmptySampleError("Received empty sample...")
return b"".join(map(self._encoded_token_to_bytes, tokens)) + self._encoded_eos_token_as_bytes
token_byte_string = b"".join(map(self._encoded_token_to_bytes, tokens))
if not token_byte_string.endswith(self._encoded_eod_token_as_bytes):
token_byte_string = token_byte_string + self._encoded_eod_token_as_bytes
return token_byte_string


class EmbeddedStreamData:
Expand Down
57 changes: 52 additions & 5 deletions src/modalities/tokenization/tokenizer_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from abc import ABC
from typing import Optional

Expand Down Expand Up @@ -62,6 +63,20 @@ def get_token_id(self, token: str) -> int:
"""
raise NotImplementedError

def is_special_token_id(self, token_id: int) -> bool:
"""Returns whether a token ID is a special token ID.

Args:
token_id (int): Token ID to check.

Raises:
NotImplementedError: Must be implemented by a subclass.

Returns:
bool: Flag whether the token ID is a special token ID.
"""
raise NotImplementedError


class PreTrainedHFTokenizer(TokenizerWrapper):
"""Wrapper for pretrained Hugging Face tokenizers."""
Expand Down Expand Up @@ -102,6 +117,7 @@ def __init__(
self.max_length = max_length
self.truncation = truncation
self.padding = padding
self.special_token_ids = set(self.tokenizer.all_special_ids)

@property
def vocab_size(self) -> int:
Expand Down Expand Up @@ -163,10 +179,25 @@ def get_token_id(self, token: str) -> int:
int: Token ID.
"""
token_id = self.tokenizer.convert_tokens_to_ids(token)
if isinstance(token_id, list):
if not isinstance(token_id, int):
raise ValueError("Token is not represented by a single token id!")
if token_id is None:
raise ValueError("Token is not represented by a single token id!")
elif token_id == self.tokenizer.unk_token_id:
warnings.warn(f"The provided eod token {token} has the same token id ({token_id}) as the unk token")
return token_id

def is_special_token_id(self, token_id: int) -> bool:
"""Returns whether a token ID is a special token ID.

Args:
token_id (int): Token ID to check.

Returns:
bool: Flag whether the token ID is a special token ID.
"""
return token_id in self.special_token_ids


class PreTrainedSPTokenizer(TokenizerWrapper):
"""Wrapper for pretrained SentencePiece tokenizers."""
Expand All @@ -189,8 +220,8 @@ def tokenize(self, text: str) -> list[int]:
Returns:
list[int]: List of token IDs.
"""
tokens = self.tokenizer.encode(text)
return tokens
token_ids = self.tokenizer.Encode(text)
return token_ids

def decode(self, token_ids: list[int]) -> str:
"""Decodes a list of token IDs into the original text.
Expand All @@ -201,7 +232,7 @@ def decode(self, token_ids: list[int]) -> str:
Returns:
str: Decoded text.
"""
decoded_text = self.tokenizer.decode(token_ids)
decoded_text = self.tokenizer.Decode(token_ids)
return decoded_text

@property
Expand All @@ -226,6 +257,22 @@ def get_token_id(self, token: str) -> int:
int: Token ID.
"""
piece_id = self.tokenizer.PieceToId(token)
if not isinstance(piece_id, int):
raise ValueError("Token cannot be represented by a single token ID!")
if piece_id == self.tokenizer.unk_id():
raise ValueError("Token is not represented by a single token id!")
raise ValueError("Token cannot be represented by a single token id!")
return piece_id

def is_special_token_id(self, token_id: int) -> bool:
"""Returns whether a token ID is a special token ID.

Args:
token_id (int): Token ID to check.

Raises:
NotImplementedError: Must be implemented by a subclass.

Returns:
bool: Flag whether the token ID is a special token ID.
"""
return self.tokenizer.IsControl(token_id)
Loading
Loading