|
1 | | -# modified 4.1.0 to add some commented out print statements for debugging; likely remove in next patch due to padding/trunction solved |
| 1 | +# modified 4.1.0 to modify "_text_length" method and add debugging |
2 | 2 | from __future__ import annotations |
3 | 3 |
|
4 | 4 | import copy |
@@ -654,8 +654,36 @@ def encode( |
654 | 654 |
|
655 | 655 | for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar): |
656 | 656 | sentences_batch = sentences_sorted[start_index : start_index + batch_size] |
| 657 | + |
| 658 | + # #==DEBUG================================================================================================ |
| 659 | + # print(f"\n=== DEBUG: Before tokenization ===") |
| 660 | + # print(f"Batch size: {len(sentences_batch)}") |
| 661 | + # print(f"Sentences in batch:") |
| 662 | + # for i, sent in enumerate(sentences_batch): |
| 663 | + # print(f" [{i}] Type: {type(sent)}, Length: {len(sent) if hasattr(sent, '__len__') else 'no len'}") |
| 664 | + # print(f" Content: {repr(sent)}...") |
| 665 | + # #==DEBUG================================================================================================ |
| 666 | + |
657 | 667 | features = self.tokenize(sentences_batch) |
658 | 668 |
|
| 669 | + # #==DEBUG================================================================================================ |
| 670 | + # print(f"\n=== DEBUG: After tokenization (features dict) ===") |
| 671 | + # print(f"Features keys: {list(features.keys())}") |
| 672 | + # for key, value in features.items(): |
| 673 | + # print(f" {key}:") |
| 674 | + # print(f" Type: {type(value)}") |
| 675 | + # if hasattr(value, 'shape'): |
| 676 | + # print(f" Shape: {value.shape}") |
| 677 | + # elif hasattr(value, '__len__'): |
| 678 | + # print(f" Length: {len(value)}") |
| 679 | + # if isinstance(value, (list, tuple)) and len(value) > 0: |
| 680 | + # print(f" First element type: {type(value[0])}") |
| 681 | + # if hasattr(value[0], '__len__'): |
| 682 | + # print(f" First element length: {len(value[0])}") |
| 683 | + # print(f" Sample content: {value}") # First 2 elements |
| 684 | + # print(f" Content preview: {str(value)}...") |
| 685 | + # #==DEBUG================================================================================================ |
| 686 | + |
659 | 687 | # print( |
660 | 688 | # f"SentenceTransformer.py - DEBUG: batch {start_index // batch_size} padded_side={self.tokenizer.padding_side if hasattr(self, 'tokenizer') else 'n/a'} " |
661 | 689 | # f"max_len={self.tokenizer.model_max_length if hasattr(self, 'tokenizer') else 'n/a'} " |
@@ -1541,14 +1569,32 @@ def push_to_hub( |
1541 | 1569 | return folder_url.pr_url |
1542 | 1570 | return folder_url.commit_url |
1543 | 1571 |
|
1544 | | - def _text_length(self, text: list[int] | list[list[int]]) -> int: |
| 1572 | + # def _text_length(self, text: list[int] | list[list[int]]) -> int: |
| 1573 | + # """ |
| 1574 | + # Help function to get the length for the input text. Text can be either |
| 1575 | + # a list of ints (which means a single text as input), or a tuple of list of ints |
| 1576 | + # (representing several text inputs to the model). |
| 1577 | + # """ |
| 1578 | + |
| 1579 | + # if isinstance(text, dict): # {key: value} case |
| 1580 | + # return len(next(iter(text.values()))) |
| 1581 | + # elif not hasattr(text, "__len__"): # Object has no len() method |
| 1582 | + # return 1 |
| 1583 | + # elif len(text) == 0 or isinstance(text[0], int): # Empty string or list of ints |
| 1584 | + # return len(text) |
| 1585 | + # else: |
| 1586 | + # return sum([len(t) for t in text]) # Sum of length of individual strings |
| 1587 | + |
| 1588 | + # custom method that's more flexible and expansive |
| 1589 | + def _text_length(self, text: str | list[int] | list[list[int]]) -> int: |
1545 | 1590 | """ |
1546 | 1591 | Help function to get the length for the input text. Text can be either |
1547 | 1592 | a list of ints (which means a single text as input), or a tuple of list of ints |
1548 | 1593 | (representing several text inputs to the model). |
1549 | 1594 | """ |
1550 | | - |
1551 | | - if isinstance(text, dict): # {key: value} case |
| 1595 | + if isinstance(text, str): # Handle string input directly |
| 1596 | + return len(text) |
| 1597 | + elif isinstance(text, dict): # {key: value} case |
1552 | 1598 | return len(next(iter(text.values()))) |
1553 | 1599 | elif not hasattr(text, "__len__"): # Object has no len() method |
1554 | 1600 | return 1 |
|
0 commit comments