Skip to content

Commit f473bbc

Browse files
authored
Add files via upload
1 parent ade1fa1 commit f473bbc

File tree

1 file changed

+46
-32
lines changed

1 file changed

+46
-32
lines changed

src/database_interactions.py

Lines changed: 46 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from collections import defaultdict, deque
1818
import shutil
1919
import random
20+
import sys
21+
import traceback
2022

2123
import numpy as np
2224
from langchain_huggingface import HuggingFaceEmbeddings
@@ -33,6 +35,23 @@
3335
# logging.basicConfig(level=logging.DEBUG, force=True)
3436
logger = logging.getLogger(__name__)
3537

38+
39+
# DEBUG - implement later to potentially see the size of objects
40+
def get_memory_usage(obj, name):
41+
"""Helper function to get memory usage of an object"""
42+
try:
43+
size_bytes = sys.getsizeof(obj)
44+
if hasattr(obj, '__len__'):
45+
# For lists/collections, also get size of contained objects
46+
if len(obj) > 0:
47+
item_size = sys.getsizeof(obj[0]) if len(obj) > 0 else 0
48+
total_size = size_bytes + (item_size * len(obj))
49+
return f"{name}: {total_size / (1024**2):.2f} MB ({len(obj)} items)"
50+
return f"{name}: {size_bytes / (1024**2):.2f} MB"
51+
except:
52+
return f"{name}: Unable to calculate size"
53+
54+
3655
class BaseEmbeddingModel:
3756
def __init__(self, model_name, model_kwargs, encode_kwargs, is_query=False):
3857
self.model_name = model_name
@@ -326,25 +345,25 @@ def initialize_vector_model(self, embedding_model_name, config_data):
326345
break
327346

328347
if "snowflake" in embedding_model_name.lower():
329-
logger.debug("Matched Snowflake condition")
348+
print("Matched Snowflake condition")
330349
model = SnowflakeEmbedding(embedding_model_name, model_kwargs, encode_kwargs).create()
331350
elif "alibaba" in embedding_model_name.lower():
332-
logger.debug("Matched Alibaba condition")
351+
print("Matched Alibaba condition")
333352
model = InflyAndAlibabaEmbedding(embedding_model_name, model_kwargs, encode_kwargs).create()
334353
elif "400m" in embedding_model_name.lower():
335-
logger.debug("Matched Stella 400m condition")
354+
print("Matched Stella 400m condition")
336355
model = Stella400MEmbedding(embedding_model_name, model_kwargs, encode_kwargs).create()
337356
elif "stella_en_1.5b_v5" in embedding_model_name.lower():
338-
logger.debug("Matched Stella 1.5B condition")
357+
print("Matched Stella 1.5B condition")
339358
model = StellaEmbedding(embedding_model_name, model_kwargs, encode_kwargs).create()
340359
elif "bge-code" in embedding_model_name.lower():
341-
logger.debug("Matches bge-code condition")
360+
print("Matches bge-code condition")
342361
model = BgeCodeEmbedding(embedding_model_name, model_kwargs, encode_kwargs).create()
343362
elif "infly" in embedding_model_name.lower():
344-
logger.debug("Matches infly condition")
363+
print("Matches infly condition")
345364
model = InflyAndAlibabaEmbedding(embedding_model_name, model_kwargs, encode_kwargs).create()
346365
else:
347-
logger.debug("No conditions matched - using base model")
366+
print("No conditions matched - using base model")
348367
model = BaseEmbeddingModel(embedding_model_name, model_kwargs, encode_kwargs).create()
349368

