(self, query, top_k=5, excluded_tools={})
| 33 | return corpus_embeddings |
| 34 | |
| 35 | def retrieving(self, query, top_k=5, excluded_tools={}): |
| 36 | print("Retrieving...") |
| 37 | start = time.time() |
| 38 | query_embedding = self.embedder.encode(query, convert_to_tensor=True) |
| 39 | hits = util.semantic_search(query_embedding, self.corpus_embeddings, top_k=10*top_k, score_function=util.cos_sim) |
| 40 | retrieved_tools = [] |
| 41 | for rank, hit in enumerate(hits[0]): |
| 42 | category, tool_name, api_name = self.corpus2tool[self.corpus[hit['corpus_id']]].split('\t') |
| 43 | category = standardize_category(category) |
| 44 | tool_name = standardize(tool_name) # standardizing |
| 45 | api_name = change_name(standardize(api_name)) # standardizing |
| 46 | if category in excluded_tools: |
| 47 | if tool_name in excluded_tools[category]: |
| 48 | top_k += 1 |
| 49 | continue |
| 50 | tmp_dict = { |
| 51 | "category": category, |
| 52 | "tool_name": tool_name, |
| 53 | "api_name": api_name |
| 54 | } |
| 55 | retrieved_tools.append(tmp_dict) |
| 56 | return retrieved_tools |
no test coverage detected