Handles all user commands except for `!exit`. May update the chat history (e.g. reset it) or the generation config (e.g. set a new flag).
(
self,
user_input: str,
interface: RichInterface,
examples: dict[str, dict[str, str]],
config: GenerationConfig,
chat: list[dict],
)
| 390 | return True |
| 391 | |
| 392 | def handle_non_exit_user_commands( |
| 393 | self, |
| 394 | user_input: str, |
| 395 | interface: RichInterface, |
| 396 | examples: dict[str, dict[str, str]], |
| 397 | config: GenerationConfig, |
| 398 | chat: list[dict], |
| 399 | ) -> tuple[list[dict], GenerationConfig]: |
| 400 | """ |
| 401 | Handles all user commands except for `!exit`. May update the chat history (e.g. reset it) or the |
| 402 | generation config (e.g. set a new flag). |
| 403 | """ |
| 404 | valid_command = True |
| 405 | |
| 406 | if user_input == "!clear": |
| 407 | chat = new_chat_history(self.system_prompt) |
| 408 | interface.clear() |
| 409 | |
| 410 | elif user_input == "!help": |
| 411 | interface.print_help() |
| 412 | |
| 413 | elif user_input.startswith("!save") and len(user_input.split()) < 2: |
| 414 | split_input = user_input.split() |
| 415 | filename = ( |
| 416 | split_input[1] |
| 417 | if len(split_input) == 2 |
| 418 | else os.path.join(self.save_folder, self.model_id, f"chat_{time.strftime('%Y-%m-%d_%H-%M-%S')}.json") |
| 419 | ) |
| 420 | save_chat(filename=filename, chat=chat, settings=self.settings) |
| 421 | interface.print_color(text=f"Chat saved to {filename}!", color="green") |
| 422 | |
| 423 | elif user_input.startswith("!set"): |
| 424 | # splits the new args into a list of strings, each string being a `flag=value` pair (same format as |
| 425 | # `generate_flags`) |
| 426 | new_generate_flags = user_input[4:].strip() |
| 427 | new_generate_flags = new_generate_flags.split() |
| 428 | # sanity check: each member in the list must have an = |
| 429 | for flag in new_generate_flags: |
| 430 | if "=" not in flag: |
| 431 | interface.print_color( |
| 432 | text=( |
| 433 | f"Invalid flag format, missing `=` after `{flag}`. Please use the format " |
| 434 | "`arg_1=value_1 arg_2=value_2 ...`." |
| 435 | ), |
| 436 | color="red", |
| 437 | ) |
| 438 | break |
| 439 | else: |
| 440 | # Update config from user flags |
| 441 | config.update(**parse_generate_flags(new_generate_flags)) |
| 442 | |
| 443 | elif user_input.startswith("!example") and len(user_input.split()) == 2: |
| 444 | example_name = user_input.split()[1] |
| 445 | if example_name in examples: |
| 446 | interface.clear() |
| 447 | chat = [] |
| 448 | interface.print_user_message(examples[example_name]["text"]) |
| 449 | chat.append({"role": "user", "content": examples[example_name]["text"]}) |
nothing calls this directly
no test coverage detected