Parses the generate flags from the user input into a dictionary of `generate` kwargs.
(generate_flags: list[str] | None)
| 595 | |
| 596 | |
| 597 | def parse_generate_flags(generate_flags: list[str] | None) -> dict: |
| 598 | """Parses the generate flags from the user input into a dictionary of `generate` kwargs.""" |
| 599 | if generate_flags is None or len(generate_flags) == 0: |
| 600 | return {} |
| 601 | |
| 602 | # Assumption: `generate_flags` is a list of strings, each string being a `flag=value` pair, that can be parsed |
| 603 | # into a json string if we: |
| 604 | # 1. Add quotes around each flag name |
| 605 | generate_flags_as_dict = {'"' + flag.split("=")[0] + '"': flag.split("=")[1] for flag in generate_flags} |
| 606 | |
| 607 | # 2. Handle types: |
| 608 | # 2. a. booleans should be lowercase, None should be null |
| 609 | generate_flags_as_dict = { |
| 610 | k: v.lower() if v.lower() in ["true", "false"] else v for k, v in generate_flags_as_dict.items() |
| 611 | } |
| 612 | generate_flags_as_dict = {k: "null" if v == "None" else v for k, v in generate_flags_as_dict.items()} |
| 613 | |
| 614 | # 2. b. strings should be quoted |
| 615 | def is_number(s: str) -> bool: |
| 616 | # handle negative numbers |
| 617 | s = s.removeprefix("-") |
| 618 | return s.replace(".", "", 1).isdigit() |
| 619 | |
| 620 | generate_flags_as_dict = {k: f'"{v}"' if not is_number(v) else v for k, v in generate_flags_as_dict.items()} |
| 621 | # 2. c. [no processing needed] lists are lists of ints because `generate` doesn't take lists of strings :) |
| 622 | # We also mention in the help message that we only accept lists of ints for now. |
| 623 | |
| 624 | # 3. Join the result into a comma separated string |
| 625 | generate_flags_string = ", ".join([f"{k}: {v}" for k, v in generate_flags_as_dict.items()]) |
| 626 | |
| 627 | # 4. Add the opening/closing brackets |
| 628 | generate_flags_string = "{" + generate_flags_string + "}" |
| 629 | |
| 630 | # 5. Remove quotes around boolean/null and around lists |
| 631 | generate_flags_string = generate_flags_string.replace('"null"', "null") |
| 632 | generate_flags_string = generate_flags_string.replace('"true"', "true") |
| 633 | generate_flags_string = generate_flags_string.replace('"false"', "false") |
| 634 | generate_flags_string = generate_flags_string.replace('"[', "[") |
| 635 | generate_flags_string = generate_flags_string.replace(']"', "]") |
| 636 | |
| 637 | # 6. Replace the `=` with `:` |
| 638 | generate_flags_string = generate_flags_string.replace("=", ":") |
| 639 | |
| 640 | try: |
| 641 | processed_generate_flags = json.loads(generate_flags_string) |
| 642 | except json.JSONDecodeError: |
| 643 | raise ValueError( |
| 644 | "Failed to convert `generate_flags` into a valid JSON object." |
| 645 | "\n`generate_flags` = {generate_flags}" |
| 646 | "\nConverted JSON string = {generate_flags_string}" |
| 647 | ) |
| 648 | return processed_generate_flags |
| 649 | |
| 650 | |
| 651 | def new_chat_history(system_prompt: str | None = None) -> list[dict]: |