Base Retriever class for VectorStore.
| 809 | |
| 810 | |
| 811 | class VectorStoreRetriever(BaseRetriever): |
| 812 | """Base Retriever class for VectorStore.""" |
| 813 | |
| 814 | vectorstore: VectorStore |
| 815 | """VectorStore to use for retrieval.""" |
| 816 | search_type: str = "similarity" |
| 817 | """Type of search to perform. Defaults to "similarity".""" |
| 818 | search_kwargs: dict = Field(default_factory=dict) |
| 819 | """Keyword arguments to pass to the search function.""" |
| 820 | allowed_search_types: ClassVar[Collection[str]] = ( |
| 821 | "similarity", |
| 822 | "similarity_score_threshold", |
| 823 | "mmr", |
| 824 | ) |
| 825 | |
| 826 | class Config: |
| 827 | """Configuration for this pydantic object.""" |
| 828 | |
| 829 | arbitrary_types_allowed = True |
| 830 | |
| 831 | @root_validator(pre=True) |
| 832 | def validate_search_type(cls, values: Dict) -> Dict: |
| 833 | """Validate search type.""" |
| 834 | search_type = values.get("search_type", "similarity") |
| 835 | if search_type not in cls.allowed_search_types: |
| 836 | raise ValueError( |
| 837 | f"search_type of {search_type} not allowed. Valid values are: " |
| 838 | f"{cls.allowed_search_types}" |
| 839 | ) |
| 840 | if search_type == "similarity_score_threshold": |
| 841 | score_threshold = values.get("search_kwargs", {}).get("score_threshold") |
| 842 | if (score_threshold is None) or (not isinstance(score_threshold, float)): |
| 843 | raise ValueError( |
| 844 | "`score_threshold` is not specified with a float value(0~1) " |
| 845 | "in `search_kwargs`." |
| 846 | ) |
| 847 | return values |
| 848 | |
| 849 | def _get_relevant_documents( |
| 850 | self, query: str, *, run_manager: CallbackManagerForRetrieverRun |
| 851 | ) -> List[Document]: |
| 852 | if self.search_type == "similarity": |
| 853 | docs = self.vectorstore.similarity_search(query, **self.search_kwargs) |
| 854 | elif self.search_type == "similarity_score_threshold": |
| 855 | docs_and_similarities = ( |
| 856 | self.vectorstore.similarity_search_with_relevance_scores( |
| 857 | query, **self.search_kwargs |
| 858 | ) |
| 859 | ) |
| 860 | docs = [doc for doc, _ in docs_and_similarities] |
| 861 | elif self.search_type == "mmr": |
| 862 | docs = self.vectorstore.max_marginal_relevance_search( |
| 863 | query, **self.search_kwargs |
| 864 | ) |
| 865 | else: |
| 866 | raise ValueError(f"search_type of {self.search_type} not allowed.") |
| 867 | return docs |
| 868 |
no outgoing calls
no test coverage detected