| """ |
| Functions used in several different places. This file should not import from any other non-lib files to prevent |
| circular dependencies. |
| """ |
|
|
| import json |
| import logging |
| from copy import copy |
| from typing import Any, Callable, Dict, Optional, Tuple, Union |
|
|
| TOP_LEVEL_IDENTIFIERS = {"description", "links", "properties"} |
|
|
|
|
| def get_json_from_model_output(input_generated_json: str): |
| """ |
| Parses a string, potentially containing Markdown code fences, into a JSON object. |
| |
| This function attempts to extract and parse a JSON object from a string, |
| often the output of a language model. It handles cases where the JSON |
| is enclosed in Markdown code fences (```json ... ``` or ``` ... ```). |
| If the initial parsing fails, it attempts a more robust parsing using |
| `_get_valid_json_from_string` and |
| logs debug messages indicating success or failure. If all attempts fail, |
| it returns an empty dictionary. |
| |
| Args: |
| input_generated_json: A string potentially containing a JSON object. |
| |
| Returns: |
| A tuple containing: |
| - The parsed JSON object (a dictionary) or an empty dictionary if parsing failed. |
| - An integer representing the number of times parsing failed initially. |
| """ |
| originally_invalid_json_count = 0 |
|
|
| generated_json_attempt_1 = copy(input_generated_json) |
| try: |
| code_split = generated_json_attempt_1.split("```") |
| if len(code_split) > 1: |
| generated_json_attempt_1 = json.loads( |
| ("```" + code_split[1]).replace("```json", "") |
| ) |
| else: |
| generated_json_attempt_1 = json.loads( |
| generated_json_attempt_1.replace("```json", "").replace("```", "") |
| ) |
| except Exception as exc: |
| logging.debug(f"could not parse AI model generated output as JSON. Exc: {exc}.") |
| |
| generated_json_attempt_1 = {} |
| some_value_in_attempt_1_is_not_a_dict = check_contents_valid( |
| generated_json_attempt_1 |
| ) |
| attempt_1_failed = ( |
| not bool(generated_json_attempt_1) or some_value_in_attempt_1_is_not_a_dict |
| ) |
| generated_json_attempt_2 = copy(input_generated_json) if attempt_1_failed else {} |
| if attempt_1_failed: |
| logging.debug( |
| "Attempting to make output valid to obtain better metrics (this works in limited cases where " |
| "the model output was simply cut off)" |
| ) |
| try: |
| code_split = generated_json_attempt_2.split("```") |
| if len(code_split) > 1: |
| generated_json_attempt_2 = json.loads( |
| _get_valid_json_from_string( |
| ("```" + code_split[1]).replace("```json", "") |
| ) |
| ) |
| else: |
| stripped_output = generated_json_attempt_2.replace( |
| "```json", "" |
| ).replace("```", "") |
| balance_outcome = attempt( |
| json.loads, (balance_braces(stripped_output),) |
| ) |
| if "error" not in balance_outcome: |
| generated_json_attempt_2 = balance_outcome |
| else: |
| generated_json_attempt_2 = json.loads( |
| _get_valid_json_from_string(stripped_output) |
| ) |
|
|
| logging.debug( |
| "Success! Reconstructed valid JSON from unparseable model output. Continuing metrics comparison..." |
| ) |
| except Exception as exc: |
| logging.debug( |
| "Failed. Setting model output as empty JSON to enable metrics comparison." |
| ) |
| generated_json_attempt_2 = {} |
| some_value_in_attempt_2_is_not_a_dict = ( |
| attempt_1_failed |
| and isinstance(generated_json_attempt_2, dict) |
| and check_contents_valid(generated_json_attempt_2) |
| ) |
| if some_value_in_attempt_1_is_not_a_dict and some_value_in_attempt_2_is_not_a_dict: |
| logging.debug(f"Could not recover model output json, aborting!") |
| originally_invalid_json_count += 1 |
| generated_json = ( |
| generated_json_attempt_1 if not attempt_1_failed else generated_json_attempt_2 |
| ) |
| return generated_json, originally_invalid_json_count |
|
|
|
|
| def check_contents_valid(generated_json_attempt_1: Union[list, dict]): |
| """ |
| Checks that the sub nodes are not lists or anything |
| |
| Args: |
| generated_json_attempt_1 (Union[list, dict]): data to check |
| |
| Returns: |
| truthy based on contents of input |
| """ |
| if isinstance(generated_json_attempt_1, list): |
| for item in generated_json_attempt_1: |
| if not isinstance(item, dict): |
| return item |
| return None |
| elif ( |
| isinstance(generated_json_attempt_1, dict) |
| and "nodes" in generated_json_attempt_1.keys() |
| ): |
| for item in generated_json_attempt_1.get("nodes", []): |
| if not isinstance(item, dict): |
| return item |
| return None |
| else: |
| for item in generated_json_attempt_1.values(): |
| if not isinstance(item, dict): |
| return item |
| return None |
|
|
|
|
| def _get_valid_json_from_string(s): |
| """ |
| Given a JSON string with potentially unclosed strings, arrays, or objects, close those things |
| to hopefully be able to parse as valid JSON |
| """ |
| double_quotes = 0 |
| single_quotes = 0 |
| brackets = [] |
|
|
| for i, c in enumerate(s): |
| if c == '"': |
| double_quotes = 1 - double_quotes |
| elif c == "'": |
| single_quotes = 1 - single_quotes |
| elif c in "{[": |
| brackets.append((i, c)) |
| elif c in "}]": |
| if double_quotes == 0 and single_quotes == 0: |
| if brackets: |
| last_opened = brackets.pop() |
| if (c == "}" and last_opened[1] != "{") or ( |
| c == "]" and last_opened[1] != "[" |
| ): |
| raise ValueError( |
| f"Mismatched brackets/quotes found: opened {last_opened[1]} @ {last_opened[0]} " |
| f"but closed {c} @ {i}" |
| ) |
| else: |
| |
| pass |
|
|
| |
| if s.strip().endswith(","): |
| logging.debug("Removing ending ,") |
| s = s.strip().rstrip(",") |
|
|
| closing_chars = "" |
|
|
| |
| if double_quotes > 0: |
| closing_chars += '"' |
| if single_quotes > 0: |
| closing_chars += "'" |
|
|
| |
| while brackets: |
| last_opened = brackets.pop() |
| if last_opened[1] == "{": |
| closing_chars += "}" |
| elif last_opened[1] == "[": |
| closing_chars += "]" |
|
|
| logging.debug(f"closing_chars: {closing_chars}") |
|
|
| output_string = s + closing_chars |
|
|
| try: |
| json.loads(output_string) |
| except Exception: |
| logging.debug( |
| "JSON string still fails to be parseable, attempting another modification..." |
| ) |
| |
| |
| new_closing_chars = "" |
| found_first_double_quote = False |
| for char in closing_chars: |
| if not found_first_double_quote and char == '"': |
| |
| |
| |
| |
| |
| |
| |
| |
| new_closing_chars += '": ""' |
| else: |
| new_closing_chars += char |
|
|
| logging.debug(f"new closing_chars: {new_closing_chars}") |
| output_string = s + new_closing_chars |
|
|
| return output_string |
|
|
|
|
| def on_fail( |
| outcome: Union[Any, Dict[str, str]], |
| fallback: Union[Any, Callable] = None, |
| ): |
| """ |
| Allows you to provide a fallback to recover from a failed outcome. |
| |
| Args: |
| outcome |
| fallback |
| |
| Returns: |
| |
| """ |
| is_fail = isinstance(outcome, dict) and "error" in outcome |
| is_callable = isinstance(fallback, Callable) |
| if is_fail and is_callable: |
| return fallback(outcome) |
| elif is_fail: |
| return fallback |
| return outcome |
|
|
|
|
| def attempt( |
| func: Callable, |
| args: Tuple[Any, ...] = (), |
| kwargs: Optional[Dict[str, Any]] = None, |
| ) -> Union[Any, Dict[str, str]]: |
| """ |
| Attempts to execute a function with the provided arguments. |
| |
| If the function raises an exception, the exception is caught and returned in a dict. |
| Args: |
| func (Callable): The function to execute. |
| args (Tuple[Any, ...], optional): A tuple of positional arguments for the function. |
| kwargs (Optional[Dict[str, Any]], optional): A dictionary of keyword arguments for the function. |
| Returns: |
| Function result or {"error": <msg>} response |
| """ |
| kwargs = kwargs or {} |
| try: |
| return func(*args, **kwargs) |
| except Exception as exc: |
| return {"error": str(exc)} |
|
|
|
|
| def balance_braces(s: str) -> str: |
| """ |
| Primitive function that just tries to add '{}' style braces to try to recover |
| the model string. |
| |
| Args: |
| s(str): string to balance braces on. |
| |
| Returns: |
| provided string with balanced braces if possible |
| """ |
| open_count = s.count("{") |
| close_count = s.count("}") |
|
|
| if open_count > close_count: |
| s += "}" * (open_count - close_count) |
| elif close_count > open_count: |
| s = "{" * (close_count - open_count) + s |
|
|
| return s |
|
|
|
|
| def flatten_list(coll): |
| flattened_data = [] |
| for set_list in coll: |
| flattened_data = flattened_data + list(set_list) |
| return flattened_data |
|
|
|
|
| def keep_errors(collection): |
| """ |
| Given a set of outcomes, keeps any that resulted in an error |
| |
| Args: |
| collection (Collection): collection of outcomes to filter. |
| |
| Returns: |
| All instances of the collection that contain an error response. |
| """ |
| return [instance for instance in collection if "error" in (instance or [])] |
|
|