Load model into prompt object by selecting model name
(self, gen_model,api_key=None, from_hf=False, trust_remote_code=False,
# new options added
use_gpu=True, sample=False, get_logits=False,
max_output=200, temperature=0.0, api_endpoint=None, **kwargs)
| 214 | os.chmod(self.prompt_path, 0o777) |
| 215 | |
| 216 | def load_model(self, gen_model,api_key=None, from_hf=False, trust_remote_code=False, |
| 217 | # new options added |
| 218 | use_gpu=True, sample=False, get_logits=False, |
| 219 | max_output=200, temperature=0.0, api_endpoint=None, **kwargs): |
| 220 | |
| 221 | """Load model into prompt object by selecting model name """ |
| 222 | |
| 223 | if api_key: |
| 224 | self.llm_model_api_key = api_key |
| 225 | |
| 226 | if not from_hf: |
| 227 | self.llm_model = self.model_catalog.load_model(gen_model, api_key=self.llm_model_api_key, |
| 228 | use_gpu=use_gpu, sample=sample, get_logits=get_logits, |
| 229 | max_output=max_output, temperature=temperature, |
| 230 | api_endpoint=api_endpoint, **kwargs) |
| 231 | if hasattr(self.llm_model, "model_card"): |
| 232 | self.llm_model_card = self.llm_model.model_card |
| 233 | |
| 234 | else: |
| 235 | |
| 236 | pt_loader = PyTorchLoader(api_key=api_key,trust_remote_code=trust_remote_code, custom_loader=None) |
| 237 | custom_hf_model = pt_loader.get_generative_model(gen_model) |
| 238 | hf_tokenizer = pt_loader.get_tokenizer(gen_model) |
| 239 | |
| 240 | # now, we have 'imported' our own custom 'instruct' model into llmware |
| 241 | self.llm_model = self.model_catalog.load_hf_generative_model(custom_hf_model, hf_tokenizer, |
| 242 | instruction_following=False, |
| 243 | prompt_wrapper="human_bot") |
| 244 | |
| 245 | # prepare 'safe name' without file paths |
| 246 | self.llm_model.model_name = re.sub("[/]","---",gen_model) |
| 247 | self.tokenizer = hf_tokenizer |
| 248 | |
| 249 | self.llm_name = gen_model |
| 250 | self.context_window_size = self.llm_model.max_input_len |
| 251 | self.llm_max_output_len = max_output |
| 252 | |
| 253 | return self |
| 254 | |
| 255 | def set_inference_parameters(self, temperature=0.5, llm_max_output_len=200): |
| 256 |