From 2eeeb4c4076e9557a97497084173e8d1d46e3ca9 Mon Sep 17 00:00:00 2001 From: FanhaiLu1 Date: Tue, 23 Apr 2024 19:01:01 +0000 Subject: [PATCH 1/4] add comments --- jetstream/engine/token_utils.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/jetstream/engine/token_utils.py b/jetstream/engine/token_utils.py index 7265b3df..4126b875 100644 --- a/jetstream/engine/token_utils.py +++ b/jetstream/engine/token_utils.py @@ -16,7 +16,7 @@ from bisect import bisect_left import logging -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import jax import jax.numpy as jnp @@ -57,7 +57,8 @@ def tokenize_and_pad( is_bos: bool = True, prefill_lengths: Optional[List[int]] = None, max_prefill_length: Optional[int] = None, -) -> Tuple[jax.Array, int]: + jax_padding: bool = True, +) -> Tuple[Union[jax.Array, np.ndarray], int]: """Tokenize and pads a string. Args: @@ -67,6 +68,7 @@ def tokenize_and_pad( as prefill is typically used when beginning sequences. prefill_lengths: Buckets to pad the sequence to for static compilation. max_prefill_length: Maximum bucket to use. + jax_padding: convert to JAX padded tokens if True. Returns: tokens: Tokenized into integers. @@ -117,7 +119,9 @@ def tokenize_and_pad( padded_tokens = tokens[-padded_length:] else: padded_tokens = np.pad(tokens, (0, padding)) - return jnp.array(padded_tokens), true_length + if jax_padding: + padded_tokens = jnp.array(padded_tokens) + return padded_tokens, true_length def process_result_tokens( From 1c4de77347564a2dae162e6592c9c4c76452801c Mon Sep 17 00:00:00 2001 From: FanhaiLu1 Date: Tue, 23 Apr 2024 22:19:31 +0000 Subject: [PATCH 2/4] add unit test --- jetstream/tests/engine/test_token_utils.py | 45 ++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/jetstream/tests/engine/test_token_utils.py b/jetstream/tests/engine/test_token_utils.py index 4aa9792b..279cd5cd 100644 --- a/jetstream/tests/engine/test_token_utils.py +++ b/jetstream/tests/engine/test_token_utils.py @@ -18,6 +18,9 @@ import unittest from typing import List +import jax +import jax.numpy as jnp +import numpy as np from sentencepiece import SentencePieceProcessor from jetstream.engine import tokenizer_pb2, token_utils @@ -103,6 +106,48 @@ def test_underscore_in_output(self): self.assertEqual(mix_output, " `__") self.assertEqual(mix_output.lstrip(), decode_output) + def test_tokenize_and_pad_jax(self): + jax.config.update("jax_platform_name", "cpu") + self.setup() + s = "I believe the meaning of life is" + vocab = self.jt_tokenizer.vocab + max_prefill_length = 1024 + padded_tokens, true_length = token_utils.tokenize_and_pad( + s=s, + vocab=vocab, + max_prefill_length=max_prefill_length, + ) + print(f"------------- padded_tokens{padded_tokens}") + print(f"------------- true_length{true_length}") + expected_padded_tokens = jnp.array([1, 306, 4658, 278, 6593, 310, 2834, 338, + 0, 0, 0, 0, 0, 0, 0, 0]) + expected_true_length = 8 + self.assertTrue( + jnp.allclose(padded_tokens, expected_padded_tokens, atol=1e-7) + ) + self.assertEqual(true_length, expected_true_length) + + def test_tokenize_and_pad_np(self): + self.setup() + s = "I believe the meaning of life is" + vocab = self.jt_tokenizer.vocab + max_prefill_length = 1024 + padded_tokens, true_length = token_utils.tokenize_and_pad( + s=s, + vocab=vocab, + max_prefill_length=max_prefill_length, + jax_padding=False + ) + print(f"------------- padded_tokens{padded_tokens}") + print(f"------------- true_length{true_length}") + expected_padded_tokens = np.array([1, 306, 4658, 278, 6593, 310, 2834, 338, + 0, 0, 0, 0, 0, 0, 0, 0]) + expected_true_length = 8 + self.assertTrue( + np.allclose(padded_tokens, expected_padded_tokens, atol=1e-7) + ) + self.assertEqual(true_length, expected_true_length) + if __name__ == "__main__": unittest.main() From c1f5d7f0028c8731958ccd9b464e2f8f8d15a94c Mon Sep 17 00:00:00 2001 From: FanhaiLu1 Date: Tue, 23 Apr 2024 22:20:20 +0000 Subject: [PATCH 3/4] remove logging --- jetstream/tests/engine/test_token_utils.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/jetstream/tests/engine/test_token_utils.py b/jetstream/tests/engine/test_token_utils.py index 279cd5cd..f48bf522 100644 --- a/jetstream/tests/engine/test_token_utils.py +++ b/jetstream/tests/engine/test_token_utils.py @@ -117,8 +117,6 @@ def test_tokenize_and_pad_jax(self): vocab=vocab, max_prefill_length=max_prefill_length, ) - print(f"------------- padded_tokens{padded_tokens}") - print(f"------------- true_length{true_length}") expected_padded_tokens = jnp.array([1, 306, 4658, 278, 6593, 310, 2834, 338, 0, 0, 0, 0, 0, 0, 0, 0]) expected_true_length = 8 @@ -138,8 +136,6 @@ def test_tokenize_and_pad_np(self): max_prefill_length=max_prefill_length, jax_padding=False ) - print(f"------------- padded_tokens{padded_tokens}") - print(f"------------- true_length{true_length}") expected_padded_tokens = np.array([1, 306, 4658, 278, 6593, 310, 2834, 338, 0, 0, 0, 0, 0, 0, 0, 0]) expected_true_length = 8 From 057fa6c76e538d7ec57c8ab505d100839b1e9e39 Mon Sep 17 00:00:00 2001 From: FanhaiLu1 Date: Tue, 23 Apr 2024 23:16:34 +0000 Subject: [PATCH 4/4] format token code --- jetstream/engine/token_utils.py | 3 ++- jetstream/tests/engine/test_token_utils.py | 31 +++++++++++----------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/jetstream/engine/token_utils.py b/jetstream/engine/token_utils.py index 4126b875..28fcdf1f 100644 --- a/jetstream/engine/token_utils.py +++ b/jetstream/engine/token_utils.py @@ -183,7 +183,8 @@ def process_result_tokens( break else: try: - token = mix_decode(vocab, tok_id) # pytype: disable=attribute-error + # pytype: disable=attribute-error + token = mix_decode(vocab, tok_id) except ValueError: # This error only occurs when using tests where the vocab range is # computed via addition and int->char is computed using chr(). Real diff --git a/jetstream/tests/engine/test_token_utils.py b/jetstream/tests/engine/test_token_utils.py index f48bf522..92e88cfc 100644 --- a/jetstream/tests/engine/test_token_utils.py +++ b/jetstream/tests/engine/test_token_utils.py @@ -101,7 +101,8 @@ def test_sp_vs_seqio(self): def test_underscore_in_output(self): self.setup() n = 21326 - mix_output = token_utils.mix_decode(vocab=self.jt_tokenizer.vocab, tok_id=n) + mix_output = token_utils.mix_decode( + vocab=self.jt_tokenizer.vocab, tok_id=n) decode_output = self.sp_tokenizer.decode([n]) self.assertEqual(mix_output, " `__") self.assertEqual(mix_output.lstrip(), decode_output) @@ -113,16 +114,16 @@ def test_tokenize_and_pad_jax(self): vocab = self.jt_tokenizer.vocab max_prefill_length = 1024 padded_tokens, true_length = token_utils.tokenize_and_pad( - s=s, - vocab=vocab, - max_prefill_length=max_prefill_length, - ) + s=s, + vocab=vocab, + max_prefill_length=max_prefill_length, + ) expected_padded_tokens = jnp.array([1, 306, 4658, 278, 6593, 310, 2834, 338, 0, 0, 0, 0, 0, 0, 0, 0]) expected_true_length = 8 self.assertTrue( - jnp.allclose(padded_tokens, expected_padded_tokens, atol=1e-7) - ) + jnp.allclose(padded_tokens, expected_padded_tokens, atol=1e-7) + ) self.assertEqual(true_length, expected_true_length) def test_tokenize_and_pad_np(self): @@ -131,17 +132,17 @@ def test_tokenize_and_pad_np(self): vocab = self.jt_tokenizer.vocab max_prefill_length = 1024 padded_tokens, true_length = token_utils.tokenize_and_pad( - s=s, - vocab=vocab, - max_prefill_length=max_prefill_length, - jax_padding=False - ) + s=s, + vocab=vocab, + max_prefill_length=max_prefill_length, + jax_padding=False + ) expected_padded_tokens = np.array([1, 306, 4658, 278, 6593, 310, 2834, 338, - 0, 0, 0, 0, 0, 0, 0, 0]) + 0, 0, 0, 0, 0, 0, 0, 0]) expected_true_length = 8 self.assertTrue( - np.allclose(padded_tokens, expected_padded_tokens, atol=1e-7) - ) + np.allclose(padded_tokens, expected_padded_tokens, atol=1e-7) + ) self.assertEqual(true_length, expected_true_length)