Skip to content

Commit eb2113a

Browse files
authored
Add files via upload
1 parent a56a626 commit eb2113a

File tree

1 file changed

+50
-4
lines changed

1 file changed

+50
-4
lines changed

src/Assets/SentenceTransformer.py

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# modified 4.1.0 to add some commented out print statements for debugging; likely remove in next patch due to padding/trunction solved
1+
# modified 4.1.0 to modify "_text_length" method and add debugging
22
from __future__ import annotations
33

44
import copy
@@ -654,8 +654,36 @@ def encode(
654654

655655
for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar):
656656
sentences_batch = sentences_sorted[start_index : start_index + batch_size]
657+
658+
# #==DEBUG================================================================================================
659+
# print(f"\n=== DEBUG: Before tokenization ===")
660+
# print(f"Batch size: {len(sentences_batch)}")
661+
# print(f"Sentences in batch:")
662+
# for i, sent in enumerate(sentences_batch):
663+
# print(f" [{i}] Type: {type(sent)}, Length: {len(sent) if hasattr(sent, '__len__') else 'no len'}")
664+
# print(f" Content: {repr(sent)}...")
665+
# #==DEBUG================================================================================================
666+
657667
features = self.tokenize(sentences_batch)
658668

669+
# #==DEBUG================================================================================================
670+
# print(f"\n=== DEBUG: After tokenization (features dict) ===")
671+
# print(f"Features keys: {list(features.keys())}")
672+
# for key, value in features.items():
673+
# print(f" {key}:")
674+
# print(f" Type: {type(value)}")
675+
# if hasattr(value, 'shape'):
676+
# print(f" Shape: {value.shape}")
677+
# elif hasattr(value, '__len__'):
678+
# print(f" Length: {len(value)}")
679+
# if isinstance(value, (list, tuple)) and len(value) > 0:
680+
# print(f" First element type: {type(value[0])}")
681+
# if hasattr(value[0], '__len__'):
682+
# print(f" First element length: {len(value[0])}")
683+
# print(f" Sample content: {value}") # First 2 elements
684+
# print(f" Content preview: {str(value)}...")
685+
# #==DEBUG================================================================================================
686+
659687
# print(
660688
# f"SentenceTransformer.py - DEBUG: batch {start_index // batch_size} padded_side={self.tokenizer.padding_side if hasattr(self, 'tokenizer') else 'n/a'} "
661689
# f"max_len={self.tokenizer.model_max_length if hasattr(self, 'tokenizer') else 'n/a'} "
@@ -1541,14 +1569,32 @@ def push_to_hub(
15411569
return folder_url.pr_url
15421570
return folder_url.commit_url
15431571

1544-
def _text_length(self, text: list[int] | list[list[int]]) -> int:
1572+
# def _text_length(self, text: list[int] | list[list[int]]) -> int:
1573+
# """
1574+
# Help function to get the length for the input text. Text can be either
1575+
# a list of ints (which means a single text as input), or a tuple of list of ints
1576+
# (representing several text inputs to the model).
1577+
# """
1578+
1579+
# if isinstance(text, dict): # {key: value} case
1580+
# return len(next(iter(text.values())))
1581+
# elif not hasattr(text, "__len__"): # Object has no len() method
1582+
# return 1
1583+
# elif len(text) == 0 or isinstance(text[0], int): # Empty string or list of ints
1584+
# return len(text)
1585+
# else:
1586+
# return sum([len(t) for t in text]) # Sum of length of individual strings
1587+
1588+
# custom method that's more flexible and expansive
1589+
def _text_length(self, text: str | list[int] | list[list[int]]) -> int:
15451590
"""
15461591
Help function to get the length for the input text. Text can be either
15471592
a list of ints (which means a single text as input), or a tuple of list of ints
15481593
(representing several text inputs to the model).
15491594
"""
1550-
1551-
if isinstance(text, dict): # {key: value} case
1595+
if isinstance(text, str): # Handle string input directly
1596+
return len(text)
1597+
elif isinstance(text, dict): # {key: value} case
15521598
return len(next(iter(text.values())))
15531599
elif not hasattr(text, "__len__"): # Object has no len() method
15541600
return 1

0 commit comments

Comments
 (0)