350369
logger.debug("🛈 %s tokenizer_kwargs=%s",
@@ -359,6 +378,7 @@ def initialize_vector_model(self, embedding_model_name, config_data):
359378

360379
@torch.inference_mode()
361380
def create_database(self, texts, embeddings):
381+
362382
my_cprint("\nComputing vectors...", "yellow")
363383
start_time = time.time()
364384

@@ -383,42 +403,37 @@ def create_database(self, texts, embeddings):
383403
chunk_counters[file_hash] += 1
384404
tiledb_id = str(random.randint(0, MAX_UINT64 - 1))
385405

386-
# CRITICAL FIX: Ensure page_content is a string and handle edge cases
406+
# ── ensure page_content is a clean string ──────────────────────
387407
if hasattr(doc, 'page_content'):
388408
if doc.page_content is None:
389409
text_str = ""
390410
elif isinstance(doc.page_content, str):
391411
text_str = doc.page_content.strip()
392412
elif isinstance(doc.page_content, (list, tuple)):
393-
# Handle list/tuple by joining with newlines
394413
text_str = "\n".join(str(item) for item in doc.page_content).strip()
395414
elif isinstance(doc.page_content, bytes):
396-
# Handle bytes by decoding
397415
try:
398416
text_str = doc.page_content.decode('utf-8', errors='ignore').strip()
399-
except:
417+
except Exception:
400418
text_str = str(doc.page_content).strip()
401419
else:
402-
# Fallback for any other type
403420
text_str = str(doc.page_content).strip()
404421
else:
405-
# If no page_content attribute, convert the whole doc to string
406422
text_str = str(doc).strip()
407423

408-
if not text_str: # silently drop zero-length chunks
424+
if not text_str: # silently drop zero-length chunks
409425
continue
410-
411-
# Final validation - ensure it's really a string
426+
412427
if not isinstance(text_str, str):
413-
logging.error(f"Failed to convert to string: {type(text_str)} - {text_str[:100]}")
428+
logging.error(f"Failed to convert to string: {type(text_str)} - {str(text_str)[:100]}")
414429
continue
415430

416431
all_texts.append(text_str)
417432
all_metadatas.append(doc.metadata)
418433
all_ids.append(tiledb_id)
419434
hash_id_mappings.append((tiledb_id, file_hash))
420435

421-
# Debug check - log if we find any non-strings (this should never happen now)
436+
# Debug check – ensure no non-strings slipped through
422437
bad_chunks = [
423438
(idx, type(txt), str(txt)[:60])
424439
for idx, txt in enumerate(all_texts)
@@ -433,53 +448,52 @@ def create_database(self, texts, embeddings):
433448
with open(self.ROOT_DIRECTORY / "config.yaml", 'r', encoding='utf-8') as config_file:
434449
config_data = yaml.safe_load(config_file)
435450

436-
# Additional safety: validate all_texts one more time and ensure proper format
451+
# Final clean-up of texts
437452
validated_texts = []
438453
for i, text in enumerate(all_texts):
439454
if isinstance(text, str):
440-
# Remove any null characters or other problematic characters
441455
cleaned_text = text.replace('\x00', '').strip()
442-
if cleaned_text: # Only add non-empty strings
456+
if cleaned_text:
443457
validated_texts.append(cleaned_text)
444458
else:
445459
logging.warning(f"Skipping empty text at index {i}")
446460
else:
447461
logging.error(f"Non-string found at index {i}: {type(text)}")
448462

449-
# Replace all_texts with validated version
450463
all_texts = validated_texts
451464

452-
# precompute vectors
465+
# ── embed documents ───────────────────────────────────────────────
453466
vectors = embeddings.embed_documents(all_texts)
467+
468+
# Build (text, embedding) tuples in correct order
454469
text_embed_pairs = [
455470
(txt, np.asarray(vec, dtype=np.float32))
456471
for txt, vec in zip(all_texts, vectors)
457472
]
458473

459-
# IMMEDIATE CLEANUP - free ~50-75% of memory
460-
# del all_texts, vectors
461-
# gc.collect()
462-
474+
# ── create TileDB vector store ────────────────────────────────────
463475
TileDB.from_embeddings(
464476
text_embeddings=text_embed_pairs,
465477
embedding=embeddings,
466-
metadatas=all_metadatas[:len(all_texts)], # Ensure metadata matches text count
467-
ids=all_ids[:len(all_texts)], # Ensure IDs match text count
478+
metadatas=all_metadatas[:len(all_texts)],
479+
ids=all_ids[:len(all_texts)],
468480
metric="euclidean",
469481
index_uri=str(self.PERSIST_DIRECTORY),
470482
index_type="FLAT",
471483
allow_dangerous_deserialization=True,
472484
)
473485

474-
my_cprint(f"Processed all chunks", "yellow")
475-
486+
my_cprint("Processed all chunks", "yellow")
487+
476488
end_time = time.time()
477489
elapsed_time = end_time - start_time
478490
my_cprint(f"Database created. Elapsed time: {elapsed_time:.2f} seconds.", "green")
479-
491+
480492
return hash_id_mappings
481493

482494
except Exception as e:
495+
# ── NEW: show full traceback from child process ───────────────────
496+
traceback.print_exc()
483497
logging.error(f"Error creating database: {str(e)}")
484498
if self.PERSIST_DIRECTORY.exists():
485499
try:

0 commit comments

Comments
 (0)