File size: 2,071 Bytes
92e030f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# retriever.py
# This file handles the setup of embeddings, vector stores, and the ensemble retriever.

from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
from config import (
    MODEL_NAME, MODEL_KWARGS, ENCODE_KWARGS, VECTOR_STORE_DIRECTORY,
    DENSE_RETRIEVER_K, KEYWORD_RETRIEVER_K, ENSEMBLE_WEIGHTS
)

def get_embedding_function():
    """Initializes and returns the HuggingFace embedding model."""
    return HuggingFaceEmbeddings(
        model_name=MODEL_NAME,
        model_kwargs=MODEL_KWARGS,
        encode_kwargs=ENCODE_KWARGS
    )

def get_vector_store(embedding_function):
    """Initializes and returns the Chroma vector store."""
    return Chroma(
        embedding_function=embedding_function,
        persist_directory=VECTOR_STORE_DIRECTORY
    )

def get_ensemble_retriever():
    """
    Creates and returns an ensemble retriever combining dense and keyword-based search.
    """
    print("Initializing embeddings and vector store...")
    embeddings = get_embedding_function()
    vector_store = get_vector_store(embeddings)

    dense_vector_retriever = vector_store.as_retriever(k=DENSE_RETRIEVER_K)

    print("Loading documents for BM25 retriever...")
    ids = vector_store.get().get("ids", [])

    if not ids:
        all_documents = []
    else:
        all_documents = vector_store.get_by_ids(ids)

    keyword_search_retriever = BM25Retriever.from_documents(
        documents=all_documents, k=KEYWORD_RETRIEVER_K
    ) if all_documents else None
    
    if keyword_search_retriever:
        print("Creating ensemble retriever...")
        ensemble_retriever = EnsembleRetriever(
            retrievers=[dense_vector_retriever, keyword_search_retriever],
            weights=ENSEMBLE_WEIGHTS
        )
    else:
        print("Creating dense-only retriever...")
        ensemble_retriever = dense_vector_retriever

    print("Retriever setup complete.")
    return ensemble_retriever