Creates a FlareChain from a language model. Args: llm: Language model to use. max_generation_len: Maximum length of the generated response. **kwargs: Additional arguments to pass to the constructor. Returns: FlareChain class with the
(
cls, llm: BaseLanguageModel, max_generation_len: int = 32, **kwargs: Any
)
| 238 | |
| 239 | @classmethod |
| 240 | def from_llm( |
| 241 | cls, llm: BaseLanguageModel, max_generation_len: int = 32, **kwargs: Any |
| 242 | ) -> FlareChain: |
| 243 | """Creates a FlareChain from a language model. |
| 244 | |
| 245 | Args: |
| 246 | llm: Language model to use. |
| 247 | max_generation_len: Maximum length of the generated response. |
| 248 | **kwargs: Additional arguments to pass to the constructor. |
| 249 | |
| 250 | Returns: |
| 251 | FlareChain class with the given language model. |
| 252 | """ |
| 253 | try: |
| 254 | from langchain_openai import OpenAI |
| 255 | except ImportError: |
| 256 | raise ImportError( |
| 257 | "OpenAI is required for FlareChain. " |
| 258 | "Please install langchain-openai." |
| 259 | "pip install langchain-openai" |
| 260 | ) |
| 261 | question_gen_chain = QuestionGeneratorChain(llm=llm) |
| 262 | response_llm = OpenAI( |
| 263 | max_tokens=max_generation_len, model_kwargs={"logprobs": 1}, temperature=0 |
| 264 | ) |
| 265 | response_chain = _OpenAIResponseChain(llm=response_llm) |
| 266 | return cls( |
| 267 | question_generator_chain=question_gen_chain, |
| 268 | response_chain=response_chain, |
| 269 | **kwargs, |
| 270 | ) |